Skip to main content

apr_cli/federation/
catalog.rs

1//! Model Catalog - Registry of available models across the federation
2//!
3//! The catalog tracks which models are available, where they're deployed,
4//! and what capabilities they support.
5
6use super::traits::*;
7use std::collections::HashMap;
8use std::sync::RwLock;
9
10// ============================================================================
11// Model Entry
12// ============================================================================
13
14/// Entry for a registered model in the catalog
15#[derive(Debug, Clone)]
16pub struct ModelEntry {
17    pub model_id: ModelId,
18    pub metadata: ModelMetadata,
19    pub deployments: Vec<ModelDeployment>,
20}
21
22/// A specific deployment of a model
23#[derive(Debug, Clone)]
24pub struct ModelDeployment {
25    pub node_id: NodeId,
26    pub region_id: RegionId,
27    pub endpoint: String,
28    pub status: DeploymentStatus,
29}
30
31/// Deployment status
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum DeploymentStatus {
34    /// Model is loading
35    Loading,
36    /// Model is ready for inference
37    Ready,
38    /// Model is draining (no new requests)
39    Draining,
40    /// Model has been removed
41    Removed,
42}
43
44// ============================================================================
45// In-Memory Catalog Implementation
46// ============================================================================
47
48/// In-memory model catalog (production would use etcd, Redis, etc.)
49pub struct ModelCatalog {
50    entries: RwLock<HashMap<ModelId, ModelEntry>>,
51    by_capability: RwLock<HashMap<String, Vec<ModelId>>>,
52}
53
54impl ModelCatalog {
55    pub fn new() -> Self {
56        Self {
57            entries: RwLock::new(HashMap::new()),
58            by_capability: RwLock::new(HashMap::new()),
59        }
60    }
61
62    /// Get all entries (for debugging/admin)
63    pub fn all_entries(&self) -> Vec<ModelEntry> {
64        self.entries
65            .read()
66            .expect("catalog lock poisoned")
67            .values()
68            .cloned()
69            .collect()
70    }
71
72    /// Get entry by ID
73    pub fn get(&self, model_id: &ModelId) -> Option<ModelEntry> {
74        self.entries
75            .read()
76            .expect("catalog lock poisoned")
77            .get(model_id)
78            .cloned()
79    }
80
81    fn capability_key(cap: &Capability) -> String {
82        match cap {
83            Capability::Transcribe => "transcribe".to_string(),
84            Capability::Synthesize => "synthesize".to_string(),
85            Capability::Generate => "generate".to_string(),
86            Capability::Code => "code".to_string(),
87            Capability::Embed => "embed".to_string(),
88            Capability::ImageGen => "image_gen".to_string(),
89            Capability::Custom(s) => format!("custom:{}", s),
90        }
91    }
92}
93
94impl Default for ModelCatalog {
95    fn default() -> Self {
96        Self::new()
97    }
98}
99
100impl ModelCatalogTrait for ModelCatalog {
101    fn register(
102        &self,
103        model_id: ModelId,
104        node_id: NodeId,
105        region_id: RegionId,
106        capabilities: Vec<Capability>,
107    ) -> BoxFuture<'_, FederationResult<()>> {
108        Box::pin(async move {
109            let mut entries = self.entries.write().expect("catalog lock poisoned");
110            let mut by_cap = self.by_capability.write().expect("catalog lock poisoned");
111
112            let deployment = ModelDeployment {
113                node_id,
114                region_id,
115                endpoint: String::new(), // Would be set by registration protocol
116                status: DeploymentStatus::Ready,
117            };
118
119            if let Some(entry) = entries.get_mut(&model_id) {
120                // Add deployment to existing model
121                entry.deployments.push(deployment);
122            } else {
123                // New model registration
124                let metadata = ModelMetadata {
125                    model_id: model_id.clone(),
126                    name: model_id.0.clone(),
127                    version: "1.0.0".to_string(),
128                    capabilities: capabilities.clone(),
129                    parameters: 0,
130                    quantization: None,
131                };
132
133                let entry = ModelEntry {
134                    model_id: model_id.clone(),
135                    metadata,
136                    deployments: vec![deployment],
137                };
138
139                entries.insert(model_id.clone(), entry);
140
141                // Index by capability
142                for cap in &capabilities {
143                    let key = Self::capability_key(cap);
144                    by_cap.entry(key).or_default().push(model_id.clone());
145                }
146            }
147
148            Ok(())
149        })
150    }
151
152    fn deregister(
153        &self,
154        model_id: ModelId,
155        node_id: NodeId,
156    ) -> BoxFuture<'_, FederationResult<()>> {
157        Box::pin(async move {
158            let mut entries = self.entries.write().expect("catalog lock poisoned");
159
160            if let Some(entry) = entries.get_mut(&model_id) {
161                entry.deployments.retain(|d| d.node_id != node_id);
162
163                // Remove entry entirely if no deployments remain
164                if entry.deployments.is_empty() {
165                    entries.remove(&model_id);
166                }
167            }
168
169            Ok(())
170        })
171    }
172
173    fn find_by_capability(
174        &self,
175        capability: &Capability,
176    ) -> BoxFuture<'_, FederationResult<Vec<(NodeId, RegionId)>>> {
177        let key = Self::capability_key(capability);
178
179        Box::pin(async move {
180            let entries = self.entries.read().expect("catalog lock poisoned");
181            let by_cap = self.by_capability.read().expect("catalog lock poisoned");
182
183            let mut results = Vec::new();
184
185            if let Some(model_ids) = by_cap.get(&key) {
186                for model_id in model_ids {
187                    if let Some(entry) = entries.get(model_id) {
188                        for deployment in &entry.deployments {
189                            if deployment.status == DeploymentStatus::Ready {
190                                results.push((
191                                    deployment.node_id.clone(),
192                                    deployment.region_id.clone(),
193                                ));
194                            }
195                        }
196                    }
197                }
198            }
199
200            Ok(results)
201        })
202    }
203
204    fn list_all(&self) -> BoxFuture<'_, FederationResult<Vec<ModelId>>> {
205        Box::pin(async move {
206            let entries = self.entries.read().expect("catalog lock poisoned");
207            Ok(entries.keys().cloned().collect())
208        })
209    }
210
211    fn get_metadata(&self, model_id: &ModelId) -> BoxFuture<'_, FederationResult<ModelMetadata>> {
212        let model_id = model_id.clone();
213
214        Box::pin(async move {
215            let entries = self.entries.read().expect("catalog lock poisoned");
216
217            entries
218                .get(&model_id)
219                .map(|e| e.metadata.clone())
220                .ok_or_else(|| {
221                    FederationError::Internal(format!("Model not found: {:?}", model_id))
222                })
223        })
224    }
225}
226
227// ============================================================================
228// Tests
229// ============================================================================
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234
235    #[tokio::test]
236    async fn test_register_and_find() {
237        let catalog = ModelCatalog::new();
238
239        catalog
240            .register(
241                ModelId("whisper-large".to_string()),
242                NodeId("node-1".to_string()),
243                RegionId("us-west".to_string()),
244                vec![Capability::Transcribe],
245            )
246            .await
247            .expect("registration failed");
248
249        let nodes = catalog
250            .find_by_capability(&Capability::Transcribe)
251            .await
252            .expect("find failed");
253
254        assert_eq!(nodes.len(), 1);
255        assert_eq!(nodes[0].0, NodeId("node-1".to_string()));
256    }
257
258    #[tokio::test]
259    async fn test_deregister() {
260        let catalog = ModelCatalog::new();
261
262        catalog
263            .register(
264                ModelId("llama-7b".to_string()),
265                NodeId("node-1".to_string()),
266                RegionId("eu-west".to_string()),
267                vec![Capability::Generate],
268            )
269            .await
270            .expect("registration failed");
271
272        catalog
273            .deregister(
274                ModelId("llama-7b".to_string()),
275                NodeId("node-1".to_string()),
276            )
277            .await
278            .expect("deregistration failed");
279
280        let models = catalog.list_all().await.expect("list failed");
281        assert!(models.is_empty());
282    }
283
284    #[tokio::test]
285    async fn test_multiple_deployments() {
286        let catalog = ModelCatalog::new();
287
288        // Same model on two nodes
289        catalog
290            .register(
291                ModelId("whisper-base".to_string()),
292                NodeId("node-1".to_string()),
293                RegionId("us-east".to_string()),
294                vec![Capability::Transcribe],
295            )
296            .await
297            .expect("registration failed");
298
299        catalog
300            .register(
301                ModelId("whisper-base".to_string()),
302                NodeId("node-2".to_string()),
303                RegionId("us-west".to_string()),
304                vec![Capability::Transcribe],
305            )
306            .await
307            .expect("registration failed");
308
309        let nodes = catalog
310            .find_by_capability(&Capability::Transcribe)
311            .await
312            .expect("find failed");
313
314        assert_eq!(nodes.len(), 2);
315    }
316
317    #[tokio::test]
318    async fn test_custom_capability() {
319        let catalog = ModelCatalog::new();
320
321        catalog
322            .register(
323                ModelId("sentiment-bert".to_string()),
324                NodeId("node-1".to_string()),
325                RegionId("ap-south".to_string()),
326                vec![Capability::Custom("sentiment".to_string())],
327            )
328            .await
329            .expect("registration failed");
330
331        let nodes = catalog
332            .find_by_capability(&Capability::Custom("sentiment".to_string()))
333            .await
334            .expect("find failed");
335
336        assert_eq!(nodes.len(), 1);
337
338        // Different custom capability should return empty
339        let empty = catalog
340            .find_by_capability(&Capability::Custom("other".to_string()))
341            .await
342            .expect("find failed");
343
344        assert!(empty.is_empty());
345    }
346
347    // =========================================================================
348    // ModelCatalog::get tests
349    // =========================================================================
350
351    #[tokio::test]
352    async fn test_get_existing_model() {
353        let catalog = ModelCatalog::new();
354
355        catalog
356            .register(
357                ModelId("whisper".to_string()),
358                NodeId("n1".to_string()),
359                RegionId("us-west".to_string()),
360                vec![Capability::Transcribe],
361            )
362            .await
363            .expect("registration failed");
364
365        let entry = catalog.get(&ModelId("whisper".to_string()));
366        assert!(entry.is_some());
367        let entry = entry.expect("entry should exist");
368        assert_eq!(entry.model_id, ModelId("whisper".to_string()));
369        assert_eq!(entry.deployments.len(), 1);
370    }
371
372    #[test]
373    fn test_get_nonexistent_model() {
374        let catalog = ModelCatalog::new();
375        let entry = catalog.get(&ModelId("nonexistent".to_string()));
376        assert!(entry.is_none());
377    }
378
379    // =========================================================================
380    // get_metadata tests
381    // =========================================================================
382
383    #[tokio::test]
384    async fn test_get_metadata_existing() {
385        let catalog = ModelCatalog::new();
386
387        catalog
388            .register(
389                ModelId("llama".to_string()),
390                NodeId("n1".to_string()),
391                RegionId("us-east".to_string()),
392                vec![Capability::Generate, Capability::Code],
393            )
394            .await
395            .expect("registration failed");
396
397        let meta = catalog
398            .get_metadata(&ModelId("llama".to_string()))
399            .await
400            .expect("metadata failed");
401
402        assert_eq!(meta.model_id, ModelId("llama".to_string()));
403        assert_eq!(meta.name, "llama");
404        assert_eq!(meta.version, "1.0.0");
405        assert_eq!(meta.capabilities.len(), 2);
406    }
407
408    #[tokio::test]
409    async fn test_get_metadata_nonexistent() {
410        let catalog = ModelCatalog::new();
411
412        let result = catalog.get_metadata(&ModelId("missing".to_string())).await;
413
414        assert!(result.is_err());
415        assert!(matches!(result.unwrap_err(), FederationError::Internal(_)));
416    }
417
418    // =========================================================================
419    // all_entries tests
420    // =========================================================================
421
422    #[test]
423    fn test_all_entries_empty() {
424        let catalog = ModelCatalog::new();
425        assert!(catalog.all_entries().is_empty());
426    }
427
428    #[tokio::test]
429    async fn test_all_entries_multiple() {
430        let catalog = ModelCatalog::new();
431
432        catalog
433            .register(
434                ModelId("m1".to_string()),
435                NodeId("n1".to_string()),
436                RegionId("r1".to_string()),
437                vec![Capability::Generate],
438            )
439            .await
440            .expect("failed");
441
442        catalog
443            .register(
444                ModelId("m2".to_string()),
445                NodeId("n2".to_string()),
446                RegionId("r2".to_string()),
447                vec![Capability::Embed],
448            )
449            .await
450            .expect("failed");
451
452        let entries = catalog.all_entries();
453        assert_eq!(entries.len(), 2);
454    }
455
456    // =========================================================================
457    // deregister edge cases
458    // =========================================================================
459
460    #[tokio::test]
461    async fn test_deregister_nonexistent_model() {
462        let catalog = ModelCatalog::new();
463
464        // Deregistering a non-existent model should succeed (no-op)
465        let result = catalog
466            .deregister(ModelId("missing".to_string()), NodeId("n1".to_string()))
467            .await;
468        assert!(result.is_ok());
469    }
470
471    #[tokio::test]
472    async fn test_deregister_nonexistent_node() {
473        let catalog = ModelCatalog::new();
474
475        catalog
476            .register(
477                ModelId("m1".to_string()),
478                NodeId("n1".to_string()),
479                RegionId("r1".to_string()),
480                vec![Capability::Generate],
481            )
482            .await
483            .expect("failed");
484
485        // Deregister a different node -> model should still exist
486        catalog
487            .deregister(ModelId("m1".to_string()), NodeId("n2".to_string()))
488            .await
489            .expect("deregister failed");
490
491        let models = catalog.list_all().await.expect("list failed");
492        assert_eq!(models.len(), 1);
493    }
494
495    #[tokio::test]
496    async fn test_deregister_partial_keeps_remaining() {
497        let catalog = ModelCatalog::new();
498
499        // Same model on two nodes
500        catalog
501            .register(
502                ModelId("m1".to_string()),
503                NodeId("n1".to_string()),
504                RegionId("r1".to_string()),
505                vec![Capability::Generate],
506            )
507            .await
508            .expect("failed");
509
510        catalog
511            .register(
512                ModelId("m1".to_string()),
513                NodeId("n2".to_string()),
514                RegionId("r2".to_string()),
515                vec![Capability::Generate],
516            )
517            .await
518            .expect("failed");
519
520        // Deregister one node
521        catalog
522            .deregister(ModelId("m1".to_string()), NodeId("n1".to_string()))
523            .await
524            .expect("deregister failed");
525
526        // Model should still exist with 1 deployment
527        let entry = catalog.get(&ModelId("m1".to_string()));
528        assert!(entry.is_some());
529        assert_eq!(entry.expect("should exist").deployments.len(), 1);
530    }
531
532    // =========================================================================
533    // capability_key coverage
534    // =========================================================================
535
536    #[tokio::test]
537    async fn test_all_capability_keys_via_registration() {
538        let catalog = ModelCatalog::new();
539
540        // Register one model for each capability variant
541        let capabilities = vec![
542            (ModelId("t1".to_string()), Capability::Transcribe),
543            (ModelId("t2".to_string()), Capability::Synthesize),
544            (ModelId("t3".to_string()), Capability::Generate),
545            (ModelId("t4".to_string()), Capability::Code),
546            (ModelId("t5".to_string()), Capability::Embed),
547            (ModelId("t6".to_string()), Capability::ImageGen),
548            (
549                ModelId("t7".to_string()),
550                Capability::Custom("custom_task".to_string()),
551            ),
552        ];
553
554        for (model_id, cap) in &capabilities {
555            catalog
556                .register(
557                    model_id.clone(),
558                    NodeId("n1".to_string()),
559                    RegionId("r1".to_string()),
560                    vec![cap.clone()],
561                )
562                .await
563                .expect("registration failed");
564        }
565
566        // Verify each can be found
567        for (_, cap) in &capabilities {
568            let nodes = catalog.find_by_capability(cap).await.expect("find failed");
569            assert_eq!(nodes.len(), 1, "Should find 1 node for {:?}", cap);
570        }
571    }
572
573    // =========================================================================
574    // DeploymentStatus tests
575    // =========================================================================
576
577    #[test]
578    fn test_deployment_status_equality() {
579        assert_eq!(DeploymentStatus::Ready, DeploymentStatus::Ready);
580        assert_ne!(DeploymentStatus::Ready, DeploymentStatus::Loading);
581        assert_ne!(DeploymentStatus::Draining, DeploymentStatus::Removed);
582    }
583
584    #[test]
585    fn test_deployment_status_all_variants() {
586        let statuses = [
587            DeploymentStatus::Loading,
588            DeploymentStatus::Ready,
589            DeploymentStatus::Draining,
590            DeploymentStatus::Removed,
591        ];
592        // All distinct
593        for (i, a) in statuses.iter().enumerate() {
594            for (j, b) in statuses.iter().enumerate() {
595                if i == j {
596                    assert_eq!(a, b);
597                } else {
598                    assert_ne!(a, b);
599                }
600            }
601        }
602    }
603
604    #[test]
605    fn test_deployment_status_copy() {
606        let status = DeploymentStatus::Draining;
607        let copied = status;
608        assert_eq!(status, copied);
609    }
610
611    // =========================================================================
612    // ModelEntry/ModelDeployment construction tests
613    // =========================================================================
614
615    #[test]
616    fn test_model_entry_clone() {
617        let entry = ModelEntry {
618            model_id: ModelId("test".to_string()),
619            metadata: ModelMetadata {
620                model_id: ModelId("test".to_string()),
621                name: "Test Model".to_string(),
622                version: "1.0".to_string(),
623                capabilities: vec![Capability::Generate],
624                parameters: 7_000_000_000,
625                quantization: Some("Q4_K".to_string()),
626            },
627            deployments: vec![ModelDeployment {
628                node_id: NodeId("n1".to_string()),
629                region_id: RegionId("us-west".to_string()),
630                endpoint: "http://n1:8080".to_string(),
631                status: DeploymentStatus::Ready,
632            }],
633        };
634
635        let cloned = entry.clone();
636        assert_eq!(cloned.model_id, ModelId("test".to_string()));
637        assert_eq!(cloned.deployments.len(), 1);
638    }
639
640    #[test]
641    fn test_model_deployment_construction() {
642        let dep = ModelDeployment {
643            node_id: NodeId("gpu-node".to_string()),
644            region_id: RegionId("eu-west".to_string()),
645            endpoint: "https://gpu-node.eu-west:443".to_string(),
646            status: DeploymentStatus::Loading,
647        };
648        assert_eq!(dep.node_id, NodeId("gpu-node".to_string()));
649        assert_eq!(dep.status, DeploymentStatus::Loading);
650    }
651
652    // =========================================================================
653    // ModelCatalog::default tests
654    // =========================================================================
655
656    #[test]
657    fn test_model_catalog_default() {
658        let catalog = ModelCatalog::default();
659        assert!(catalog.all_entries().is_empty());
660    }
661
662    // =========================================================================
663    // find_by_capability with non-Ready deployments
664    // =========================================================================
665
666    #[tokio::test]
667    async fn test_find_by_capability_empty() {
668        let catalog = ModelCatalog::new();
669        let nodes = catalog
670            .find_by_capability(&Capability::Generate)
671            .await
672            .expect("find failed");
673        assert!(nodes.is_empty());
674    }
675
676    #[tokio::test]
677    async fn test_find_by_capability_no_match() {
678        let catalog = ModelCatalog::new();
679
680        catalog
681            .register(
682                ModelId("whisper".to_string()),
683                NodeId("n1".to_string()),
684                RegionId("r1".to_string()),
685                vec![Capability::Transcribe],
686            )
687            .await
688            .expect("failed");
689
690        // Search for different capability
691        let nodes = catalog
692            .find_by_capability(&Capability::Generate)
693            .await
694            .expect("find failed");
695        assert!(nodes.is_empty());
696    }
697
698    // =========================================================================
699    // list_all tests
700    // =========================================================================
701
702    #[tokio::test]
703    async fn test_list_all_empty() {
704        let catalog = ModelCatalog::new();
705        let models = catalog.list_all().await.expect("list failed");
706        assert!(models.is_empty());
707    }
708
709    #[tokio::test]
710    async fn test_list_all_after_deregister_all() {
711        let catalog = ModelCatalog::new();
712
713        catalog
714            .register(
715                ModelId("m1".to_string()),
716                NodeId("n1".to_string()),
717                RegionId("r1".to_string()),
718                vec![Capability::Generate],
719            )
720            .await
721            .expect("failed");
722
723        catalog
724            .deregister(ModelId("m1".to_string()), NodeId("n1".to_string()))
725            .await
726            .expect("failed");
727
728        let models = catalog.list_all().await.expect("list failed");
729        assert!(models.is_empty());
730    }
731}