m2m/models/
registry.rs

1//! Model registry for model lookups and management.
2//!
3//! The registry provides:
4//! - Fast model lookup by ID or abbreviation
5//! - Encoding inference for token counting
6//! - Abbreviation expansion for decompression
7//! - Optional dynamic model fetching from OpenRouter
8
9use std::collections::HashMap;
10use std::sync::RwLock;
11
12use crate::error::{M2MError, Result};
13use crate::models::card::{Encoding, ModelCard, Provider};
14use crate::models::embedded::get_embedded_models;
15
16/// Model registry with embedded + dynamic models
17///
18/// The registry maintains two sets of models:
19/// 1. Embedded models: Compiled into the binary, always available
20/// 2. Dynamic models: Fetched at runtime (optional), stored in RwLock
21///
22/// # Example
23/// ```
24/// use m2m::models::ModelRegistry;
25///
26/// let registry = ModelRegistry::new();
27///
28/// // Lookup by ID
29/// let card = registry.get("openai/gpt-4o").unwrap();
30/// assert_eq!(card.abbrev, "og4o");
31///
32/// // Lookup by abbreviation
33/// let card = registry.get("ml3170i").unwrap();
34/// assert_eq!(card.id, "meta-llama/llama-3.1-70b-instruct");
35///
36/// // Abbreviate a model name
37/// let abbrev = registry.abbreviate("openai/gpt-4o");
38/// assert_eq!(abbrev, "og4o");
39///
40/// // Expand an abbreviation
41/// let id = registry.expand("og4o").unwrap();
42/// assert_eq!(id, "openai/gpt-4o");
43/// ```
44pub struct ModelRegistry {
45    /// ID -> ModelCard
46    by_id: HashMap<String, ModelCard>,
47
48    /// Abbreviation -> ID
49    abbrev_to_id: HashMap<String, String>,
50
51    /// Dynamic models (fetched at runtime)
52    dynamic: RwLock<HashMap<String, ModelCard>>,
53
54    /// Dynamic abbreviations
55    dynamic_abbrevs: RwLock<HashMap<String, String>>,
56}
57
58impl Default for ModelRegistry {
59    fn default() -> Self {
60        Self::new()
61    }
62}
63
64impl ModelRegistry {
65    /// Create a new registry with embedded models loaded
66    pub fn new() -> Self {
67        let mut registry = Self {
68            by_id: HashMap::new(),
69            abbrev_to_id: HashMap::new(),
70            dynamic: RwLock::new(HashMap::new()),
71            dynamic_abbrevs: RwLock::new(HashMap::new()),
72        };
73
74        registry.load_embedded();
75        registry
76    }
77
78    /// Load embedded models into the registry
79    fn load_embedded(&mut self) {
80        for card in get_embedded_models() {
81            self.abbrev_to_id
82                .insert(card.abbrev.clone(), card.id.clone());
83            self.by_id.insert(card.id.clone(), card);
84        }
85    }
86
87    /// Get a model by ID or abbreviation
88    ///
89    /// Tries lookups in order:
90    /// 1. Direct ID match in embedded models
91    /// 2. Abbreviation match
92    /// 3. Dynamic models (if any)
93    pub fn get(&self, id_or_abbrev: &str) -> Option<ModelCard> {
94        // Try direct ID lookup in embedded
95        if let Some(card) = self.by_id.get(id_or_abbrev) {
96            return Some(card.clone());
97        }
98
99        // Try abbreviation lookup
100        if let Some(full_id) = self.abbrev_to_id.get(id_or_abbrev) {
101            if let Some(card) = self.by_id.get(full_id) {
102                return Some(card.clone());
103            }
104        }
105
106        // Try dynamic models
107        if let Ok(dynamic) = self.dynamic.read() {
108            if let Some(card) = dynamic.get(id_or_abbrev) {
109                return Some(card.clone());
110            }
111        }
112
113        // Try dynamic abbreviations
114        if let Ok(abbrevs) = self.dynamic_abbrevs.read() {
115            if let Some(full_id) = abbrevs.get(id_or_abbrev) {
116                if let Ok(dynamic) = self.dynamic.read() {
117                    if let Some(card) = dynamic.get(full_id) {
118                        return Some(card.clone());
119                    }
120                }
121            }
122        }
123
124        None
125    }
126
127    /// Check if a model exists in the registry
128    pub fn contains(&self, id_or_abbrev: &str) -> bool {
129        self.get(id_or_abbrev).is_some()
130    }
131
132    /// Get the encoding for a model (with fallback inference)
133    ///
134    /// If the model is not in the registry, infers encoding from the model ID.
135    pub fn get_encoding(&self, model: &str) -> Encoding {
136        self.get(model)
137            .map(|c| c.encoding)
138            .unwrap_or_else(|| Encoding::infer_from_id(model))
139    }
140
141    /// Get the context length for a model (with safe default)
142    pub fn get_context_length(&self, model: &str) -> u32 {
143        self.get(model).map(|c| c.context_length).unwrap_or(128000) // Safe default
144    }
145
146    /// Abbreviate a model ID
147    ///
148    /// Returns the abbreviation from the registry if available,
149    /// otherwise generates one using the standard algorithm.
150    pub fn abbreviate(&self, model_id: &str) -> String {
151        // Check embedded models
152        if let Some(card) = self.by_id.get(model_id) {
153            return card.abbrev.clone();
154        }
155
156        // Check dynamic models
157        if let Ok(dynamic) = self.dynamic.read() {
158            if let Some(card) = dynamic.get(model_id) {
159                return card.abbrev.clone();
160            }
161        }
162
163        // Generate abbreviation
164        let provider = Provider::from_model_id(model_id);
165        ModelCard::generate_abbrev(model_id, provider)
166    }
167
168    /// Expand an abbreviation to full model ID
169    ///
170    /// Returns None if the abbreviation is not recognized.
171    pub fn expand(&self, abbrev: &str) -> Option<String> {
172        // Check embedded abbreviations
173        if let Some(id) = self.abbrev_to_id.get(abbrev) {
174            return Some(id.clone());
175        }
176
177        // Check dynamic abbreviations
178        if let Ok(abbrevs) = self.dynamic_abbrevs.read() {
179            if let Some(id) = abbrevs.get(abbrev) {
180                return Some(id.clone());
181            }
182        }
183
184        None
185    }
186
187    /// List all known model IDs (embedded only, not dynamic)
188    pub fn list_ids(&self) -> Vec<&str> {
189        self.by_id.keys().map(|s| s.as_str()).collect()
190    }
191
192    /// List all known abbreviations
193    pub fn list_abbrevs(&self) -> Vec<&str> {
194        self.abbrev_to_id.keys().map(|s| s.as_str()).collect()
195    }
196
197    /// Get total count of models (embedded + dynamic)
198    pub fn len(&self) -> usize {
199        let dynamic_count = self.dynamic.read().map(|d| d.len()).unwrap_or(0);
200        self.by_id.len() + dynamic_count
201    }
202
203    /// Check if registry is empty
204    pub fn is_empty(&self) -> bool {
205        self.len() == 0
206    }
207
208    /// Get count of embedded models
209    pub fn embedded_count(&self) -> usize {
210        self.by_id.len()
211    }
212
213    /// Get count of dynamic models
214    pub fn dynamic_count(&self) -> usize {
215        self.dynamic.read().map(|d| d.len()).unwrap_or(0)
216    }
217
218    /// Add a model to the dynamic registry
219    pub fn add_dynamic(&self, card: ModelCard) -> Result<()> {
220        let mut dynamic = self
221            .dynamic
222            .write()
223            .map_err(|_| M2MError::Compression("Lock poisoned".into()))?;
224
225        let mut abbrevs = self
226            .dynamic_abbrevs
227            .write()
228            .map_err(|_| M2MError::Compression("Lock poisoned".into()))?;
229
230        abbrevs.insert(card.abbrev.clone(), card.id.clone());
231        dynamic.insert(card.id.clone(), card);
232
233        Ok(())
234    }
235
236    /// Clear dynamic models
237    pub fn clear_dynamic(&self) -> Result<()> {
238        let mut dynamic = self
239            .dynamic
240            .write()
241            .map_err(|_| M2MError::Compression("Lock poisoned".into()))?;
242
243        let mut abbrevs = self
244            .dynamic_abbrevs
245            .write()
246            .map_err(|_| M2MError::Compression("Lock poisoned".into()))?;
247
248        dynamic.clear();
249        abbrevs.clear();
250
251        Ok(())
252    }
253
254    /// Get models filtered by provider
255    pub fn get_by_provider(&self, provider: Provider) -> Vec<ModelCard> {
256        self.by_id
257            .values()
258            .filter(|card| card.provider == provider)
259            .cloned()
260            .collect()
261    }
262
263    /// Search models by ID substring
264    pub fn search(&self, query: &str) -> Vec<ModelCard> {
265        let query_lower = query.to_lowercase();
266
267        self.by_id
268            .values()
269            .filter(|card| {
270                card.id.to_lowercase().contains(&query_lower)
271                    || card.abbrev.to_lowercase().contains(&query_lower)
272            })
273            .cloned()
274            .collect()
275    }
276
277    /// Iterate over all embedded models
278    pub fn iter(&self) -> impl Iterator<Item = &ModelCard> {
279        self.by_id.values()
280    }
281}
282
283/// OpenRouter API model response (for future dynamic fetching)
284#[derive(Debug, serde::Deserialize)]
285pub struct OpenRouterModel {
286    pub id: String,
287    pub name: Option<String>,
288    pub context_length: Option<u32>,
289    pub pricing: Option<OpenRouterPricing>,
290}
291
292#[derive(Debug, serde::Deserialize)]
293pub struct OpenRouterPricing {
294    pub prompt: Option<String>,
295    pub completion: Option<String>,
296}
297
298/// Response from OpenRouter /models API
299///
300/// Used for dynamic model registry updates. This struct is prepared for
301/// future `fetch_openrouter_models` implementation (requires `reqwest` feature).
302/// The struct and methods are intentionally public for API consumers who
303/// want to implement their own fetching logic.
304#[derive(Debug, serde::Deserialize)]
305pub struct OpenRouterModelsResponse {
306    /// List of available models
307    pub data: Vec<OpenRouterModel>,
308}
309
310// Note: These methods are intentionally public for API consumers implementing
311// their own OpenRouter model fetching. Clippy flags them as dead code because
312// the built-in fetch function isn't implemented yet.
313#[allow(dead_code)]
314impl OpenRouterModelsResponse {
315    /// Get the list of models
316    pub fn models(&self) -> &[OpenRouterModel] {
317        &self.data
318    }
319
320    /// Get the number of models
321    pub fn len(&self) -> usize {
322        self.data.len()
323    }
324
325    /// Check if empty
326    pub fn is_empty(&self) -> bool {
327        self.data.is_empty()
328    }
329}
330
331impl ModelCard {
332    /// Create ModelCard from OpenRouter API model
333    pub fn from_openrouter(model: OpenRouterModel) -> Self {
334        let provider = Provider::from_model_id(&model.id);
335        let encoding = Encoding::infer_from_id(&model.id);
336        let abbrev = Self::generate_abbrev(&model.id, provider);
337
338        Self {
339            id: model.id,
340            abbrev,
341            provider,
342            encoding,
343            context_length: model.context_length.unwrap_or(128000),
344            defaults: crate::models::card::default_params(),
345            supported_params: crate::models::card::common_params(),
346            pricing: model.pricing.and_then(|p| {
347                let prompt: f64 = p.prompt?.parse().ok()?;
348                let completion: f64 = p.completion?.parse().ok()?;
349                Some(crate::models::card::Pricing::new(prompt, completion))
350            }),
351            supports_streaming: true,
352            supports_tools: false,
353            supports_vision: false,
354        }
355    }
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361
362    #[test]
363    fn test_registry_creation() {
364        let registry = ModelRegistry::new();
365        assert!(registry.embedded_count() >= 35);
366    }
367
368    #[test]
369    fn test_get_by_id() {
370        let registry = ModelRegistry::new();
371        let card = registry.get("openai/gpt-4o").expect("Should find gpt-4o");
372        assert_eq!(card.abbrev, "og4o");
373        assert_eq!(card.encoding, Encoding::O200kBase);
374    }
375
376    #[test]
377    fn test_get_by_abbrev() {
378        let registry = ModelRegistry::new();
379        let card = registry.get("ml3170i").expect("Should find by abbrev");
380        assert_eq!(card.id, "meta-llama/llama-3.1-70b-instruct");
381    }
382
383    #[test]
384    fn test_abbreviate() {
385        let registry = ModelRegistry::new();
386
387        // Known model
388        assert_eq!(registry.abbreviate("openai/gpt-4o"), "og4o");
389
390        // Unknown model (generates abbreviation)
391        let abbrev = registry.abbreviate("openai/gpt-5-super");
392        assert!(abbrev.starts_with("o")); // OpenAI prefix
393    }
394
395    #[test]
396    fn test_expand() {
397        let registry = ModelRegistry::new();
398
399        assert_eq!(registry.expand("og4o"), Some("openai/gpt-4o".to_string()));
400        assert_eq!(
401            registry.expand("ml3170i"),
402            Some("meta-llama/llama-3.1-70b-instruct".to_string())
403        );
404        assert_eq!(registry.expand("unknown"), None);
405    }
406
407    #[test]
408    fn test_get_encoding() {
409        let registry = ModelRegistry::new();
410
411        // Known model
412        assert_eq!(registry.get_encoding("openai/gpt-4o"), Encoding::O200kBase);
413
414        // Unknown model (infers encoding)
415        assert_eq!(
416            registry.get_encoding("openai/gpt-4o-future"),
417            Encoding::O200kBase
418        );
419        assert_eq!(
420            registry.get_encoding("some-random-model"),
421            Encoding::Heuristic
422        );
423    }
424
425    #[test]
426    fn test_contains() {
427        let registry = ModelRegistry::new();
428
429        assert!(registry.contains("openai/gpt-4o"));
430        assert!(registry.contains("og4o"));
431        assert!(!registry.contains("nonexistent-model"));
432    }
433
434    #[test]
435    fn test_get_by_provider() {
436        let registry = ModelRegistry::new();
437
438        let openai_models = registry.get_by_provider(Provider::OpenAI);
439        assert!(!openai_models.is_empty());
440        assert!(openai_models.iter().all(|m| m.provider == Provider::OpenAI));
441
442        let meta_models = registry.get_by_provider(Provider::Meta);
443        assert!(!meta_models.is_empty());
444        assert!(meta_models.iter().all(|m| m.provider == Provider::Meta));
445    }
446
447    #[test]
448    fn test_search() {
449        let registry = ModelRegistry::new();
450
451        let results = registry.search("gpt-4");
452        assert!(!results.is_empty());
453        assert!(results.iter().all(|m| m.id.contains("gpt-4")));
454
455        let results = registry.search("llama");
456        assert!(!results.is_empty());
457        assert!(results.iter().all(|m| m.id.contains("llama")));
458    }
459
460    #[test]
461    fn test_dynamic_models() {
462        let registry = ModelRegistry::new();
463        let initial_count = registry.len();
464
465        // Add a dynamic model
466        let card = ModelCard::new("test/custom-model");
467        registry.add_dynamic(card).unwrap();
468
469        assert_eq!(registry.len(), initial_count + 1);
470        assert_eq!(registry.dynamic_count(), 1);
471
472        // Should be findable
473        let found = registry.get("test/custom-model");
474        assert!(found.is_some());
475
476        // Clear dynamic
477        registry.clear_dynamic().unwrap();
478        assert_eq!(registry.dynamic_count(), 0);
479    }
480
481    #[test]
482    fn test_openrouter_response_parsing() {
483        // Test that OpenRouterModelsResponse can deserialize API responses
484        let json = r#"{
485            "data": [
486                {
487                    "id": "openai/gpt-4o",
488                    "name": "GPT-4o",
489                    "context_length": 128000,
490                    "pricing": {
491                        "prompt": "0.000005",
492                        "completion": "0.000015"
493                    }
494                },
495                {
496                    "id": "anthropic/claude-3-opus",
497                    "name": "Claude 3 Opus",
498                    "context_length": 200000
499                }
500            ]
501        }"#;
502
503        let response: OpenRouterModelsResponse = serde_json::from_str(json).unwrap();
504        // Test the accessor methods
505        assert_eq!(response.len(), 2);
506        assert!(!response.is_empty());
507
508        let models = response.models();
509        assert_eq!(models[0].id, "openai/gpt-4o");
510        assert_eq!(models[0].context_length, Some(128000));
511        assert!(models[0].pricing.is_some());
512        assert_eq!(models[1].id, "anthropic/claude-3-opus");
513        assert!(models[1].pricing.is_none());
514    }
515
516    #[test]
517    fn test_model_card_from_openrouter() {
518        let model = OpenRouterModel {
519            id: "openai/gpt-4o-test".to_string(),
520            name: Some("GPT-4o Test".to_string()),
521            context_length: Some(128000),
522            pricing: Some(OpenRouterPricing {
523                prompt: Some("0.000005".to_string()),
524                completion: Some("0.000015".to_string()),
525            }),
526        };
527
528        let card = ModelCard::from_openrouter(model);
529        assert_eq!(card.id, "openai/gpt-4o-test");
530        assert_eq!(card.provider, Provider::OpenAI);
531        assert_eq!(card.context_length, 128000);
532        assert!(card.pricing.is_some());
533    }
534}