Skip to main content

codewhale_agent/
lib.rs

1use std::collections::HashMap;
2
3use codewhale_config::ProviderKind;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct ModelInfo {
8    pub id: String,
9    pub provider: ProviderKind,
10    pub aliases: Vec<String>,
11    pub supports_tools: bool,
12    pub supports_reasoning: bool,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ModelResolution {
17    pub requested: Option<String>,
18    pub resolved: ModelInfo,
19    pub used_fallback: bool,
20    pub fallback_chain: Vec<String>,
21}
22
23#[derive(Debug, Clone)]
24pub struct ModelRegistry {
25    models: Vec<ModelInfo>,
26    alias_map: HashMap<String, usize>,
27}
28
29impl Default for ModelRegistry {
30    fn default() -> Self {
31        let models = vec![
32            ModelInfo {
33                id: "deepseek-v4-pro".to_string(),
34                provider: ProviderKind::Deepseek,
35                aliases: vec![],
36                supports_tools: true,
37                supports_reasoning: true,
38            },
39            ModelInfo {
40                id: "deepseek-v4-flash".to_string(),
41                provider: ProviderKind::Deepseek,
42                aliases: vec![
43                    "deepseek-chat".to_string(),
44                    "deepseek-reasoner".to_string(),
45                    "deepseek-r1".to_string(),
46                    "deepseek-v3".to_string(),
47                    "deepseek-v3.2".to_string(),
48                ],
49                supports_tools: true,
50                supports_reasoning: true,
51            },
52            ModelInfo {
53                id: "deepseek-ai/deepseek-v4-pro".to_string(),
54                provider: ProviderKind::NvidiaNim,
55                aliases: vec![
56                    "deepseek-v4-pro".to_string(),
57                    "nvidia-deepseek-v4-pro".to_string(),
58                    "nim-deepseek-v4-pro".to_string(),
59                ],
60                supports_tools: true,
61                supports_reasoning: true,
62            },
63            ModelInfo {
64                id: "deepseek-ai/deepseek-v4-flash".to_string(),
65                provider: ProviderKind::NvidiaNim,
66                aliases: vec![
67                    "deepseek-v4-flash".to_string(),
68                    "deepseek-chat".to_string(),
69                    "deepseek-reasoner".to_string(),
70                    "nvidia-deepseek-v4-flash".to_string(),
71                    "nim-deepseek-v4-flash".to_string(),
72                ],
73                supports_tools: true,
74                supports_reasoning: true,
75            },
76            ModelInfo {
77                id: "deepseek-v4-pro".to_string(),
78                provider: ProviderKind::Openai,
79                aliases: vec!["openai-compatible-deepseek-v4-pro".to_string()],
80                supports_tools: true,
81                supports_reasoning: true,
82            },
83            ModelInfo {
84                id: "deepseek-v4-flash".to_string(),
85                provider: ProviderKind::Openai,
86                aliases: vec!["openai-compatible-deepseek-v4-flash".to_string()],
87                supports_tools: true,
88                supports_reasoning: true,
89            },
90            ModelInfo {
91                id: "deepseek-reasoner".to_string(),
92                provider: ProviderKind::WanjieArk,
93                aliases: vec![
94                    "wanjie-deepseek-reasoner".to_string(),
95                    "ark-wanjie-deepseek-reasoner".to_string(),
96                ],
97                supports_tools: true,
98                supports_reasoning: true,
99            },
100            ModelInfo {
101                id: "deepseek/deepseek-v4-pro".to_string(),
102                provider: ProviderKind::Openrouter,
103                aliases: vec![
104                    "deepseek-v4-pro".to_string(),
105                    "openrouter-deepseek-v4-pro".to_string(),
106                ],
107                supports_tools: true,
108                supports_reasoning: true,
109            },
110            ModelInfo {
111                id: "deepseek/deepseek-v4-flash".to_string(),
112                provider: ProviderKind::Openrouter,
113                aliases: vec![
114                    "deepseek-v4-flash".to_string(),
115                    "deepseek-chat".to_string(),
116                    "deepseek-reasoner".to_string(),
117                    "openrouter-deepseek-v4-flash".to_string(),
118                ],
119                supports_tools: true,
120                supports_reasoning: true,
121            },
122            ModelInfo {
123                id: "deepseek/deepseek-v4-pro".to_string(),
124                provider: ProviderKind::Novita,
125                aliases: vec![
126                    "deepseek-v4-pro".to_string(),
127                    "novita-deepseek-v4-pro".to_string(),
128                ],
129                supports_tools: true,
130                supports_reasoning: true,
131            },
132            ModelInfo {
133                id: "deepseek/deepseek-v4-flash".to_string(),
134                provider: ProviderKind::Novita,
135                aliases: vec![
136                    "deepseek-v4-flash".to_string(),
137                    "deepseek-chat".to_string(),
138                    "deepseek-reasoner".to_string(),
139                    "novita-deepseek-v4-flash".to_string(),
140                ],
141                supports_tools: true,
142                supports_reasoning: true,
143            },
144            ModelInfo {
145                id: "accounts/fireworks/models/deepseek-v4-pro".to_string(),
146                provider: ProviderKind::Fireworks,
147                aliases: vec![
148                    "deepseek-v4-pro".to_string(),
149                    "fireworks-deepseek-v4-pro".to_string(),
150                ],
151                supports_tools: true,
152                supports_reasoning: true,
153            },
154            ModelInfo {
155                id: "kimi-k2.6".to_string(),
156                provider: ProviderKind::Moonshot,
157                aliases: vec![
158                    "kimi".to_string(),
159                    "kimi-k2".to_string(),
160                    "moonshot-kimi-k2.6".to_string(),
161                ],
162                supports_tools: true,
163                supports_reasoning: true,
164            },
165            ModelInfo {
166                id: "deepseek-ai/DeepSeek-V4-Pro".to_string(),
167                provider: ProviderKind::Sglang,
168                aliases: vec![
169                    "deepseek-v4-pro".to_string(),
170                    "sglang-deepseek-v4-pro".to_string(),
171                ],
172                supports_tools: true,
173                supports_reasoning: true,
174            },
175            ModelInfo {
176                id: "deepseek-ai/DeepSeek-V4-Flash".to_string(),
177                provider: ProviderKind::Sglang,
178                aliases: vec![
179                    "deepseek-v4-flash".to_string(),
180                    "deepseek-chat".to_string(),
181                    "deepseek-reasoner".to_string(),
182                    "sglang-deepseek-v4-flash".to_string(),
183                ],
184                supports_tools: true,
185                supports_reasoning: true,
186            },
187            ModelInfo {
188                id: "deepseek-ai/DeepSeek-V4-Pro".to_string(),
189                provider: ProviderKind::Vllm,
190                aliases: vec![
191                    "deepseek-v4-pro".to_string(),
192                    "vllm-deepseek-v4-pro".to_string(),
193                ],
194                supports_tools: true,
195                supports_reasoning: true,
196            },
197            ModelInfo {
198                id: "deepseek-ai/DeepSeek-V4-Flash".to_string(),
199                provider: ProviderKind::Vllm,
200                aliases: vec![
201                    "deepseek-v4-flash".to_string(),
202                    "deepseek-chat".to_string(),
203                    "deepseek-reasoner".to_string(),
204                    "vllm-deepseek-v4-flash".to_string(),
205                ],
206                supports_tools: true,
207                supports_reasoning: true,
208            },
209            ModelInfo {
210                id: "deepseek-coder:1.3b".to_string(),
211                provider: ProviderKind::Ollama,
212                aliases: vec![],
213                supports_tools: true,
214                supports_reasoning: false,
215            },
216            ModelInfo {
217                id: "mimo-v2.5-pro".to_string(),
218                provider: ProviderKind::Xiaomi,
219                aliases: vec!["mimo-pro".to_string(), "mimo-v2-pro".to_string()],
220                supports_tools: true,
221                supports_reasoning: true,
222            },
223            ModelInfo {
224                id: "mimo-v2.5".to_string(),
225                provider: ProviderKind::Xiaomi,
226                aliases: vec!["mimo-omni".to_string(), "mimo-v2-omni".to_string()],
227                supports_tools: true,
228                supports_reasoning: true,
229            },
230        ];
231        Self::new(models)
232    }
233}
234
235impl ModelRegistry {
236    #[must_use]
237    pub fn new(models: Vec<ModelInfo>) -> Self {
238        let mut alias_map = HashMap::new();
239        for (idx, model) in models.iter().enumerate() {
240            alias_map.entry(normalize(&model.id)).or_insert(idx);
241            for alias in &model.aliases {
242                alias_map.entry(normalize(alias)).or_insert(idx);
243            }
244        }
245        Self { models, alias_map }
246    }
247
248    #[must_use]
249    pub fn list(&self) -> Vec<ModelInfo> {
250        self.models.clone()
251    }
252
253    #[must_use]
254    pub fn resolve(
255        &self,
256        requested: Option<&str>,
257        provider_hint: Option<ProviderKind>,
258    ) -> ModelResolution {
259        let mut fallback_chain = Vec::new();
260
261        if let Some(name) = requested {
262            fallback_chain.push(format!("requested:{name}"));
263            if provider_hint == Some(ProviderKind::Ollama) {
264                return ModelResolution {
265                    requested: Some(name.to_string()),
266                    resolved: ModelInfo {
267                        id: name.trim().to_string(),
268                        provider: ProviderKind::Ollama,
269                        aliases: Vec::new(),
270                        supports_tools: true,
271                        supports_reasoning: false,
272                    },
273                    used_fallback: false,
274                    fallback_chain,
275                };
276            }
277            if let Some(provider) = provider_hint
278                && let Some(model) = self
279                    .models
280                    .iter()
281                    .find(|m| m.provider == provider && model_matches(m, name))
282                    .cloned()
283            {
284                return ModelResolution {
285                    requested: Some(name.to_string()),
286                    resolved: preserve_requested_model_id_case(model, name),
287                    used_fallback: false,
288                    fallback_chain,
289                };
290            }
291            if let Some(idx) = self.alias_map.get(&normalize(name)) {
292                return ModelResolution {
293                    requested: Some(name.to_string()),
294                    resolved: preserve_requested_model_id_case(self.models[*idx].clone(), name),
295                    used_fallback: false,
296                    fallback_chain,
297                };
298            }
299        }
300
301        let provider = provider_hint.unwrap_or(ProviderKind::Deepseek);
302        fallback_chain.push(format!("provider_default:{}", provider.as_str()));
303        if let Some(model) = self.models.iter().find(|m| m.provider == provider).cloned() {
304            return ModelResolution {
305                requested: requested.map(ToOwned::to_owned),
306                resolved: model,
307                used_fallback: true,
308                fallback_chain,
309            };
310        }
311
312        let final_fallback = self.models.first().cloned().unwrap_or(ModelInfo {
313            id: "deepseek-v4-pro".to_string(),
314            provider: ProviderKind::Deepseek,
315            aliases: Vec::new(),
316            supports_tools: true,
317            supports_reasoning: true,
318        });
319        fallback_chain.push("global_default:deepseek-v4-pro".to_string());
320        ModelResolution {
321            requested: requested.map(ToOwned::to_owned),
322            resolved: final_fallback,
323            used_fallback: true,
324            fallback_chain,
325        }
326    }
327}
328
329fn normalize(value: &str) -> String {
330    value.trim().to_ascii_lowercase()
331}
332
333fn model_matches(model: &ModelInfo, requested: &str) -> bool {
334    let requested = normalize(requested);
335    normalize(&model.id) == requested
336        || model
337            .aliases
338            .iter()
339            .any(|alias| normalize(alias) == requested)
340}
341
342fn preserve_requested_model_id_case(mut model: ModelInfo, requested: &str) -> ModelInfo {
343    let requested = requested.trim();
344    if model.id.eq_ignore_ascii_case(requested) {
345        model.id = requested.to_string();
346    }
347    model
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353
354    #[test]
355    fn deepseek_v4_pro_alias_stays_deepseek_by_default() {
356        let registry = ModelRegistry::default();
357        let resolved = registry.resolve(Some("deepseek-v4-pro"), None);
358
359        assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
360        assert_eq!(resolved.resolved.id, "deepseek-v4-pro");
361    }
362
363    #[test]
364    fn deepseek_v4_pro_alias_resolves_to_nvidia_nim_when_provider_hinted() {
365        let registry = ModelRegistry::default();
366        let resolved = registry.resolve(Some("deepseek-v4-pro"), Some(ProviderKind::NvidiaNim));
367
368        assert_eq!(resolved.resolved.provider, ProviderKind::NvidiaNim);
369        assert_eq!(resolved.resolved.id, "deepseek-ai/deepseek-v4-pro");
370    }
371
372    #[test]
373    fn nvidia_nim_default_uses_catalog_model_id() {
374        let registry = ModelRegistry::default();
375        let resolved = registry.resolve(None, Some(ProviderKind::NvidiaNim));
376
377        assert_eq!(resolved.resolved.provider, ProviderKind::NvidiaNim);
378        assert_eq!(resolved.resolved.id, "deepseek-ai/deepseek-v4-pro");
379    }
380
381    #[test]
382    fn deepseek_v4_flash_alias_resolves_to_nvidia_nim_when_provider_hinted() {
383        let registry = ModelRegistry::default();
384        let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::NvidiaNim));
385
386        assert_eq!(resolved.resolved.provider, ProviderKind::NvidiaNim);
387        assert_eq!(resolved.resolved.id, "deepseek-ai/deepseek-v4-flash");
388    }
389
390    #[test]
391    fn openrouter_default_uses_namespaced_model_id() {
392        let registry = ModelRegistry::default();
393        let resolved = registry.resolve(None, Some(ProviderKind::Openrouter));
394
395        assert_eq!(resolved.resolved.provider, ProviderKind::Openrouter);
396        assert_eq!(resolved.resolved.id, "deepseek/deepseek-v4-pro");
397    }
398
399    #[test]
400    fn wanjie_ark_default_uses_reasoner_model_id() {
401        let registry = ModelRegistry::default();
402        let resolved = registry.resolve(None, Some(ProviderKind::WanjieArk));
403
404        assert_eq!(resolved.resolved.provider, ProviderKind::WanjieArk);
405        assert_eq!(resolved.resolved.id, "deepseek-reasoner");
406        assert!(resolved.resolved.supports_reasoning);
407    }
408
409    #[test]
410    fn novita_default_uses_namespaced_model_id() {
411        let registry = ModelRegistry::default();
412        let resolved = registry.resolve(None, Some(ProviderKind::Novita));
413
414        assert_eq!(resolved.resolved.provider, ProviderKind::Novita);
415        assert_eq!(resolved.resolved.id, "deepseek/deepseek-v4-pro");
416    }
417
418    #[test]
419    fn fireworks_default_uses_canonical_model_id() {
420        let registry = ModelRegistry::default();
421        let resolved = registry.resolve(None, Some(ProviderKind::Fireworks));
422
423        assert_eq!(resolved.resolved.provider, ProviderKind::Fireworks);
424        assert_eq!(
425            resolved.resolved.id,
426            "accounts/fireworks/models/deepseek-v4-pro"
427        );
428    }
429
430    #[test]
431    fn sglang_default_uses_canonical_model_id() {
432        let registry = ModelRegistry::default();
433        let resolved = registry.resolve(None, Some(ProviderKind::Sglang));
434
435        assert_eq!(resolved.resolved.provider, ProviderKind::Sglang);
436        assert_eq!(resolved.resolved.id, "deepseek-ai/DeepSeek-V4-Pro");
437    }
438
439    #[test]
440    fn deepseek_v4_flash_alias_resolves_to_openrouter_when_provider_hinted() {
441        let registry = ModelRegistry::default();
442        let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::Openrouter));
443
444        assert_eq!(resolved.resolved.provider, ProviderKind::Openrouter);
445        assert_eq!(resolved.resolved.id, "deepseek/deepseek-v4-flash");
446    }
447
448    #[test]
449    fn deepseek_v4_flash_alias_resolves_to_novita_when_provider_hinted() {
450        let registry = ModelRegistry::default();
451        let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::Novita));
452
453        assert_eq!(resolved.resolved.provider, ProviderKind::Novita);
454        assert_eq!(resolved.resolved.id, "deepseek/deepseek-v4-flash");
455    }
456
457    #[test]
458    fn deepseek_v4_flash_alias_resolves_to_sglang_when_provider_hinted() {
459        let registry = ModelRegistry::default();
460        let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::Sglang));
461
462        assert_eq!(resolved.resolved.provider, ProviderKind::Sglang);
463        assert_eq!(resolved.resolved.id, "deepseek-ai/DeepSeek-V4-Flash");
464    }
465
466    #[test]
467    fn vllm_default_uses_canonical_model_id() {
468        let registry = ModelRegistry::default();
469        let resolved = registry.resolve(None, Some(ProviderKind::Vllm));
470
471        assert_eq!(resolved.resolved.provider, ProviderKind::Vllm);
472        assert_eq!(resolved.resolved.id, "deepseek-ai/DeepSeek-V4-Pro");
473    }
474
475    #[test]
476    fn ollama_default_uses_small_local_model_id() {
477        let registry = ModelRegistry::default();
478        let resolved = registry.resolve(None, Some(ProviderKind::Ollama));
479
480        assert_eq!(resolved.resolved.provider, ProviderKind::Ollama);
481        assert_eq!(resolved.resolved.id, "deepseek-coder:1.3b");
482        assert!(!resolved.resolved.supports_reasoning);
483    }
484
485    #[test]
486    fn ollama_requested_model_tag_is_preserved() {
487        let registry = ModelRegistry::default();
488        let resolved = registry.resolve(Some("qwen2.5-coder:7b"), Some(ProviderKind::Ollama));
489
490        assert_eq!(resolved.resolved.provider, ProviderKind::Ollama);
491        assert_eq!(resolved.resolved.id, "qwen2.5-coder:7b");
492        assert!(!resolved.used_fallback);
493    }
494
495    #[test]
496    fn deepseek_v4_flash_alias_resolves_to_vllm_when_provider_hinted() {
497        let registry = ModelRegistry::default();
498        let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::Vllm));
499
500        assert_eq!(resolved.resolved.provider, ProviderKind::Vllm);
501        assert_eq!(resolved.resolved.id, "deepseek-ai/DeepSeek-V4-Flash");
502    }
503
504    #[test]
505    fn preserves_requested_model_casing_for_third_party_providers() {
506        let registry = ModelRegistry::default();
507        let resolved = registry.resolve(Some("DeepSeek-V4-Pro"), None);
508
509        assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
510        assert_eq!(resolved.resolved.id, "DeepSeek-V4-Pro");
511    }
512
513    #[test]
514    fn preserves_requested_model_casing_with_provider_hint() {
515        let registry = ModelRegistry::default();
516        let resolved = registry.resolve(Some("DeepSeek-V4-Pro"), Some(ProviderKind::Deepseek));
517
518        assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
519        assert_eq!(resolved.resolved.id, "DeepSeek-V4-Pro");
520    }
521
522    #[test]
523    fn preserves_requested_model_casing_without_surrounding_whitespace() {
524        let registry = ModelRegistry::default();
525        let resolved = registry.resolve(Some("  DeepSeek-V4-Pro  "), None);
526
527        assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
528        assert_eq!(resolved.resolved.id, "DeepSeek-V4-Pro");
529    }
530
531    #[test]
532    fn alias_match_does_not_override_requested_casing() {
533        let registry = ModelRegistry::default();
534        let resolved = registry.resolve(Some("deepseek-reasoner"), None);
535
536        assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
537        assert_eq!(resolved.resolved.id, "deepseek-v4-flash");
538    }
539}