Skip to main content

deepseek_agent/
lib.rs

1use std::collections::HashMap;
2
3use deepseek_config::ProviderKind;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct ModelInfo {
8    pub id: String,
9    pub provider: ProviderKind,
10    pub aliases: Vec<String>,
11    pub supports_tools: bool,
12    pub supports_reasoning: bool,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ModelResolution {
17    pub requested: Option<String>,
18    pub resolved: ModelInfo,
19    pub used_fallback: bool,
20    pub fallback_chain: Vec<String>,
21}
22
23#[derive(Debug, Clone)]
24pub struct ModelRegistry {
25    models: Vec<ModelInfo>,
26    alias_map: HashMap<String, usize>,
27}
28
29impl Default for ModelRegistry {
30    fn default() -> Self {
31        let models = vec![
32            ModelInfo {
33                id: "deepseek-v4-pro".to_string(),
34                provider: ProviderKind::Deepseek,
35                aliases: vec![],
36                supports_tools: true,
37                supports_reasoning: true,
38            },
39            ModelInfo {
40                id: "deepseek-v4-flash".to_string(),
41                provider: ProviderKind::Deepseek,
42                aliases: vec![
43                    "deepseek-chat".to_string(),
44                    "deepseek-reasoner".to_string(),
45                    "deepseek-r1".to_string(),
46                    "deepseek-v3".to_string(),
47                    "deepseek-v3.2".to_string(),
48                ],
49                supports_tools: true,
50                supports_reasoning: true,
51            },
52            ModelInfo {
53                id: "deepseek-ai/deepseek-v4-pro".to_string(),
54                provider: ProviderKind::NvidiaNim,
55                aliases: vec![
56                    "deepseek-v4-pro".to_string(),
57                    "nvidia-deepseek-v4-pro".to_string(),
58                    "nim-deepseek-v4-pro".to_string(),
59                ],
60                supports_tools: true,
61                supports_reasoning: true,
62            },
63            ModelInfo {
64                id: "deepseek-ai/deepseek-v4-flash".to_string(),
65                provider: ProviderKind::NvidiaNim,
66                aliases: vec![
67                    "deepseek-v4-flash".to_string(),
68                    "deepseek-chat".to_string(),
69                    "deepseek-reasoner".to_string(),
70                    "nvidia-deepseek-v4-flash".to_string(),
71                    "nim-deepseek-v4-flash".to_string(),
72                ],
73                supports_tools: true,
74                supports_reasoning: true,
75            },
76            ModelInfo {
77                id: "gpt-4.1".to_string(),
78                provider: ProviderKind::Openai,
79                aliases: vec!["gpt4.1".to_string(), "gpt-4o".to_string()],
80                supports_tools: true,
81                supports_reasoning: true,
82            },
83            ModelInfo {
84                id: "gpt-4.1-mini".to_string(),
85                provider: ProviderKind::Openai,
86                aliases: vec!["gpt-4o-mini".to_string()],
87                supports_tools: true,
88                supports_reasoning: false,
89            },
90            ModelInfo {
91                id: "deepseek/deepseek-v4-pro".to_string(),
92                provider: ProviderKind::Openrouter,
93                aliases: vec![
94                    "deepseek-v4-pro".to_string(),
95                    "openrouter-deepseek-v4-pro".to_string(),
96                ],
97                supports_tools: true,
98                supports_reasoning: true,
99            },
100            ModelInfo {
101                id: "deepseek/deepseek-v4-flash".to_string(),
102                provider: ProviderKind::Openrouter,
103                aliases: vec![
104                    "deepseek-v4-flash".to_string(),
105                    "deepseek-chat".to_string(),
106                    "deepseek-reasoner".to_string(),
107                    "openrouter-deepseek-v4-flash".to_string(),
108                ],
109                supports_tools: true,
110                supports_reasoning: true,
111            },
112            ModelInfo {
113                id: "deepseek/deepseek-v4-pro".to_string(),
114                provider: ProviderKind::Novita,
115                aliases: vec![
116                    "deepseek-v4-pro".to_string(),
117                    "novita-deepseek-v4-pro".to_string(),
118                ],
119                supports_tools: true,
120                supports_reasoning: true,
121            },
122            ModelInfo {
123                id: "deepseek/deepseek-v4-flash".to_string(),
124                provider: ProviderKind::Novita,
125                aliases: vec![
126                    "deepseek-v4-flash".to_string(),
127                    "deepseek-chat".to_string(),
128                    "deepseek-reasoner".to_string(),
129                    "novita-deepseek-v4-flash".to_string(),
130                ],
131                supports_tools: true,
132                supports_reasoning: true,
133            },
134        ];
135        Self::new(models)
136    }
137}
138
139impl ModelRegistry {
140    #[must_use]
141    pub fn new(models: Vec<ModelInfo>) -> Self {
142        let mut alias_map = HashMap::new();
143        for (idx, model) in models.iter().enumerate() {
144            alias_map.entry(normalize(&model.id)).or_insert(idx);
145            for alias in &model.aliases {
146                alias_map.entry(normalize(alias)).or_insert(idx);
147            }
148        }
149        Self { models, alias_map }
150    }
151
152    #[must_use]
153    pub fn list(&self) -> Vec<ModelInfo> {
154        self.models.clone()
155    }
156
157    #[must_use]
158    pub fn resolve(
159        &self,
160        requested: Option<&str>,
161        provider_hint: Option<ProviderKind>,
162    ) -> ModelResolution {
163        let mut fallback_chain = Vec::new();
164
165        if let Some(name) = requested {
166            fallback_chain.push(format!("requested:{name}"));
167            if let Some(provider) = provider_hint
168                && let Some(model) = self
169                    .models
170                    .iter()
171                    .find(|m| m.provider == provider && model_matches(m, name))
172                    .cloned()
173            {
174                return ModelResolution {
175                    requested: Some(name.to_string()),
176                    resolved: model,
177                    used_fallback: false,
178                    fallback_chain,
179                };
180            }
181            if let Some(idx) = self.alias_map.get(&normalize(name)) {
182                return ModelResolution {
183                    requested: Some(name.to_string()),
184                    resolved: self.models[*idx].clone(),
185                    used_fallback: false,
186                    fallback_chain,
187                };
188            }
189        }
190
191        let provider = provider_hint.unwrap_or(ProviderKind::Deepseek);
192        fallback_chain.push(format!("provider_default:{}", provider.as_str()));
193        if let Some(model) = self.models.iter().find(|m| m.provider == provider).cloned() {
194            return ModelResolution {
195                requested: requested.map(ToOwned::to_owned),
196                resolved: model,
197                used_fallback: true,
198                fallback_chain,
199            };
200        }
201
202        let final_fallback = self.models.first().cloned().unwrap_or(ModelInfo {
203            id: "deepseek-v4-pro".to_string(),
204            provider: ProviderKind::Deepseek,
205            aliases: Vec::new(),
206            supports_tools: true,
207            supports_reasoning: true,
208        });
209        fallback_chain.push("global_default:deepseek-v4-pro".to_string());
210        ModelResolution {
211            requested: requested.map(ToOwned::to_owned),
212            resolved: final_fallback,
213            used_fallback: true,
214            fallback_chain,
215        }
216    }
217}
218
219fn normalize(value: &str) -> String {
220    value.trim().to_ascii_lowercase()
221}
222
223fn model_matches(model: &ModelInfo, requested: &str) -> bool {
224    let requested = normalize(requested);
225    normalize(&model.id) == requested
226        || model
227            .aliases
228            .iter()
229            .any(|alias| normalize(alias) == requested)
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    #[test]
237    fn deepseek_v4_pro_alias_stays_deepseek_by_default() {
238        let registry = ModelRegistry::default();
239        let resolved = registry.resolve(Some("deepseek-v4-pro"), None);
240
241        assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
242        assert_eq!(resolved.resolved.id, "deepseek-v4-pro");
243    }
244
245    #[test]
246    fn deepseek_v4_pro_alias_resolves_to_nvidia_nim_when_provider_hinted() {
247        let registry = ModelRegistry::default();
248        let resolved = registry.resolve(Some("deepseek-v4-pro"), Some(ProviderKind::NvidiaNim));
249
250        assert_eq!(resolved.resolved.provider, ProviderKind::NvidiaNim);
251        assert_eq!(resolved.resolved.id, "deepseek-ai/deepseek-v4-pro");
252    }
253
254    #[test]
255    fn nvidia_nim_default_uses_catalog_model_id() {
256        let registry = ModelRegistry::default();
257        let resolved = registry.resolve(None, Some(ProviderKind::NvidiaNim));
258
259        assert_eq!(resolved.resolved.provider, ProviderKind::NvidiaNim);
260        assert_eq!(resolved.resolved.id, "deepseek-ai/deepseek-v4-pro");
261    }
262
263    #[test]
264    fn deepseek_v4_flash_alias_resolves_to_nvidia_nim_when_provider_hinted() {
265        let registry = ModelRegistry::default();
266        let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::NvidiaNim));
267
268        assert_eq!(resolved.resolved.provider, ProviderKind::NvidiaNim);
269        assert_eq!(resolved.resolved.id, "deepseek-ai/deepseek-v4-flash");
270    }
271
272    #[test]
273    fn openrouter_default_uses_namespaced_model_id() {
274        let registry = ModelRegistry::default();
275        let resolved = registry.resolve(None, Some(ProviderKind::Openrouter));
276
277        assert_eq!(resolved.resolved.provider, ProviderKind::Openrouter);
278        assert_eq!(resolved.resolved.id, "deepseek/deepseek-v4-pro");
279    }
280
281    #[test]
282    fn novita_default_uses_namespaced_model_id() {
283        let registry = ModelRegistry::default();
284        let resolved = registry.resolve(None, Some(ProviderKind::Novita));
285
286        assert_eq!(resolved.resolved.provider, ProviderKind::Novita);
287        assert_eq!(resolved.resolved.id, "deepseek/deepseek-v4-pro");
288    }
289
290    #[test]
291    fn deepseek_v4_flash_alias_resolves_to_openrouter_when_provider_hinted() {
292        let registry = ModelRegistry::default();
293        let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::Openrouter));
294
295        assert_eq!(resolved.resolved.provider, ProviderKind::Openrouter);
296        assert_eq!(resolved.resolved.id, "deepseek/deepseek-v4-flash");
297    }
298
299    #[test]
300    fn deepseek_v4_flash_alias_resolves_to_novita_when_provider_hinted() {
301        let registry = ModelRegistry::default();
302        let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::Novita));
303
304        assert_eq!(resolved.resolved.provider, ProviderKind::Novita);
305        assert_eq!(resolved.resolved.id, "deepseek/deepseek-v4-flash");
306    }
307}