1use super::traits::*;
7use std::collections::HashMap;
8use std::sync::RwLock;
9
10#[derive(Debug, Clone)]
16pub struct ModelEntry {
17 pub model_id: ModelId,
18 pub metadata: ModelMetadata,
19 pub deployments: Vec<ModelDeployment>,
20}
21
22#[derive(Debug, Clone)]
24pub struct ModelDeployment {
25 pub node_id: NodeId,
26 pub region_id: RegionId,
27 pub endpoint: String,
28 pub status: DeploymentStatus,
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum DeploymentStatus {
34 Loading,
36 Ready,
38 Draining,
40 Removed,
42}
43
44pub struct ModelCatalog {
50 entries: RwLock<HashMap<ModelId, ModelEntry>>,
51 by_capability: RwLock<HashMap<String, Vec<ModelId>>>,
52}
53
54impl ModelCatalog {
55 pub fn new() -> Self {
56 Self {
57 entries: RwLock::new(HashMap::new()),
58 by_capability: RwLock::new(HashMap::new()),
59 }
60 }
61
62 pub fn all_entries(&self) -> Vec<ModelEntry> {
64 self.entries
65 .read()
66 .expect("catalog lock poisoned")
67 .values()
68 .cloned()
69 .collect()
70 }
71
72 pub fn get(&self, model_id: &ModelId) -> Option<ModelEntry> {
74 self.entries
75 .read()
76 .expect("catalog lock poisoned")
77 .get(model_id)
78 .cloned()
79 }
80
81 fn capability_key(cap: &Capability) -> String {
82 match cap {
83 Capability::Transcribe => "transcribe".to_string(),
84 Capability::Synthesize => "synthesize".to_string(),
85 Capability::Generate => "generate".to_string(),
86 Capability::Code => "code".to_string(),
87 Capability::Embed => "embed".to_string(),
88 Capability::ImageGen => "image_gen".to_string(),
89 Capability::Custom(s) => format!("custom:{}", s),
90 }
91 }
92}
93
94impl Default for ModelCatalog {
95 fn default() -> Self {
96 Self::new()
97 }
98}
99
100impl ModelCatalogTrait for ModelCatalog {
101 fn register(
102 &self,
103 model_id: ModelId,
104 node_id: NodeId,
105 region_id: RegionId,
106 capabilities: Vec<Capability>,
107 ) -> BoxFuture<'_, FederationResult<()>> {
108 Box::pin(async move {
109 let mut entries = self.entries.write().expect("catalog lock poisoned");
110 let mut by_cap = self.by_capability.write().expect("catalog lock poisoned");
111
112 let deployment = ModelDeployment {
113 node_id,
114 region_id,
115 endpoint: String::new(), status: DeploymentStatus::Ready,
117 };
118
119 if let Some(entry) = entries.get_mut(&model_id) {
120 entry.deployments.push(deployment);
122 } else {
123 let metadata = ModelMetadata {
125 model_id: model_id.clone(),
126 name: model_id.0.clone(),
127 version: "1.0.0".to_string(),
128 capabilities: capabilities.clone(),
129 parameters: 0,
130 quantization: None,
131 };
132
133 let entry = ModelEntry {
134 model_id: model_id.clone(),
135 metadata,
136 deployments: vec![deployment],
137 };
138
139 entries.insert(model_id.clone(), entry);
140
141 for cap in &capabilities {
143 let key = Self::capability_key(cap);
144 by_cap.entry(key).or_default().push(model_id.clone());
145 }
146 }
147
148 Ok(())
149 })
150 }
151
152 fn deregister(
153 &self,
154 model_id: ModelId,
155 node_id: NodeId,
156 ) -> BoxFuture<'_, FederationResult<()>> {
157 Box::pin(async move {
158 let mut entries = self.entries.write().expect("catalog lock poisoned");
159
160 if let Some(entry) = entries.get_mut(&model_id) {
161 entry.deployments.retain(|d| d.node_id != node_id);
162
163 if entry.deployments.is_empty() {
165 entries.remove(&model_id);
166 }
167 }
168
169 Ok(())
170 })
171 }
172
173 fn find_by_capability(
174 &self,
175 capability: &Capability,
176 ) -> BoxFuture<'_, FederationResult<Vec<(NodeId, RegionId)>>> {
177 let key = Self::capability_key(capability);
178
179 Box::pin(async move {
180 let entries = self.entries.read().expect("catalog lock poisoned");
181 let by_cap = self.by_capability.read().expect("catalog lock poisoned");
182
183 let mut results = Vec::new();
184
185 if let Some(model_ids) = by_cap.get(&key) {
186 for model_id in model_ids {
187 if let Some(entry) = entries.get(model_id) {
188 for deployment in &entry.deployments {
189 if deployment.status == DeploymentStatus::Ready {
190 results.push((
191 deployment.node_id.clone(),
192 deployment.region_id.clone(),
193 ));
194 }
195 }
196 }
197 }
198 }
199
200 Ok(results)
201 })
202 }
203
204 fn list_all(&self) -> BoxFuture<'_, FederationResult<Vec<ModelId>>> {
205 Box::pin(async move {
206 let entries = self.entries.read().expect("catalog lock poisoned");
207 Ok(entries.keys().cloned().collect())
208 })
209 }
210
211 fn get_metadata(&self, model_id: &ModelId) -> BoxFuture<'_, FederationResult<ModelMetadata>> {
212 let model_id = model_id.clone();
213
214 Box::pin(async move {
215 let entries = self.entries.read().expect("catalog lock poisoned");
216
217 entries
218 .get(&model_id)
219 .map(|e| e.metadata.clone())
220 .ok_or_else(|| {
221 FederationError::Internal(format!("Model not found: {:?}", model_id))
222 })
223 })
224 }
225}
226
227#[cfg(test)]
232mod tests {
233 use super::*;
234
235 #[tokio::test]
236 async fn test_register_and_find() {
237 let catalog = ModelCatalog::new();
238
239 catalog
240 .register(
241 ModelId("whisper-large".to_string()),
242 NodeId("node-1".to_string()),
243 RegionId("us-west".to_string()),
244 vec![Capability::Transcribe],
245 )
246 .await
247 .expect("registration failed");
248
249 let nodes = catalog
250 .find_by_capability(&Capability::Transcribe)
251 .await
252 .expect("find failed");
253
254 assert_eq!(nodes.len(), 1);
255 assert_eq!(nodes[0].0, NodeId("node-1".to_string()));
256 }
257
258 #[tokio::test]
259 async fn test_deregister() {
260 let catalog = ModelCatalog::new();
261
262 catalog
263 .register(
264 ModelId("llama-7b".to_string()),
265 NodeId("node-1".to_string()),
266 RegionId("eu-west".to_string()),
267 vec![Capability::Generate],
268 )
269 .await
270 .expect("registration failed");
271
272 catalog
273 .deregister(
274 ModelId("llama-7b".to_string()),
275 NodeId("node-1".to_string()),
276 )
277 .await
278 .expect("deregistration failed");
279
280 let models = catalog.list_all().await.expect("list failed");
281 assert!(models.is_empty());
282 }
283
284 #[tokio::test]
285 async fn test_multiple_deployments() {
286 let catalog = ModelCatalog::new();
287
288 catalog
290 .register(
291 ModelId("whisper-base".to_string()),
292 NodeId("node-1".to_string()),
293 RegionId("us-east".to_string()),
294 vec![Capability::Transcribe],
295 )
296 .await
297 .expect("registration failed");
298
299 catalog
300 .register(
301 ModelId("whisper-base".to_string()),
302 NodeId("node-2".to_string()),
303 RegionId("us-west".to_string()),
304 vec![Capability::Transcribe],
305 )
306 .await
307 .expect("registration failed");
308
309 let nodes = catalog
310 .find_by_capability(&Capability::Transcribe)
311 .await
312 .expect("find failed");
313
314 assert_eq!(nodes.len(), 2);
315 }
316
317 #[tokio::test]
318 async fn test_custom_capability() {
319 let catalog = ModelCatalog::new();
320
321 catalog
322 .register(
323 ModelId("sentiment-bert".to_string()),
324 NodeId("node-1".to_string()),
325 RegionId("ap-south".to_string()),
326 vec![Capability::Custom("sentiment".to_string())],
327 )
328 .await
329 .expect("registration failed");
330
331 let nodes = catalog
332 .find_by_capability(&Capability::Custom("sentiment".to_string()))
333 .await
334 .expect("find failed");
335
336 assert_eq!(nodes.len(), 1);
337
338 let empty = catalog
340 .find_by_capability(&Capability::Custom("other".to_string()))
341 .await
342 .expect("find failed");
343
344 assert!(empty.is_empty());
345 }
346}