Skip to main content

deepseek_agent/
lib.rs

1use std::collections::HashMap;
2
3use deepseek_config::ProviderKind;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct ModelInfo {
8    pub id: String,
9    pub provider: ProviderKind,
10    pub aliases: Vec<String>,
11    pub supports_tools: bool,
12    pub supports_reasoning: bool,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ModelResolution {
17    pub requested: Option<String>,
18    pub resolved: ModelInfo,
19    pub used_fallback: bool,
20    pub fallback_chain: Vec<String>,
21}
22
23#[derive(Debug, Clone)]
24pub struct ModelRegistry {
25    models: Vec<ModelInfo>,
26    alias_map: HashMap<String, usize>,
27}
28
29impl Default for ModelRegistry {
30    fn default() -> Self {
31        let models = vec![
32            ModelInfo {
33                id: "deepseek-v4-pro".to_string(),
34                provider: ProviderKind::Deepseek,
35                aliases: vec![],
36                supports_tools: true,
37                supports_reasoning: true,
38            },
39            ModelInfo {
40                id: "deepseek-v4-flash".to_string(),
41                provider: ProviderKind::Deepseek,
42                aliases: vec![
43                    "deepseek-chat".to_string(),
44                    "deepseek-reasoner".to_string(),
45                    "deepseek-r1".to_string(),
46                    "deepseek-v3".to_string(),
47                    "deepseek-v3.2".to_string(),
48                ],
49                supports_tools: true,
50                supports_reasoning: true,
51            },
52            ModelInfo {
53                id: "deepseek-ai/deepseek-v4-pro".to_string(),
54                provider: ProviderKind::NvidiaNim,
55                aliases: vec![
56                    "deepseek-v4-pro".to_string(),
57                    "nvidia-deepseek-v4-pro".to_string(),
58                    "nim-deepseek-v4-pro".to_string(),
59                ],
60                supports_tools: true,
61                supports_reasoning: true,
62            },
63            ModelInfo {
64                id: "deepseek-ai/deepseek-v4-flash".to_string(),
65                provider: ProviderKind::NvidiaNim,
66                aliases: vec![
67                    "deepseek-v4-flash".to_string(),
68                    "deepseek-chat".to_string(),
69                    "deepseek-reasoner".to_string(),
70                    "nvidia-deepseek-v4-flash".to_string(),
71                    "nim-deepseek-v4-flash".to_string(),
72                ],
73                supports_tools: true,
74                supports_reasoning: true,
75            },
76            ModelInfo {
77                id: "gpt-4.1".to_string(),
78                provider: ProviderKind::Openai,
79                aliases: vec!["gpt4.1".to_string(), "gpt-4o".to_string()],
80                supports_tools: true,
81                supports_reasoning: true,
82            },
83            ModelInfo {
84                id: "gpt-4.1-mini".to_string(),
85                provider: ProviderKind::Openai,
86                aliases: vec!["gpt-4o-mini".to_string()],
87                supports_tools: true,
88                supports_reasoning: false,
89            },
90        ];
91        Self::new(models)
92    }
93}
94
95impl ModelRegistry {
96    #[must_use]
97    pub fn new(models: Vec<ModelInfo>) -> Self {
98        let mut alias_map = HashMap::new();
99        for (idx, model) in models.iter().enumerate() {
100            alias_map.entry(normalize(&model.id)).or_insert(idx);
101            for alias in &model.aliases {
102                alias_map.entry(normalize(alias)).or_insert(idx);
103            }
104        }
105        Self { models, alias_map }
106    }
107
108    #[must_use]
109    pub fn list(&self) -> Vec<ModelInfo> {
110        self.models.clone()
111    }
112
113    #[must_use]
114    pub fn resolve(
115        &self,
116        requested: Option<&str>,
117        provider_hint: Option<ProviderKind>,
118    ) -> ModelResolution {
119        let mut fallback_chain = Vec::new();
120
121        if let Some(name) = requested {
122            fallback_chain.push(format!("requested:{name}"));
123            if let Some(provider) = provider_hint
124                && let Some(model) = self
125                    .models
126                    .iter()
127                    .find(|m| m.provider == provider && model_matches(m, name))
128                    .cloned()
129            {
130                return ModelResolution {
131                    requested: Some(name.to_string()),
132                    resolved: model,
133                    used_fallback: false,
134                    fallback_chain,
135                };
136            }
137            if let Some(idx) = self.alias_map.get(&normalize(name)) {
138                return ModelResolution {
139                    requested: Some(name.to_string()),
140                    resolved: self.models[*idx].clone(),
141                    used_fallback: false,
142                    fallback_chain,
143                };
144            }
145        }
146
147        let provider = provider_hint.unwrap_or(ProviderKind::Deepseek);
148        fallback_chain.push(format!("provider_default:{}", provider.as_str()));
149        if let Some(model) = self.models.iter().find(|m| m.provider == provider).cloned() {
150            return ModelResolution {
151                requested: requested.map(ToOwned::to_owned),
152                resolved: model,
153                used_fallback: true,
154                fallback_chain,
155            };
156        }
157
158        let final_fallback = self.models.first().cloned().unwrap_or(ModelInfo {
159            id: "deepseek-v4-pro".to_string(),
160            provider: ProviderKind::Deepseek,
161            aliases: Vec::new(),
162            supports_tools: true,
163            supports_reasoning: true,
164        });
165        fallback_chain.push("global_default:deepseek-v4-pro".to_string());
166        ModelResolution {
167            requested: requested.map(ToOwned::to_owned),
168            resolved: final_fallback,
169            used_fallback: true,
170            fallback_chain,
171        }
172    }
173}
174
175fn normalize(value: &str) -> String {
176    value.trim().to_ascii_lowercase()
177}
178
179fn model_matches(model: &ModelInfo, requested: &str) -> bool {
180    let requested = normalize(requested);
181    normalize(&model.id) == requested
182        || model
183            .aliases
184            .iter()
185            .any(|alias| normalize(alias) == requested)
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191
192    #[test]
193    fn deepseek_v4_pro_alias_stays_deepseek_by_default() {
194        let registry = ModelRegistry::default();
195        let resolved = registry.resolve(Some("deepseek-v4-pro"), None);
196
197        assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
198        assert_eq!(resolved.resolved.id, "deepseek-v4-pro");
199    }
200
201    #[test]
202    fn deepseek_v4_pro_alias_resolves_to_nvidia_nim_when_provider_hinted() {
203        let registry = ModelRegistry::default();
204        let resolved = registry.resolve(Some("deepseek-v4-pro"), Some(ProviderKind::NvidiaNim));
205
206        assert_eq!(resolved.resolved.provider, ProviderKind::NvidiaNim);
207        assert_eq!(resolved.resolved.id, "deepseek-ai/deepseek-v4-pro");
208    }
209
210    #[test]
211    fn nvidia_nim_default_uses_catalog_model_id() {
212        let registry = ModelRegistry::default();
213        let resolved = registry.resolve(None, Some(ProviderKind::NvidiaNim));
214
215        assert_eq!(resolved.resolved.provider, ProviderKind::NvidiaNim);
216        assert_eq!(resolved.resolved.id, "deepseek-ai/deepseek-v4-pro");
217    }
218
219    #[test]
220    fn deepseek_v4_flash_alias_resolves_to_nvidia_nim_when_provider_hinted() {
221        let registry = ModelRegistry::default();
222        let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::NvidiaNim));
223
224        assert_eq!(resolved.resolved.provider, ProviderKind::NvidiaNim);
225        assert_eq!(resolved.resolved.id, "deepseek-ai/deepseek-v4-flash");
226    }
227}