1use crate::{Api, CompatSettings, Cost, InputModality, MaxTokensField, Model, ThinkingFormat};
8use once_cell::sync::Lazy;
9use parking_lot::RwLock;
10use std::collections::HashMap;
11
12fn extract_model_name(id: &str) -> &str {
14 id.rsplit_once('/').map(|(_, name)| name).unwrap_or(id)
15}
16
17fn default_compat_for_provider(provider: &str) -> Option<CompatSettings> {
22 match provider {
23 "openai" | "openai-responses" | "openai-completions" => Some(CompatSettings {
24 thinking_format: Some(ThinkingFormat::OpenAI),
25 max_tokens_field: Some(MaxTokensField::MaxCompletionTokens),
26 ..CompatSettings::default()
27 }),
28 "openrouter" => Some(CompatSettings {
29 thinking_format: Some(ThinkingFormat::OpenRouter),
30 requires_tool_result_name: true,
31 ..CompatSettings::default()
32 }),
33 "deepseek" => Some(CompatSettings {
34 thinking_format: Some(ThinkingFormat::DeepSeek),
35 max_tokens_field: Some(MaxTokensField::MaxTokens),
36 ..CompatSettings::default()
37 }),
38 "zai" => Some(CompatSettings {
39 thinking_format: Some(ThinkingFormat::Zai),
40 ..CompatSettings::default()
41 }),
42 _ => None,
45 }
46}
47
48static STATIC_MODELS: Lazy<HashMap<String, Model>> = Lazy::new(|| {
50 let mut map = HashMap::new();
51
52 add_openai_models(&mut map);
54
55 add_anthropic_models(&mut map);
57
58 add_google_models(&mut map);
60
61 add_deepseek_models(&mut map);
63
64 add_mistral_models(&mut map);
66
67 add_groq_models(&mut map);
69
70 add_cerebras_models(&mut map);
72
73 add_xai_models(&mut map);
75
76 add_openrouter_models(&mut map);
78
79 add_azure_models(&mut map);
81
82 add_zai_models(&mut map);
84
85 add_minimax_models(&mut map);
87
88 map
89});
90
91fn add_openai_models(map: &mut HashMap<String, Model>) {
92 let models = [
93 ("openai/gpt-4o", "GPT-4o", true, 2.5, 10.0),
94 ("openai/gpt-4o-mini", "GPT-4o Mini", true, 0.15, 0.60),
95 ("openai/gpt-4-turbo", "GPT-4 Turbo", true, 10.0, 30.0),
96 ("openai/gpt-4", "GPT-4", false, 30.0, 60.0),
97 ("openai/gpt-3.5-turbo", "GPT-3.5 Turbo", false, 0.5, 1.5),
98 ("openai/o1-preview", "OpenAI o1 Preview", true, 15.0, 60.0),
99 ("openai/o1-mini", "OpenAI o1 Mini", true, 15.0, 60.0),
100 ("openai/o1", "OpenAI o1", true, 15.0, 60.0),
101 ("openai/o3", "OpenAI o3", true, 15.0, 60.0),
102 ("openai/o3-mini", "OpenAI o3 Mini", true, 15.0, 60.0),
103 ];
104
105 for (id, name, reasoning, input_cost, output_cost) in models {
106 map.insert(
107 id.to_string(),
108 Model {
109 id: extract_model_name(id).to_string(),
110 name: name.to_string(),
111 api: Api::OpenAiCompletions,
112 provider: "openai".to_string(),
113 base_url: "https://api.openai.com/v1".to_string(),
114 reasoning,
115 input: if reasoning {
116 vec![InputModality::Text]
117 } else {
118 vec![InputModality::Text, InputModality::Image]
119 },
120 cost: Cost {
121 input: input_cost,
122 output: output_cost,
123 cache_read: input_cost * 0.5,
124 cache_write: input_cost * 7.5,
125 },
126 context_window: 128_000,
127 max_tokens: 32_000,
128 headers: Default::default(),
129 compat: default_compat_for_provider("openai"),
130 },
131 );
132 }
133}
134
135fn add_anthropic_models(map: &mut HashMap<String, Model>) {
136 let models = [
137 (
138 "anthropic/claude-sonnet-4-20250514",
139 "Claude Sonnet 4",
140 true,
141 3.0,
142 15.0,
143 ),
144 (
145 "anthropic/claude-opus-4-20250514",
146 "Claude Opus 4",
147 true,
148 15.0,
149 75.0,
150 ),
151 (
152 "anthropic/claude-3-5-sonnet-20241022",
153 "Claude 3.5 Sonnet",
154 true,
155 3.0,
156 15.0,
157 ),
158 (
159 "anthropic/claude-3-5-haiku-20241022",
160 "Claude 3.5 Haiku",
161 false,
162 0.8,
163 4.0,
164 ),
165 (
166 "anthropic/claude-3-opus",
167 "Claude 3 Opus",
168 false,
169 15.0,
170 75.0,
171 ),
172 (
173 "anthropic/claude-3-sonnet",
174 "Claude 3 Sonnet",
175 false,
176 3.0,
177 15.0,
178 ),
179 (
180 "anthropic/claude-3-haiku",
181 "Claude 3 Haiku",
182 false,
183 0.25,
184 1.25,
185 ),
186 ];
187
188 for (id, name, reasoning, input_cost, output_cost) in models {
189 map.insert(
190 id.to_string(),
191 Model {
192 id: extract_model_name(id).to_string(),
193 name: name.to_string(),
194 api: Api::AnthropicMessages,
195 provider: "anthropic".to_string(),
196 base_url: "https://api.anthropic.com".to_string(),
197 reasoning,
198 input: vec![InputModality::Text, InputModality::Image],
199 cost: Cost {
200 input: input_cost,
201 output: output_cost,
202 cache_read: input_cost * 0.1,
203 cache_write: input_cost * 1.25,
204 },
205 context_window: 200_000,
206 max_tokens: 8192,
207 headers: Default::default(),
208 compat: default_compat_for_provider("anthropic"),
209 },
210 );
211 }
212}
213
214fn add_google_models(map: &mut HashMap<String, Model>) {
215 let models = [
216 (
217 "google/gemini-2.0-flash",
218 "Gemini 2.0 Flash",
219 0.0,
220 0.0,
221 1_000_000,
222 ),
223 (
224 "google/gemini-2.5-flash",
225 "Gemini 2.5 Flash",
226 0.0,
227 0.0,
228 1_000_000,
229 ),
230 (
231 "google/gemini-2.5-pro",
232 "Gemini 2.5 Pro",
233 1.25,
234 5.0,
235 2_000_000,
236 ),
237 (
238 "google/gemini-1.5-flash",
239 "Gemini 1.5 Flash",
240 0.0,
241 0.0,
242 1_000_000,
243 ),
244 (
245 "google/gemini-1.5-pro",
246 "Gemini 1.5 Pro",
247 1.25,
248 5.0,
249 2_000_000,
250 ),
251 ("google/gemini-pro", "Gemini Pro", 0.125, 0.5, 32_000),
252 ];
253
254 for (id, name, input_cost, output_cost, ctx) in models {
255 map.insert(
256 id.to_string(),
257 Model {
258 id: extract_model_name(id).to_string(),
259 name: name.to_string(),
260 api: Api::GoogleGenerativeAi,
261 provider: "google".to_string(),
262 base_url: "https://generativelanguage.googleapis.com".to_string(),
263 reasoning: false,
264 input: vec![InputModality::Text, InputModality::Image],
265 cost: Cost {
266 input: input_cost,
267 output: output_cost,
268 cache_read: 0.0,
269 cache_write: 0.0,
270 },
271 context_window: ctx,
272 max_tokens: 8192,
273 headers: Default::default(),
274 compat: default_compat_for_provider("google"),
275 },
276 );
277 }
278}
279
280fn add_deepseek_models(map: &mut HashMap<String, Model>) {
281 let models = [
282 ("deepseek/deepseek-chat", "DeepSeek Chat", false, 0.27, 1.1),
283 (
284 "deepseek/deepseek-chat-v3",
285 "DeepSeek Chat V3",
286 false,
287 0.27,
288 1.1,
289 ),
290 (
291 "deepseek/deepseek-reasoner",
292 "DeepSeek Reasoner",
293 true,
294 0.55,
295 2.19,
296 ),
297 (
298 "deepseek/deepseek-coder",
299 "DeepSeek Coder",
300 false,
301 0.27,
302 1.1,
303 ),
304 ];
305
306 for (id, name, reasoning, input_cost, output_cost) in models {
307 map.insert(
308 id.to_string(),
309 Model {
310 id: extract_model_name(id).to_string(),
311 name: name.to_string(),
312 api: Api::OpenAiCompletions,
313 provider: "deepseek".to_string(),
314 base_url: "https://api.deepseek.com".to_string(),
315 reasoning,
316 input: vec![InputModality::Text],
317 cost: Cost {
318 input: input_cost,
319 output: output_cost,
320 cache_read: 0.1,
321 cache_write: 1.0,
322 },
323 context_window: 64_000,
324 max_tokens: 8192,
325 headers: Default::default(),
326 compat: default_compat_for_provider("deepseek"),
327 },
328 );
329 }
330}
331
332fn add_mistral_models(map: &mut HashMap<String, Model>) {
333 let models = [
334 (
335 "mistral/mistral-large-latest",
336 "Mistral Large",
337 false,
338 2.0,
339 6.0,
340 ),
341 (
342 "mistral/mistral-medium-latest",
343 "Mistral Medium",
344 false,
345 0.5,
346 1.5,
347 ),
348 (
349 "mistral/mistral-small-latest",
350 "Mistral Small",
351 false,
352 0.2,
353 0.6,
354 ),
355 ("mistral/mistral-nemo", "Mistral Nemo", false, 0.15, 0.15),
356 ("mistral/codestral", "Codestral", false, 0.3, 0.9),
357 (
358 "mistral/codestral-mamba",
359 "Codestral Mamba",
360 false,
361 0.25,
362 0.25,
363 ),
364 (
365 "mistral/open-mixtral-8x22b",
366 "Mixtral 8x22B",
367 false,
368 0.45,
369 1.4,
370 ),
371 (
372 "mistral/open-mixtral-8x7b",
373 "Mixtral 8x7B",
374 false,
375 0.24,
376 0.24,
377 ),
378 ];
379
380 for (id, name, reasoning, input_cost, output_cost) in models {
381 map.insert(
382 id.to_string(),
383 Model {
384 id: extract_model_name(id).to_string(),
385 name: name.to_string(),
386 api: Api::OpenAiCompletions,
387 provider: "mistral".to_string(),
388 base_url: "https://api.mistral.ai".to_string(),
389 reasoning,
390 input: vec![InputModality::Text],
391 cost: Cost {
392 input: input_cost,
393 output: output_cost,
394 cache_read: 0.0,
395 cache_write: 0.0,
396 },
397 context_window: 128_000,
398 max_tokens: 32_000,
399 headers: Default::default(),
400 compat: default_compat_for_provider("mistral"),
401 },
402 );
403 }
404}
405
406fn add_groq_models(map: &mut HashMap<String, Model>) {
407 let models = [
408 (
409 "groq/llama-3.3-70b-versatile",
410 "Llama 3.3 70B Versatile",
411 false,
412 0.0,
413 0.0,
414 ),
415 (
416 "groq/llama-3.1-70b-versatile",
417 "Llama 3.1 70B Versatile",
418 false,
419 0.0,
420 0.0,
421 ),
422 (
423 "groq/llama-3.1-8b-instant",
424 "Llama 3.1 8B Instant",
425 false,
426 0.0,
427 0.0,
428 ),
429 (
430 "groq/llama-3-70b-versatile",
431 "Llama 3 70B Versatile",
432 false,
433 0.0,
434 0.0,
435 ),
436 (
437 "groq/llama-3-8b-versatile",
438 "Llama 3 8B Versatile",
439 false,
440 0.0,
441 0.0,
442 ),
443 ("groq/mixtral-8x7b-32768", "Mixtral 8x7B", false, 0.0, 0.0),
444 ("groq/gemma2-9b-it", "Gemma 2 9B", false, 0.0, 0.0),
445 ("groq/gemma-7b-it", "Gemma 7B", false, 0.0, 0.0),
446 ];
447
448 for (id, name, reasoning, input_cost, output_cost) in models {
449 map.insert(
450 id.to_string(),
451 Model {
452 id: extract_model_name(id).to_string(),
453 name: name.to_string(),
454 api: Api::OpenAiCompletions,
455 provider: "groq".to_string(),
456 base_url: "https://api.groq.com/openai/v1".to_string(),
457 reasoning,
458 input: vec![InputModality::Text],
459 cost: Cost {
460 input: input_cost,
461 output: output_cost,
462 cache_read: 0.0,
463 cache_write: 0.0,
464 },
465 context_window: 128_000,
466 max_tokens: 8192,
467 headers: Default::default(),
468 compat: default_compat_for_provider("groq"),
469 },
470 );
471 }
472}
473
474fn add_cerebras_models(map: &mut HashMap<String, Model>) {
475 let models = [
476 ("cerebras/llama-3.3-70b", "Llama 3.3 70B", false, 0.0, 0.0),
477 ("cerebras/llama-3.1-8b", "Llama 3.1 8B", false, 0.0, 0.0),
478 ("cerebras/qwen-2.5-32b", "Qwen 2.5 32B", false, 0.0, 0.0),
479 ("cerebras/qwen-2.5-7b", "Qwen 2.5 7B", false, 0.0, 0.0),
480 ];
481
482 for (id, name, reasoning, input_cost, output_cost) in models {
483 map.insert(
484 id.to_string(),
485 Model {
486 id: extract_model_name(id).to_string(),
487 name: name.to_string(),
488 api: Api::OpenAiCompletions,
489 provider: "cerebras".to_string(),
490 base_url: "https://api.cerebras.ai".to_string(),
491 reasoning,
492 input: vec![InputModality::Text],
493 cost: Cost {
494 input: input_cost,
495 output: output_cost,
496 cache_read: 0.0,
497 cache_write: 0.0,
498 },
499 context_window: 128_000,
500 max_tokens: 8192,
501 headers: Default::default(),
502 compat: default_compat_for_provider("cerebras"),
503 },
504 );
505 }
506}
507
508fn add_xai_models(map: &mut HashMap<String, Model>) {
509 let models = [
510 ("xai/grok-2", "Grok 2", false, 5.0, 15.0),
511 ("xai/grok-2-mini", "Grok 2 Mini", false, 0.3, 0.5),
512 ("xai/grok-1", "Grok 1", false, 5.0, 15.0),
513 ("xai/grok-1.5", "Grok 1.5", false, 5.0, 15.0),
514 ];
515
516 for (id, name, reasoning, input_cost, output_cost) in models {
517 map.insert(
518 id.to_string(),
519 Model {
520 id: extract_model_name(id).to_string(),
521 name: name.to_string(),
522 api: Api::OpenAiCompletions,
523 provider: "xai".to_string(),
524 base_url: "https://api.x.ai/v1".to_string(),
525 reasoning,
526 input: vec![InputModality::Text],
527 cost: Cost {
528 input: input_cost,
529 output: output_cost,
530 cache_read: 0.0,
531 cache_write: 0.0,
532 },
533 context_window: 131_072,
534 max_tokens: 8192,
535 headers: Default::default(),
536 compat: default_compat_for_provider("xai"),
537 },
538 );
539 }
540}
541
542fn add_openrouter_models(map: &mut HashMap<String, Model>) {
543 let models = [
544 (
545 "openrouter/anthropic/claude-3.5-sonnet",
546 "Claude 3.5 Sonnet",
547 false,
548 3.0,
549 15.0,
550 ),
551 (
552 "openrouter/anthropic/claude-3-opus",
553 "Claude 3 Opus",
554 false,
555 15.0,
556 75.0,
557 ),
558 (
559 "openrouter/google/gemini-pro-1.5",
560 "Gemini Pro 1.5",
561 false,
562 1.25,
563 5.0,
564 ),
565 (
566 "openrouter/meta-llama/llama-3-70b",
567 "Llama 3 70B",
568 false,
569 0.65,
570 2.75,
571 ),
572 (
573 "openrouter/meta-llama/llama-3-8b",
574 "Llama 3 8B",
575 false,
576 0.2,
577 0.2,
578 ),
579 (
580 "openrouter/mistralai/mistral-large",
581 "Mistral Large",
582 false,
583 2.0,
584 6.0,
585 ),
586 (
587 "openrouter/deepseek/deepseek-chat",
588 "DeepSeek Chat",
589 false,
590 0.27,
591 1.1,
592 ),
593 ("openrouter/qwen/qwen-2-72b", "Qwen 2 72B", false, 0.9, 0.9),
594 (
595 "openrouter/nousresearch/hermes-3-llama-3-70b",
596 "Hermes 3 70B",
597 false,
598 0.5,
599 1.5,
600 ),
601 ];
602
603 for (id, name, reasoning, input_cost, output_cost) in models {
604 map.insert(
605 id.to_string(),
606 Model {
607 id: extract_model_name(id).to_string(),
608 name: name.to_string(),
609 api: Api::OpenAiCompletions,
610 provider: "openrouter".to_string(),
611 base_url: "https://openrouter.ai/api/v1".to_string(),
612 reasoning,
613 input: vec![InputModality::Text],
614 cost: Cost {
615 input: input_cost,
616 output: output_cost,
617 cache_read: 0.0,
618 cache_write: 0.0,
619 },
620 context_window: 128_000,
621 max_tokens: 32_000,
622 headers: [
623 ("HTTP-Referer".to_string(), "https://oxi-ai".to_string()),
624 ("X-Title".to_string(), "oxi-ai".to_string()),
625 ]
626 .into_iter()
627 .collect(),
628 compat: default_compat_for_provider("openrouter"),
629 },
630 );
631 }
632}
633
634fn add_azure_models(map: &mut HashMap<String, Model>) {
635 let models = [
636 ("azure-openai/gpt-4o", "GPT-4o", false, 2.5, 10.0),
637 ("azure-openai/gpt-4o-mini", "GPT-4o Mini", false, 0.15, 0.60),
638 ("azure-openai/gpt-4-turbo", "GPT-4 Turbo", false, 10.0, 30.0),
639 ];
640
641 for (id, name, reasoning, input_cost, output_cost) in models {
642 map.insert(
643 id.to_string(),
644 Model {
645 id: extract_model_name(id).to_string(),
646 name: name.to_string(),
647 api: Api::AzureOpenAiResponses,
648 provider: "azure-openai".to_string(),
649 base_url: "https://{your-resource-name}.openai.azure.com".to_string(),
650 reasoning,
651 input: vec![InputModality::Text, InputModality::Image],
652 cost: Cost {
653 input: input_cost,
654 output: output_cost,
655 cache_read: 0.0,
656 cache_write: 0.0,
657 },
658 context_window: 128_000,
659 max_tokens: 32_000,
660 headers: Default::default(),
661 compat: Some(crate::CompatSettings {
662 supports_store: false,
663 supports_developer_role: false,
664 supports_reasoning_effort: false,
665 supports_usage_in_streaming: false,
666 max_tokens_field: Some(crate::MaxTokensField::MaxCompletionTokens),
667 requires_tool_result_name: true,
668 requires_assistant_after_tool_result: false,
669 requires_thinking_as_text: false,
670 thinking_format: None,
671 }),
672 },
673 );
674 }
675}
676
677fn add_zai_models(map: &mut HashMap<String, Model>) {
678 let models = [
679 ("zai/glm-4.7", "GLM-4.7", true, 0.0, 0.0),
680 ("zai/glm-5-turbo", "GLM-5-Turbo", true, 0.0, 0.0),
681 ("zai/glm-5.1", "GLM-5.1", true, 0.0, 0.0),
682 ("zai/glm-5v-turbo", "GLM-5V-Turbo", true, 0.0, 0.0),
683 ("zai/glm-4.5-air", "GLM-4.5-Air", true, 0.0, 0.0),
684 ];
685
686 for (id, name, reasoning, input_cost, output_cost) in models {
687 map.insert(
688 id.to_string(),
689 Model {
690 id: extract_model_name(id).to_string(),
691 name: name.to_string(),
692 api: Api::OpenAiCompletions,
693 provider: "zai".to_string(),
694 base_url: "https://api.z.ai/api/coding/paas/v4".to_string(),
695 reasoning,
696 input: vec![InputModality::Text],
697 cost: Cost {
698 input: input_cost,
699 output: output_cost,
700 cache_read: 0.0,
701 cache_write: 0.0,
702 },
703 context_window: 200_000,
704 max_tokens: 131_072,
705 headers: Default::default(),
706 compat: default_compat_for_provider("zai"),
707 },
708 );
709 }
710}
711
712fn add_minimax_models(map: &mut HashMap<String, Model>) {
713 let models = [
714 ("minimax/MiniMax-M2.7", "MiniMax-M2.7", true, 0.0, 0.0),
715 (
716 "minimax/MiniMax-M2.7-highspeed",
717 "MiniMax-M2.7-highspeed",
718 true,
719 0.0,
720 0.0,
721 ),
722 ];
723
724 for (id, name, reasoning, input_cost, output_cost) in models {
725 map.insert(
726 id.to_string(),
727 Model {
728 id: extract_model_name(id).to_string(),
729 name: name.to_string(),
730 api: Api::AnthropicMessages,
731 provider: "minimax".to_string(),
732 base_url: "https://api.minimax.io".to_string(),
733 reasoning,
734 input: vec![InputModality::Text],
735 cost: Cost {
736 input: input_cost,
737 output: output_cost,
738 cache_read: 0.06,
739 cache_write: 0.375,
740 },
741 context_window: 204_800,
742 max_tokens: 131_072,
743 headers: Default::default(),
744 compat: default_compat_for_provider("minimax"),
745 },
746 );
747 }
748}
749
750#[derive(Default)]
756pub struct ModelRegistry {
757 static_models: HashMap<String, Model>,
758 dynamic_models: parking_lot::RwLock<HashMap<String, Model>>,
759}
760
761impl ModelRegistry {
762 pub fn new() -> Self {
764 Self {
765 static_models: HashMap::new(),
766 dynamic_models: RwLock::new(HashMap::new()),
767 }
768 }
769
770 pub fn from_static() -> Self {
774 Self {
775 static_models: STATIC_MODELS.clone(),
776 dynamic_models: RwLock::new(HashMap::new()),
777 }
778 }
779
780 pub fn register(&self, model: Model) {
785 let key = format!("{}/{}", model.provider, model.id);
786 self.dynamic_models.write().insert(key, model);
787 }
788
789 pub fn unregister(&self, provider: &str, model_id: &str) {
791 let key = format!("{}/{}", provider, model_id);
792 self.dynamic_models.write().remove(&key);
793 }
794
795 pub fn lookup(&self, provider: &str, model_id: &str) -> Option<Model> {
799 let key = format!("{}/{}", provider, model_id);
800 if let Some(m) = self.dynamic_models.read().get(&key) {
802 return Some(m.clone());
803 }
804 self.static_models.get(&key).cloned()
806 }
807
808 pub fn get(provider: &str, model_id: &str) -> Option<&'static Model> {
810 let key = format!("{}/{}", provider, model_id);
811 STATIC_MODELS.get(&key)
812 }
813
814 pub fn get_by_provider(provider: &str) -> Vec<&'static Model> {
816 STATIC_MODELS
817 .values()
818 .filter(|m| m.provider == provider)
819 .collect()
820 }
821
822 pub fn all() -> Vec<&'static Model> {
824 STATIC_MODELS.values().collect()
825 }
826
827 pub fn dynamic_models(&self) -> Vec<Model> {
829 self.dynamic_models.read().values().cloned().collect()
830 }
831
832 pub fn model_ids(&self) -> Vec<String> {
834 let static_ids: Vec<String> = self.static_models.keys().cloned().collect();
835 let dynamic_ids: Vec<String> = self.dynamic_models.read().keys().cloned().collect();
836 static_ids.into_iter().chain(dynamic_ids).collect()
837 }
838
839 pub fn search(pattern: &str) -> Vec<&'static Model> {
841 let pattern_lower = pattern.to_lowercase();
842 STATIC_MODELS
843 .values()
844 .filter(|m| {
845 m.id.to_lowercase().contains(&pattern_lower)
846 || m.name.to_lowercase().contains(&pattern_lower)
847 })
848 .collect()
849 }
850}
851
852static GLOBAL_REGISTRY: Lazy<ModelRegistry> = Lazy::new(ModelRegistry::from_static);
856
857pub fn register_model(model: Model) {
865 GLOBAL_REGISTRY.register(model);
866}
867
868pub fn unregister_model(provider: &str, model_id: &str) {
870 GLOBAL_REGISTRY.unregister(provider, model_id);
871}
872
873pub fn lookup_model(provider: &str, model_id: &str) -> Option<Model> {
877 GLOBAL_REGISTRY.lookup(provider, model_id)
878}
879
880pub fn get_model(provider: &str, model_id: &str) -> Option<&'static Model> {
882 ModelRegistry::get(provider, model_id)
883}
884
885pub fn get_providers() -> Vec<&'static str> {
887 let mut providers: Vec<&'static str> = STATIC_MODELS
888 .values()
889 .map(|m| m.provider.as_str())
890 .collect();
891 providers.sort();
892 providers.dedup();
893 providers
894}
895
896pub fn get_models(provider: &str) -> Vec<&'static Model> {
898 ModelRegistry::get_by_provider(provider)
899}
900
901pub fn dynamic_models() -> Vec<Model> {
903 GLOBAL_REGISTRY.dynamic_models()
904}
905
906#[cfg(test)]
907mod tests {
908 use super::*;
909
910 #[test]
911 fn test_get_model() {
912 let model = get_model("openai", "gpt-4o");
913 assert!(model.is_some());
914 let model = model.unwrap();
915 assert_eq!(model.provider, "openai");
916 }
918
919 #[test]
920 fn test_get_providers() {
921 let providers = get_providers();
922 assert!(providers.contains(&"openai"));
923 assert!(providers.contains(&"anthropic"));
924 assert!(providers.contains(&"google"));
925 assert!(providers.contains(&"deepseek"));
926 assert!(providers.contains(&"mistral"));
927 assert!(providers.contains(&"groq"));
928 }
929
930 #[test]
931 fn test_deepseek_model() {
932 let model = get_model("deepseek", "deepseek-chat");
933 assert!(model.is_some());
934 let model = model.unwrap();
935 assert_eq!(model.provider, "deepseek");
936 assert_eq!(model.base_url, "https://api.deepseek.com");
937 }
938
939 #[test]
940 fn test_search_models() {
941 let results = ModelRegistry::search("gpt");
942 assert!(!results.is_empty());
943 assert!(results
944 .iter()
945 .all(|m| m.name.to_lowercase().contains("gpt")));
946 }
947
948 #[test]
949 fn test_model_registry_instance() {
950 let registry = ModelRegistry::from_static();
951 assert!(registry.lookup("openai", "gpt-4o").is_some());
952 assert!(registry.lookup("fake", "fake-model").is_none());
953 }
954
955 #[test]
956 fn test_model_registry_register_dynamic() {
957 let registry = ModelRegistry::new();
958 let custom_model = Model {
959 id: "custom-model".to_string(),
960 name: "Custom Model".to_string(),
961 api: Api::OpenAiCompletions,
962 provider: "custom".to_string(),
963 base_url: "https://custom.example.com".to_string(),
964 reasoning: false,
965 input: vec![InputModality::Text],
966 cost: Cost {
967 input: 1.0,
968 output: 2.0,
969 cache_read: 0.5,
970 cache_write: 5.0,
971 },
972 context_window: 100_000,
973 max_tokens: 8192,
974 headers: Default::default(),
975 compat: None,
976 };
977 registry.register(custom_model.clone());
978 assert!(registry.lookup("custom", "custom-model").is_some());
979 }
980}