Skip to main content

oxi_ai/
model_registry.rs

1//! Model registry for oxi-ai
2//!
3//! Provides a centralized registry of available LLM models.
4//! Supports both static built-in models and dynamic runtime registration
5//! for custom OpenAI-compatible providers.
6
7use crate::{Api, CompatSettings, Cost, InputModality, MaxTokensField, Model, ThinkingFormat};
8use once_cell::sync::Lazy;
9use parking_lot::RwLock;
10use std::collections::HashMap;
11
12/// Extract the model name after the last '/', or return the whole id if no '/' is present.
13fn extract_model_name(id: &str) -> &str {
14    id.rsplit_once('/').map(|(_, name)| name).unwrap_or(id)
15}
16
17/// Return provider-specific compatibility defaults.
18///
19/// Internal helper used by `add_*_models()` functions so that every model
20/// from the same provider gets the same `compat` baseline.
21fn default_compat_for_provider(provider: &str) -> Option<CompatSettings> {
22    match provider {
23        "openai" | "openai-responses" | "openai-completions" => Some(CompatSettings {
24            thinking_format: Some(ThinkingFormat::OpenAI),
25            max_tokens_field: Some(MaxTokensField::MaxCompletionTokens),
26            ..CompatSettings::default()
27        }),
28        "openrouter" => Some(CompatSettings {
29            thinking_format: Some(ThinkingFormat::OpenRouter),
30            requires_tool_result_name: true,
31            ..CompatSettings::default()
32        }),
33        "deepseek" => Some(CompatSettings {
34            thinking_format: Some(ThinkingFormat::DeepSeek),
35            max_tokens_field: Some(MaxTokensField::MaxTokens),
36            ..CompatSettings::default()
37        }),
38        "zai" => Some(CompatSettings {
39            thinking_format: Some(ThinkingFormat::Zai),
40            ..CompatSettings::default()
41        }),
42        // azure-openai already has explicit CompatSettings in add_azure_models()
43        // All other providers: use defaults (return None)
44        _ => None,
45    }
46}
47
48/// Global model registry (static built-in models)
49static STATIC_MODELS: Lazy<HashMap<String, Model>> = Lazy::new(|| {
50    let mut map = HashMap::new();
51
52    // OpenAI models
53    add_openai_models(&mut map);
54
55    // Anthropic models
56    add_anthropic_models(&mut map);
57
58    // Google models
59    add_google_models(&mut map);
60
61    // DeepSeek models
62    add_deepseek_models(&mut map);
63
64    // Mistral models
65    add_mistral_models(&mut map);
66
67    // Groq models
68    add_groq_models(&mut map);
69
70    // Cerebras models
71    add_cerebras_models(&mut map);
72
73    // xAI models
74    add_xai_models(&mut map);
75
76    // OpenRouter models
77    add_openrouter_models(&mut map);
78
79    // Azure OpenAI models
80    add_azure_models(&mut map);
81
82    // ZAI models
83    add_zai_models(&mut map);
84
85    // MiniMax models
86    add_minimax_models(&mut map);
87
88    map
89});
90
91fn add_openai_models(map: &mut HashMap<String, Model>) {
92    let models = [
93        ("openai/gpt-4o", "GPT-4o", true, 2.5, 10.0),
94        ("openai/gpt-4o-mini", "GPT-4o Mini", true, 0.15, 0.60),
95        ("openai/gpt-4-turbo", "GPT-4 Turbo", true, 10.0, 30.0),
96        ("openai/gpt-4", "GPT-4", false, 30.0, 60.0),
97        ("openai/gpt-3.5-turbo", "GPT-3.5 Turbo", false, 0.5, 1.5),
98        ("openai/o1-preview", "OpenAI o1 Preview", true, 15.0, 60.0),
99        ("openai/o1-mini", "OpenAI o1 Mini", true, 15.0, 60.0),
100        ("openai/o1", "OpenAI o1", true, 15.0, 60.0),
101        ("openai/o3", "OpenAI o3", true, 15.0, 60.0),
102        ("openai/o3-mini", "OpenAI o3 Mini", true, 15.0, 60.0),
103    ];
104
105    for (id, name, reasoning, input_cost, output_cost) in models {
106        map.insert(
107            id.to_string(),
108            Model {
109                id: extract_model_name(id).to_string(),
110                name: name.to_string(),
111                api: Api::OpenAiCompletions,
112                provider: "openai".to_string(),
113                base_url: "https://api.openai.com/v1".to_string(),
114                reasoning,
115                input: if reasoning {
116                    vec![InputModality::Text]
117                } else {
118                    vec![InputModality::Text, InputModality::Image]
119                },
120                cost: Cost {
121                    input: input_cost,
122                    output: output_cost,
123                    cache_read: input_cost * 0.5,
124                    cache_write: input_cost * 7.5,
125                },
126                context_window: 128_000,
127                max_tokens: 32_000,
128                headers: Default::default(),
129                compat: default_compat_for_provider("openai"),
130            },
131        );
132    }
133}
134
135fn add_anthropic_models(map: &mut HashMap<String, Model>) {
136    let models = [
137        (
138            "anthropic/claude-sonnet-4-20250514",
139            "Claude Sonnet 4",
140            true,
141            3.0,
142            15.0,
143        ),
144        (
145            "anthropic/claude-opus-4-20250514",
146            "Claude Opus 4",
147            true,
148            15.0,
149            75.0,
150        ),
151        (
152            "anthropic/claude-3-5-sonnet-20241022",
153            "Claude 3.5 Sonnet",
154            true,
155            3.0,
156            15.0,
157        ),
158        (
159            "anthropic/claude-3-5-haiku-20241022",
160            "Claude 3.5 Haiku",
161            false,
162            0.8,
163            4.0,
164        ),
165        (
166            "anthropic/claude-3-opus",
167            "Claude 3 Opus",
168            false,
169            15.0,
170            75.0,
171        ),
172        (
173            "anthropic/claude-3-sonnet",
174            "Claude 3 Sonnet",
175            false,
176            3.0,
177            15.0,
178        ),
179        (
180            "anthropic/claude-3-haiku",
181            "Claude 3 Haiku",
182            false,
183            0.25,
184            1.25,
185        ),
186    ];
187
188    for (id, name, reasoning, input_cost, output_cost) in models {
189        map.insert(
190            id.to_string(),
191            Model {
192                id: extract_model_name(id).to_string(),
193                name: name.to_string(),
194                api: Api::AnthropicMessages,
195                provider: "anthropic".to_string(),
196                base_url: "https://api.anthropic.com".to_string(),
197                reasoning,
198                input: vec![InputModality::Text, InputModality::Image],
199                cost: Cost {
200                    input: input_cost,
201                    output: output_cost,
202                    cache_read: input_cost * 0.1,
203                    cache_write: input_cost * 1.25,
204                },
205                context_window: 200_000,
206                max_tokens: 8192,
207                headers: Default::default(),
208                compat: default_compat_for_provider("anthropic"),
209            },
210        );
211    }
212}
213
214fn add_google_models(map: &mut HashMap<String, Model>) {
215    let models = [
216        (
217            "google/gemini-2.0-flash",
218            "Gemini 2.0 Flash",
219            0.0,
220            0.0,
221            1_000_000,
222        ),
223        (
224            "google/gemini-2.5-flash",
225            "Gemini 2.5 Flash",
226            0.0,
227            0.0,
228            1_000_000,
229        ),
230        (
231            "google/gemini-2.5-pro",
232            "Gemini 2.5 Pro",
233            1.25,
234            5.0,
235            2_000_000,
236        ),
237        (
238            "google/gemini-1.5-flash",
239            "Gemini 1.5 Flash",
240            0.0,
241            0.0,
242            1_000_000,
243        ),
244        (
245            "google/gemini-1.5-pro",
246            "Gemini 1.5 Pro",
247            1.25,
248            5.0,
249            2_000_000,
250        ),
251        ("google/gemini-pro", "Gemini Pro", 0.125, 0.5, 32_000),
252    ];
253
254    for (id, name, input_cost, output_cost, ctx) in models {
255        map.insert(
256            id.to_string(),
257            Model {
258                id: extract_model_name(id).to_string(),
259                name: name.to_string(),
260                api: Api::GoogleGenerativeAi,
261                provider: "google".to_string(),
262                base_url: "https://generativelanguage.googleapis.com".to_string(),
263                reasoning: false,
264                input: vec![InputModality::Text, InputModality::Image],
265                cost: Cost {
266                    input: input_cost,
267                    output: output_cost,
268                    cache_read: 0.0,
269                    cache_write: 0.0,
270                },
271                context_window: ctx,
272                max_tokens: 8192,
273                headers: Default::default(),
274                compat: default_compat_for_provider("google"),
275            },
276        );
277    }
278}
279
280fn add_deepseek_models(map: &mut HashMap<String, Model>) {
281    let models = [
282        ("deepseek/deepseek-chat", "DeepSeek Chat", false, 0.27, 1.1),
283        (
284            "deepseek/deepseek-chat-v3",
285            "DeepSeek Chat V3",
286            false,
287            0.27,
288            1.1,
289        ),
290        (
291            "deepseek/deepseek-reasoner",
292            "DeepSeek Reasoner",
293            true,
294            0.55,
295            2.19,
296        ),
297        (
298            "deepseek/deepseek-coder",
299            "DeepSeek Coder",
300            false,
301            0.27,
302            1.1,
303        ),
304    ];
305
306    for (id, name, reasoning, input_cost, output_cost) in models {
307        map.insert(
308            id.to_string(),
309            Model {
310                id: extract_model_name(id).to_string(),
311                name: name.to_string(),
312                api: Api::OpenAiCompletions,
313                provider: "deepseek".to_string(),
314                base_url: "https://api.deepseek.com".to_string(),
315                reasoning,
316                input: vec![InputModality::Text],
317                cost: Cost {
318                    input: input_cost,
319                    output: output_cost,
320                    cache_read: 0.1,
321                    cache_write: 1.0,
322                },
323                context_window: 64_000,
324                max_tokens: 8192,
325                headers: Default::default(),
326                compat: default_compat_for_provider("deepseek"),
327            },
328        );
329    }
330}
331
332fn add_mistral_models(map: &mut HashMap<String, Model>) {
333    let models = [
334        (
335            "mistral/mistral-large-latest",
336            "Mistral Large",
337            false,
338            2.0,
339            6.0,
340        ),
341        (
342            "mistral/mistral-medium-latest",
343            "Mistral Medium",
344            false,
345            0.5,
346            1.5,
347        ),
348        (
349            "mistral/mistral-small-latest",
350            "Mistral Small",
351            false,
352            0.2,
353            0.6,
354        ),
355        ("mistral/mistral-nemo", "Mistral Nemo", false, 0.15, 0.15),
356        ("mistral/codestral", "Codestral", false, 0.3, 0.9),
357        (
358            "mistral/codestral-mamba",
359            "Codestral Mamba",
360            false,
361            0.25,
362            0.25,
363        ),
364        (
365            "mistral/open-mixtral-8x22b",
366            "Mixtral 8x22B",
367            false,
368            0.45,
369            1.4,
370        ),
371        (
372            "mistral/open-mixtral-8x7b",
373            "Mixtral 8x7B",
374            false,
375            0.24,
376            0.24,
377        ),
378    ];
379
380    for (id, name, reasoning, input_cost, output_cost) in models {
381        map.insert(
382            id.to_string(),
383            Model {
384                id: extract_model_name(id).to_string(),
385                name: name.to_string(),
386                api: Api::OpenAiCompletions,
387                provider: "mistral".to_string(),
388                base_url: "https://api.mistral.ai".to_string(),
389                reasoning,
390                input: vec![InputModality::Text],
391                cost: Cost {
392                    input: input_cost,
393                    output: output_cost,
394                    cache_read: 0.0,
395                    cache_write: 0.0,
396                },
397                context_window: 128_000,
398                max_tokens: 32_000,
399                headers: Default::default(),
400                compat: default_compat_for_provider("mistral"),
401            },
402        );
403    }
404}
405
406fn add_groq_models(map: &mut HashMap<String, Model>) {
407    let models = [
408        (
409            "groq/llama-3.3-70b-versatile",
410            "Llama 3.3 70B Versatile",
411            false,
412            0.0,
413            0.0,
414        ),
415        (
416            "groq/llama-3.1-70b-versatile",
417            "Llama 3.1 70B Versatile",
418            false,
419            0.0,
420            0.0,
421        ),
422        (
423            "groq/llama-3.1-8b-instant",
424            "Llama 3.1 8B Instant",
425            false,
426            0.0,
427            0.0,
428        ),
429        (
430            "groq/llama-3-70b-versatile",
431            "Llama 3 70B Versatile",
432            false,
433            0.0,
434            0.0,
435        ),
436        (
437            "groq/llama-3-8b-versatile",
438            "Llama 3 8B Versatile",
439            false,
440            0.0,
441            0.0,
442        ),
443        ("groq/mixtral-8x7b-32768", "Mixtral 8x7B", false, 0.0, 0.0),
444        ("groq/gemma2-9b-it", "Gemma 2 9B", false, 0.0, 0.0),
445        ("groq/gemma-7b-it", "Gemma 7B", false, 0.0, 0.0),
446    ];
447
448    for (id, name, reasoning, input_cost, output_cost) in models {
449        map.insert(
450            id.to_string(),
451            Model {
452                id: extract_model_name(id).to_string(),
453                name: name.to_string(),
454                api: Api::OpenAiCompletions,
455                provider: "groq".to_string(),
456                base_url: "https://api.groq.com/openai/v1".to_string(),
457                reasoning,
458                input: vec![InputModality::Text],
459                cost: Cost {
460                    input: input_cost,
461                    output: output_cost,
462                    cache_read: 0.0,
463                    cache_write: 0.0,
464                },
465                context_window: 128_000,
466                max_tokens: 8192,
467                headers: Default::default(),
468                compat: default_compat_for_provider("groq"),
469            },
470        );
471    }
472}
473
474fn add_cerebras_models(map: &mut HashMap<String, Model>) {
475    let models = [
476        ("cerebras/llama-3.3-70b", "Llama 3.3 70B", false, 0.0, 0.0),
477        ("cerebras/llama-3.1-8b", "Llama 3.1 8B", false, 0.0, 0.0),
478        ("cerebras/qwen-2.5-32b", "Qwen 2.5 32B", false, 0.0, 0.0),
479        ("cerebras/qwen-2.5-7b", "Qwen 2.5 7B", false, 0.0, 0.0),
480    ];
481
482    for (id, name, reasoning, input_cost, output_cost) in models {
483        map.insert(
484            id.to_string(),
485            Model {
486                id: extract_model_name(id).to_string(),
487                name: name.to_string(),
488                api: Api::OpenAiCompletions,
489                provider: "cerebras".to_string(),
490                base_url: "https://api.cerebras.ai".to_string(),
491                reasoning,
492                input: vec![InputModality::Text],
493                cost: Cost {
494                    input: input_cost,
495                    output: output_cost,
496                    cache_read: 0.0,
497                    cache_write: 0.0,
498                },
499                context_window: 128_000,
500                max_tokens: 8192,
501                headers: Default::default(),
502                compat: default_compat_for_provider("cerebras"),
503            },
504        );
505    }
506}
507
508fn add_xai_models(map: &mut HashMap<String, Model>) {
509    let models = [
510        ("xai/grok-2", "Grok 2", false, 5.0, 15.0),
511        ("xai/grok-2-mini", "Grok 2 Mini", false, 0.3, 0.5),
512        ("xai/grok-1", "Grok 1", false, 5.0, 15.0),
513        ("xai/grok-1.5", "Grok 1.5", false, 5.0, 15.0),
514    ];
515
516    for (id, name, reasoning, input_cost, output_cost) in models {
517        map.insert(
518            id.to_string(),
519            Model {
520                id: extract_model_name(id).to_string(),
521                name: name.to_string(),
522                api: Api::OpenAiCompletions,
523                provider: "xai".to_string(),
524                base_url: "https://api.x.ai/v1".to_string(),
525                reasoning,
526                input: vec![InputModality::Text],
527                cost: Cost {
528                    input: input_cost,
529                    output: output_cost,
530                    cache_read: 0.0,
531                    cache_write: 0.0,
532                },
533                context_window: 131_072,
534                max_tokens: 8192,
535                headers: Default::default(),
536                compat: default_compat_for_provider("xai"),
537            },
538        );
539    }
540}
541
542fn add_openrouter_models(map: &mut HashMap<String, Model>) {
543    let models = [
544        (
545            "openrouter/anthropic/claude-3.5-sonnet",
546            "Claude 3.5 Sonnet",
547            false,
548            3.0,
549            15.0,
550        ),
551        (
552            "openrouter/anthropic/claude-3-opus",
553            "Claude 3 Opus",
554            false,
555            15.0,
556            75.0,
557        ),
558        (
559            "openrouter/google/gemini-pro-1.5",
560            "Gemini Pro 1.5",
561            false,
562            1.25,
563            5.0,
564        ),
565        (
566            "openrouter/meta-llama/llama-3-70b",
567            "Llama 3 70B",
568            false,
569            0.65,
570            2.75,
571        ),
572        (
573            "openrouter/meta-llama/llama-3-8b",
574            "Llama 3 8B",
575            false,
576            0.2,
577            0.2,
578        ),
579        (
580            "openrouter/mistralai/mistral-large",
581            "Mistral Large",
582            false,
583            2.0,
584            6.0,
585        ),
586        (
587            "openrouter/deepseek/deepseek-chat",
588            "DeepSeek Chat",
589            false,
590            0.27,
591            1.1,
592        ),
593        ("openrouter/qwen/qwen-2-72b", "Qwen 2 72B", false, 0.9, 0.9),
594        (
595            "openrouter/nousresearch/hermes-3-llama-3-70b",
596            "Hermes 3 70B",
597            false,
598            0.5,
599            1.5,
600        ),
601    ];
602
603    for (id, name, reasoning, input_cost, output_cost) in models {
604        map.insert(
605            id.to_string(),
606            Model {
607                id: extract_model_name(id).to_string(),
608                name: name.to_string(),
609                api: Api::OpenAiCompletions,
610                provider: "openrouter".to_string(),
611                base_url: "https://openrouter.ai/api/v1".to_string(),
612                reasoning,
613                input: vec![InputModality::Text],
614                cost: Cost {
615                    input: input_cost,
616                    output: output_cost,
617                    cache_read: 0.0,
618                    cache_write: 0.0,
619                },
620                context_window: 128_000,
621                max_tokens: 32_000,
622                headers: [
623                    ("HTTP-Referer".to_string(), "https://oxi-ai".to_string()),
624                    ("X-Title".to_string(), "oxi-ai".to_string()),
625                ]
626                .into_iter()
627                .collect(),
628                compat: default_compat_for_provider("openrouter"),
629            },
630        );
631    }
632}
633
634fn add_azure_models(map: &mut HashMap<String, Model>) {
635    let models = [
636        ("azure-openai/gpt-4o", "GPT-4o", false, 2.5, 10.0),
637        ("azure-openai/gpt-4o-mini", "GPT-4o Mini", false, 0.15, 0.60),
638        ("azure-openai/gpt-4-turbo", "GPT-4 Turbo", false, 10.0, 30.0),
639    ];
640
641    for (id, name, reasoning, input_cost, output_cost) in models {
642        map.insert(
643            id.to_string(),
644            Model {
645                id: extract_model_name(id).to_string(),
646                name: name.to_string(),
647                api: Api::AzureOpenAiResponses,
648                provider: "azure-openai".to_string(),
649                base_url: "https://{your-resource-name}.openai.azure.com".to_string(),
650                reasoning,
651                input: vec![InputModality::Text, InputModality::Image],
652                cost: Cost {
653                    input: input_cost,
654                    output: output_cost,
655                    cache_read: 0.0,
656                    cache_write: 0.0,
657                },
658                context_window: 128_000,
659                max_tokens: 32_000,
660                headers: Default::default(),
661                compat: Some(crate::CompatSettings {
662                    supports_store: false,
663                    supports_developer_role: false,
664                    supports_reasoning_effort: false,
665                    supports_usage_in_streaming: false,
666                    max_tokens_field: Some(crate::MaxTokensField::MaxCompletionTokens),
667                    requires_tool_result_name: true,
668                    requires_assistant_after_tool_result: false,
669                    requires_thinking_as_text: false,
670                    thinking_format: None,
671                }),
672            },
673        );
674    }
675}
676
677fn add_zai_models(map: &mut HashMap<String, Model>) {
678    let models = [
679        ("zai/glm-4.7", "GLM-4.7", true, 0.0, 0.0),
680        ("zai/glm-5-turbo", "GLM-5-Turbo", true, 0.0, 0.0),
681        ("zai/glm-5.1", "GLM-5.1", true, 0.0, 0.0),
682        ("zai/glm-5v-turbo", "GLM-5V-Turbo", true, 0.0, 0.0),
683        ("zai/glm-4.5-air", "GLM-4.5-Air", true, 0.0, 0.0),
684    ];
685
686    for (id, name, reasoning, input_cost, output_cost) in models {
687        map.insert(
688            id.to_string(),
689            Model {
690                id: extract_model_name(id).to_string(),
691                name: name.to_string(),
692                api: Api::OpenAiCompletions,
693                provider: "zai".to_string(),
694                base_url: "https://api.z.ai/api/coding/paas/v4".to_string(),
695                reasoning,
696                input: vec![InputModality::Text],
697                cost: Cost {
698                    input: input_cost,
699                    output: output_cost,
700                    cache_read: 0.0,
701                    cache_write: 0.0,
702                },
703                context_window: 200_000,
704                max_tokens: 131_072,
705                headers: Default::default(),
706                compat: default_compat_for_provider("zai"),
707            },
708        );
709    }
710}
711
712fn add_minimax_models(map: &mut HashMap<String, Model>) {
713    let models = [
714        ("minimax/MiniMax-M2.7", "MiniMax-M2.7", true, 0.0, 0.0),
715        (
716            "minimax/MiniMax-M2.7-highspeed",
717            "MiniMax-M2.7-highspeed",
718            true,
719            0.0,
720            0.0,
721        ),
722    ];
723
724    for (id, name, reasoning, input_cost, output_cost) in models {
725        map.insert(
726            id.to_string(),
727            Model {
728                id: extract_model_name(id).to_string(),
729                name: name.to_string(),
730                api: Api::AnthropicMessages,
731                provider: "minimax".to_string(),
732                base_url: "https://api.minimax.io".to_string(),
733                reasoning,
734                input: vec![InputModality::Text],
735                cost: Cost {
736                    input: input_cost,
737                    output: output_cost,
738                    cache_read: 0.06,
739                    cache_write: 0.375,
740                },
741                context_window: 204_800,
742                max_tokens: 131_072,
743                headers: Default::default(),
744                compat: default_compat_for_provider("minimax"),
745            },
746        );
747    }
748}
749
750/// Lightweight model registry for SDK/engine usage.
751///
752/// Stores model metadata (provider, base_url, API type, costs) without
753/// authentication details. For CLI usage with auth integration, see
754/// `oxi_store::CliModelRegistry`.
755#[derive(Default)]
756pub struct ModelRegistry {
757    static_models: HashMap<String, Model>,
758    dynamic_models: parking_lot::RwLock<HashMap<String, Model>>,
759}
760
761impl ModelRegistry {
762    /// Create a new empty registry.
763    pub fn new() -> Self {
764        Self {
765            static_models: HashMap::new(),
766            dynamic_models: RwLock::new(HashMap::new()),
767        }
768    }
769
770    /// Create a registry pre-populated with all built-in static models.
771    ///
772    /// This loads models from the embedded static database.
773    pub fn from_static() -> Self {
774        Self {
775            static_models: STATIC_MODELS.clone(),
776            dynamic_models: RwLock::new(HashMap::new()),
777        }
778    }
779
780    /// Register a model at runtime.
781    ///
782    /// If a model with the same `provider/model_id` key already exists,
783    /// the new one replaces it.
784    pub fn register(&self, model: Model) {
785        let key = format!("{}/{}", model.provider, model.id);
786        self.dynamic_models.write().insert(key, model);
787    }
788
789    /// Unregister a previously registered dynamic model.
790    pub fn unregister(&self, provider: &str, model_id: &str) {
791        let key = format!("{}/{}", provider, model_id);
792        self.dynamic_models.write().remove(&key);
793    }
794
795    /// Look up a model by provider and model ID.
796    ///
797    /// Dynamic models take priority over static ones.
798    pub fn lookup(&self, provider: &str, model_id: &str) -> Option<Model> {
799        let key = format!("{}/{}", provider, model_id);
800        // Dynamic models take priority
801        if let Some(m) = self.dynamic_models.read().get(&key) {
802            return Some(m.clone());
803        }
804        // Then static models
805        self.static_models.get(&key).cloned()
806    }
807
808    /// Get a model by provider/model ID (static models only).
809    pub fn get(provider: &str, model_id: &str) -> Option<&'static Model> {
810        let key = format!("{}/{}", provider, model_id);
811        STATIC_MODELS.get(&key)
812    }
813
814    /// Get all models from a provider (static only).
815    pub fn get_by_provider(provider: &str) -> Vec<&'static Model> {
816        STATIC_MODELS
817            .values()
818            .filter(|m| m.provider == provider)
819            .collect()
820    }
821
822    /// Get all available models (static only).
823    pub fn all() -> Vec<&'static Model> {
824        STATIC_MODELS.values().collect()
825    }
826
827    /// Get all dynamically registered models.
828    pub fn dynamic_models(&self) -> Vec<Model> {
829        self.dynamic_models.read().values().cloned().collect()
830    }
831
832    /// Get all registered model IDs as `provider/model` strings.
833    pub fn model_ids(&self) -> Vec<String> {
834        let static_ids: Vec<String> = self.static_models.keys().cloned().collect();
835        let dynamic_ids: Vec<String> = self.dynamic_models.read().keys().cloned().collect();
836        static_ids.into_iter().chain(dynamic_ids).collect()
837    }
838
839    /// Search models by pattern (static only).
840    pub fn search(pattern: &str) -> Vec<&'static Model> {
841        let pattern_lower = pattern.to_lowercase();
842        STATIC_MODELS
843            .values()
844            .filter(|m| {
845                m.id.to_lowercase().contains(&pattern_lower)
846                    || m.name.to_lowercase().contains(&pattern_lower)
847            })
848            .collect()
849    }
850}
851
852// ── Global registry instance ────────────────────────────────────────
853
854/// Global model registry instance (for convenience functions).
855static GLOBAL_REGISTRY: Lazy<ModelRegistry> = Lazy::new(ModelRegistry::from_static);
856
857// ── Convenience functions using global registry ─────────────────────
858
859/// Register a model at runtime.
860///
861/// Call this during startup for each custom provider's model.
862/// If a model with the same `provider/model_id` key already exists,
863/// the new one replaces it.
864pub fn register_model(model: Model) {
865    GLOBAL_REGISTRY.register(model);
866}
867
868/// Unregister a previously registered dynamic model.
869pub fn unregister_model(provider: &str, model_id: &str) {
870    GLOBAL_REGISTRY.unregister(provider, model_id);
871}
872
873/// Look up a model by provider and model ID, checking both dynamic and static registries.
874///
875/// Dynamic models take priority over static ones.
876pub fn lookup_model(provider: &str, model_id: &str) -> Option<Model> {
877    GLOBAL_REGISTRY.lookup(provider, model_id)
878}
879
880/// Convenience function to get a model (static registry only – use [`lookup_model`] for dynamic too)
881pub fn get_model(provider: &str, model_id: &str) -> Option<&'static Model> {
882    ModelRegistry::get(provider, model_id)
883}
884
885/// Get all available providers
886pub fn get_providers() -> Vec<&'static str> {
887    let mut providers: Vec<&'static str> = STATIC_MODELS
888        .values()
889        .map(|m| m.provider.as_str())
890        .collect();
891    providers.sort();
892    providers.dedup();
893    providers
894}
895
896/// Get all models from a provider
897pub fn get_models(provider: &str) -> Vec<&'static Model> {
898    ModelRegistry::get_by_provider(provider)
899}
900
901/// Get all dynamically registered models.
902pub fn dynamic_models() -> Vec<Model> {
903    GLOBAL_REGISTRY.dynamic_models()
904}
905
906#[cfg(test)]
907mod tests {
908    use super::*;
909
910    #[test]
911    fn test_get_model() {
912        let model = get_model("openai", "gpt-4o");
913        assert!(model.is_some());
914        let model = model.unwrap();
915        assert_eq!(model.provider, "openai");
916        // Note: gpt-4o has reasoning enabled
917    }
918
919    #[test]
920    fn test_get_providers() {
921        let providers = get_providers();
922        assert!(providers.contains(&"openai"));
923        assert!(providers.contains(&"anthropic"));
924        assert!(providers.contains(&"google"));
925        assert!(providers.contains(&"deepseek"));
926        assert!(providers.contains(&"mistral"));
927        assert!(providers.contains(&"groq"));
928    }
929
930    #[test]
931    fn test_deepseek_model() {
932        let model = get_model("deepseek", "deepseek-chat");
933        assert!(model.is_some());
934        let model = model.unwrap();
935        assert_eq!(model.provider, "deepseek");
936        assert_eq!(model.base_url, "https://api.deepseek.com");
937    }
938
939    #[test]
940    fn test_search_models() {
941        let results = ModelRegistry::search("gpt");
942        assert!(!results.is_empty());
943        assert!(results
944            .iter()
945            .all(|m| m.name.to_lowercase().contains("gpt")));
946    }
947
948    #[test]
949    fn test_model_registry_instance() {
950        let registry = ModelRegistry::from_static();
951        assert!(registry.lookup("openai", "gpt-4o").is_some());
952        assert!(registry.lookup("fake", "fake-model").is_none());
953    }
954
955    #[test]
956    fn test_model_registry_register_dynamic() {
957        let registry = ModelRegistry::new();
958        let custom_model = Model {
959            id: "custom-model".to_string(),
960            name: "Custom Model".to_string(),
961            api: Api::OpenAiCompletions,
962            provider: "custom".to_string(),
963            base_url: "https://custom.example.com".to_string(),
964            reasoning: false,
965            input: vec![InputModality::Text],
966            cost: Cost {
967                input: 1.0,
968                output: 2.0,
969                cache_read: 0.5,
970                cache_write: 5.0,
971            },
972            context_window: 100_000,
973            max_tokens: 8192,
974            headers: Default::default(),
975            compat: None,
976        };
977        registry.register(custom_model.clone());
978        assert!(registry.lookup("custom", "custom-model").is_some());
979    }
980}