1use super::traits::*;
7use std::collections::HashMap;
8use std::sync::RwLock;
9
10#[derive(Debug, Clone)]
16pub struct ModelEntry {
17 pub model_id: ModelId,
18 pub metadata: ModelMetadata,
19 pub deployments: Vec<ModelDeployment>,
20}
21
22#[derive(Debug, Clone)]
24pub struct ModelDeployment {
25 pub node_id: NodeId,
26 pub region_id: RegionId,
27 pub endpoint: String,
28 pub status: DeploymentStatus,
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum DeploymentStatus {
34 Loading,
36 Ready,
38 Draining,
40 Removed,
42}
43
44pub struct ModelCatalog {
50 entries: RwLock<HashMap<ModelId, ModelEntry>>,
51 by_capability: RwLock<HashMap<String, Vec<ModelId>>>,
52}
53
54impl ModelCatalog {
55 pub fn new() -> Self {
56 Self {
57 entries: RwLock::new(HashMap::new()),
58 by_capability: RwLock::new(HashMap::new()),
59 }
60 }
61
62 pub fn all_entries(&self) -> Vec<ModelEntry> {
64 self.entries
65 .read()
66 .expect("catalog lock poisoned")
67 .values()
68 .cloned()
69 .collect()
70 }
71
72 pub fn get(&self, model_id: &ModelId) -> Option<ModelEntry> {
74 self.entries
75 .read()
76 .expect("catalog lock poisoned")
77 .get(model_id)
78 .cloned()
79 }
80
81 fn capability_key(cap: &Capability) -> String {
82 match cap {
83 Capability::Transcribe => "transcribe".to_string(),
84 Capability::Synthesize => "synthesize".to_string(),
85 Capability::Generate => "generate".to_string(),
86 Capability::Code => "code".to_string(),
87 Capability::Embed => "embed".to_string(),
88 Capability::ImageGen => "image_gen".to_string(),
89 Capability::Custom(s) => format!("custom:{}", s),
90 }
91 }
92}
93
94impl Default for ModelCatalog {
95 fn default() -> Self {
96 Self::new()
97 }
98}
99
100impl ModelCatalogTrait for ModelCatalog {
101 fn register(
102 &self,
103 model_id: ModelId,
104 node_id: NodeId,
105 region_id: RegionId,
106 capabilities: Vec<Capability>,
107 ) -> BoxFuture<'_, FederationResult<()>> {
108 Box::pin(async move {
109 let mut entries = self.entries.write().expect("catalog lock poisoned");
110 let mut by_cap = self.by_capability.write().expect("catalog lock poisoned");
111
112 let deployment = ModelDeployment {
113 node_id,
114 region_id,
115 endpoint: String::new(), status: DeploymentStatus::Ready,
117 };
118
119 if let Some(entry) = entries.get_mut(&model_id) {
120 entry.deployments.push(deployment);
122 } else {
123 let metadata = ModelMetadata {
125 model_id: model_id.clone(),
126 name: model_id.0.clone(),
127 version: "1.0.0".to_string(),
128 capabilities: capabilities.clone(),
129 parameters: 0,
130 quantization: None,
131 };
132
133 let entry = ModelEntry {
134 model_id: model_id.clone(),
135 metadata,
136 deployments: vec![deployment],
137 };
138
139 entries.insert(model_id.clone(), entry);
140
141 for cap in &capabilities {
143 let key = Self::capability_key(cap);
144 by_cap.entry(key).or_default().push(model_id.clone());
145 }
146 }
147
148 Ok(())
149 })
150 }
151
152 fn deregister(
153 &self,
154 model_id: ModelId,
155 node_id: NodeId,
156 ) -> BoxFuture<'_, FederationResult<()>> {
157 Box::pin(async move {
158 let mut entries = self.entries.write().expect("catalog lock poisoned");
159
160 if let Some(entry) = entries.get_mut(&model_id) {
161 entry.deployments.retain(|d| d.node_id != node_id);
162
163 if entry.deployments.is_empty() {
165 entries.remove(&model_id);
166 }
167 }
168
169 Ok(())
170 })
171 }
172
173 fn find_by_capability(
174 &self,
175 capability: &Capability,
176 ) -> BoxFuture<'_, FederationResult<Vec<(NodeId, RegionId)>>> {
177 let key = Self::capability_key(capability);
178
179 Box::pin(async move {
180 let entries = self.entries.read().expect("catalog lock poisoned");
181 let by_cap = self.by_capability.read().expect("catalog lock poisoned");
182
183 let mut results = Vec::new();
184
185 if let Some(model_ids) = by_cap.get(&key) {
186 for model_id in model_ids {
187 if let Some(entry) = entries.get(model_id) {
188 for deployment in &entry.deployments {
189 if deployment.status == DeploymentStatus::Ready {
190 results.push((
191 deployment.node_id.clone(),
192 deployment.region_id.clone(),
193 ));
194 }
195 }
196 }
197 }
198 }
199
200 Ok(results)
201 })
202 }
203
204 fn list_all(&self) -> BoxFuture<'_, FederationResult<Vec<ModelId>>> {
205 Box::pin(async move {
206 let entries = self.entries.read().expect("catalog lock poisoned");
207 Ok(entries.keys().cloned().collect())
208 })
209 }
210
211 fn get_metadata(&self, model_id: &ModelId) -> BoxFuture<'_, FederationResult<ModelMetadata>> {
212 let model_id = model_id.clone();
213
214 Box::pin(async move {
215 let entries = self.entries.read().expect("catalog lock poisoned");
216
217 entries
218 .get(&model_id)
219 .map(|e| e.metadata.clone())
220 .ok_or_else(|| {
221 FederationError::Internal(format!("Model not found: {:?}", model_id))
222 })
223 })
224 }
225}
226
227#[cfg(test)]
232mod tests {
233 use super::*;
234
235 #[tokio::test]
236 async fn test_register_and_find() {
237 let catalog = ModelCatalog::new();
238
239 catalog
240 .register(
241 ModelId("whisper-large".to_string()),
242 NodeId("node-1".to_string()),
243 RegionId("us-west".to_string()),
244 vec![Capability::Transcribe],
245 )
246 .await
247 .expect("registration failed");
248
249 let nodes = catalog
250 .find_by_capability(&Capability::Transcribe)
251 .await
252 .expect("find failed");
253
254 assert_eq!(nodes.len(), 1);
255 assert_eq!(nodes[0].0, NodeId("node-1".to_string()));
256 }
257
258 #[tokio::test]
259 async fn test_deregister() {
260 let catalog = ModelCatalog::new();
261
262 catalog
263 .register(
264 ModelId("llama-7b".to_string()),
265 NodeId("node-1".to_string()),
266 RegionId("eu-west".to_string()),
267 vec![Capability::Generate],
268 )
269 .await
270 .expect("registration failed");
271
272 catalog
273 .deregister(
274 ModelId("llama-7b".to_string()),
275 NodeId("node-1".to_string()),
276 )
277 .await
278 .expect("deregistration failed");
279
280 let models = catalog.list_all().await.expect("list failed");
281 assert!(models.is_empty());
282 }
283
284 #[tokio::test]
285 async fn test_multiple_deployments() {
286 let catalog = ModelCatalog::new();
287
288 catalog
290 .register(
291 ModelId("whisper-base".to_string()),
292 NodeId("node-1".to_string()),
293 RegionId("us-east".to_string()),
294 vec![Capability::Transcribe],
295 )
296 .await
297 .expect("registration failed");
298
299 catalog
300 .register(
301 ModelId("whisper-base".to_string()),
302 NodeId("node-2".to_string()),
303 RegionId("us-west".to_string()),
304 vec![Capability::Transcribe],
305 )
306 .await
307 .expect("registration failed");
308
309 let nodes = catalog
310 .find_by_capability(&Capability::Transcribe)
311 .await
312 .expect("find failed");
313
314 assert_eq!(nodes.len(), 2);
315 }
316
317 #[tokio::test]
318 async fn test_custom_capability() {
319 let catalog = ModelCatalog::new();
320
321 catalog
322 .register(
323 ModelId("sentiment-bert".to_string()),
324 NodeId("node-1".to_string()),
325 RegionId("ap-south".to_string()),
326 vec![Capability::Custom("sentiment".to_string())],
327 )
328 .await
329 .expect("registration failed");
330
331 let nodes = catalog
332 .find_by_capability(&Capability::Custom("sentiment".to_string()))
333 .await
334 .expect("find failed");
335
336 assert_eq!(nodes.len(), 1);
337
338 let empty = catalog
340 .find_by_capability(&Capability::Custom("other".to_string()))
341 .await
342 .expect("find failed");
343
344 assert!(empty.is_empty());
345 }
346
347 #[tokio::test]
352 async fn test_get_existing_model() {
353 let catalog = ModelCatalog::new();
354
355 catalog
356 .register(
357 ModelId("whisper".to_string()),
358 NodeId("n1".to_string()),
359 RegionId("us-west".to_string()),
360 vec![Capability::Transcribe],
361 )
362 .await
363 .expect("registration failed");
364
365 let entry = catalog.get(&ModelId("whisper".to_string()));
366 assert!(entry.is_some());
367 let entry = entry.expect("entry should exist");
368 assert_eq!(entry.model_id, ModelId("whisper".to_string()));
369 assert_eq!(entry.deployments.len(), 1);
370 }
371
372 #[test]
373 fn test_get_nonexistent_model() {
374 let catalog = ModelCatalog::new();
375 let entry = catalog.get(&ModelId("nonexistent".to_string()));
376 assert!(entry.is_none());
377 }
378
379 #[tokio::test]
384 async fn test_get_metadata_existing() {
385 let catalog = ModelCatalog::new();
386
387 catalog
388 .register(
389 ModelId("llama".to_string()),
390 NodeId("n1".to_string()),
391 RegionId("us-east".to_string()),
392 vec![Capability::Generate, Capability::Code],
393 )
394 .await
395 .expect("registration failed");
396
397 let meta = catalog
398 .get_metadata(&ModelId("llama".to_string()))
399 .await
400 .expect("metadata failed");
401
402 assert_eq!(meta.model_id, ModelId("llama".to_string()));
403 assert_eq!(meta.name, "llama");
404 assert_eq!(meta.version, "1.0.0");
405 assert_eq!(meta.capabilities.len(), 2);
406 }
407
408 #[tokio::test]
409 async fn test_get_metadata_nonexistent() {
410 let catalog = ModelCatalog::new();
411
412 let result = catalog.get_metadata(&ModelId("missing".to_string())).await;
413
414 assert!(result.is_err());
415 assert!(matches!(result.unwrap_err(), FederationError::Internal(_)));
416 }
417
418 #[test]
423 fn test_all_entries_empty() {
424 let catalog = ModelCatalog::new();
425 assert!(catalog.all_entries().is_empty());
426 }
427
428 #[tokio::test]
429 async fn test_all_entries_multiple() {
430 let catalog = ModelCatalog::new();
431
432 catalog
433 .register(
434 ModelId("m1".to_string()),
435 NodeId("n1".to_string()),
436 RegionId("r1".to_string()),
437 vec![Capability::Generate],
438 )
439 .await
440 .expect("failed");
441
442 catalog
443 .register(
444 ModelId("m2".to_string()),
445 NodeId("n2".to_string()),
446 RegionId("r2".to_string()),
447 vec![Capability::Embed],
448 )
449 .await
450 .expect("failed");
451
452 let entries = catalog.all_entries();
453 assert_eq!(entries.len(), 2);
454 }
455
456 #[tokio::test]
461 async fn test_deregister_nonexistent_model() {
462 let catalog = ModelCatalog::new();
463
464 let result = catalog
466 .deregister(ModelId("missing".to_string()), NodeId("n1".to_string()))
467 .await;
468 assert!(result.is_ok());
469 }
470
471 #[tokio::test]
472 async fn test_deregister_nonexistent_node() {
473 let catalog = ModelCatalog::new();
474
475 catalog
476 .register(
477 ModelId("m1".to_string()),
478 NodeId("n1".to_string()),
479 RegionId("r1".to_string()),
480 vec![Capability::Generate],
481 )
482 .await
483 .expect("failed");
484
485 catalog
487 .deregister(ModelId("m1".to_string()), NodeId("n2".to_string()))
488 .await
489 .expect("deregister failed");
490
491 let models = catalog.list_all().await.expect("list failed");
492 assert_eq!(models.len(), 1);
493 }
494
495 #[tokio::test]
496 async fn test_deregister_partial_keeps_remaining() {
497 let catalog = ModelCatalog::new();
498
499 catalog
501 .register(
502 ModelId("m1".to_string()),
503 NodeId("n1".to_string()),
504 RegionId("r1".to_string()),
505 vec![Capability::Generate],
506 )
507 .await
508 .expect("failed");
509
510 catalog
511 .register(
512 ModelId("m1".to_string()),
513 NodeId("n2".to_string()),
514 RegionId("r2".to_string()),
515 vec![Capability::Generate],
516 )
517 .await
518 .expect("failed");
519
520 catalog
522 .deregister(ModelId("m1".to_string()), NodeId("n1".to_string()))
523 .await
524 .expect("deregister failed");
525
526 let entry = catalog.get(&ModelId("m1".to_string()));
528 assert!(entry.is_some());
529 assert_eq!(entry.expect("should exist").deployments.len(), 1);
530 }
531
532 #[tokio::test]
537 async fn test_all_capability_keys_via_registration() {
538 let catalog = ModelCatalog::new();
539
540 let capabilities = vec![
542 (ModelId("t1".to_string()), Capability::Transcribe),
543 (ModelId("t2".to_string()), Capability::Synthesize),
544 (ModelId("t3".to_string()), Capability::Generate),
545 (ModelId("t4".to_string()), Capability::Code),
546 (ModelId("t5".to_string()), Capability::Embed),
547 (ModelId("t6".to_string()), Capability::ImageGen),
548 (
549 ModelId("t7".to_string()),
550 Capability::Custom("custom_task".to_string()),
551 ),
552 ];
553
554 for (model_id, cap) in &capabilities {
555 catalog
556 .register(
557 model_id.clone(),
558 NodeId("n1".to_string()),
559 RegionId("r1".to_string()),
560 vec![cap.clone()],
561 )
562 .await
563 .expect("registration failed");
564 }
565
566 for (_, cap) in &capabilities {
568 let nodes = catalog.find_by_capability(cap).await.expect("find failed");
569 assert_eq!(nodes.len(), 1, "Should find 1 node for {:?}", cap);
570 }
571 }
572
573 #[test]
578 fn test_deployment_status_equality() {
579 assert_eq!(DeploymentStatus::Ready, DeploymentStatus::Ready);
580 assert_ne!(DeploymentStatus::Ready, DeploymentStatus::Loading);
581 assert_ne!(DeploymentStatus::Draining, DeploymentStatus::Removed);
582 }
583
584 #[test]
585 fn test_deployment_status_all_variants() {
586 let statuses = [
587 DeploymentStatus::Loading,
588 DeploymentStatus::Ready,
589 DeploymentStatus::Draining,
590 DeploymentStatus::Removed,
591 ];
592 for (i, a) in statuses.iter().enumerate() {
594 for (j, b) in statuses.iter().enumerate() {
595 if i == j {
596 assert_eq!(a, b);
597 } else {
598 assert_ne!(a, b);
599 }
600 }
601 }
602 }
603
604 #[test]
605 fn test_deployment_status_copy() {
606 let status = DeploymentStatus::Draining;
607 let copied = status;
608 assert_eq!(status, copied);
609 }
610
611 #[test]
616 fn test_model_entry_clone() {
617 let entry = ModelEntry {
618 model_id: ModelId("test".to_string()),
619 metadata: ModelMetadata {
620 model_id: ModelId("test".to_string()),
621 name: "Test Model".to_string(),
622 version: "1.0".to_string(),
623 capabilities: vec![Capability::Generate],
624 parameters: 7_000_000_000,
625 quantization: Some("Q4_K".to_string()),
626 },
627 deployments: vec![ModelDeployment {
628 node_id: NodeId("n1".to_string()),
629 region_id: RegionId("us-west".to_string()),
630 endpoint: "http://n1:8080".to_string(),
631 status: DeploymentStatus::Ready,
632 }],
633 };
634
635 let cloned = entry.clone();
636 assert_eq!(cloned.model_id, ModelId("test".to_string()));
637 assert_eq!(cloned.deployments.len(), 1);
638 }
639
640 #[test]
641 fn test_model_deployment_construction() {
642 let dep = ModelDeployment {
643 node_id: NodeId("gpu-node".to_string()),
644 region_id: RegionId("eu-west".to_string()),
645 endpoint: "https://gpu-node.eu-west:443".to_string(),
646 status: DeploymentStatus::Loading,
647 };
648 assert_eq!(dep.node_id, NodeId("gpu-node".to_string()));
649 assert_eq!(dep.status, DeploymentStatus::Loading);
650 }
651
652 #[test]
657 fn test_model_catalog_default() {
658 let catalog = ModelCatalog::default();
659 assert!(catalog.all_entries().is_empty());
660 }
661
662 #[tokio::test]
667 async fn test_find_by_capability_empty() {
668 let catalog = ModelCatalog::new();
669 let nodes = catalog
670 .find_by_capability(&Capability::Generate)
671 .await
672 .expect("find failed");
673 assert!(nodes.is_empty());
674 }
675
676 #[tokio::test]
677 async fn test_find_by_capability_no_match() {
678 let catalog = ModelCatalog::new();
679
680 catalog
681 .register(
682 ModelId("whisper".to_string()),
683 NodeId("n1".to_string()),
684 RegionId("r1".to_string()),
685 vec![Capability::Transcribe],
686 )
687 .await
688 .expect("failed");
689
690 let nodes = catalog
692 .find_by_capability(&Capability::Generate)
693 .await
694 .expect("find failed");
695 assert!(nodes.is_empty());
696 }
697
698 #[tokio::test]
703 async fn test_list_all_empty() {
704 let catalog = ModelCatalog::new();
705 let models = catalog.list_all().await.expect("list failed");
706 assert!(models.is_empty());
707 }
708
709 #[tokio::test]
710 async fn test_list_all_after_deregister_all() {
711 let catalog = ModelCatalog::new();
712
713 catalog
714 .register(
715 ModelId("m1".to_string()),
716 NodeId("n1".to_string()),
717 RegionId("r1".to_string()),
718 vec![Capability::Generate],
719 )
720 .await
721 .expect("failed");
722
723 catalog
724 .deregister(ModelId("m1".to_string()), NodeId("n1".to_string()))
725 .await
726 .expect("failed");
727
728 let models = catalog.list_all().await.expect("list failed");
729 assert!(models.is_empty());
730 }
731}