1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5use crate::config::AgentKind;
6
7#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
12pub struct ModelEntry {
13 #[serde(default)]
14 pub description: String,
15 #[serde(default)]
16 pub provider: String,
17 #[serde(default)]
19 pub claude: Option<String>,
20 #[serde(default)]
22 pub codex: Option<String>,
23 #[serde(default)]
25 pub opencode: Option<String>,
26 #[serde(default)]
28 pub cursor: Option<String>,
29}
30
31impl ModelEntry {
32 pub fn agent_model(&self, kind: AgentKind) -> Option<&str> {
34 match kind {
35 AgentKind::Claude => self.claude.as_deref(),
36 AgentKind::Codex => self.codex.as_deref(),
37 AgentKind::OpenCode => self.opencode.as_deref(),
38 AgentKind::Cursor => self.cursor.as_deref(),
39 }
40 }
41
42 pub fn supported_agents(&self) -> Vec<AgentKind> {
44 let mut agents = Vec::new();
45 if self.claude.is_some() {
46 agents.push(AgentKind::Claude);
47 }
48 if self.codex.is_some() {
49 agents.push(AgentKind::Codex);
50 }
51 if self.opencode.is_some() {
52 agents.push(AgentKind::OpenCode);
53 }
54 if self.cursor.is_some() {
55 agents.push(AgentKind::Cursor);
56 }
57 agents
58 }
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
63pub struct ModelRegistry {
64 #[serde(default)]
65 pub models: HashMap<String, ModelEntry>,
66}
67
68#[derive(Debug, Clone, PartialEq)]
70pub enum ModelResolution {
71 Resolved {
73 canonical_name: String,
74 agent_id: String,
75 },
76 NoAgentMapping { canonical_name: String },
78 Passthrough { raw: String },
80}
81
82impl ModelResolution {
83 pub fn model_id(&self) -> &str {
85 match self {
86 ModelResolution::Resolved { agent_id, .. } => agent_id,
87 ModelResolution::NoAgentMapping { canonical_name } => canonical_name,
88 ModelResolution::Passthrough { raw } => raw,
89 }
90 }
91}
92
93impl ModelRegistry {
94 pub fn builtin() -> Self {
96 let content = include_str!("../models.toml");
97 match toml::from_str(content) {
98 Ok(reg) => reg,
99 Err(e) => {
100 tracing::warn!("builtin models.toml is malformed: {e}");
101 Self::default()
102 }
103 }
104 }
105
106 pub fn from_toml(content: &str) -> Result<Self, String> {
108 toml::from_str(content).map_err(|e| e.to_string())
109 }
110
111 pub fn merge(&self, overrides: &ModelRegistry) -> ModelRegistry {
113 let mut merged = self.clone();
114 for (name, entry) in &overrides.models {
115 merged.models.insert(name.clone(), entry.clone());
116 }
117 merged
118 }
119
120 pub fn resolve(&self, name: &str, agent: AgentKind) -> ModelResolution {
122 if let Some(entry) = self.models.get(name) {
123 if let Some(agent_id) = entry.agent_model(agent) {
124 ModelResolution::Resolved {
125 canonical_name: name.to_string(),
126 agent_id: agent_id.to_string(),
127 }
128 } else {
129 ModelResolution::NoAgentMapping {
130 canonical_name: name.to_string(),
131 }
132 }
133 } else {
134 ModelResolution::Passthrough {
135 raw: name.to_string(),
136 }
137 }
138 }
139
140 pub fn names(&self) -> Vec<&str> {
142 let mut names: Vec<&str> = self.models.keys().map(|s| s.as_str()).collect();
143 names.sort();
144 names
145 }
146
147 pub fn models_for_agent(&self, agent: AgentKind) -> Vec<(&str, &str)> {
149 let mut result: Vec<(&str, &str)> = self
150 .models
151 .iter()
152 .filter_map(|(name, entry)| {
153 entry
154 .agent_model(agent)
155 .map(|id| (name.as_str(), id))
156 })
157 .collect();
158 result.sort_by_key(|(name, _)| *name);
159 result
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166
167 #[test]
168 fn builtin_registry_parses() {
169 let reg = ModelRegistry::builtin();
170 assert!(!reg.models.is_empty());
171 assert!(reg.models.contains_key("opus"));
172 }
173
174 #[test]
175 fn builtin_opus_has_claude_mapping() {
176 let reg = ModelRegistry::builtin();
177 let entry = reg.models.get("opus").unwrap();
178 assert_eq!(
179 entry.agent_model(AgentKind::Claude),
180 Some("claude-opus-4-6")
181 );
182 assert!(entry.agent_model(AgentKind::Codex).is_none());
183 }
184
185 #[test]
186 fn builtin_opus_has_multi_agent_mapping() {
187 let reg = ModelRegistry::builtin();
188 let entry = reg.models.get("opus").unwrap();
189 assert_eq!(entry.agent_model(AgentKind::Claude), Some("claude-opus-4-6"));
190 assert_eq!(entry.agent_model(AgentKind::OpenCode), Some("anthropic/claude-opus-4-6"));
191 assert_eq!(entry.agent_model(AgentKind::Cursor), Some("claude-opus-4-6"));
192 assert!(entry.agent_model(AgentKind::Codex).is_none());
193 }
194
195 #[test]
196 fn resolve_known_model_with_agent() {
197 let reg = ModelRegistry::builtin();
198 let res = reg.resolve("opus", AgentKind::Claude);
199 assert_eq!(
200 res,
201 ModelResolution::Resolved {
202 canonical_name: "opus".into(),
203 agent_id: "claude-opus-4-6".into(),
204 }
205 );
206 assert_eq!(res.model_id(), "claude-opus-4-6");
207 }
208
209 #[test]
210 fn resolve_known_model_no_agent_mapping() {
211 let reg = ModelRegistry::builtin();
212 let res = reg.resolve("opus", AgentKind::Codex);
213 assert_eq!(
214 res,
215 ModelResolution::NoAgentMapping {
216 canonical_name: "opus".into(),
217 }
218 );
219 assert_eq!(res.model_id(), "opus");
220 }
221
222 #[test]
223 fn resolve_unknown_model_passthrough() {
224 let reg = ModelRegistry::builtin();
225 let res = reg.resolve("my-custom-model", AgentKind::Claude);
226 assert_eq!(
227 res,
228 ModelResolution::Passthrough {
229 raw: "my-custom-model".into(),
230 }
231 );
232 assert_eq!(res.model_id(), "my-custom-model");
233 }
234
235 #[test]
236 fn from_toml_valid() {
237 let toml = r#"
238[models.test-model]
239description = "Test"
240provider = "test"
241claude = "test-id"
242"#;
243 let reg = ModelRegistry::from_toml(toml).unwrap();
244 assert!(reg.models.contains_key("test-model"));
245 }
246
247 #[test]
248 fn from_toml_empty() {
249 let reg = ModelRegistry::from_toml("").unwrap();
250 assert!(reg.models.is_empty());
251 }
252
253 #[test]
254 fn from_toml_invalid() {
255 let result = ModelRegistry::from_toml("not valid toml {{{{");
256 assert!(result.is_err());
257 }
258
259 #[test]
260 fn merge_disjoint() {
261 let a = ModelRegistry::from_toml(
262 r#"
263[models.a]
264description = "A"
265provider = "test"
266claude = "a-claude"
267"#,
268 )
269 .unwrap();
270 let b = ModelRegistry::from_toml(
271 r#"
272[models.b]
273description = "B"
274provider = "test"
275codex = "b-codex"
276"#,
277 )
278 .unwrap();
279 let merged = a.merge(&b);
280 assert!(merged.models.contains_key("a"));
281 assert!(merged.models.contains_key("b"));
282 }
283
284 #[test]
285 fn merge_override() {
286 let base = ModelRegistry::from_toml(
287 r#"
288[models.sonnet]
289description = "Original"
290provider = "anthropic"
291claude = "original-id"
292"#,
293 )
294 .unwrap();
295 let overrides = ModelRegistry::from_toml(
296 r#"
297[models.sonnet]
298description = "Custom"
299provider = "anthropic"
300claude = "custom-id"
301"#,
302 )
303 .unwrap();
304 let merged = base.merge(&overrides);
305 let entry = merged.models.get("sonnet").unwrap();
306 assert_eq!(entry.claude.as_deref(), Some("custom-id"));
307 assert_eq!(entry.description, "Custom");
308 }
309
310 #[test]
311 fn names_sorted() {
312 let reg = ModelRegistry::builtin();
313 let names = reg.names();
314 let mut sorted = names.clone();
315 sorted.sort();
316 assert_eq!(names, sorted);
317 }
318
319 #[test]
320 fn models_for_agent_filters() {
321 let reg = ModelRegistry::builtin();
322 let claude_models = reg.models_for_agent(AgentKind::Claude);
323 assert!(claude_models.iter().any(|(name, _)| *name == "opus"));
325 let codex_models = reg.models_for_agent(AgentKind::Codex);
327 assert!(codex_models.is_empty());
328 }
329
330 #[test]
331 fn supported_agents() {
332 let entry = ModelEntry {
333 description: "test".into(),
334 provider: "test".into(),
335 claude: Some("c".into()),
336 codex: None,
337 opencode: Some("o".into()),
338 cursor: None,
339 };
340 let agents = entry.supported_agents();
341 assert_eq!(agents, vec![AgentKind::Claude, AgentKind::OpenCode]);
342 }
343
344 #[test]
345 fn model_entry_default() {
346 let entry = ModelEntry::default();
347 assert!(entry.description.is_empty());
348 assert!(entry.claude.is_none());
349 assert!(entry.codex.is_none());
350 assert!(entry.opencode.is_none());
351 assert!(entry.cursor.is_none());
352 }
353}