Skip to main content

deepseek_agent/
lib.rs

1use std::collections::HashMap;
2
3use deepseek_config::ProviderKind;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct ModelInfo {
8    pub id: String,
9    pub provider: ProviderKind,
10    pub aliases: Vec<String>,
11    pub supports_tools: bool,
12    pub supports_reasoning: bool,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ModelResolution {
17    pub requested: Option<String>,
18    pub resolved: ModelInfo,
19    pub used_fallback: bool,
20    pub fallback_chain: Vec<String>,
21}
22
23#[derive(Debug, Clone)]
24pub struct ModelRegistry {
25    models: Vec<ModelInfo>,
26    alias_map: HashMap<String, usize>,
27}
28
29impl Default for ModelRegistry {
30    fn default() -> Self {
31        let models = vec![
32            ModelInfo {
33                id: "deepseek-v4-pro".to_string(),
34                provider: ProviderKind::Deepseek,
35                aliases: vec![],
36                supports_tools: true,
37                supports_reasoning: true,
38            },
39            ModelInfo {
40                id: "deepseek-v4-flash".to_string(),
41                provider: ProviderKind::Deepseek,
42                aliases: vec![
43                    "deepseek-chat".to_string(),
44                    "deepseek-reasoner".to_string(),
45                    "deepseek-r1".to_string(),
46                    "deepseek-v3".to_string(),
47                    "deepseek-v3.2".to_string(),
48                ],
49                supports_tools: true,
50                supports_reasoning: true,
51            },
52            ModelInfo {
53                id: "deepseek-ai/deepseek-v4-pro".to_string(),
54                provider: ProviderKind::NvidiaNim,
55                aliases: vec![
56                    "deepseek-v4-pro".to_string(),
57                    "nvidia-deepseek-v4-pro".to_string(),
58                    "nim-deepseek-v4-pro".to_string(),
59                ],
60                supports_tools: true,
61                supports_reasoning: true,
62            },
63            ModelInfo {
64                id: "deepseek-ai/deepseek-v4-flash".to_string(),
65                provider: ProviderKind::NvidiaNim,
66                aliases: vec![
67                    "deepseek-v4-flash".to_string(),
68                    "deepseek-chat".to_string(),
69                    "deepseek-reasoner".to_string(),
70                    "nvidia-deepseek-v4-flash".to_string(),
71                    "nim-deepseek-v4-flash".to_string(),
72                ],
73                supports_tools: true,
74                supports_reasoning: true,
75            },
76            ModelInfo {
77                id: "gpt-4.1".to_string(),
78                provider: ProviderKind::Openai,
79                aliases: vec!["gpt4.1".to_string(), "gpt-4o".to_string()],
80                supports_tools: true,
81                supports_reasoning: true,
82            },
83            ModelInfo {
84                id: "gpt-4.1-mini".to_string(),
85                provider: ProviderKind::Openai,
86                aliases: vec!["gpt-4o-mini".to_string()],
87                supports_tools: true,
88                supports_reasoning: false,
89            },
90            ModelInfo {
91                id: "deepseek-reasoner".to_string(),
92                provider: ProviderKind::WanjieArk,
93                aliases: vec![
94                    "wanjie-deepseek-reasoner".to_string(),
95                    "ark-wanjie-deepseek-reasoner".to_string(),
96                ],
97                supports_tools: true,
98                supports_reasoning: true,
99            },
100            ModelInfo {
101                id: "deepseek/deepseek-v4-pro".to_string(),
102                provider: ProviderKind::Openrouter,
103                aliases: vec![
104                    "deepseek-v4-pro".to_string(),
105                    "openrouter-deepseek-v4-pro".to_string(),
106                ],
107                supports_tools: true,
108                supports_reasoning: true,
109            },
110            ModelInfo {
111                id: "deepseek/deepseek-v4-flash".to_string(),
112                provider: ProviderKind::Openrouter,
113                aliases: vec![
114                    "deepseek-v4-flash".to_string(),
115                    "deepseek-chat".to_string(),
116                    "deepseek-reasoner".to_string(),
117                    "openrouter-deepseek-v4-flash".to_string(),
118                ],
119                supports_tools: true,
120                supports_reasoning: true,
121            },
122            ModelInfo {
123                id: "deepseek/deepseek-v4-pro".to_string(),
124                provider: ProviderKind::Novita,
125                aliases: vec![
126                    "deepseek-v4-pro".to_string(),
127                    "novita-deepseek-v4-pro".to_string(),
128                ],
129                supports_tools: true,
130                supports_reasoning: true,
131            },
132            ModelInfo {
133                id: "deepseek/deepseek-v4-flash".to_string(),
134                provider: ProviderKind::Novita,
135                aliases: vec![
136                    "deepseek-v4-flash".to_string(),
137                    "deepseek-chat".to_string(),
138                    "deepseek-reasoner".to_string(),
139                    "novita-deepseek-v4-flash".to_string(),
140                ],
141                supports_tools: true,
142                supports_reasoning: true,
143            },
144            ModelInfo {
145                id: "accounts/fireworks/models/deepseek-v4-pro".to_string(),
146                provider: ProviderKind::Fireworks,
147                aliases: vec![
148                    "deepseek-v4-pro".to_string(),
149                    "fireworks-deepseek-v4-pro".to_string(),
150                ],
151                supports_tools: true,
152                supports_reasoning: true,
153            },
154            ModelInfo {
155                id: "deepseek-ai/DeepSeek-V4-Pro".to_string(),
156                provider: ProviderKind::Sglang,
157                aliases: vec![
158                    "deepseek-v4-pro".to_string(),
159                    "sglang-deepseek-v4-pro".to_string(),
160                ],
161                supports_tools: true,
162                supports_reasoning: true,
163            },
164            ModelInfo {
165                id: "deepseek-ai/DeepSeek-V4-Flash".to_string(),
166                provider: ProviderKind::Sglang,
167                aliases: vec![
168                    "deepseek-v4-flash".to_string(),
169                    "deepseek-chat".to_string(),
170                    "deepseek-reasoner".to_string(),
171                    "sglang-deepseek-v4-flash".to_string(),
172                ],
173                supports_tools: true,
174                supports_reasoning: true,
175            },
176            ModelInfo {
177                id: "deepseek-ai/DeepSeek-V4-Pro".to_string(),
178                provider: ProviderKind::Vllm,
179                aliases: vec![
180                    "deepseek-v4-pro".to_string(),
181                    "vllm-deepseek-v4-pro".to_string(),
182                ],
183                supports_tools: true,
184                supports_reasoning: true,
185            },
186            ModelInfo {
187                id: "deepseek-ai/DeepSeek-V4-Flash".to_string(),
188                provider: ProviderKind::Vllm,
189                aliases: vec![
190                    "deepseek-v4-flash".to_string(),
191                    "deepseek-chat".to_string(),
192                    "deepseek-reasoner".to_string(),
193                    "vllm-deepseek-v4-flash".to_string(),
194                ],
195                supports_tools: true,
196                supports_reasoning: true,
197            },
198            ModelInfo {
199                id: "deepseek-coder:1.3b".to_string(),
200                provider: ProviderKind::Ollama,
201                aliases: vec![],
202                supports_tools: true,
203                supports_reasoning: false,
204            },
205        ];
206        Self::new(models)
207    }
208}
209
210impl ModelRegistry {
211    #[must_use]
212    pub fn new(models: Vec<ModelInfo>) -> Self {
213        let mut alias_map = HashMap::new();
214        for (idx, model) in models.iter().enumerate() {
215            alias_map.entry(normalize(&model.id)).or_insert(idx);
216            for alias in &model.aliases {
217                alias_map.entry(normalize(alias)).or_insert(idx);
218            }
219        }
220        Self { models, alias_map }
221    }
222
223    #[must_use]
224    pub fn list(&self) -> Vec<ModelInfo> {
225        self.models.clone()
226    }
227
228    #[must_use]
229    pub fn resolve(
230        &self,
231        requested: Option<&str>,
232        provider_hint: Option<ProviderKind>,
233    ) -> ModelResolution {
234        let mut fallback_chain = Vec::new();
235
236        if let Some(name) = requested {
237            fallback_chain.push(format!("requested:{name}"));
238            if provider_hint == Some(ProviderKind::Ollama) {
239                return ModelResolution {
240                    requested: Some(name.to_string()),
241                    resolved: ModelInfo {
242                        id: name.trim().to_string(),
243                        provider: ProviderKind::Ollama,
244                        aliases: Vec::new(),
245                        supports_tools: true,
246                        supports_reasoning: false,
247                    },
248                    used_fallback: false,
249                    fallback_chain,
250                };
251            }
252            if let Some(provider) = provider_hint
253                && let Some(model) = self
254                    .models
255                    .iter()
256                    .find(|m| m.provider == provider && model_matches(m, name))
257                    .cloned()
258            {
259                return ModelResolution {
260                    requested: Some(name.to_string()),
261                    resolved: preserve_requested_model_id_case(model, name),
262                    used_fallback: false,
263                    fallback_chain,
264                };
265            }
266            if let Some(idx) = self.alias_map.get(&normalize(name)) {
267                return ModelResolution {
268                    requested: Some(name.to_string()),
269                    resolved: preserve_requested_model_id_case(self.models[*idx].clone(), name),
270                    used_fallback: false,
271                    fallback_chain,
272                };
273            }
274        }
275
276        let provider = provider_hint.unwrap_or(ProviderKind::Deepseek);
277        fallback_chain.push(format!("provider_default:{}", provider.as_str()));
278        if let Some(model) = self.models.iter().find(|m| m.provider == provider).cloned() {
279            return ModelResolution {
280                requested: requested.map(ToOwned::to_owned),
281                resolved: model,
282                used_fallback: true,
283                fallback_chain,
284            };
285        }
286
287        let final_fallback = self.models.first().cloned().unwrap_or(ModelInfo {
288            id: "deepseek-v4-pro".to_string(),
289            provider: ProviderKind::Deepseek,
290            aliases: Vec::new(),
291            supports_tools: true,
292            supports_reasoning: true,
293        });
294        fallback_chain.push("global_default:deepseek-v4-pro".to_string());
295        ModelResolution {
296            requested: requested.map(ToOwned::to_owned),
297            resolved: final_fallback,
298            used_fallback: true,
299            fallback_chain,
300        }
301    }
302}
303
304fn normalize(value: &str) -> String {
305    value.trim().to_ascii_lowercase()
306}
307
308fn model_matches(model: &ModelInfo, requested: &str) -> bool {
309    let requested = normalize(requested);
310    normalize(&model.id) == requested
311        || model
312            .aliases
313            .iter()
314            .any(|alias| normalize(alias) == requested)
315}
316
317fn preserve_requested_model_id_case(mut model: ModelInfo, requested: &str) -> ModelInfo {
318    let requested = requested.trim();
319    if model.id.eq_ignore_ascii_case(requested) {
320        model.id = requested.to_string();
321    }
322    model
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328
329    #[test]
330    fn deepseek_v4_pro_alias_stays_deepseek_by_default() {
331        let registry = ModelRegistry::default();
332        let resolved = registry.resolve(Some("deepseek-v4-pro"), None);
333
334        assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
335        assert_eq!(resolved.resolved.id, "deepseek-v4-pro");
336    }
337
338    #[test]
339    fn deepseek_v4_pro_alias_resolves_to_nvidia_nim_when_provider_hinted() {
340        let registry = ModelRegistry::default();
341        let resolved = registry.resolve(Some("deepseek-v4-pro"), Some(ProviderKind::NvidiaNim));
342
343        assert_eq!(resolved.resolved.provider, ProviderKind::NvidiaNim);
344        assert_eq!(resolved.resolved.id, "deepseek-ai/deepseek-v4-pro");
345    }
346
347    #[test]
348    fn nvidia_nim_default_uses_catalog_model_id() {
349        let registry = ModelRegistry::default();
350        let resolved = registry.resolve(None, Some(ProviderKind::NvidiaNim));
351
352        assert_eq!(resolved.resolved.provider, ProviderKind::NvidiaNim);
353        assert_eq!(resolved.resolved.id, "deepseek-ai/deepseek-v4-pro");
354    }
355
356    #[test]
357    fn deepseek_v4_flash_alias_resolves_to_nvidia_nim_when_provider_hinted() {
358        let registry = ModelRegistry::default();
359        let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::NvidiaNim));
360
361        assert_eq!(resolved.resolved.provider, ProviderKind::NvidiaNim);
362        assert_eq!(resolved.resolved.id, "deepseek-ai/deepseek-v4-flash");
363    }
364
365    #[test]
366    fn openrouter_default_uses_namespaced_model_id() {
367        let registry = ModelRegistry::default();
368        let resolved = registry.resolve(None, Some(ProviderKind::Openrouter));
369
370        assert_eq!(resolved.resolved.provider, ProviderKind::Openrouter);
371        assert_eq!(resolved.resolved.id, "deepseek/deepseek-v4-pro");
372    }
373
374    #[test]
375    fn wanjie_ark_default_uses_reasoner_model_id() {
376        let registry = ModelRegistry::default();
377        let resolved = registry.resolve(None, Some(ProviderKind::WanjieArk));
378
379        assert_eq!(resolved.resolved.provider, ProviderKind::WanjieArk);
380        assert_eq!(resolved.resolved.id, "deepseek-reasoner");
381        assert!(resolved.resolved.supports_reasoning);
382    }
383
384    #[test]
385    fn novita_default_uses_namespaced_model_id() {
386        let registry = ModelRegistry::default();
387        let resolved = registry.resolve(None, Some(ProviderKind::Novita));
388
389        assert_eq!(resolved.resolved.provider, ProviderKind::Novita);
390        assert_eq!(resolved.resolved.id, "deepseek/deepseek-v4-pro");
391    }
392
393    #[test]
394    fn fireworks_default_uses_canonical_model_id() {
395        let registry = ModelRegistry::default();
396        let resolved = registry.resolve(None, Some(ProviderKind::Fireworks));
397
398        assert_eq!(resolved.resolved.provider, ProviderKind::Fireworks);
399        assert_eq!(
400            resolved.resolved.id,
401            "accounts/fireworks/models/deepseek-v4-pro"
402        );
403    }
404
405    #[test]
406    fn sglang_default_uses_canonical_model_id() {
407        let registry = ModelRegistry::default();
408        let resolved = registry.resolve(None, Some(ProviderKind::Sglang));
409
410        assert_eq!(resolved.resolved.provider, ProviderKind::Sglang);
411        assert_eq!(resolved.resolved.id, "deepseek-ai/DeepSeek-V4-Pro");
412    }
413
414    #[test]
415    fn deepseek_v4_flash_alias_resolves_to_openrouter_when_provider_hinted() {
416        let registry = ModelRegistry::default();
417        let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::Openrouter));
418
419        assert_eq!(resolved.resolved.provider, ProviderKind::Openrouter);
420        assert_eq!(resolved.resolved.id, "deepseek/deepseek-v4-flash");
421    }
422
423    #[test]
424    fn deepseek_v4_flash_alias_resolves_to_novita_when_provider_hinted() {
425        let registry = ModelRegistry::default();
426        let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::Novita));
427
428        assert_eq!(resolved.resolved.provider, ProviderKind::Novita);
429        assert_eq!(resolved.resolved.id, "deepseek/deepseek-v4-flash");
430    }
431
432    #[test]
433    fn deepseek_v4_flash_alias_resolves_to_sglang_when_provider_hinted() {
434        let registry = ModelRegistry::default();
435        let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::Sglang));
436
437        assert_eq!(resolved.resolved.provider, ProviderKind::Sglang);
438        assert_eq!(resolved.resolved.id, "deepseek-ai/DeepSeek-V4-Flash");
439    }
440
441    #[test]
442    fn vllm_default_uses_canonical_model_id() {
443        let registry = ModelRegistry::default();
444        let resolved = registry.resolve(None, Some(ProviderKind::Vllm));
445
446        assert_eq!(resolved.resolved.provider, ProviderKind::Vllm);
447        assert_eq!(resolved.resolved.id, "deepseek-ai/DeepSeek-V4-Pro");
448    }
449
450    #[test]
451    fn ollama_default_uses_small_local_model_id() {
452        let registry = ModelRegistry::default();
453        let resolved = registry.resolve(None, Some(ProviderKind::Ollama));
454
455        assert_eq!(resolved.resolved.provider, ProviderKind::Ollama);
456        assert_eq!(resolved.resolved.id, "deepseek-coder:1.3b");
457        assert!(!resolved.resolved.supports_reasoning);
458    }
459
460    #[test]
461    fn ollama_requested_model_tag_is_preserved() {
462        let registry = ModelRegistry::default();
463        let resolved = registry.resolve(Some("qwen2.5-coder:7b"), Some(ProviderKind::Ollama));
464
465        assert_eq!(resolved.resolved.provider, ProviderKind::Ollama);
466        assert_eq!(resolved.resolved.id, "qwen2.5-coder:7b");
467        assert!(!resolved.used_fallback);
468    }
469
470    #[test]
471    fn deepseek_v4_flash_alias_resolves_to_vllm_when_provider_hinted() {
472        let registry = ModelRegistry::default();
473        let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::Vllm));
474
475        assert_eq!(resolved.resolved.provider, ProviderKind::Vllm);
476        assert_eq!(resolved.resolved.id, "deepseek-ai/DeepSeek-V4-Flash");
477    }
478
479    #[test]
480    fn preserves_requested_model_casing_for_third_party_providers() {
481        let registry = ModelRegistry::default();
482        let resolved = registry.resolve(Some("DeepSeek-V4-Pro"), None);
483
484        assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
485        assert_eq!(resolved.resolved.id, "DeepSeek-V4-Pro");
486    }
487
488    #[test]
489    fn preserves_requested_model_casing_with_provider_hint() {
490        let registry = ModelRegistry::default();
491        let resolved = registry.resolve(Some("DeepSeek-V4-Pro"), Some(ProviderKind::Deepseek));
492
493        assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
494        assert_eq!(resolved.resolved.id, "DeepSeek-V4-Pro");
495    }
496
497    #[test]
498    fn preserves_requested_model_casing_without_surrounding_whitespace() {
499        let registry = ModelRegistry::default();
500        let resolved = registry.resolve(Some("  DeepSeek-V4-Pro  "), None);
501
502        assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
503        assert_eq!(resolved.resolved.id, "DeepSeek-V4-Pro");
504    }
505
506    #[test]
507    fn alias_match_does_not_override_requested_casing() {
508        let registry = ModelRegistry::default();
509        let resolved = registry.resolve(Some("deepseek-reasoner"), None);
510
511        assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
512        assert_eq!(resolved.resolved.id, "deepseek-v4-flash");
513    }
514}