apr_cli/federation/
catalog.rs

1//! Model Catalog - Registry of available models across the federation
2//!
3//! The catalog tracks which models are available, where they're deployed,
4//! and what capabilities they support.
5
6use super::traits::*;
7use std::collections::HashMap;
8use std::sync::RwLock;
9
10// ============================================================================
11// Model Entry
12// ============================================================================
13
14/// Entry for a registered model in the catalog
15#[derive(Debug, Clone)]
16pub struct ModelEntry {
17    pub model_id: ModelId,
18    pub metadata: ModelMetadata,
19    pub deployments: Vec<ModelDeployment>,
20}
21
22/// A specific deployment of a model
23#[derive(Debug, Clone)]
24pub struct ModelDeployment {
25    pub node_id: NodeId,
26    pub region_id: RegionId,
27    pub endpoint: String,
28    pub status: DeploymentStatus,
29}
30
31/// Deployment status
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum DeploymentStatus {
34    /// Model is loading
35    Loading,
36    /// Model is ready for inference
37    Ready,
38    /// Model is draining (no new requests)
39    Draining,
40    /// Model has been removed
41    Removed,
42}
43
44// ============================================================================
45// In-Memory Catalog Implementation
46// ============================================================================
47
48/// In-memory model catalog (production would use etcd, Redis, etc.)
49pub struct ModelCatalog {
50    entries: RwLock<HashMap<ModelId, ModelEntry>>,
51    by_capability: RwLock<HashMap<String, Vec<ModelId>>>,
52}
53
54impl ModelCatalog {
55    pub fn new() -> Self {
56        Self {
57            entries: RwLock::new(HashMap::new()),
58            by_capability: RwLock::new(HashMap::new()),
59        }
60    }
61
62    /// Get all entries (for debugging/admin)
63    pub fn all_entries(&self) -> Vec<ModelEntry> {
64        self.entries
65            .read()
66            .expect("catalog lock poisoned")
67            .values()
68            .cloned()
69            .collect()
70    }
71
72    /// Get entry by ID
73    pub fn get(&self, model_id: &ModelId) -> Option<ModelEntry> {
74        self.entries
75            .read()
76            .expect("catalog lock poisoned")
77            .get(model_id)
78            .cloned()
79    }
80
81    fn capability_key(cap: &Capability) -> String {
82        match cap {
83            Capability::Transcribe => "transcribe".to_string(),
84            Capability::Synthesize => "synthesize".to_string(),
85            Capability::Generate => "generate".to_string(),
86            Capability::Code => "code".to_string(),
87            Capability::Embed => "embed".to_string(),
88            Capability::ImageGen => "image_gen".to_string(),
89            Capability::Custom(s) => format!("custom:{}", s),
90        }
91    }
92}
93
94impl Default for ModelCatalog {
95    fn default() -> Self {
96        Self::new()
97    }
98}
99
100impl ModelCatalogTrait for ModelCatalog {
101    fn register(
102        &self,
103        model_id: ModelId,
104        node_id: NodeId,
105        region_id: RegionId,
106        capabilities: Vec<Capability>,
107    ) -> BoxFuture<'_, FederationResult<()>> {
108        Box::pin(async move {
109            let mut entries = self.entries.write().expect("catalog lock poisoned");
110            let mut by_cap = self.by_capability.write().expect("catalog lock poisoned");
111
112            let deployment = ModelDeployment {
113                node_id,
114                region_id,
115                endpoint: String::new(), // Would be set by registration protocol
116                status: DeploymentStatus::Ready,
117            };
118
119            if let Some(entry) = entries.get_mut(&model_id) {
120                // Add deployment to existing model
121                entry.deployments.push(deployment);
122            } else {
123                // New model registration
124                let metadata = ModelMetadata {
125                    model_id: model_id.clone(),
126                    name: model_id.0.clone(),
127                    version: "1.0.0".to_string(),
128                    capabilities: capabilities.clone(),
129                    parameters: 0,
130                    quantization: None,
131                };
132
133                let entry = ModelEntry {
134                    model_id: model_id.clone(),
135                    metadata,
136                    deployments: vec![deployment],
137                };
138
139                entries.insert(model_id.clone(), entry);
140
141                // Index by capability
142                for cap in &capabilities {
143                    let key = Self::capability_key(cap);
144                    by_cap.entry(key).or_default().push(model_id.clone());
145                }
146            }
147
148            Ok(())
149        })
150    }
151
152    fn deregister(
153        &self,
154        model_id: ModelId,
155        node_id: NodeId,
156    ) -> BoxFuture<'_, FederationResult<()>> {
157        Box::pin(async move {
158            let mut entries = self.entries.write().expect("catalog lock poisoned");
159
160            if let Some(entry) = entries.get_mut(&model_id) {
161                entry.deployments.retain(|d| d.node_id != node_id);
162
163                // Remove entry entirely if no deployments remain
164                if entry.deployments.is_empty() {
165                    entries.remove(&model_id);
166                }
167            }
168
169            Ok(())
170        })
171    }
172
173    fn find_by_capability(
174        &self,
175        capability: &Capability,
176    ) -> BoxFuture<'_, FederationResult<Vec<(NodeId, RegionId)>>> {
177        let key = Self::capability_key(capability);
178
179        Box::pin(async move {
180            let entries = self.entries.read().expect("catalog lock poisoned");
181            let by_cap = self.by_capability.read().expect("catalog lock poisoned");
182
183            let mut results = Vec::new();
184
185            if let Some(model_ids) = by_cap.get(&key) {
186                for model_id in model_ids {
187                    if let Some(entry) = entries.get(model_id) {
188                        for deployment in &entry.deployments {
189                            if deployment.status == DeploymentStatus::Ready {
190                                results.push((
191                                    deployment.node_id.clone(),
192                                    deployment.region_id.clone(),
193                                ));
194                            }
195                        }
196                    }
197                }
198            }
199
200            Ok(results)
201        })
202    }
203
204    fn list_all(&self) -> BoxFuture<'_, FederationResult<Vec<ModelId>>> {
205        Box::pin(async move {
206            let entries = self.entries.read().expect("catalog lock poisoned");
207            Ok(entries.keys().cloned().collect())
208        })
209    }
210
211    fn get_metadata(&self, model_id: &ModelId) -> BoxFuture<'_, FederationResult<ModelMetadata>> {
212        let model_id = model_id.clone();
213
214        Box::pin(async move {
215            let entries = self.entries.read().expect("catalog lock poisoned");
216
217            entries
218                .get(&model_id)
219                .map(|e| e.metadata.clone())
220                .ok_or_else(|| {
221                    FederationError::Internal(format!("Model not found: {:?}", model_id))
222                })
223        })
224    }
225}
226
227// ============================================================================
228// Tests
229// ============================================================================
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234
235    #[tokio::test]
236    async fn test_register_and_find() {
237        let catalog = ModelCatalog::new();
238
239        catalog
240            .register(
241                ModelId("whisper-large".to_string()),
242                NodeId("node-1".to_string()),
243                RegionId("us-west".to_string()),
244                vec![Capability::Transcribe],
245            )
246            .await
247            .expect("registration failed");
248
249        let nodes = catalog
250            .find_by_capability(&Capability::Transcribe)
251            .await
252            .expect("find failed");
253
254        assert_eq!(nodes.len(), 1);
255        assert_eq!(nodes[0].0, NodeId("node-1".to_string()));
256    }
257
258    #[tokio::test]
259    async fn test_deregister() {
260        let catalog = ModelCatalog::new();
261
262        catalog
263            .register(
264                ModelId("llama-7b".to_string()),
265                NodeId("node-1".to_string()),
266                RegionId("eu-west".to_string()),
267                vec![Capability::Generate],
268            )
269            .await
270            .expect("registration failed");
271
272        catalog
273            .deregister(
274                ModelId("llama-7b".to_string()),
275                NodeId("node-1".to_string()),
276            )
277            .await
278            .expect("deregistration failed");
279
280        let models = catalog.list_all().await.expect("list failed");
281        assert!(models.is_empty());
282    }
283
284    #[tokio::test]
285    async fn test_multiple_deployments() {
286        let catalog = ModelCatalog::new();
287
288        // Same model on two nodes
289        catalog
290            .register(
291                ModelId("whisper-base".to_string()),
292                NodeId("node-1".to_string()),
293                RegionId("us-east".to_string()),
294                vec![Capability::Transcribe],
295            )
296            .await
297            .expect("registration failed");
298
299        catalog
300            .register(
301                ModelId("whisper-base".to_string()),
302                NodeId("node-2".to_string()),
303                RegionId("us-west".to_string()),
304                vec![Capability::Transcribe],
305            )
306            .await
307            .expect("registration failed");
308
309        let nodes = catalog
310            .find_by_capability(&Capability::Transcribe)
311            .await
312            .expect("find failed");
313
314        assert_eq!(nodes.len(), 2);
315    }
316
317    #[tokio::test]
318    async fn test_custom_capability() {
319        let catalog = ModelCatalog::new();
320
321        catalog
322            .register(
323                ModelId("sentiment-bert".to_string()),
324                NodeId("node-1".to_string()),
325                RegionId("ap-south".to_string()),
326                vec![Capability::Custom("sentiment".to_string())],
327            )
328            .await
329            .expect("registration failed");
330
331        let nodes = catalog
332            .find_by_capability(&Capability::Custom("sentiment".to_string()))
333            .await
334            .expect("find failed");
335
336        assert_eq!(nodes.len(), 1);
337
338        // Different custom capability should return empty
339        let empty = catalog
340            .find_by_capability(&Capability::Custom("other".to_string()))
341            .await
342            .expect("find failed");
343
344        assert!(empty.is_empty());
345    }
346}