1use std::collections::HashMap;
2
3use deepseek_config::ProviderKind;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct ModelInfo {
8 pub id: String,
9 pub provider: ProviderKind,
10 pub aliases: Vec<String>,
11 pub supports_tools: bool,
12 pub supports_reasoning: bool,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ModelResolution {
17 pub requested: Option<String>,
18 pub resolved: ModelInfo,
19 pub used_fallback: bool,
20 pub fallback_chain: Vec<String>,
21}
22
23#[derive(Debug, Clone)]
24pub struct ModelRegistry {
25 models: Vec<ModelInfo>,
26 alias_map: HashMap<String, usize>,
27}
28
29impl Default for ModelRegistry {
30 fn default() -> Self {
31 let models = vec![
32 ModelInfo {
33 id: "deepseek-v4-pro".to_string(),
34 provider: ProviderKind::Deepseek,
35 aliases: vec![],
36 supports_tools: true,
37 supports_reasoning: true,
38 },
39 ModelInfo {
40 id: "deepseek-v4-flash".to_string(),
41 provider: ProviderKind::Deepseek,
42 aliases: vec![
43 "deepseek-chat".to_string(),
44 "deepseek-reasoner".to_string(),
45 "deepseek-r1".to_string(),
46 "deepseek-v3".to_string(),
47 "deepseek-v3.2".to_string(),
48 ],
49 supports_tools: true,
50 supports_reasoning: true,
51 },
52 ModelInfo {
53 id: "deepseek-ai/deepseek-v4-pro".to_string(),
54 provider: ProviderKind::NvidiaNim,
55 aliases: vec![
56 "deepseek-v4-pro".to_string(),
57 "nvidia-deepseek-v4-pro".to_string(),
58 "nim-deepseek-v4-pro".to_string(),
59 ],
60 supports_tools: true,
61 supports_reasoning: true,
62 },
63 ModelInfo {
64 id: "deepseek-ai/deepseek-v4-flash".to_string(),
65 provider: ProviderKind::NvidiaNim,
66 aliases: vec![
67 "deepseek-v4-flash".to_string(),
68 "deepseek-chat".to_string(),
69 "deepseek-reasoner".to_string(),
70 "nvidia-deepseek-v4-flash".to_string(),
71 "nim-deepseek-v4-flash".to_string(),
72 ],
73 supports_tools: true,
74 supports_reasoning: true,
75 },
76 ModelInfo {
77 id: "gpt-4.1".to_string(),
78 provider: ProviderKind::Openai,
79 aliases: vec!["gpt4.1".to_string(), "gpt-4o".to_string()],
80 supports_tools: true,
81 supports_reasoning: true,
82 },
83 ModelInfo {
84 id: "gpt-4.1-mini".to_string(),
85 provider: ProviderKind::Openai,
86 aliases: vec!["gpt-4o-mini".to_string()],
87 supports_tools: true,
88 supports_reasoning: false,
89 },
90 ];
91 Self::new(models)
92 }
93}
94
95impl ModelRegistry {
96 #[must_use]
97 pub fn new(models: Vec<ModelInfo>) -> Self {
98 let mut alias_map = HashMap::new();
99 for (idx, model) in models.iter().enumerate() {
100 alias_map.entry(normalize(&model.id)).or_insert(idx);
101 for alias in &model.aliases {
102 alias_map.entry(normalize(alias)).or_insert(idx);
103 }
104 }
105 Self { models, alias_map }
106 }
107
108 #[must_use]
109 pub fn list(&self) -> Vec<ModelInfo> {
110 self.models.clone()
111 }
112
113 #[must_use]
114 pub fn resolve(
115 &self,
116 requested: Option<&str>,
117 provider_hint: Option<ProviderKind>,
118 ) -> ModelResolution {
119 let mut fallback_chain = Vec::new();
120
121 if let Some(name) = requested {
122 fallback_chain.push(format!("requested:{name}"));
123 if let Some(provider) = provider_hint
124 && let Some(model) = self
125 .models
126 .iter()
127 .find(|m| m.provider == provider && model_matches(m, name))
128 .cloned()
129 {
130 return ModelResolution {
131 requested: Some(name.to_string()),
132 resolved: model,
133 used_fallback: false,
134 fallback_chain,
135 };
136 }
137 if let Some(idx) = self.alias_map.get(&normalize(name)) {
138 return ModelResolution {
139 requested: Some(name.to_string()),
140 resolved: self.models[*idx].clone(),
141 used_fallback: false,
142 fallback_chain,
143 };
144 }
145 }
146
147 let provider = provider_hint.unwrap_or(ProviderKind::Deepseek);
148 fallback_chain.push(format!("provider_default:{}", provider.as_str()));
149 if let Some(model) = self.models.iter().find(|m| m.provider == provider).cloned() {
150 return ModelResolution {
151 requested: requested.map(ToOwned::to_owned),
152 resolved: model,
153 used_fallback: true,
154 fallback_chain,
155 };
156 }
157
158 let final_fallback = self.models.first().cloned().unwrap_or(ModelInfo {
159 id: "deepseek-v4-pro".to_string(),
160 provider: ProviderKind::Deepseek,
161 aliases: Vec::new(),
162 supports_tools: true,
163 supports_reasoning: true,
164 });
165 fallback_chain.push("global_default:deepseek-v4-pro".to_string());
166 ModelResolution {
167 requested: requested.map(ToOwned::to_owned),
168 resolved: final_fallback,
169 used_fallback: true,
170 fallback_chain,
171 }
172 }
173}
174
175fn normalize(value: &str) -> String {
176 value.trim().to_ascii_lowercase()
177}
178
179fn model_matches(model: &ModelInfo, requested: &str) -> bool {
180 let requested = normalize(requested);
181 normalize(&model.id) == requested
182 || model
183 .aliases
184 .iter()
185 .any(|alias| normalize(alias) == requested)
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191
192 #[test]
193 fn deepseek_v4_pro_alias_stays_deepseek_by_default() {
194 let registry = ModelRegistry::default();
195 let resolved = registry.resolve(Some("deepseek-v4-pro"), None);
196
197 assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
198 assert_eq!(resolved.resolved.id, "deepseek-v4-pro");
199 }
200
201 #[test]
202 fn deepseek_v4_pro_alias_resolves_to_nvidia_nim_when_provider_hinted() {
203 let registry = ModelRegistry::default();
204 let resolved = registry.resolve(Some("deepseek-v4-pro"), Some(ProviderKind::NvidiaNim));
205
206 assert_eq!(resolved.resolved.provider, ProviderKind::NvidiaNim);
207 assert_eq!(resolved.resolved.id, "deepseek-ai/deepseek-v4-pro");
208 }
209
210 #[test]
211 fn nvidia_nim_default_uses_catalog_model_id() {
212 let registry = ModelRegistry::default();
213 let resolved = registry.resolve(None, Some(ProviderKind::NvidiaNim));
214
215 assert_eq!(resolved.resolved.provider, ProviderKind::NvidiaNim);
216 assert_eq!(resolved.resolved.id, "deepseek-ai/deepseek-v4-pro");
217 }
218
219 #[test]
220 fn deepseek_v4_flash_alias_resolves_to_nvidia_nim_when_provider_hinted() {
221 let registry = ModelRegistry::default();
222 let resolved = registry.resolve(Some("deepseek-v4-flash"), Some(ProviderKind::NvidiaNim));
223
224 assert_eq!(resolved.resolved.provider, ProviderKind::NvidiaNim);
225 assert_eq!(resolved.resolved.id, "deepseek-ai/deepseek-v4-flash");
226 }
227}