Skip to main content

harness/
models.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5use crate::config::AgentKind;
6
7/// A single model entry in the registry.
8///
9/// Contains metadata (`description`, `provider`) and per-agent model ID mappings.
10/// Not every agent needs a mapping — only those that support the model.
11#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
12pub struct ModelEntry {
13    #[serde(default)]
14    pub description: String,
15    #[serde(default)]
16    pub provider: String,
17    /// Model ID for Claude Code (e.g. `claude-sonnet-4-5-20250929`).
18    #[serde(default)]
19    pub claude: Option<String>,
20    /// Model ID for Codex CLI.
21    #[serde(default)]
22    pub codex: Option<String>,
23    /// Model ID for OpenCode CLI.
24    #[serde(default)]
25    pub opencode: Option<String>,
26    /// Model ID for Cursor CLI.
27    #[serde(default)]
28    pub cursor: Option<String>,
29}
30
31impl ModelEntry {
32    /// Get the model ID string for the given agent, if this model supports that agent.
33    pub fn agent_model(&self, kind: AgentKind) -> Option<&str> {
34        match kind {
35            AgentKind::Claude => self.claude.as_deref(),
36            AgentKind::Codex => self.codex.as_deref(),
37            AgentKind::OpenCode => self.opencode.as_deref(),
38            AgentKind::Cursor => self.cursor.as_deref(),
39        }
40    }
41
42    /// Return all agent kinds that have a mapping in this entry.
43    pub fn supported_agents(&self) -> Vec<AgentKind> {
44        let mut agents = Vec::new();
45        if self.claude.is_some() {
46            agents.push(AgentKind::Claude);
47        }
48        if self.codex.is_some() {
49            agents.push(AgentKind::Codex);
50        }
51        if self.opencode.is_some() {
52            agents.push(AgentKind::OpenCode);
53        }
54        if self.cursor.is_some() {
55            agents.push(AgentKind::Cursor);
56        }
57        agents
58    }
59}
60
61/// The model registry — a map from canonical names to model entries.
62#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
63pub struct ModelRegistry {
64    #[serde(default)]
65    pub models: HashMap<String, ModelEntry>,
66}
67
68/// The outcome of resolving a model name against the registry.
69#[derive(Debug, Clone, PartialEq)]
70pub enum ModelResolution {
71    /// Found in registry and the agent has a mapping.
72    Resolved {
73        canonical_name: String,
74        agent_id: String,
75    },
76    /// Found in registry but no mapping for this agent.
77    NoAgentMapping { canonical_name: String },
78    /// Not found in registry — pass through as-is.
79    Passthrough { raw: String },
80}
81
82impl ModelResolution {
83    /// Return the model ID string to pass to the agent CLI.
84    pub fn model_id(&self) -> &str {
85        match self {
86            ModelResolution::Resolved { agent_id, .. } => agent_id,
87            ModelResolution::NoAgentMapping { canonical_name } => canonical_name,
88            ModelResolution::Passthrough { raw } => raw,
89        }
90    }
91}
92
93impl ModelRegistry {
94    /// Parse the builtin models.toml compiled into the binary.
95    pub fn builtin() -> Self {
96        let content = include_str!("../models.toml");
97        match toml::from_str(content) {
98            Ok(reg) => reg,
99            Err(e) => {
100                tracing::warn!("builtin models.toml is malformed: {e}");
101                Self::default()
102            }
103        }
104    }
105
106    /// Parse a TOML string into a registry.
107    pub fn from_toml(content: &str) -> Result<Self, String> {
108        toml::from_str(content).map_err(|e| e.to_string())
109    }
110
111    /// Merge another registry into this one. `overrides` wins on conflicts.
112    pub fn merge(&self, overrides: &ModelRegistry) -> ModelRegistry {
113        let mut merged = self.clone();
114        for (name, entry) in &overrides.models {
115            merged.models.insert(name.clone(), entry.clone());
116        }
117        merged
118    }
119
120    /// Look up a model name and resolve it for the given agent.
121    pub fn resolve(&self, name: &str, agent: AgentKind) -> ModelResolution {
122        if let Some(entry) = self.models.get(name) {
123            if let Some(agent_id) = entry.agent_model(agent) {
124                ModelResolution::Resolved {
125                    canonical_name: name.to_string(),
126                    agent_id: agent_id.to_string(),
127                }
128            } else {
129                ModelResolution::NoAgentMapping {
130                    canonical_name: name.to_string(),
131                }
132            }
133        } else {
134            ModelResolution::Passthrough {
135                raw: name.to_string(),
136            }
137        }
138    }
139
140    /// Return all canonical model names, sorted.
141    pub fn names(&self) -> Vec<&str> {
142        let mut names: Vec<&str> = self.models.keys().map(|s| s.as_str()).collect();
143        names.sort();
144        names
145    }
146
147    /// Return models that have a mapping for the given agent.
148    pub fn models_for_agent(&self, agent: AgentKind) -> Vec<(&str, &str)> {
149        let mut result: Vec<(&str, &str)> = self
150            .models
151            .iter()
152            .filter_map(|(name, entry)| {
153                entry
154                    .agent_model(agent)
155                    .map(|id| (name.as_str(), id))
156            })
157            .collect();
158        result.sort_by_key(|(name, _)| *name);
159        result
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    #[test]
168    fn builtin_registry_parses() {
169        let reg = ModelRegistry::builtin();
170        assert!(!reg.models.is_empty());
171        assert!(reg.models.contains_key("opus"));
172    }
173
174    #[test]
175    fn builtin_opus_has_claude_mapping() {
176        let reg = ModelRegistry::builtin();
177        let entry = reg.models.get("opus").unwrap();
178        assert_eq!(
179            entry.agent_model(AgentKind::Claude),
180            Some("claude-opus-4-6")
181        );
182        assert!(entry.agent_model(AgentKind::Codex).is_none());
183    }
184
185    #[test]
186    fn builtin_opus_has_multi_agent_mapping() {
187        let reg = ModelRegistry::builtin();
188        let entry = reg.models.get("opus").unwrap();
189        assert_eq!(entry.agent_model(AgentKind::Claude), Some("claude-opus-4-6"));
190        assert_eq!(entry.agent_model(AgentKind::OpenCode), Some("anthropic/claude-opus-4-6"));
191        assert_eq!(entry.agent_model(AgentKind::Cursor), Some("claude-opus-4-6"));
192        assert!(entry.agent_model(AgentKind::Codex).is_none());
193    }
194
195    #[test]
196    fn resolve_known_model_with_agent() {
197        let reg = ModelRegistry::builtin();
198        let res = reg.resolve("opus", AgentKind::Claude);
199        assert_eq!(
200            res,
201            ModelResolution::Resolved {
202                canonical_name: "opus".into(),
203                agent_id: "claude-opus-4-6".into(),
204            }
205        );
206        assert_eq!(res.model_id(), "claude-opus-4-6");
207    }
208
209    #[test]
210    fn resolve_known_model_no_agent_mapping() {
211        let reg = ModelRegistry::builtin();
212        let res = reg.resolve("opus", AgentKind::Codex);
213        assert_eq!(
214            res,
215            ModelResolution::NoAgentMapping {
216                canonical_name: "opus".into(),
217            }
218        );
219        assert_eq!(res.model_id(), "opus");
220    }
221
222    #[test]
223    fn resolve_unknown_model_passthrough() {
224        let reg = ModelRegistry::builtin();
225        let res = reg.resolve("my-custom-model", AgentKind::Claude);
226        assert_eq!(
227            res,
228            ModelResolution::Passthrough {
229                raw: "my-custom-model".into(),
230            }
231        );
232        assert_eq!(res.model_id(), "my-custom-model");
233    }
234
235    #[test]
236    fn from_toml_valid() {
237        let toml = r#"
238[models.test-model]
239description = "Test"
240provider = "test"
241claude = "test-id"
242"#;
243        let reg = ModelRegistry::from_toml(toml).unwrap();
244        assert!(reg.models.contains_key("test-model"));
245    }
246
247    #[test]
248    fn from_toml_empty() {
249        let reg = ModelRegistry::from_toml("").unwrap();
250        assert!(reg.models.is_empty());
251    }
252
253    #[test]
254    fn from_toml_invalid() {
255        let result = ModelRegistry::from_toml("not valid toml {{{{");
256        assert!(result.is_err());
257    }
258
259    #[test]
260    fn merge_disjoint() {
261        let a = ModelRegistry::from_toml(
262            r#"
263[models.a]
264description = "A"
265provider = "test"
266claude = "a-claude"
267"#,
268        )
269        .unwrap();
270        let b = ModelRegistry::from_toml(
271            r#"
272[models.b]
273description = "B"
274provider = "test"
275codex = "b-codex"
276"#,
277        )
278        .unwrap();
279        let merged = a.merge(&b);
280        assert!(merged.models.contains_key("a"));
281        assert!(merged.models.contains_key("b"));
282    }
283
284    #[test]
285    fn merge_override() {
286        let base = ModelRegistry::from_toml(
287            r#"
288[models.sonnet]
289description = "Original"
290provider = "anthropic"
291claude = "original-id"
292"#,
293        )
294        .unwrap();
295        let overrides = ModelRegistry::from_toml(
296            r#"
297[models.sonnet]
298description = "Custom"
299provider = "anthropic"
300claude = "custom-id"
301"#,
302        )
303        .unwrap();
304        let merged = base.merge(&overrides);
305        let entry = merged.models.get("sonnet").unwrap();
306        assert_eq!(entry.claude.as_deref(), Some("custom-id"));
307        assert_eq!(entry.description, "Custom");
308    }
309
310    #[test]
311    fn names_sorted() {
312        let reg = ModelRegistry::builtin();
313        let names = reg.names();
314        let mut sorted = names.clone();
315        sorted.sort();
316        assert_eq!(names, sorted);
317    }
318
319    #[test]
320    fn models_for_agent_filters() {
321        let reg = ModelRegistry::builtin();
322        let claude_models = reg.models_for_agent(AgentKind::Claude);
323        // Claude should have opus
324        assert!(claude_models.iter().any(|(name, _)| *name == "opus"));
325        // Codex should have no models (opus is not mapped for codex)
326        let codex_models = reg.models_for_agent(AgentKind::Codex);
327        assert!(codex_models.is_empty());
328    }
329
330    #[test]
331    fn supported_agents() {
332        let entry = ModelEntry {
333            description: "test".into(),
334            provider: "test".into(),
335            claude: Some("c".into()),
336            codex: None,
337            opencode: Some("o".into()),
338            cursor: None,
339        };
340        let agents = entry.supported_agents();
341        assert_eq!(agents, vec![AgentKind::Claude, AgentKind::OpenCode]);
342    }
343
344    #[test]
345    fn model_entry_default() {
346        let entry = ModelEntry::default();
347        assert!(entry.description.is_empty());
348        assert!(entry.claude.is_none());
349        assert!(entry.codex.is_none());
350        assert!(entry.opencode.is_none());
351        assert!(entry.cursor.is_none());
352    }
353}