1use std::collections::HashMap;
10use std::sync::RwLock;
11
12use crate::error::{M2MError, Result};
13use crate::models::card::{Encoding, ModelCard, Provider};
14use crate::models::embedded::get_embedded_models;
15
16pub struct ModelRegistry {
45 by_id: HashMap<String, ModelCard>,
47
48 abbrev_to_id: HashMap<String, String>,
50
51 dynamic: RwLock<HashMap<String, ModelCard>>,
53
54 dynamic_abbrevs: RwLock<HashMap<String, String>>,
56}
57
58impl Default for ModelRegistry {
59 fn default() -> Self {
60 Self::new()
61 }
62}
63
64impl ModelRegistry {
65 pub fn new() -> Self {
67 let mut registry = Self {
68 by_id: HashMap::new(),
69 abbrev_to_id: HashMap::new(),
70 dynamic: RwLock::new(HashMap::new()),
71 dynamic_abbrevs: RwLock::new(HashMap::new()),
72 };
73
74 registry.load_embedded();
75 registry
76 }
77
78 fn load_embedded(&mut self) {
80 for card in get_embedded_models() {
81 self.abbrev_to_id
82 .insert(card.abbrev.clone(), card.id.clone());
83 self.by_id.insert(card.id.clone(), card);
84 }
85 }
86
87 pub fn get(&self, id_or_abbrev: &str) -> Option<ModelCard> {
94 if let Some(card) = self.by_id.get(id_or_abbrev) {
96 return Some(card.clone());
97 }
98
99 if let Some(full_id) = self.abbrev_to_id.get(id_or_abbrev) {
101 if let Some(card) = self.by_id.get(full_id) {
102 return Some(card.clone());
103 }
104 }
105
106 if let Ok(dynamic) = self.dynamic.read() {
108 if let Some(card) = dynamic.get(id_or_abbrev) {
109 return Some(card.clone());
110 }
111 }
112
113 if let Ok(abbrevs) = self.dynamic_abbrevs.read() {
115 if let Some(full_id) = abbrevs.get(id_or_abbrev) {
116 if let Ok(dynamic) = self.dynamic.read() {
117 if let Some(card) = dynamic.get(full_id) {
118 return Some(card.clone());
119 }
120 }
121 }
122 }
123
124 None
125 }
126
127 pub fn contains(&self, id_or_abbrev: &str) -> bool {
129 self.get(id_or_abbrev).is_some()
130 }
131
132 pub fn get_encoding(&self, model: &str) -> Encoding {
136 self.get(model)
137 .map(|c| c.encoding)
138 .unwrap_or_else(|| Encoding::infer_from_id(model))
139 }
140
141 pub fn get_context_length(&self, model: &str) -> u32 {
143 self.get(model).map(|c| c.context_length).unwrap_or(128000) }
145
146 pub fn abbreviate(&self, model_id: &str) -> String {
151 if let Some(card) = self.by_id.get(model_id) {
153 return card.abbrev.clone();
154 }
155
156 if let Ok(dynamic) = self.dynamic.read() {
158 if let Some(card) = dynamic.get(model_id) {
159 return card.abbrev.clone();
160 }
161 }
162
163 let provider = Provider::from_model_id(model_id);
165 ModelCard::generate_abbrev(model_id, provider)
166 }
167
168 pub fn expand(&self, abbrev: &str) -> Option<String> {
172 if let Some(id) = self.abbrev_to_id.get(abbrev) {
174 return Some(id.clone());
175 }
176
177 if let Ok(abbrevs) = self.dynamic_abbrevs.read() {
179 if let Some(id) = abbrevs.get(abbrev) {
180 return Some(id.clone());
181 }
182 }
183
184 None
185 }
186
187 pub fn list_ids(&self) -> Vec<&str> {
189 self.by_id.keys().map(|s| s.as_str()).collect()
190 }
191
192 pub fn list_abbrevs(&self) -> Vec<&str> {
194 self.abbrev_to_id.keys().map(|s| s.as_str()).collect()
195 }
196
197 pub fn len(&self) -> usize {
199 let dynamic_count = self.dynamic.read().map(|d| d.len()).unwrap_or(0);
200 self.by_id.len() + dynamic_count
201 }
202
203 pub fn is_empty(&self) -> bool {
205 self.len() == 0
206 }
207
208 pub fn embedded_count(&self) -> usize {
210 self.by_id.len()
211 }
212
213 pub fn dynamic_count(&self) -> usize {
215 self.dynamic.read().map(|d| d.len()).unwrap_or(0)
216 }
217
218 pub fn add_dynamic(&self, card: ModelCard) -> Result<()> {
220 let mut dynamic = self
221 .dynamic
222 .write()
223 .map_err(|_| M2MError::Compression("Lock poisoned".into()))?;
224
225 let mut abbrevs = self
226 .dynamic_abbrevs
227 .write()
228 .map_err(|_| M2MError::Compression("Lock poisoned".into()))?;
229
230 abbrevs.insert(card.abbrev.clone(), card.id.clone());
231 dynamic.insert(card.id.clone(), card);
232
233 Ok(())
234 }
235
236 pub fn clear_dynamic(&self) -> Result<()> {
238 let mut dynamic = self
239 .dynamic
240 .write()
241 .map_err(|_| M2MError::Compression("Lock poisoned".into()))?;
242
243 let mut abbrevs = self
244 .dynamic_abbrevs
245 .write()
246 .map_err(|_| M2MError::Compression("Lock poisoned".into()))?;
247
248 dynamic.clear();
249 abbrevs.clear();
250
251 Ok(())
252 }
253
254 pub fn get_by_provider(&self, provider: Provider) -> Vec<ModelCard> {
256 self.by_id
257 .values()
258 .filter(|card| card.provider == provider)
259 .cloned()
260 .collect()
261 }
262
263 pub fn search(&self, query: &str) -> Vec<ModelCard> {
265 let query_lower = query.to_lowercase();
266
267 self.by_id
268 .values()
269 .filter(|card| {
270 card.id.to_lowercase().contains(&query_lower)
271 || card.abbrev.to_lowercase().contains(&query_lower)
272 })
273 .cloned()
274 .collect()
275 }
276
277 pub fn iter(&self) -> impl Iterator<Item = &ModelCard> {
279 self.by_id.values()
280 }
281}
282
283#[derive(Debug, serde::Deserialize)]
285pub struct OpenRouterModel {
286 pub id: String,
287 pub name: Option<String>,
288 pub context_length: Option<u32>,
289 pub pricing: Option<OpenRouterPricing>,
290}
291
292#[derive(Debug, serde::Deserialize)]
293pub struct OpenRouterPricing {
294 pub prompt: Option<String>,
295 pub completion: Option<String>,
296}
297
298#[derive(Debug, serde::Deserialize)]
305pub struct OpenRouterModelsResponse {
306 pub data: Vec<OpenRouterModel>,
308}
309
310#[allow(dead_code)]
314impl OpenRouterModelsResponse {
315 pub fn models(&self) -> &[OpenRouterModel] {
317 &self.data
318 }
319
320 pub fn len(&self) -> usize {
322 self.data.len()
323 }
324
325 pub fn is_empty(&self) -> bool {
327 self.data.is_empty()
328 }
329}
330
331impl ModelCard {
332 pub fn from_openrouter(model: OpenRouterModel) -> Self {
334 let provider = Provider::from_model_id(&model.id);
335 let encoding = Encoding::infer_from_id(&model.id);
336 let abbrev = Self::generate_abbrev(&model.id, provider);
337
338 Self {
339 id: model.id,
340 abbrev,
341 provider,
342 encoding,
343 context_length: model.context_length.unwrap_or(128000),
344 defaults: crate::models::card::default_params(),
345 supported_params: crate::models::card::common_params(),
346 pricing: model.pricing.and_then(|p| {
347 let prompt: f64 = p.prompt?.parse().ok()?;
348 let completion: f64 = p.completion?.parse().ok()?;
349 Some(crate::models::card::Pricing::new(prompt, completion))
350 }),
351 supports_streaming: true,
352 supports_tools: false,
353 supports_vision: false,
354 }
355 }
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361
362 #[test]
363 fn test_registry_creation() {
364 let registry = ModelRegistry::new();
365 assert!(registry.embedded_count() >= 35);
366 }
367
368 #[test]
369 fn test_get_by_id() {
370 let registry = ModelRegistry::new();
371 let card = registry.get("openai/gpt-4o").expect("Should find gpt-4o");
372 assert_eq!(card.abbrev, "og4o");
373 assert_eq!(card.encoding, Encoding::O200kBase);
374 }
375
376 #[test]
377 fn test_get_by_abbrev() {
378 let registry = ModelRegistry::new();
379 let card = registry.get("ml3170i").expect("Should find by abbrev");
380 assert_eq!(card.id, "meta-llama/llama-3.1-70b-instruct");
381 }
382
383 #[test]
384 fn test_abbreviate() {
385 let registry = ModelRegistry::new();
386
387 assert_eq!(registry.abbreviate("openai/gpt-4o"), "og4o");
389
390 let abbrev = registry.abbreviate("openai/gpt-5-super");
392 assert!(abbrev.starts_with("o")); }
394
395 #[test]
396 fn test_expand() {
397 let registry = ModelRegistry::new();
398
399 assert_eq!(registry.expand("og4o"), Some("openai/gpt-4o".to_string()));
400 assert_eq!(
401 registry.expand("ml3170i"),
402 Some("meta-llama/llama-3.1-70b-instruct".to_string())
403 );
404 assert_eq!(registry.expand("unknown"), None);
405 }
406
407 #[test]
408 fn test_get_encoding() {
409 let registry = ModelRegistry::new();
410
411 assert_eq!(registry.get_encoding("openai/gpt-4o"), Encoding::O200kBase);
413
414 assert_eq!(
416 registry.get_encoding("openai/gpt-4o-future"),
417 Encoding::O200kBase
418 );
419 assert_eq!(
420 registry.get_encoding("some-random-model"),
421 Encoding::Heuristic
422 );
423 }
424
425 #[test]
426 fn test_contains() {
427 let registry = ModelRegistry::new();
428
429 assert!(registry.contains("openai/gpt-4o"));
430 assert!(registry.contains("og4o"));
431 assert!(!registry.contains("nonexistent-model"));
432 }
433
434 #[test]
435 fn test_get_by_provider() {
436 let registry = ModelRegistry::new();
437
438 let openai_models = registry.get_by_provider(Provider::OpenAI);
439 assert!(!openai_models.is_empty());
440 assert!(openai_models.iter().all(|m| m.provider == Provider::OpenAI));
441
442 let meta_models = registry.get_by_provider(Provider::Meta);
443 assert!(!meta_models.is_empty());
444 assert!(meta_models.iter().all(|m| m.provider == Provider::Meta));
445 }
446
447 #[test]
448 fn test_search() {
449 let registry = ModelRegistry::new();
450
451 let results = registry.search("gpt-4");
452 assert!(!results.is_empty());
453 assert!(results.iter().all(|m| m.id.contains("gpt-4")));
454
455 let results = registry.search("llama");
456 assert!(!results.is_empty());
457 assert!(results.iter().all(|m| m.id.contains("llama")));
458 }
459
460 #[test]
461 fn test_dynamic_models() {
462 let registry = ModelRegistry::new();
463 let initial_count = registry.len();
464
465 let card = ModelCard::new("test/custom-model");
467 registry.add_dynamic(card).unwrap();
468
469 assert_eq!(registry.len(), initial_count + 1);
470 assert_eq!(registry.dynamic_count(), 1);
471
472 let found = registry.get("test/custom-model");
474 assert!(found.is_some());
475
476 registry.clear_dynamic().unwrap();
478 assert_eq!(registry.dynamic_count(), 0);
479 }
480
481 #[test]
482 fn test_openrouter_response_parsing() {
483 let json = r#"{
485 "data": [
486 {
487 "id": "openai/gpt-4o",
488 "name": "GPT-4o",
489 "context_length": 128000,
490 "pricing": {
491 "prompt": "0.000005",
492 "completion": "0.000015"
493 }
494 },
495 {
496 "id": "anthropic/claude-3-opus",
497 "name": "Claude 3 Opus",
498 "context_length": 200000
499 }
500 ]
501 }"#;
502
503 let response: OpenRouterModelsResponse = serde_json::from_str(json).unwrap();
504 assert_eq!(response.len(), 2);
506 assert!(!response.is_empty());
507
508 let models = response.models();
509 assert_eq!(models[0].id, "openai/gpt-4o");
510 assert_eq!(models[0].context_length, Some(128000));
511 assert!(models[0].pricing.is_some());
512 assert_eq!(models[1].id, "anthropic/claude-3-opus");
513 assert!(models[1].pricing.is_none());
514 }
515
516 #[test]
517 fn test_model_card_from_openrouter() {
518 let model = OpenRouterModel {
519 id: "openai/gpt-4o-test".to_string(),
520 name: Some("GPT-4o Test".to_string()),
521 context_length: Some(128000),
522 pricing: Some(OpenRouterPricing {
523 prompt: Some("0.000005".to_string()),
524 completion: Some("0.000015".to_string()),
525 }),
526 };
527
528 let card = ModelCard::from_openrouter(model);
529 assert_eq!(card.id, "openai/gpt-4o-test");
530 assert_eq!(card.provider, Provider::OpenAI);
531 assert_eq!(card.context_length, 128000);
532 assert!(card.pricing.is_some());
533 }
534}