Skip to main content

mur_common/
model.rs

1//! Named model registry shared by all agents.
2//!
3//! On disk: `~/.mur/models.yaml`. Schema:
4//!
5//! ```yaml
6//! schema_version: 1
7//! models:
8//!   anthropic_opus_4_7:
9//!     provider: anthropic
10//!     model: claude-opus-4-7
11//!     secret: env:ANTHROPIC_API_KEY
12//!     capabilities: [chat, tools]
13//! ```
14
15use crate::secret::SecretRef;
16use serde::{Deserialize, Serialize};
17use std::collections::BTreeMap;
18use std::path::{Path, PathBuf};
19
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
21pub struct ModelEntry {
22    pub provider: String,
23    pub model: String,
24    #[serde(default, skip_serializing_if = "Option::is_none")]
25    pub base_url: Option<String>,
26    #[serde(default, skip_serializing_if = "Option::is_none")]
27    pub secret: Option<SecretRef>,
28    #[serde(default, skip_serializing_if = "Vec::is_empty")]
29    pub capabilities: Vec<String>,
30    #[serde(default, skip_serializing_if = "serde_json::Value::is_null")]
31    pub params: serde_json::Value,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
35pub struct RoleEntry {
36    /// Registry model ID (key in `models:`) to use as primary.
37    pub primary: String,
38    /// Fallback model ID if primary is unavailable.
39    #[serde(default, skip_serializing_if = "Option::is_none")]
40    pub fallback: Option<String>,
41    /// Optional daily cost cap in USD.
42    #[serde(default, skip_serializing_if = "Option::is_none")]
43    pub cost_budget_per_day_usd: Option<f64>,
44    /// If true, only use local models when handling sensitive data.
45    #[serde(default)]
46    pub privacy_local_only: bool,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
50pub struct ModelRegistry {
51    pub schema_version: u32,
52    #[serde(default)]
53    pub models: BTreeMap<String, ModelEntry>,
54    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
55    pub roles: BTreeMap<String, RoleEntry>,
56}
57
58impl Default for ModelRegistry {
59    fn default() -> Self {
60        Self {
61            schema_version: 1,
62            models: BTreeMap::new(),
63            roles: BTreeMap::new(),
64        }
65    }
66}
67
68impl ModelRegistry {
69    pub fn load_from(path: &Path) -> anyhow::Result<Self> {
70        if !path.exists() {
71            return Ok(Self::default());
72        }
73        let body = std::fs::read_to_string(path)?;
74        if body.trim().is_empty() {
75            return Ok(Self::default());
76        }
77        Ok(serde_yaml_ng::from_str(&body)?)
78    }
79
80    pub fn save_to(&self, path: &Path) -> anyhow::Result<()> {
81        if let Some(parent) = path.parent() {
82            std::fs::create_dir_all(parent)?;
83        }
84        let body = serde_yaml_ng::to_string(self)?;
85        let tmp = path.with_extension("yaml.tmp");
86        std::fs::write(&tmp, body)?;
87        std::fs::rename(&tmp, path)?;
88        Ok(())
89    }
90
91    pub fn default_path() -> anyhow::Result<PathBuf> {
92        // Honor MUR_HOME (used by test harnesses and Windows CI, where
93        // `dirs::home_dir()` reads SHGetKnownFolderPath and ignores HOME).
94        if let Ok(p) = std::env::var("MUR_HOME")
95            && !p.is_empty()
96        {
97            return Ok(PathBuf::from(p).join("models.yaml"));
98        }
99        let home = dirs::home_dir().ok_or_else(|| anyhow::anyhow!("no home dir"))?;
100        Ok(home.join(".mur/models.yaml"))
101    }
102
103    /// Return the primary model ID for `role`, or the fallback if the primary
104    /// is not in the `models` map, or `None` if the role is not configured.
105    pub fn resolve_role(&self, role: &str) -> Option<&str> {
106        let entry = self.roles.get(role)?;
107        if self.models.contains_key(&entry.primary) {
108            return Some(&entry.primary);
109        }
110        // primary not in registry — try fallback
111        if let Some(fb) = &entry.fallback
112            && self.models.contains_key(fb)
113        {
114            return Some(fb);
115        }
116        // role configured but no available model
117        None
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124
125    #[test]
126    fn parses_full_registry() {
127        let yaml = r#"
128schema_version: 1
129models:
130  anthropic_opus_4_7:
131    provider: anthropic
132    model: claude-opus-4-7
133    secret: env:ANTHROPIC_API_KEY
134    capabilities: [chat, tools]
135  ollama_llama3:
136    provider: ollama
137    model: llama3.2:3b
138    base_url: http://127.0.0.1:11434
139"#;
140        let r: ModelRegistry = serde_yaml_ng::from_str(yaml).unwrap();
141        assert_eq!(r.schema_version, 1);
142        assert_eq!(r.models.len(), 2);
143        let opus = r.models.get("anthropic_opus_4_7").unwrap();
144        assert_eq!(opus.provider, "anthropic");
145        assert_eq!(
146            opus.secret,
147            Some(SecretRef::Env("ANTHROPIC_API_KEY".into()))
148        );
149        assert!(r.models["ollama_llama3"].secret.is_none());
150    }
151
152    #[test]
153    fn round_trip_preserves_shape() {
154        let mut r = ModelRegistry::default();
155        r.models.insert(
156            "foo".into(),
157            ModelEntry {
158                provider: "anthropic".into(),
159                model: "claude-opus-4-7".into(),
160                base_url: None,
161                secret: Some(SecretRef::Keychain {
162                    service: "mur".into(),
163                    account: "anthropic".into(),
164                }),
165                capabilities: vec!["chat".into()],
166                params: serde_json::Value::Null,
167            },
168        );
169        let s = serde_yaml_ng::to_string(&r).unwrap();
170        let parsed: ModelRegistry = serde_yaml_ng::from_str(&s).unwrap();
171        assert_eq!(r, parsed);
172    }
173
174    #[test]
175    fn rejects_unknown_secret_scheme() {
176        let yaml = r#"
177schema_version: 1
178models:
179  bad:
180    provider: x
181    model: y
182    secret: bogus:value
183"#;
184        let r: Result<ModelRegistry, _> = serde_yaml_ng::from_str(yaml);
185        assert!(r.is_err(), "should reject unknown scheme");
186    }
187
188    #[test]
189    fn test_registry_roundtrip_with_roles() {
190        let yaml = r#"
191schema_version: 1
192models:
193  haiku:
194    provider: anthropic
195    model: claude-haiku-4-5
196roles:
197  reflector:
198    primary: haiku
199    fallback: null
200    cost_budget_per_day_usd: 0.5
201"#;
202        let reg: ModelRegistry = serde_yaml_ng::from_str(yaml).unwrap();
203        assert_eq!(reg.roles["reflector"].primary, "haiku");
204        let back = serde_yaml_ng::to_string(&reg).unwrap();
205        let reg2: ModelRegistry = serde_yaml_ng::from_str(&back).unwrap();
206        assert_eq!(reg, reg2);
207    }
208
209    #[test]
210    fn test_resolve_role_primary() {
211        let mut reg = ModelRegistry::default();
212        reg.models.insert(
213            "haiku".into(),
214            ModelEntry {
215                provider: "anthropic".into(),
216                model: "claude-haiku-4-5".into(),
217                base_url: None,
218                secret: None,
219                capabilities: vec![],
220                params: serde_json::Value::Null,
221            },
222        );
223        reg.roles.insert(
224            "reflector".into(),
225            RoleEntry {
226                primary: "haiku".into(),
227                fallback: None,
228                ..Default::default()
229            },
230        );
231        assert_eq!(reg.resolve_role("reflector"), Some("haiku"));
232    }
233
234    #[test]
235    fn test_resolve_role_fallback() {
236        let mut reg = ModelRegistry::default();
237        reg.models.insert(
238            "haiku".into(),
239            ModelEntry {
240                provider: "anthropic".into(),
241                model: "claude-haiku-4-5".into(),
242                base_url: None,
243                secret: None,
244                capabilities: vec![],
245                params: serde_json::Value::Null,
246            },
247        );
248        reg.roles.insert(
249            "reflector".into(),
250            RoleEntry {
251                primary: "nonexistent".into(),
252                fallback: Some("haiku".into()),
253                ..Default::default()
254            },
255        );
256        assert_eq!(reg.resolve_role("reflector"), Some("haiku"));
257    }
258
259    #[test]
260    fn test_resolve_role_none() {
261        let reg = ModelRegistry::default();
262        assert_eq!(reg.resolve_role("reflector"), None);
263    }
264}
265
266#[cfg(test)]
267mod io_tests {
268    use super::*;
269    use tempfile::tempdir;
270
271    #[test]
272    fn load_returns_empty_when_file_missing() {
273        let dir = tempdir().unwrap();
274        let r = ModelRegistry::load_from(&dir.path().join("nope.yaml")).unwrap();
275        assert_eq!(r.models.len(), 0);
276        assert_eq!(r.schema_version, 1);
277    }
278
279    #[test]
280    fn save_then_load_round_trips() {
281        let dir = tempdir().unwrap();
282        let p = dir.path().join("models.yaml");
283        let mut r = ModelRegistry::default();
284        r.models.insert(
285            "x".into(),
286            ModelEntry {
287                provider: "ollama".into(),
288                model: "llama3.2:3b".into(),
289                base_url: None,
290                secret: None,
291                capabilities: vec![],
292                params: serde_json::Value::Null,
293            },
294        );
295        r.save_to(&p).unwrap();
296        let r2 = ModelRegistry::load_from(&p).unwrap();
297        assert_eq!(r, r2);
298    }
299
300    #[test]
301    fn save_uses_atomic_rename() {
302        let dir = tempdir().unwrap();
303        let p = dir.path().join("models.yaml");
304        ModelRegistry::default().save_to(&p).unwrap();
305        let temp = dir.path().join("models.yaml.tmp");
306        assert!(!temp.exists(), "atomic temp left behind");
307    }
308}