1use std::collections::HashMap;
2
3use codewhale_config::ProviderKind;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct ModelInfo {
8 pub id: String,
9 pub provider: ProviderKind,
10 pub aliases: Vec<String>,
11 pub supports_tools: bool,
12 pub supports_reasoning: bool,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ModelResolution {
17 pub requested: Option<String>,
18 pub resolved: ModelInfo,
19 pub used_fallback: bool,
20 pub fallback_chain: Vec<String>,
21}
22
23#[derive(Debug, Clone)]
24pub struct ModelRegistry {
25 models: Vec<ModelInfo>,
26 alias_map: HashMap<String, usize>,
27}
28
29impl Default for ModelRegistry {
30 fn default() -> Self {
31 let models = vec![
32 ModelInfo {
33 id: "deepseek-v4-pro".to_string(),
34 provider: ProviderKind::Deepseek,
35 aliases: vec![],
36 supports_tools: true,
37 supports_reasoning: true,
38 },
39 ModelInfo {
40 id: "deepseek-v4-flash".to_string(),
41 provider: ProviderKind::Deepseek,
42 aliases: vec![
43 "deepseek-chat".to_string(),
44 "deepseek-reasoner".to_string(),
45 "deepseek-r1".to_string(),
46 "deepseek-v3".to_string(),
47 "deepseek-v3.2".to_string(),
48 ],
49 supports_tools: true,
50 supports_reasoning: true,
51 },
52 ModelInfo {
53 id: "deepseek-ai/deepseek-v4-pro".to_string(),
54 provider: ProviderKind::NvidiaNim,
55 aliases: vec![
56 "deepseek-v4-pro".to_string(),
57 "nvidia-deepseek-v4-pro".to_string(),
58 "nim-deepseek-v4-pro".to_string(),
59 ],
60 supports_tools: true,
61 supports_reasoning: true,
62 },
63 ModelInfo {
64 id: "deepseek-ai/deepseek-v4-flash".to_string(),
65 provider: ProviderKind::NvidiaNim,
66 aliases: vec![
67 "deepseek-v4-flash".to_string(),
68 "deepseek-chat".to_string(),
69 "deepseek-reasoner".to_string(),
70 "nvidia-deepseek-v4-flash".to_string(),
71 "nim-deepseek-v4-flash".to_string(),
72 ],
73 supports_tools: true,
74 supports_reasoning: true,
75 },
76 ModelInfo {
77 id: "deepseek-v4-pro".to_string(),
78 provider: ProviderKind::Openai,
79 aliases: vec!["openai-compatible-deepseek-v4-pro".to_string()],
80 supports_tools: true,
81 supports_reasoning: true,
82 },
83 ModelInfo {
84 id: "deepseek-v4-flash".to_string(),
85 provider: ProviderKind::Openai,
86 aliases: vec!["openai-compatible-deepseek-v4-flash".to_string()],
87 supports_tools: true,
88 supports_reasoning: true,
89 },
90 ModelInfo {
91 id: "deepseek-reasoner".to_string(),
92 provider: ProviderKind::WanjieArk,
93 aliases: vec![
94 "wanjie-deepseek-reasoner".to_string(),
95 "ark-wanjie-deepseek-reasoner".to_string(),
96 ],
97 supports_tools: true,
98 supports_reasoning: true,
99 },
100 ModelInfo {
101 id: "deepseek/deepseek-v4-pro".to_string(),
102 provider: ProviderKind::Openrouter,
103 aliases: vec![
104 "deepseek-v4-pro".to_string(),
105 "openrouter-deepseek-v4-pro".to_string(),
106 ],
107 supports_tools: true,
108 supports_reasoning: true,
109 },
110 ModelInfo {
111 id: "deepseek/deepseek-v4-flash".to_string(),
112 provider: ProviderKind::Openrouter,
113 aliases: vec![
114 "deepseek-v4-flash".to_string(),
115 "deepseek-chat".to_string(),
116 "deepseek-reasoner".to_string(),
117 "openrouter-deepseek-v4-flash".to_string(),
118 ],
119 supports_tools: true,
120 supports_reasoning: true,
121 },
122 ModelInfo {
123 id: "deepseek/deepseek-v4-pro".to_string(),
124 provider: ProviderKind::Novita,
125 aliases: vec![
126 "deepseek-v4-pro".to_string(),
127 "novita-deepseek-v4-pro".to_string(),
128 ],
129 supports_tools: true,
130 supports_reasoning: true,
131 },
132 ModelInfo {
133 id: "deepseek/deepseek-v4-flash".to_string(),
134 provider: ProviderKind::Novita,
135 aliases: vec![
136 "deepseek-v4-flash".to_string(),
137 "deepseek-chat".to_string(),
138 "deepseek-reasoner".to_string(),
139 "novita-deepseek-v4-flash".to_string(),
140 ],
141 supports_tools: true,
142 supports_reasoning: true,
143 },
144 ModelInfo {
145 id: "accounts/fireworks/models/deepseek-v4-pro".to_string(),
146 provider: ProviderKind::Fireworks,
147 aliases: vec![
148 "deepseek-v4-pro".to_string(),
149 "fireworks-deepseek-v4-pro".to_string(),
150 ],
151 supports_tools: true,
152 supports_reasoning: true,
153 },
154 ModelInfo {
155 id: "kimi-k2.6".to_string(),
156 provider: ProviderKind::Moonshot,
157 aliases: vec![
158 "kimi".to_string(),
159 "kimi-k2".to_string(),
160 "moonshot-kimi-k2.6".to_string(),
161 ],
162 supports_tools: true,
163 supports_reasoning: true,
164 },
165 ModelInfo {
166 id: "deepseek-ai/DeepSeek-V4-Pro".to_string(),
167 provider: ProviderKind::Sglang,
168 aliases: vec![
169 "deepseek-v4-pro".to_string(),
170 "sglang-deepseek-v4-pro".to_string(),
171 ],
172 supports_tools: true,
173 supports_reasoning: true,
174 },
175 ModelInfo {
176 id: "deepseek-ai/DeepSeek-V4-Flash".to_string(),
177 provider: ProviderKind::Sglang,
178 aliases: vec![
179 "deepseek-v4-flash".to_string(),
180 "deepseek-chat".to_string(),
181 "deepseek-reasoner".to_string(),
182 "sglang-deepseek-v4-flash".to_string(),
183 ],
184 supports_tools: true,
185 supports_reasoning: true,
186 },
187 ModelInfo {
188 id: "deepseek-ai/DeepSeek-V4-Pro".to_string(),
189 provider: ProviderKind::Vllm,
190 aliases: vec![
191 "deepseek-v4-pro".to_string(),
192 "vllm-deepseek-v4-pro".to_string(),
193 ],
194 supports_tools: true,
195 supports_reasoning: true,
196 },
197 ModelInfo {
198 id: "deepseek-ai/DeepSeek-V4-Flash".to_string(),
199 provider: ProviderKind::Vllm,
200 aliases: vec![
201 "deepseek-v4-flash".to_string(),
202 "deepseek-chat".to_string(),
203 "deepseek-reasoner".to_string(),
204 "vllm-deepseek-v4-flash".to_string(),
205 ],
206 supports_tools: true,
207 supports_reasoning: true,
208 },
209 ModelInfo {
210 id: "deepseek-coder:1.3b".to_string(),
211 provider: ProviderKind::Ollama,
212 aliases: vec![],
213 supports_tools: true,
214 supports_reasoning: false,
215 },
216 ];
217 Self::new(models)
218 }
219}
220
221impl ModelRegistry {
222 #[must_use]
223 pub fn new(models: Vec<ModelInfo>) -> Self {
224 let mut alias_map = HashMap::new();
225 for (idx, model) in models.iter().enumerate() {
226 alias_map.entry(normalize(&model.id)).or_insert(idx);
227 for alias in &model.aliases {
228 alias_map.entry(normalize(alias)).or_insert(idx);
229 }
230 }
231 Self { models, alias_map }
232 }
233
234 #[must_use]
235 pub fn list(&self) -> Vec<ModelInfo> {
236 self.models.clone()
237 }
238
239 #[must_use]
240 pub fn resolve(
241 &self,
242 requested: Option<&str>,
243 provider_hint: Option<ProviderKind>,
244 ) -> ModelResolution {
245 let mut fallback_chain = Vec::new();
246
247 if let Some(name) = requested {
248 fallback_chain.push(format!("requested:{name}"));
249 if provider_hint == Some(ProviderKind::Ollama) {
250 return ModelResolution {
251 requested: Some(name.to_string()),
252 resolved: ModelInfo {
253 id: name.trim().to_string(),
254 provider: ProviderKind::Ollama,
255 aliases: Vec::new(),
256 supports_tools: true,
257 supports_reasoning: false,
258 },
259 used_fallback: false,
260 fallback_chain,
261 };
262 }
263 if let Some(provider) = provider_hint
264 && let Some(model) = self
265 .models
266 .iter()
267 .find(|m| m.provider == provider && model_matches(m, name))
268 .cloned()
269 {
270 return ModelResolution {
271 requested: Some(name.to_string()),
272 resolved: preserve_requested_model_id_case(model, name),
273 used_fallback: false,
274 fallback_chain,
275 };
276 }
277 if let Some(idx) = self.alias_map.get(&normalize(name)) {
278 return ModelResolution {
279 requested: Some(name.to_string()),
280 resolved: preserve_requested_model_id_case(self.models[*idx].clone(), name),
281 used_fallback: false,
282 fallback_chain,
283 };
284 }
285 }
286
287 let provider = provider_hint.unwrap_or(ProviderKind::Deepseek);
288 fallback_chain.push(format!("provider_default:{}", provider.as_str()));
289 if let Some(model) = self.models.iter().find(|m| m.provider == provider).cloned() {
290 return ModelResolution {
291 requested: requested.map(ToOwned::to_owned),
292 resolved: model,
293 used_fallback: true,
294 fallback_chain,
295 };
296 }
297
298 let final_fallback = self.models.first().cloned().unwrap_or(ModelInfo {
299 id: "deepseek-v4-pro".to_string(),
300 provider: ProviderKind::Deepseek,
301 aliases: Vec::new(),
302 supports_tools: true,
303 supports_reasoning: true,
304 });
305 fallback_chain.push("global_default:deepseek-v4-pro".to_string());
306 ModelResolution {
307 requested: requested.map(ToOwned::to_owned),
308 resolved: final_fallback,
309 used_fallback: true,
310 fallback_chain,
311 }
312 }
313}
314
315fn normalize(value: &str) -> String {
316 value.trim().to_ascii_lowercase()
317}
318
319fn model_matches(model: &ModelInfo, requested: &str) -> bool {
320 let requested = normalize(requested);
321 normalize(&model.id) == requested
322 || model
323 .aliases
324 .iter()
325 .any(|alias| normalize(alias) == requested)
326}
327
328fn preserve_requested_model_id_case(mut model: ModelInfo, requested: &str) -> ModelInfo {
329 let requested = requested.trim();
330 if model.id.eq_ignore_ascii_case(requested) {
331 model.id = requested.to_string();
332 }
333 model
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339
340 #[test]
341 fn deepseek_v4_pro_alias_stays_deepseek_by_default() {
342 let registry = ModelRegistry::default();
343 let resolved = registry.resolve(Some("deepseek-v4-pro"), None);
344
345 assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
346 assert_eq!(resolved.resolved.id, "deepseek-v4-pro");
347 }
348
349 #[test]
350 fn deepseek_v4_pro_alias_resolves_to_nvidia_nim_when_provider_hinted() {
351 let registry = ModelRegistry::default();
352 let resolved = registry.resolve(Some("deepseek-v4-pro"), Some(ProviderKind::NvidiaNim));
353
354 assert_eq!(resolved.resolved.provider, ProviderKind::NvidiaNim);
355 assert_eq!(resolved.resolved.id, "deepseek-ai/deepseek-v4-pro");
356 }
357
358 #[test]
359 fn nvidia_nim_default_uses_catalog_model_id() {
360 let registry = ModelRegistry::default();
361 let resolved = registry.resolve(None, Some(ProviderKind::NvidiaNim));
362
363 assert_eq!(resolved.resolved.provider, ProviderKind::NvidiaNim);
364 assert_eq!(resolved.resolved.id, "deepseek-ai/deepseek-v4-pro");
365 }
366
367 #[test]
368 fn deepseek_v4_flash_alias_resolves_to_nvidia_nim_when_provider_hinted() {
369 let registry = ModelRegistry::default();
370 let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::NvidiaNim));
371
372 assert_eq!(resolved.resolved.provider, ProviderKind::NvidiaNim);
373 assert_eq!(resolved.resolved.id, "deepseek-ai/deepseek-v4-flash");
374 }
375
376 #[test]
377 fn openrouter_default_uses_namespaced_model_id() {
378 let registry = ModelRegistry::default();
379 let resolved = registry.resolve(None, Some(ProviderKind::Openrouter));
380
381 assert_eq!(resolved.resolved.provider, ProviderKind::Openrouter);
382 assert_eq!(resolved.resolved.id, "deepseek/deepseek-v4-pro");
383 }
384
385 #[test]
386 fn wanjie_ark_default_uses_reasoner_model_id() {
387 let registry = ModelRegistry::default();
388 let resolved = registry.resolve(None, Some(ProviderKind::WanjieArk));
389
390 assert_eq!(resolved.resolved.provider, ProviderKind::WanjieArk);
391 assert_eq!(resolved.resolved.id, "deepseek-reasoner");
392 assert!(resolved.resolved.supports_reasoning);
393 }
394
395 #[test]
396 fn novita_default_uses_namespaced_model_id() {
397 let registry = ModelRegistry::default();
398 let resolved = registry.resolve(None, Some(ProviderKind::Novita));
399
400 assert_eq!(resolved.resolved.provider, ProviderKind::Novita);
401 assert_eq!(resolved.resolved.id, "deepseek/deepseek-v4-pro");
402 }
403
404 #[test]
405 fn fireworks_default_uses_canonical_model_id() {
406 let registry = ModelRegistry::default();
407 let resolved = registry.resolve(None, Some(ProviderKind::Fireworks));
408
409 assert_eq!(resolved.resolved.provider, ProviderKind::Fireworks);
410 assert_eq!(
411 resolved.resolved.id,
412 "accounts/fireworks/models/deepseek-v4-pro"
413 );
414 }
415
416 #[test]
417 fn sglang_default_uses_canonical_model_id() {
418 let registry = ModelRegistry::default();
419 let resolved = registry.resolve(None, Some(ProviderKind::Sglang));
420
421 assert_eq!(resolved.resolved.provider, ProviderKind::Sglang);
422 assert_eq!(resolved.resolved.id, "deepseek-ai/DeepSeek-V4-Pro");
423 }
424
425 #[test]
426 fn deepseek_v4_flash_alias_resolves_to_openrouter_when_provider_hinted() {
427 let registry = ModelRegistry::default();
428 let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::Openrouter));
429
430 assert_eq!(resolved.resolved.provider, ProviderKind::Openrouter);
431 assert_eq!(resolved.resolved.id, "deepseek/deepseek-v4-flash");
432 }
433
434 #[test]
435 fn deepseek_v4_flash_alias_resolves_to_novita_when_provider_hinted() {
436 let registry = ModelRegistry::default();
437 let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::Novita));
438
439 assert_eq!(resolved.resolved.provider, ProviderKind::Novita);
440 assert_eq!(resolved.resolved.id, "deepseek/deepseek-v4-flash");
441 }
442
443 #[test]
444 fn deepseek_v4_flash_alias_resolves_to_sglang_when_provider_hinted() {
445 let registry = ModelRegistry::default();
446 let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::Sglang));
447
448 assert_eq!(resolved.resolved.provider, ProviderKind::Sglang);
449 assert_eq!(resolved.resolved.id, "deepseek-ai/DeepSeek-V4-Flash");
450 }
451
452 #[test]
453 fn vllm_default_uses_canonical_model_id() {
454 let registry = ModelRegistry::default();
455 let resolved = registry.resolve(None, Some(ProviderKind::Vllm));
456
457 assert_eq!(resolved.resolved.provider, ProviderKind::Vllm);
458 assert_eq!(resolved.resolved.id, "deepseek-ai/DeepSeek-V4-Pro");
459 }
460
461 #[test]
462 fn ollama_default_uses_small_local_model_id() {
463 let registry = ModelRegistry::default();
464 let resolved = registry.resolve(None, Some(ProviderKind::Ollama));
465
466 assert_eq!(resolved.resolved.provider, ProviderKind::Ollama);
467 assert_eq!(resolved.resolved.id, "deepseek-coder:1.3b");
468 assert!(!resolved.resolved.supports_reasoning);
469 }
470
471 #[test]
472 fn ollama_requested_model_tag_is_preserved() {
473 let registry = ModelRegistry::default();
474 let resolved = registry.resolve(Some("qwen2.5-coder:7b"), Some(ProviderKind::Ollama));
475
476 assert_eq!(resolved.resolved.provider, ProviderKind::Ollama);
477 assert_eq!(resolved.resolved.id, "qwen2.5-coder:7b");
478 assert!(!resolved.used_fallback);
479 }
480
481 #[test]
482 fn deepseek_v4_flash_alias_resolves_to_vllm_when_provider_hinted() {
483 let registry = ModelRegistry::default();
484 let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::Vllm));
485
486 assert_eq!(resolved.resolved.provider, ProviderKind::Vllm);
487 assert_eq!(resolved.resolved.id, "deepseek-ai/DeepSeek-V4-Flash");
488 }
489
490 #[test]
491 fn preserves_requested_model_casing_for_third_party_providers() {
492 let registry = ModelRegistry::default();
493 let resolved = registry.resolve(Some("DeepSeek-V4-Pro"), None);
494
495 assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
496 assert_eq!(resolved.resolved.id, "DeepSeek-V4-Pro");
497 }
498
499 #[test]
500 fn preserves_requested_model_casing_with_provider_hint() {
501 let registry = ModelRegistry::default();
502 let resolved = registry.resolve(Some("DeepSeek-V4-Pro"), Some(ProviderKind::Deepseek));
503
504 assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
505 assert_eq!(resolved.resolved.id, "DeepSeek-V4-Pro");
506 }
507
508 #[test]
509 fn preserves_requested_model_casing_without_surrounding_whitespace() {
510 let registry = ModelRegistry::default();
511 let resolved = registry.resolve(Some(" DeepSeek-V4-Pro "), None);
512
513 assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
514 assert_eq!(resolved.resolved.id, "DeepSeek-V4-Pro");
515 }
516
517 #[test]
518 fn alias_match_does_not_override_requested_casing() {
519 let registry = ModelRegistry::default();
520 let resolved = registry.resolve(Some("deepseek-reasoner"), None);
521
522 assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
523 assert_eq!(resolved.resolved.id, "deepseek-v4-flash");
524 }
525}