1use std::collections::HashMap;
2
3use codewhale_config::ProviderKind;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct ModelInfo {
8 pub id: String,
9 pub provider: ProviderKind,
10 pub aliases: Vec<String>,
11 pub supports_tools: bool,
12 pub supports_reasoning: bool,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ModelResolution {
17 pub requested: Option<String>,
18 pub resolved: ModelInfo,
19 pub used_fallback: bool,
20 pub fallback_chain: Vec<String>,
21}
22
23#[derive(Debug, Clone)]
24pub struct ModelRegistry {
25 models: Vec<ModelInfo>,
26 alias_map: HashMap<String, usize>,
27}
28
29impl Default for ModelRegistry {
30 fn default() -> Self {
31 let models = vec![
32 ModelInfo {
33 id: "deepseek-v4-pro".to_string(),
34 provider: ProviderKind::Deepseek,
35 aliases: vec![],
36 supports_tools: true,
37 supports_reasoning: true,
38 },
39 ModelInfo {
40 id: "deepseek-v4-flash".to_string(),
41 provider: ProviderKind::Deepseek,
42 aliases: vec![
43 "deepseek-chat".to_string(),
44 "deepseek-reasoner".to_string(),
45 "deepseek-r1".to_string(),
46 "deepseek-v3".to_string(),
47 "deepseek-v3.2".to_string(),
48 ],
49 supports_tools: true,
50 supports_reasoning: true,
51 },
52 ModelInfo {
53 id: "deepseek-ai/deepseek-v4-pro".to_string(),
54 provider: ProviderKind::NvidiaNim,
55 aliases: vec![
56 "deepseek-v4-pro".to_string(),
57 "nvidia-deepseek-v4-pro".to_string(),
58 "nim-deepseek-v4-pro".to_string(),
59 ],
60 supports_tools: true,
61 supports_reasoning: true,
62 },
63 ModelInfo {
64 id: "deepseek-ai/deepseek-v4-flash".to_string(),
65 provider: ProviderKind::NvidiaNim,
66 aliases: vec![
67 "deepseek-v4-flash".to_string(),
68 "deepseek-chat".to_string(),
69 "deepseek-reasoner".to_string(),
70 "nvidia-deepseek-v4-flash".to_string(),
71 "nim-deepseek-v4-flash".to_string(),
72 ],
73 supports_tools: true,
74 supports_reasoning: true,
75 },
76 ModelInfo {
77 id: "deepseek-v4-pro".to_string(),
78 provider: ProviderKind::Openai,
79 aliases: vec!["openai-compatible-deepseek-v4-pro".to_string()],
80 supports_tools: true,
81 supports_reasoning: true,
82 },
83 ModelInfo {
84 id: "deepseek-v4-flash".to_string(),
85 provider: ProviderKind::Openai,
86 aliases: vec!["openai-compatible-deepseek-v4-flash".to_string()],
87 supports_tools: true,
88 supports_reasoning: true,
89 },
90 ModelInfo {
91 id: "deepseek-reasoner".to_string(),
92 provider: ProviderKind::WanjieArk,
93 aliases: vec![
94 "wanjie-deepseek-reasoner".to_string(),
95 "ark-wanjie-deepseek-reasoner".to_string(),
96 ],
97 supports_tools: true,
98 supports_reasoning: true,
99 },
100 ModelInfo {
101 id: "deepseek/deepseek-v4-pro".to_string(),
102 provider: ProviderKind::Openrouter,
103 aliases: vec![
104 "deepseek-v4-pro".to_string(),
105 "openrouter-deepseek-v4-pro".to_string(),
106 ],
107 supports_tools: true,
108 supports_reasoning: true,
109 },
110 ModelInfo {
111 id: "deepseek/deepseek-v4-flash".to_string(),
112 provider: ProviderKind::Openrouter,
113 aliases: vec![
114 "deepseek-v4-flash".to_string(),
115 "deepseek-chat".to_string(),
116 "deepseek-reasoner".to_string(),
117 "openrouter-deepseek-v4-flash".to_string(),
118 ],
119 supports_tools: true,
120 supports_reasoning: true,
121 },
122 ModelInfo {
123 id: "deepseek/deepseek-v4-pro".to_string(),
124 provider: ProviderKind::Novita,
125 aliases: vec![
126 "deepseek-v4-pro".to_string(),
127 "novita-deepseek-v4-pro".to_string(),
128 ],
129 supports_tools: true,
130 supports_reasoning: true,
131 },
132 ModelInfo {
133 id: "deepseek/deepseek-v4-flash".to_string(),
134 provider: ProviderKind::Novita,
135 aliases: vec![
136 "deepseek-v4-flash".to_string(),
137 "deepseek-chat".to_string(),
138 "deepseek-reasoner".to_string(),
139 "novita-deepseek-v4-flash".to_string(),
140 ],
141 supports_tools: true,
142 supports_reasoning: true,
143 },
144 ModelInfo {
145 id: "accounts/fireworks/models/deepseek-v4-pro".to_string(),
146 provider: ProviderKind::Fireworks,
147 aliases: vec![
148 "deepseek-v4-pro".to_string(),
149 "fireworks-deepseek-v4-pro".to_string(),
150 ],
151 supports_tools: true,
152 supports_reasoning: true,
153 },
154 ModelInfo {
155 id: "kimi-k2.6".to_string(),
156 provider: ProviderKind::Moonshot,
157 aliases: vec![
158 "kimi".to_string(),
159 "kimi-k2".to_string(),
160 "moonshot-kimi-k2.6".to_string(),
161 ],
162 supports_tools: true,
163 supports_reasoning: true,
164 },
165 ModelInfo {
166 id: "deepseek-ai/DeepSeek-V4-Pro".to_string(),
167 provider: ProviderKind::Sglang,
168 aliases: vec![
169 "deepseek-v4-pro".to_string(),
170 "sglang-deepseek-v4-pro".to_string(),
171 ],
172 supports_tools: true,
173 supports_reasoning: true,
174 },
175 ModelInfo {
176 id: "deepseek-ai/DeepSeek-V4-Flash".to_string(),
177 provider: ProviderKind::Sglang,
178 aliases: vec![
179 "deepseek-v4-flash".to_string(),
180 "deepseek-chat".to_string(),
181 "deepseek-reasoner".to_string(),
182 "sglang-deepseek-v4-flash".to_string(),
183 ],
184 supports_tools: true,
185 supports_reasoning: true,
186 },
187 ModelInfo {
188 id: "deepseek-ai/DeepSeek-V4-Pro".to_string(),
189 provider: ProviderKind::Vllm,
190 aliases: vec![
191 "deepseek-v4-pro".to_string(),
192 "vllm-deepseek-v4-pro".to_string(),
193 ],
194 supports_tools: true,
195 supports_reasoning: true,
196 },
197 ModelInfo {
198 id: "deepseek-ai/DeepSeek-V4-Flash".to_string(),
199 provider: ProviderKind::Vllm,
200 aliases: vec![
201 "deepseek-v4-flash".to_string(),
202 "deepseek-chat".to_string(),
203 "deepseek-reasoner".to_string(),
204 "vllm-deepseek-v4-flash".to_string(),
205 ],
206 supports_tools: true,
207 supports_reasoning: true,
208 },
209 ModelInfo {
210 id: "deepseek-coder:1.3b".to_string(),
211 provider: ProviderKind::Ollama,
212 aliases: vec![],
213 supports_tools: true,
214 supports_reasoning: false,
215 },
216 ModelInfo {
217 id: "mimo-v2.5-pro".to_string(),
218 provider: ProviderKind::Xiaomi,
219 aliases: vec!["mimo-pro".to_string(), "mimo-v2-pro".to_string()],
220 supports_tools: true,
221 supports_reasoning: true,
222 },
223 ModelInfo {
224 id: "mimo-v2.5".to_string(),
225 provider: ProviderKind::Xiaomi,
226 aliases: vec!["mimo-omni".to_string(), "mimo-v2-omni".to_string()],
227 supports_tools: true,
228 supports_reasoning: true,
229 },
230 ];
231 Self::new(models)
232 }
233}
234
235impl ModelRegistry {
236 #[must_use]
237 pub fn new(models: Vec<ModelInfo>) -> Self {
238 let mut alias_map = HashMap::new();
239 for (idx, model) in models.iter().enumerate() {
240 alias_map.entry(normalize(&model.id)).or_insert(idx);
241 for alias in &model.aliases {
242 alias_map.entry(normalize(alias)).or_insert(idx);
243 }
244 }
245 Self { models, alias_map }
246 }
247
248 #[must_use]
249 pub fn list(&self) -> Vec<ModelInfo> {
250 self.models.clone()
251 }
252
253 #[must_use]
254 pub fn resolve(
255 &self,
256 requested: Option<&str>,
257 provider_hint: Option<ProviderKind>,
258 ) -> ModelResolution {
259 let mut fallback_chain = Vec::new();
260
261 if let Some(name) = requested {
262 fallback_chain.push(format!("requested:{name}"));
263 if provider_hint == Some(ProviderKind::Ollama) {
264 return ModelResolution {
265 requested: Some(name.to_string()),
266 resolved: ModelInfo {
267 id: name.trim().to_string(),
268 provider: ProviderKind::Ollama,
269 aliases: Vec::new(),
270 supports_tools: true,
271 supports_reasoning: false,
272 },
273 used_fallback: false,
274 fallback_chain,
275 };
276 }
277 if let Some(provider) = provider_hint
278 && let Some(model) = self
279 .models
280 .iter()
281 .find(|m| m.provider == provider && model_matches(m, name))
282 .cloned()
283 {
284 return ModelResolution {
285 requested: Some(name.to_string()),
286 resolved: preserve_requested_model_id_case(model, name),
287 used_fallback: false,
288 fallback_chain,
289 };
290 }
291 if let Some(idx) = self.alias_map.get(&normalize(name)) {
292 return ModelResolution {
293 requested: Some(name.to_string()),
294 resolved: preserve_requested_model_id_case(self.models[*idx].clone(), name),
295 used_fallback: false,
296 fallback_chain,
297 };
298 }
299 }
300
301 let provider = provider_hint.unwrap_or(ProviderKind::Deepseek);
302 fallback_chain.push(format!("provider_default:{}", provider.as_str()));
303 if let Some(model) = self.models.iter().find(|m| m.provider == provider).cloned() {
304 return ModelResolution {
305 requested: requested.map(ToOwned::to_owned),
306 resolved: model,
307 used_fallback: true,
308 fallback_chain,
309 };
310 }
311
312 let final_fallback = self.models.first().cloned().unwrap_or(ModelInfo {
313 id: "deepseek-v4-pro".to_string(),
314 provider: ProviderKind::Deepseek,
315 aliases: Vec::new(),
316 supports_tools: true,
317 supports_reasoning: true,
318 });
319 fallback_chain.push("global_default:deepseek-v4-pro".to_string());
320 ModelResolution {
321 requested: requested.map(ToOwned::to_owned),
322 resolved: final_fallback,
323 used_fallback: true,
324 fallback_chain,
325 }
326 }
327}
328
329fn normalize(value: &str) -> String {
330 value.trim().to_ascii_lowercase()
331}
332
333fn model_matches(model: &ModelInfo, requested: &str) -> bool {
334 let requested = normalize(requested);
335 normalize(&model.id) == requested
336 || model
337 .aliases
338 .iter()
339 .any(|alias| normalize(alias) == requested)
340}
341
342fn preserve_requested_model_id_case(mut model: ModelInfo, requested: &str) -> ModelInfo {
343 let requested = requested.trim();
344 if model.id.eq_ignore_ascii_case(requested) {
345 model.id = requested.to_string();
346 }
347 model
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353
354 #[test]
355 fn deepseek_v4_pro_alias_stays_deepseek_by_default() {
356 let registry = ModelRegistry::default();
357 let resolved = registry.resolve(Some("deepseek-v4-pro"), None);
358
359 assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
360 assert_eq!(resolved.resolved.id, "deepseek-v4-pro");
361 }
362
363 #[test]
364 fn deepseek_v4_pro_alias_resolves_to_nvidia_nim_when_provider_hinted() {
365 let registry = ModelRegistry::default();
366 let resolved = registry.resolve(Some("deepseek-v4-pro"), Some(ProviderKind::NvidiaNim));
367
368 assert_eq!(resolved.resolved.provider, ProviderKind::NvidiaNim);
369 assert_eq!(resolved.resolved.id, "deepseek-ai/deepseek-v4-pro");
370 }
371
372 #[test]
373 fn nvidia_nim_default_uses_catalog_model_id() {
374 let registry = ModelRegistry::default();
375 let resolved = registry.resolve(None, Some(ProviderKind::NvidiaNim));
376
377 assert_eq!(resolved.resolved.provider, ProviderKind::NvidiaNim);
378 assert_eq!(resolved.resolved.id, "deepseek-ai/deepseek-v4-pro");
379 }
380
381 #[test]
382 fn deepseek_v4_flash_alias_resolves_to_nvidia_nim_when_provider_hinted() {
383 let registry = ModelRegistry::default();
384 let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::NvidiaNim));
385
386 assert_eq!(resolved.resolved.provider, ProviderKind::NvidiaNim);
387 assert_eq!(resolved.resolved.id, "deepseek-ai/deepseek-v4-flash");
388 }
389
390 #[test]
391 fn openrouter_default_uses_namespaced_model_id() {
392 let registry = ModelRegistry::default();
393 let resolved = registry.resolve(None, Some(ProviderKind::Openrouter));
394
395 assert_eq!(resolved.resolved.provider, ProviderKind::Openrouter);
396 assert_eq!(resolved.resolved.id, "deepseek/deepseek-v4-pro");
397 }
398
399 #[test]
400 fn wanjie_ark_default_uses_reasoner_model_id() {
401 let registry = ModelRegistry::default();
402 let resolved = registry.resolve(None, Some(ProviderKind::WanjieArk));
403
404 assert_eq!(resolved.resolved.provider, ProviderKind::WanjieArk);
405 assert_eq!(resolved.resolved.id, "deepseek-reasoner");
406 assert!(resolved.resolved.supports_reasoning);
407 }
408
409 #[test]
410 fn novita_default_uses_namespaced_model_id() {
411 let registry = ModelRegistry::default();
412 let resolved = registry.resolve(None, Some(ProviderKind::Novita));
413
414 assert_eq!(resolved.resolved.provider, ProviderKind::Novita);
415 assert_eq!(resolved.resolved.id, "deepseek/deepseek-v4-pro");
416 }
417
418 #[test]
419 fn fireworks_default_uses_canonical_model_id() {
420 let registry = ModelRegistry::default();
421 let resolved = registry.resolve(None, Some(ProviderKind::Fireworks));
422
423 assert_eq!(resolved.resolved.provider, ProviderKind::Fireworks);
424 assert_eq!(
425 resolved.resolved.id,
426 "accounts/fireworks/models/deepseek-v4-pro"
427 );
428 }
429
430 #[test]
431 fn sglang_default_uses_canonical_model_id() {
432 let registry = ModelRegistry::default();
433 let resolved = registry.resolve(None, Some(ProviderKind::Sglang));
434
435 assert_eq!(resolved.resolved.provider, ProviderKind::Sglang);
436 assert_eq!(resolved.resolved.id, "deepseek-ai/DeepSeek-V4-Pro");
437 }
438
439 #[test]
440 fn deepseek_v4_flash_alias_resolves_to_openrouter_when_provider_hinted() {
441 let registry = ModelRegistry::default();
442 let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::Openrouter));
443
444 assert_eq!(resolved.resolved.provider, ProviderKind::Openrouter);
445 assert_eq!(resolved.resolved.id, "deepseek/deepseek-v4-flash");
446 }
447
448 #[test]
449 fn deepseek_v4_flash_alias_resolves_to_novita_when_provider_hinted() {
450 let registry = ModelRegistry::default();
451 let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::Novita));
452
453 assert_eq!(resolved.resolved.provider, ProviderKind::Novita);
454 assert_eq!(resolved.resolved.id, "deepseek/deepseek-v4-flash");
455 }
456
457 #[test]
458 fn deepseek_v4_flash_alias_resolves_to_sglang_when_provider_hinted() {
459 let registry = ModelRegistry::default();
460 let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::Sglang));
461
462 assert_eq!(resolved.resolved.provider, ProviderKind::Sglang);
463 assert_eq!(resolved.resolved.id, "deepseek-ai/DeepSeek-V4-Flash");
464 }
465
466 #[test]
467 fn vllm_default_uses_canonical_model_id() {
468 let registry = ModelRegistry::default();
469 let resolved = registry.resolve(None, Some(ProviderKind::Vllm));
470
471 assert_eq!(resolved.resolved.provider, ProviderKind::Vllm);
472 assert_eq!(resolved.resolved.id, "deepseek-ai/DeepSeek-V4-Pro");
473 }
474
475 #[test]
476 fn ollama_default_uses_small_local_model_id() {
477 let registry = ModelRegistry::default();
478 let resolved = registry.resolve(None, Some(ProviderKind::Ollama));
479
480 assert_eq!(resolved.resolved.provider, ProviderKind::Ollama);
481 assert_eq!(resolved.resolved.id, "deepseek-coder:1.3b");
482 assert!(!resolved.resolved.supports_reasoning);
483 }
484
485 #[test]
486 fn ollama_requested_model_tag_is_preserved() {
487 let registry = ModelRegistry::default();
488 let resolved = registry.resolve(Some("qwen2.5-coder:7b"), Some(ProviderKind::Ollama));
489
490 assert_eq!(resolved.resolved.provider, ProviderKind::Ollama);
491 assert_eq!(resolved.resolved.id, "qwen2.5-coder:7b");
492 assert!(!resolved.used_fallback);
493 }
494
495 #[test]
496 fn deepseek_v4_flash_alias_resolves_to_vllm_when_provider_hinted() {
497 let registry = ModelRegistry::default();
498 let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::Vllm));
499
500 assert_eq!(resolved.resolved.provider, ProviderKind::Vllm);
501 assert_eq!(resolved.resolved.id, "deepseek-ai/DeepSeek-V4-Flash");
502 }
503
504 #[test]
505 fn preserves_requested_model_casing_for_third_party_providers() {
506 let registry = ModelRegistry::default();
507 let resolved = registry.resolve(Some("DeepSeek-V4-Pro"), None);
508
509 assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
510 assert_eq!(resolved.resolved.id, "DeepSeek-V4-Pro");
511 }
512
513 #[test]
514 fn preserves_requested_model_casing_with_provider_hint() {
515 let registry = ModelRegistry::default();
516 let resolved = registry.resolve(Some("DeepSeek-V4-Pro"), Some(ProviderKind::Deepseek));
517
518 assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
519 assert_eq!(resolved.resolved.id, "DeepSeek-V4-Pro");
520 }
521
522 #[test]
523 fn preserves_requested_model_casing_without_surrounding_whitespace() {
524 let registry = ModelRegistry::default();
525 let resolved = registry.resolve(Some(" DeepSeek-V4-Pro "), None);
526
527 assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
528 assert_eq!(resolved.resolved.id, "DeepSeek-V4-Pro");
529 }
530
531 #[test]
532 fn alias_match_does_not_override_requested_casing() {
533 let registry = ModelRegistry::default();
534 let resolved = registry.resolve(Some("deepseek-reasoner"), None);
535
536 assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
537 assert_eq!(resolved.resolved.id, "deepseek-v4-flash");
538 }
539}