Skip to main content

codewhale_agent/
lib.rs

1use std::collections::HashMap;
2
3use codewhale_config::ProviderKind;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct ModelInfo {
8    pub id: String,
9    pub provider: ProviderKind,
10    pub aliases: Vec<String>,
11    pub supports_tools: bool,
12    pub supports_reasoning: bool,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ModelResolution {
17    pub requested: Option<String>,
18    pub resolved: ModelInfo,
19    pub used_fallback: bool,
20    pub fallback_chain: Vec<String>,
21}
22
23#[derive(Debug, Clone)]
24pub struct ModelRegistry {
25    models: Vec<ModelInfo>,
26    alias_map: HashMap<String, usize>,
27}
28
29impl Default for ModelRegistry {
30    fn default() -> Self {
31        let models = vec![
32            ModelInfo {
33                id: "deepseek-v4-pro".to_string(),
34                provider: ProviderKind::Deepseek,
35                aliases: vec![],
36                supports_tools: true,
37                supports_reasoning: true,
38            },
39            ModelInfo {
40                id: "deepseek-v4-flash".to_string(),
41                provider: ProviderKind::Deepseek,
42                aliases: vec![
43                    "deepseek-chat".to_string(),
44                    "deepseek-reasoner".to_string(),
45                    "deepseek-r1".to_string(),
46                    "deepseek-v3".to_string(),
47                    "deepseek-v3.2".to_string(),
48                ],
49                supports_tools: true,
50                supports_reasoning: true,
51            },
52            ModelInfo {
53                id: "deepseek-ai/deepseek-v4-pro".to_string(),
54                provider: ProviderKind::NvidiaNim,
55                aliases: vec![
56                    "deepseek-v4-pro".to_string(),
57                    "nvidia-deepseek-v4-pro".to_string(),
58                    "nim-deepseek-v4-pro".to_string(),
59                ],
60                supports_tools: true,
61                supports_reasoning: true,
62            },
63            ModelInfo {
64                id: "deepseek-ai/deepseek-v4-flash".to_string(),
65                provider: ProviderKind::NvidiaNim,
66                aliases: vec![
67                    "deepseek-v4-flash".to_string(),
68                    "deepseek-chat".to_string(),
69                    "deepseek-reasoner".to_string(),
70                    "nvidia-deepseek-v4-flash".to_string(),
71                    "nim-deepseek-v4-flash".to_string(),
72                ],
73                supports_tools: true,
74                supports_reasoning: true,
75            },
76            ModelInfo {
77                id: "deepseek-v4-pro".to_string(),
78                provider: ProviderKind::Openai,
79                aliases: vec!["openai-compatible-deepseek-v4-pro".to_string()],
80                supports_tools: true,
81                supports_reasoning: true,
82            },
83            ModelInfo {
84                id: "deepseek-v4-flash".to_string(),
85                provider: ProviderKind::Openai,
86                aliases: vec!["openai-compatible-deepseek-v4-flash".to_string()],
87                supports_tools: true,
88                supports_reasoning: true,
89            },
90            ModelInfo {
91                id: "deepseek-reasoner".to_string(),
92                provider: ProviderKind::WanjieArk,
93                aliases: vec![
94                    "wanjie-deepseek-reasoner".to_string(),
95                    "ark-wanjie-deepseek-reasoner".to_string(),
96                ],
97                supports_tools: true,
98                supports_reasoning: true,
99            },
100            ModelInfo {
101                id: "deepseek/deepseek-v4-pro".to_string(),
102                provider: ProviderKind::Openrouter,
103                aliases: vec![
104                    "deepseek-v4-pro".to_string(),
105                    "openrouter-deepseek-v4-pro".to_string(),
106                ],
107                supports_tools: true,
108                supports_reasoning: true,
109            },
110            ModelInfo {
111                id: "deepseek/deepseek-v4-flash".to_string(),
112                provider: ProviderKind::Openrouter,
113                aliases: vec![
114                    "deepseek-v4-flash".to_string(),
115                    "deepseek-chat".to_string(),
116                    "deepseek-reasoner".to_string(),
117                    "openrouter-deepseek-v4-flash".to_string(),
118                ],
119                supports_tools: true,
120                supports_reasoning: true,
121            },
122            ModelInfo {
123                id: "deepseek/deepseek-v4-pro".to_string(),
124                provider: ProviderKind::Novita,
125                aliases: vec![
126                    "deepseek-v4-pro".to_string(),
127                    "novita-deepseek-v4-pro".to_string(),
128                ],
129                supports_tools: true,
130                supports_reasoning: true,
131            },
132            ModelInfo {
133                id: "deepseek/deepseek-v4-flash".to_string(),
134                provider: ProviderKind::Novita,
135                aliases: vec![
136                    "deepseek-v4-flash".to_string(),
137                    "deepseek-chat".to_string(),
138                    "deepseek-reasoner".to_string(),
139                    "novita-deepseek-v4-flash".to_string(),
140                ],
141                supports_tools: true,
142                supports_reasoning: true,
143            },
144            ModelInfo {
145                id: "accounts/fireworks/models/deepseek-v4-pro".to_string(),
146                provider: ProviderKind::Fireworks,
147                aliases: vec![
148                    "deepseek-v4-pro".to_string(),
149                    "fireworks-deepseek-v4-pro".to_string(),
150                ],
151                supports_tools: true,
152                supports_reasoning: true,
153            },
154            ModelInfo {
155                id: "kimi-k2.6".to_string(),
156                provider: ProviderKind::Moonshot,
157                aliases: vec![
158                    "kimi".to_string(),
159                    "kimi-k2".to_string(),
160                    "moonshot-kimi-k2.6".to_string(),
161                ],
162                supports_tools: true,
163                supports_reasoning: true,
164            },
165            ModelInfo {
166                id: "deepseek-ai/DeepSeek-V4-Pro".to_string(),
167                provider: ProviderKind::Sglang,
168                aliases: vec![
169                    "deepseek-v4-pro".to_string(),
170                    "sglang-deepseek-v4-pro".to_string(),
171                ],
172                supports_tools: true,
173                supports_reasoning: true,
174            },
175            ModelInfo {
176                id: "deepseek-ai/DeepSeek-V4-Flash".to_string(),
177                provider: ProviderKind::Sglang,
178                aliases: vec![
179                    "deepseek-v4-flash".to_string(),
180                    "deepseek-chat".to_string(),
181                    "deepseek-reasoner".to_string(),
182                    "sglang-deepseek-v4-flash".to_string(),
183                ],
184                supports_tools: true,
185                supports_reasoning: true,
186            },
187            ModelInfo {
188                id: "deepseek-ai/DeepSeek-V4-Pro".to_string(),
189                provider: ProviderKind::Vllm,
190                aliases: vec![
191                    "deepseek-v4-pro".to_string(),
192                    "vllm-deepseek-v4-pro".to_string(),
193                ],
194                supports_tools: true,
195                supports_reasoning: true,
196            },
197            ModelInfo {
198                id: "deepseek-ai/DeepSeek-V4-Flash".to_string(),
199                provider: ProviderKind::Vllm,
200                aliases: vec![
201                    "deepseek-v4-flash".to_string(),
202                    "deepseek-chat".to_string(),
203                    "deepseek-reasoner".to_string(),
204                    "vllm-deepseek-v4-flash".to_string(),
205                ],
206                supports_tools: true,
207                supports_reasoning: true,
208            },
209            ModelInfo {
210                id: "deepseek-coder:1.3b".to_string(),
211                provider: ProviderKind::Ollama,
212                aliases: vec![],
213                supports_tools: true,
214                supports_reasoning: false,
215            },
216        ];
217        Self::new(models)
218    }
219}
220
221impl ModelRegistry {
222    #[must_use]
223    pub fn new(models: Vec<ModelInfo>) -> Self {
224        let mut alias_map = HashMap::new();
225        for (idx, model) in models.iter().enumerate() {
226            alias_map.entry(normalize(&model.id)).or_insert(idx);
227            for alias in &model.aliases {
228                alias_map.entry(normalize(alias)).or_insert(idx);
229            }
230        }
231        Self { models, alias_map }
232    }
233
234    #[must_use]
235    pub fn list(&self) -> Vec<ModelInfo> {
236        self.models.clone()
237    }
238
239    #[must_use]
240    pub fn resolve(
241        &self,
242        requested: Option<&str>,
243        provider_hint: Option<ProviderKind>,
244    ) -> ModelResolution {
245        let mut fallback_chain = Vec::new();
246
247        if let Some(name) = requested {
248            fallback_chain.push(format!("requested:{name}"));
249            if provider_hint == Some(ProviderKind::Ollama) {
250                return ModelResolution {
251                    requested: Some(name.to_string()),
252                    resolved: ModelInfo {
253                        id: name.trim().to_string(),
254                        provider: ProviderKind::Ollama,
255                        aliases: Vec::new(),
256                        supports_tools: true,
257                        supports_reasoning: false,
258                    },
259                    used_fallback: false,
260                    fallback_chain,
261                };
262            }
263            if let Some(provider) = provider_hint
264                && let Some(model) = self
265                    .models
266                    .iter()
267                    .find(|m| m.provider == provider && model_matches(m, name))
268                    .cloned()
269            {
270                return ModelResolution {
271                    requested: Some(name.to_string()),
272                    resolved: preserve_requested_model_id_case(model, name),
273                    used_fallback: false,
274                    fallback_chain,
275                };
276            }
277            if let Some(idx) = self.alias_map.get(&normalize(name)) {
278                return ModelResolution {
279                    requested: Some(name.to_string()),
280                    resolved: preserve_requested_model_id_case(self.models[*idx].clone(), name),
281                    used_fallback: false,
282                    fallback_chain,
283                };
284            }
285        }
286
287        let provider = provider_hint.unwrap_or(ProviderKind::Deepseek);
288        fallback_chain.push(format!("provider_default:{}", provider.as_str()));
289        if let Some(model) = self.models.iter().find(|m| m.provider == provider).cloned() {
290            return ModelResolution {
291                requested: requested.map(ToOwned::to_owned),
292                resolved: model,
293                used_fallback: true,
294                fallback_chain,
295            };
296        }
297
298        let final_fallback = self.models.first().cloned().unwrap_or(ModelInfo {
299            id: "deepseek-v4-pro".to_string(),
300            provider: ProviderKind::Deepseek,
301            aliases: Vec::new(),
302            supports_tools: true,
303            supports_reasoning: true,
304        });
305        fallback_chain.push("global_default:deepseek-v4-pro".to_string());
306        ModelResolution {
307            requested: requested.map(ToOwned::to_owned),
308            resolved: final_fallback,
309            used_fallback: true,
310            fallback_chain,
311        }
312    }
313}
314
315fn normalize(value: &str) -> String {
316    value.trim().to_ascii_lowercase()
317}
318
319fn model_matches(model: &ModelInfo, requested: &str) -> bool {
320    let requested = normalize(requested);
321    normalize(&model.id) == requested
322        || model
323            .aliases
324            .iter()
325            .any(|alias| normalize(alias) == requested)
326}
327
328fn preserve_requested_model_id_case(mut model: ModelInfo, requested: &str) -> ModelInfo {
329    let requested = requested.trim();
330    if model.id.eq_ignore_ascii_case(requested) {
331        model.id = requested.to_string();
332    }
333    model
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339
340    #[test]
341    fn deepseek_v4_pro_alias_stays_deepseek_by_default() {
342        let registry = ModelRegistry::default();
343        let resolved = registry.resolve(Some("deepseek-v4-pro"), None);
344
345        assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
346        assert_eq!(resolved.resolved.id, "deepseek-v4-pro");
347    }
348
349    #[test]
350    fn deepseek_v4_pro_alias_resolves_to_nvidia_nim_when_provider_hinted() {
351        let registry = ModelRegistry::default();
352        let resolved = registry.resolve(Some("deepseek-v4-pro"), Some(ProviderKind::NvidiaNim));
353
354        assert_eq!(resolved.resolved.provider, ProviderKind::NvidiaNim);
355        assert_eq!(resolved.resolved.id, "deepseek-ai/deepseek-v4-pro");
356    }
357
358    #[test]
359    fn nvidia_nim_default_uses_catalog_model_id() {
360        let registry = ModelRegistry::default();
361        let resolved = registry.resolve(None, Some(ProviderKind::NvidiaNim));
362
363        assert_eq!(resolved.resolved.provider, ProviderKind::NvidiaNim);
364        assert_eq!(resolved.resolved.id, "deepseek-ai/deepseek-v4-pro");
365    }
366
367    #[test]
368    fn deepseek_v4_flash_alias_resolves_to_nvidia_nim_when_provider_hinted() {
369        let registry = ModelRegistry::default();
370        let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::NvidiaNim));
371
372        assert_eq!(resolved.resolved.provider, ProviderKind::NvidiaNim);
373        assert_eq!(resolved.resolved.id, "deepseek-ai/deepseek-v4-flash");
374    }
375
376    #[test]
377    fn openrouter_default_uses_namespaced_model_id() {
378        let registry = ModelRegistry::default();
379        let resolved = registry.resolve(None, Some(ProviderKind::Openrouter));
380
381        assert_eq!(resolved.resolved.provider, ProviderKind::Openrouter);
382        assert_eq!(resolved.resolved.id, "deepseek/deepseek-v4-pro");
383    }
384
385    #[test]
386    fn wanjie_ark_default_uses_reasoner_model_id() {
387        let registry = ModelRegistry::default();
388        let resolved = registry.resolve(None, Some(ProviderKind::WanjieArk));
389
390        assert_eq!(resolved.resolved.provider, ProviderKind::WanjieArk);
391        assert_eq!(resolved.resolved.id, "deepseek-reasoner");
392        assert!(resolved.resolved.supports_reasoning);
393    }
394
395    #[test]
396    fn novita_default_uses_namespaced_model_id() {
397        let registry = ModelRegistry::default();
398        let resolved = registry.resolve(None, Some(ProviderKind::Novita));
399
400        assert_eq!(resolved.resolved.provider, ProviderKind::Novita);
401        assert_eq!(resolved.resolved.id, "deepseek/deepseek-v4-pro");
402    }
403
404    #[test]
405    fn fireworks_default_uses_canonical_model_id() {
406        let registry = ModelRegistry::default();
407        let resolved = registry.resolve(None, Some(ProviderKind::Fireworks));
408
409        assert_eq!(resolved.resolved.provider, ProviderKind::Fireworks);
410        assert_eq!(
411            resolved.resolved.id,
412            "accounts/fireworks/models/deepseek-v4-pro"
413        );
414    }
415
416    #[test]
417    fn sglang_default_uses_canonical_model_id() {
418        let registry = ModelRegistry::default();
419        let resolved = registry.resolve(None, Some(ProviderKind::Sglang));
420
421        assert_eq!(resolved.resolved.provider, ProviderKind::Sglang);
422        assert_eq!(resolved.resolved.id, "deepseek-ai/DeepSeek-V4-Pro");
423    }
424
425    #[test]
426    fn deepseek_v4_flash_alias_resolves_to_openrouter_when_provider_hinted() {
427        let registry = ModelRegistry::default();
428        let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::Openrouter));
429
430        assert_eq!(resolved.resolved.provider, ProviderKind::Openrouter);
431        assert_eq!(resolved.resolved.id, "deepseek/deepseek-v4-flash");
432    }
433
434    #[test]
435    fn deepseek_v4_flash_alias_resolves_to_novita_when_provider_hinted() {
436        let registry = ModelRegistry::default();
437        let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::Novita));
438
439        assert_eq!(resolved.resolved.provider, ProviderKind::Novita);
440        assert_eq!(resolved.resolved.id, "deepseek/deepseek-v4-flash");
441    }
442
443    #[test]
444    fn deepseek_v4_flash_alias_resolves_to_sglang_when_provider_hinted() {
445        let registry = ModelRegistry::default();
446        let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::Sglang));
447
448        assert_eq!(resolved.resolved.provider, ProviderKind::Sglang);
449        assert_eq!(resolved.resolved.id, "deepseek-ai/DeepSeek-V4-Flash");
450    }
451
452    #[test]
453    fn vllm_default_uses_canonical_model_id() {
454        let registry = ModelRegistry::default();
455        let resolved = registry.resolve(None, Some(ProviderKind::Vllm));
456
457        assert_eq!(resolved.resolved.provider, ProviderKind::Vllm);
458        assert_eq!(resolved.resolved.id, "deepseek-ai/DeepSeek-V4-Pro");
459    }
460
461    #[test]
462    fn ollama_default_uses_small_local_model_id() {
463        let registry = ModelRegistry::default();
464        let resolved = registry.resolve(None, Some(ProviderKind::Ollama));
465
466        assert_eq!(resolved.resolved.provider, ProviderKind::Ollama);
467        assert_eq!(resolved.resolved.id, "deepseek-coder:1.3b");
468        assert!(!resolved.resolved.supports_reasoning);
469    }
470
471    #[test]
472    fn ollama_requested_model_tag_is_preserved() {
473        let registry = ModelRegistry::default();
474        let resolved = registry.resolve(Some("qwen2.5-coder:7b"), Some(ProviderKind::Ollama));
475
476        assert_eq!(resolved.resolved.provider, ProviderKind::Ollama);
477        assert_eq!(resolved.resolved.id, "qwen2.5-coder:7b");
478        assert!(!resolved.used_fallback);
479    }
480
481    #[test]
482    fn deepseek_v4_flash_alias_resolves_to_vllm_when_provider_hinted() {
483        let registry = ModelRegistry::default();
484        let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::Vllm));
485
486        assert_eq!(resolved.resolved.provider, ProviderKind::Vllm);
487        assert_eq!(resolved.resolved.id, "deepseek-ai/DeepSeek-V4-Flash");
488    }
489
490    #[test]
491    fn preserves_requested_model_casing_for_third_party_providers() {
492        let registry = ModelRegistry::default();
493        let resolved = registry.resolve(Some("DeepSeek-V4-Pro"), None);
494
495        assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
496        assert_eq!(resolved.resolved.id, "DeepSeek-V4-Pro");
497    }
498
499    #[test]
500    fn preserves_requested_model_casing_with_provider_hint() {
501        let registry = ModelRegistry::default();
502        let resolved = registry.resolve(Some("DeepSeek-V4-Pro"), Some(ProviderKind::Deepseek));
503
504        assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
505        assert_eq!(resolved.resolved.id, "DeepSeek-V4-Pro");
506    }
507
508    #[test]
509    fn preserves_requested_model_casing_without_surrounding_whitespace() {
510        let registry = ModelRegistry::default();
511        let resolved = registry.resolve(Some("  DeepSeek-V4-Pro  "), None);
512
513        assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
514        assert_eq!(resolved.resolved.id, "DeepSeek-V4-Pro");
515    }
516
517    #[test]
518    fn alias_match_does_not_override_requested_casing() {
519        let registry = ModelRegistry::default();
520        let resolved = registry.resolve(Some("deepseek-reasoner"), None);
521
522        assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
523        assert_eq!(resolved.resolved.id, "deepseek-v4-flash");
524    }
525}