Skip to main content

deepseek_agent/
lib.rs

1use std::collections::HashMap;
2
3use deepseek_config::ProviderKind;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct ModelInfo {
8    pub id: String,
9    pub provider: ProviderKind,
10    pub aliases: Vec<String>,
11    pub supports_tools: bool,
12    pub supports_reasoning: bool,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ModelResolution {
17    pub requested: Option<String>,
18    pub resolved: ModelInfo,
19    pub used_fallback: bool,
20    pub fallback_chain: Vec<String>,
21}
22
23#[derive(Debug, Clone)]
24pub struct ModelRegistry {
25    models: Vec<ModelInfo>,
26    alias_map: HashMap<String, usize>,
27}
28
29impl Default for ModelRegistry {
30    fn default() -> Self {
31        let models = vec![
32            ModelInfo {
33                id: "deepseek-v4-pro".to_string(),
34                provider: ProviderKind::Deepseek,
35                aliases: vec![],
36                supports_tools: true,
37                supports_reasoning: true,
38            },
39            ModelInfo {
40                id: "deepseek-v4-flash".to_string(),
41                provider: ProviderKind::Deepseek,
42                aliases: vec![
43                    "deepseek-chat".to_string(),
44                    "deepseek-reasoner".to_string(),
45                    "deepseek-r1".to_string(),
46                    "deepseek-v3".to_string(),
47                    "deepseek-v3.2".to_string(),
48                ],
49                supports_tools: true,
50                supports_reasoning: true,
51            },
52            ModelInfo {
53                id: "deepseek-ai/deepseek-v4-pro".to_string(),
54                provider: ProviderKind::NvidiaNim,
55                aliases: vec![
56                    "deepseek-v4-pro".to_string(),
57                    "nvidia-deepseek-v4-pro".to_string(),
58                    "nim-deepseek-v4-pro".to_string(),
59                ],
60                supports_tools: true,
61                supports_reasoning: true,
62            },
63            ModelInfo {
64                id: "deepseek-ai/deepseek-v4-flash".to_string(),
65                provider: ProviderKind::NvidiaNim,
66                aliases: vec![
67                    "deepseek-v4-flash".to_string(),
68                    "deepseek-chat".to_string(),
69                    "deepseek-reasoner".to_string(),
70                    "nvidia-deepseek-v4-flash".to_string(),
71                    "nim-deepseek-v4-flash".to_string(),
72                ],
73                supports_tools: true,
74                supports_reasoning: true,
75            },
76            ModelInfo {
77                id: "gpt-4.1".to_string(),
78                provider: ProviderKind::Openai,
79                aliases: vec!["gpt4.1".to_string(), "gpt-4o".to_string()],
80                supports_tools: true,
81                supports_reasoning: true,
82            },
83            ModelInfo {
84                id: "gpt-4.1-mini".to_string(),
85                provider: ProviderKind::Openai,
86                aliases: vec!["gpt-4o-mini".to_string()],
87                supports_tools: true,
88                supports_reasoning: false,
89            },
90            ModelInfo {
91                id: "deepseek/deepseek-v4-pro".to_string(),
92                provider: ProviderKind::Openrouter,
93                aliases: vec![
94                    "deepseek-v4-pro".to_string(),
95                    "openrouter-deepseek-v4-pro".to_string(),
96                ],
97                supports_tools: true,
98                supports_reasoning: true,
99            },
100            ModelInfo {
101                id: "deepseek/deepseek-v4-flash".to_string(),
102                provider: ProviderKind::Openrouter,
103                aliases: vec![
104                    "deepseek-v4-flash".to_string(),
105                    "deepseek-chat".to_string(),
106                    "deepseek-reasoner".to_string(),
107                    "openrouter-deepseek-v4-flash".to_string(),
108                ],
109                supports_tools: true,
110                supports_reasoning: true,
111            },
112            ModelInfo {
113                id: "deepseek/deepseek-v4-pro".to_string(),
114                provider: ProviderKind::Novita,
115                aliases: vec![
116                    "deepseek-v4-pro".to_string(),
117                    "novita-deepseek-v4-pro".to_string(),
118                ],
119                supports_tools: true,
120                supports_reasoning: true,
121            },
122            ModelInfo {
123                id: "deepseek/deepseek-v4-flash".to_string(),
124                provider: ProviderKind::Novita,
125                aliases: vec![
126                    "deepseek-v4-flash".to_string(),
127                    "deepseek-chat".to_string(),
128                    "deepseek-reasoner".to_string(),
129                    "novita-deepseek-v4-flash".to_string(),
130                ],
131                supports_tools: true,
132                supports_reasoning: true,
133            },
134            ModelInfo {
135                id: "accounts/fireworks/models/deepseek-v4-pro".to_string(),
136                provider: ProviderKind::Fireworks,
137                aliases: vec![
138                    "deepseek-v4-pro".to_string(),
139                    "fireworks-deepseek-v4-pro".to_string(),
140                ],
141                supports_tools: true,
142                supports_reasoning: true,
143            },
144            ModelInfo {
145                id: "deepseek-ai/DeepSeek-V4-Pro".to_string(),
146                provider: ProviderKind::Sglang,
147                aliases: vec![
148                    "deepseek-v4-pro".to_string(),
149                    "sglang-deepseek-v4-pro".to_string(),
150                ],
151                supports_tools: true,
152                supports_reasoning: true,
153            },
154            ModelInfo {
155                id: "deepseek-ai/DeepSeek-V4-Flash".to_string(),
156                provider: ProviderKind::Sglang,
157                aliases: vec![
158                    "deepseek-v4-flash".to_string(),
159                    "deepseek-chat".to_string(),
160                    "deepseek-reasoner".to_string(),
161                    "sglang-deepseek-v4-flash".to_string(),
162                ],
163                supports_tools: true,
164                supports_reasoning: true,
165            },
166            ModelInfo {
167                id: "deepseek-ai/DeepSeek-V4-Pro".to_string(),
168                provider: ProviderKind::Vllm,
169                aliases: vec![
170                    "deepseek-v4-pro".to_string(),
171                    "vllm-deepseek-v4-pro".to_string(),
172                ],
173                supports_tools: true,
174                supports_reasoning: true,
175            },
176            ModelInfo {
177                id: "deepseek-ai/DeepSeek-V4-Flash".to_string(),
178                provider: ProviderKind::Vllm,
179                aliases: vec![
180                    "deepseek-v4-flash".to_string(),
181                    "deepseek-chat".to_string(),
182                    "deepseek-reasoner".to_string(),
183                    "vllm-deepseek-v4-flash".to_string(),
184                ],
185                supports_tools: true,
186                supports_reasoning: true,
187            },
188            ModelInfo {
189                id: "deepseek-coder:1.3b".to_string(),
190                provider: ProviderKind::Ollama,
191                aliases: vec![],
192                supports_tools: true,
193                supports_reasoning: false,
194            },
195        ];
196        Self::new(models)
197    }
198}
199
200impl ModelRegistry {
201    #[must_use]
202    pub fn new(models: Vec<ModelInfo>) -> Self {
203        let mut alias_map = HashMap::new();
204        for (idx, model) in models.iter().enumerate() {
205            alias_map.entry(normalize(&model.id)).or_insert(idx);
206            for alias in &model.aliases {
207                alias_map.entry(normalize(alias)).or_insert(idx);
208            }
209        }
210        Self { models, alias_map }
211    }
212
213    #[must_use]
214    pub fn list(&self) -> Vec<ModelInfo> {
215        self.models.clone()
216    }
217
218    #[must_use]
219    pub fn resolve(
220        &self,
221        requested: Option<&str>,
222        provider_hint: Option<ProviderKind>,
223    ) -> ModelResolution {
224        let mut fallback_chain = Vec::new();
225
226        if let Some(name) = requested {
227            fallback_chain.push(format!("requested:{name}"));
228            if provider_hint == Some(ProviderKind::Ollama) {
229                return ModelResolution {
230                    requested: Some(name.to_string()),
231                    resolved: ModelInfo {
232                        id: name.trim().to_string(),
233                        provider: ProviderKind::Ollama,
234                        aliases: Vec::new(),
235                        supports_tools: true,
236                        supports_reasoning: false,
237                    },
238                    used_fallback: false,
239                    fallback_chain,
240                };
241            }
242            if let Some(provider) = provider_hint
243                && let Some(model) = self
244                    .models
245                    .iter()
246                    .find(|m| m.provider == provider && model_matches(m, name))
247                    .cloned()
248            {
249                return ModelResolution {
250                    requested: Some(name.to_string()),
251                    resolved: preserve_requested_model_id_case(model, name),
252                    used_fallback: false,
253                    fallback_chain,
254                };
255            }
256            if let Some(idx) = self.alias_map.get(&normalize(name)) {
257                return ModelResolution {
258                    requested: Some(name.to_string()),
259                    resolved: preserve_requested_model_id_case(self.models[*idx].clone(), name),
260                    used_fallback: false,
261                    fallback_chain,
262                };
263            }
264        }
265
266        let provider = provider_hint.unwrap_or(ProviderKind::Deepseek);
267        fallback_chain.push(format!("provider_default:{}", provider.as_str()));
268        if let Some(model) = self.models.iter().find(|m| m.provider == provider).cloned() {
269            return ModelResolution {
270                requested: requested.map(ToOwned::to_owned),
271                resolved: model,
272                used_fallback: true,
273                fallback_chain,
274            };
275        }
276
277        let final_fallback = self.models.first().cloned().unwrap_or(ModelInfo {
278            id: "deepseek-v4-pro".to_string(),
279            provider: ProviderKind::Deepseek,
280            aliases: Vec::new(),
281            supports_tools: true,
282            supports_reasoning: true,
283        });
284        fallback_chain.push("global_default:deepseek-v4-pro".to_string());
285        ModelResolution {
286            requested: requested.map(ToOwned::to_owned),
287            resolved: final_fallback,
288            used_fallback: true,
289            fallback_chain,
290        }
291    }
292}
293
294fn normalize(value: &str) -> String {
295    value.trim().to_ascii_lowercase()
296}
297
298fn model_matches(model: &ModelInfo, requested: &str) -> bool {
299    let requested = normalize(requested);
300    normalize(&model.id) == requested
301        || model
302            .aliases
303            .iter()
304            .any(|alias| normalize(alias) == requested)
305}
306
307fn preserve_requested_model_id_case(mut model: ModelInfo, requested: &str) -> ModelInfo {
308    let requested = requested.trim();
309    if model.id.eq_ignore_ascii_case(requested) {
310        model.id = requested.to_string();
311    }
312    model
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    #[test]
320    fn deepseek_v4_pro_alias_stays_deepseek_by_default() {
321        let registry = ModelRegistry::default();
322        let resolved = registry.resolve(Some("deepseek-v4-pro"), None);
323
324        assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
325        assert_eq!(resolved.resolved.id, "deepseek-v4-pro");
326    }
327
328    #[test]
329    fn deepseek_v4_pro_alias_resolves_to_nvidia_nim_when_provider_hinted() {
330        let registry = ModelRegistry::default();
331        let resolved = registry.resolve(Some("deepseek-v4-pro"), Some(ProviderKind::NvidiaNim));
332
333        assert_eq!(resolved.resolved.provider, ProviderKind::NvidiaNim);
334        assert_eq!(resolved.resolved.id, "deepseek-ai/deepseek-v4-pro");
335    }
336
337    #[test]
338    fn nvidia_nim_default_uses_catalog_model_id() {
339        let registry = ModelRegistry::default();
340        let resolved = registry.resolve(None, Some(ProviderKind::NvidiaNim));
341
342        assert_eq!(resolved.resolved.provider, ProviderKind::NvidiaNim);
343        assert_eq!(resolved.resolved.id, "deepseek-ai/deepseek-v4-pro");
344    }
345
346    #[test]
347    fn deepseek_v4_flash_alias_resolves_to_nvidia_nim_when_provider_hinted() {
348        let registry = ModelRegistry::default();
349        let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::NvidiaNim));
350
351        assert_eq!(resolved.resolved.provider, ProviderKind::NvidiaNim);
352        assert_eq!(resolved.resolved.id, "deepseek-ai/deepseek-v4-flash");
353    }
354
355    #[test]
356    fn openrouter_default_uses_namespaced_model_id() {
357        let registry = ModelRegistry::default();
358        let resolved = registry.resolve(None, Some(ProviderKind::Openrouter));
359
360        assert_eq!(resolved.resolved.provider, ProviderKind::Openrouter);
361        assert_eq!(resolved.resolved.id, "deepseek/deepseek-v4-pro");
362    }
363
364    #[test]
365    fn novita_default_uses_namespaced_model_id() {
366        let registry = ModelRegistry::default();
367        let resolved = registry.resolve(None, Some(ProviderKind::Novita));
368
369        assert_eq!(resolved.resolved.provider, ProviderKind::Novita);
370        assert_eq!(resolved.resolved.id, "deepseek/deepseek-v4-pro");
371    }
372
373    #[test]
374    fn fireworks_default_uses_canonical_model_id() {
375        let registry = ModelRegistry::default();
376        let resolved = registry.resolve(None, Some(ProviderKind::Fireworks));
377
378        assert_eq!(resolved.resolved.provider, ProviderKind::Fireworks);
379        assert_eq!(
380            resolved.resolved.id,
381            "accounts/fireworks/models/deepseek-v4-pro"
382        );
383    }
384
385    #[test]
386    fn sglang_default_uses_canonical_model_id() {
387        let registry = ModelRegistry::default();
388        let resolved = registry.resolve(None, Some(ProviderKind::Sglang));
389
390        assert_eq!(resolved.resolved.provider, ProviderKind::Sglang);
391        assert_eq!(resolved.resolved.id, "deepseek-ai/DeepSeek-V4-Pro");
392    }
393
394    #[test]
395    fn deepseek_v4_flash_alias_resolves_to_openrouter_when_provider_hinted() {
396        let registry = ModelRegistry::default();
397        let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::Openrouter));
398
399        assert_eq!(resolved.resolved.provider, ProviderKind::Openrouter);
400        assert_eq!(resolved.resolved.id, "deepseek/deepseek-v4-flash");
401    }
402
403    #[test]
404    fn deepseek_v4_flash_alias_resolves_to_novita_when_provider_hinted() {
405        let registry = ModelRegistry::default();
406        let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::Novita));
407
408        assert_eq!(resolved.resolved.provider, ProviderKind::Novita);
409        assert_eq!(resolved.resolved.id, "deepseek/deepseek-v4-flash");
410    }
411
412    #[test]
413    fn deepseek_v4_flash_alias_resolves_to_sglang_when_provider_hinted() {
414        let registry = ModelRegistry::default();
415        let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::Sglang));
416
417        assert_eq!(resolved.resolved.provider, ProviderKind::Sglang);
418        assert_eq!(resolved.resolved.id, "deepseek-ai/DeepSeek-V4-Flash");
419    }
420
421    #[test]
422    fn vllm_default_uses_canonical_model_id() {
423        let registry = ModelRegistry::default();
424        let resolved = registry.resolve(None, Some(ProviderKind::Vllm));
425
426        assert_eq!(resolved.resolved.provider, ProviderKind::Vllm);
427        assert_eq!(resolved.resolved.id, "deepseek-ai/DeepSeek-V4-Pro");
428    }
429
430    #[test]
431    fn ollama_default_uses_small_local_model_id() {
432        let registry = ModelRegistry::default();
433        let resolved = registry.resolve(None, Some(ProviderKind::Ollama));
434
435        assert_eq!(resolved.resolved.provider, ProviderKind::Ollama);
436        assert_eq!(resolved.resolved.id, "deepseek-coder:1.3b");
437        assert!(!resolved.resolved.supports_reasoning);
438    }
439
440    #[test]
441    fn ollama_requested_model_tag_is_preserved() {
442        let registry = ModelRegistry::default();
443        let resolved = registry.resolve(Some("qwen2.5-coder:7b"), Some(ProviderKind::Ollama));
444
445        assert_eq!(resolved.resolved.provider, ProviderKind::Ollama);
446        assert_eq!(resolved.resolved.id, "qwen2.5-coder:7b");
447        assert!(!resolved.used_fallback);
448    }
449
450    #[test]
451    fn deepseek_v4_flash_alias_resolves_to_vllm_when_provider_hinted() {
452        let registry = ModelRegistry::default();
453        let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::Vllm));
454
455        assert_eq!(resolved.resolved.provider, ProviderKind::Vllm);
456        assert_eq!(resolved.resolved.id, "deepseek-ai/DeepSeek-V4-Flash");
457    }
458
459    #[test]
460    fn preserves_requested_model_casing_for_third_party_providers() {
461        let registry = ModelRegistry::default();
462        let resolved = registry.resolve(Some("DeepSeek-V4-Pro"), None);
463
464        assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
465        assert_eq!(resolved.resolved.id, "DeepSeek-V4-Pro");
466    }
467
468    #[test]
469    fn preserves_requested_model_casing_with_provider_hint() {
470        let registry = ModelRegistry::default();
471        let resolved = registry.resolve(Some("DeepSeek-V4-Pro"), Some(ProviderKind::Deepseek));
472
473        assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
474        assert_eq!(resolved.resolved.id, "DeepSeek-V4-Pro");
475    }
476
477    #[test]
478    fn preserves_requested_model_casing_without_surrounding_whitespace() {
479        let registry = ModelRegistry::default();
480        let resolved = registry.resolve(Some("  DeepSeek-V4-Pro  "), None);
481
482        assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
483        assert_eq!(resolved.resolved.id, "DeepSeek-V4-Pro");
484    }
485
486    #[test]
487    fn alias_match_does_not_override_requested_casing() {
488        let registry = ModelRegistry::default();
489        let resolved = registry.resolve(Some("deepseek-reasoner"), None);
490
491        assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
492        assert_eq!(resolved.resolved.id, "deepseek-v4-flash");
493    }
494}