Skip to main content

mur_common/
model.rs

1//! Named model registry shared by all agents.
2//!
3//! On disk: `~/.mur/models.yaml`. Schema:
4//!
5//! ```yaml
6//! schema_version: 1
7//! models:
8//!   anthropic_opus_4_7:
9//!     provider: anthropic
10//!     model: claude-opus-4-7
11//!     secret: env:ANTHROPIC_API_KEY
12//!     capabilities: [chat, tools]
13//! ```
14
15use crate::route::{RoutePolicy, RouteTier};
16use crate::secret::SecretRef;
17use serde::{Deserialize, Serialize};
18use std::collections::BTreeMap;
19use std::path::{Path, PathBuf};
20
21#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
22pub struct ModelEntry {
23    pub provider: String,
24    pub model: String,
25    #[serde(default, skip_serializing_if = "Option::is_none")]
26    pub base_url: Option<String>,
27    #[serde(default, skip_serializing_if = "Option::is_none")]
28    pub secret: Option<SecretRef>,
29    #[serde(default, skip_serializing_if = "Vec::is_empty")]
30    pub capabilities: Vec<String>,
31    #[serde(default, skip_serializing_if = "serde_json::Value::is_null")]
32    pub params: serde_json::Value,
33    /// Routing tier: cheap/local vs frontier/expensive.
34    /// When absent, the router infers based on provider.
35    #[serde(default, skip_serializing_if = "Option::is_none")]
36    pub tier: Option<RouteTier>,
37    /// Estimated USD cost per 1000 output tokens.
38    /// Used for ledger cost estimates.
39    #[serde(default, skip_serializing_if = "Option::is_none")]
40    pub cost_per_1k_tokens: Option<f64>,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
44pub struct RoleEntry {
45    /// Registry model ID (key in `models:`) to use as primary.
46    pub primary: String,
47    /// Fallback model ID if primary is unavailable.
48    #[serde(default, skip_serializing_if = "Option::is_none")]
49    pub fallback: Option<String>,
50    /// Optional daily cost cap in USD.
51    #[serde(default, skip_serializing_if = "Option::is_none")]
52    pub cost_budget_per_day_usd: Option<f64>,
53    /// If true, only use local models when handling sensitive data.
54    #[serde(default)]
55    pub privacy_local_only: bool,
56    /// Per-role routing policy override.
57    /// When absent, the router uses the default heuristic.
58    #[serde(default, skip_serializing_if = "Option::is_none")]
59    pub route_policy: Option<RoutePolicy>,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
63pub struct ModelRegistry {
64    pub schema_version: u32,
65    #[serde(default)]
66    pub models: BTreeMap<String, ModelEntry>,
67    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
68    pub roles: BTreeMap<String, RoleEntry>,
69}
70
71impl Default for ModelRegistry {
72    fn default() -> Self {
73        Self {
74            schema_version: 1,
75            models: BTreeMap::new(),
76            roles: BTreeMap::new(),
77        }
78    }
79}
80
81impl ModelRegistry {
82    pub fn load_from(path: &Path) -> anyhow::Result<Self> {
83        if !path.exists() {
84            return Ok(Self::default());
85        }
86        let body = std::fs::read_to_string(path)?;
87        if body.trim().is_empty() {
88            return Ok(Self::default());
89        }
90        Ok(serde_yaml_ng::from_str(&body)?)
91    }
92
93    pub fn save_to(&self, path: &Path) -> anyhow::Result<()> {
94        if let Some(parent) = path.parent() {
95            std::fs::create_dir_all(parent)?;
96        }
97        let body = serde_yaml_ng::to_string(self)?;
98        let tmp = path.with_extension("yaml.tmp");
99        std::fs::write(&tmp, body)?;
100        std::fs::rename(&tmp, path)?;
101        Ok(())
102    }
103
104    pub fn default_path() -> anyhow::Result<PathBuf> {
105        // Honor MUR_HOME (used by test harnesses and Windows CI, where
106        // `dirs::home_dir()` reads SHGetKnownFolderPath and ignores HOME).
107        if let Ok(p) = std::env::var("MUR_HOME")
108            && !p.is_empty()
109        {
110            return Ok(PathBuf::from(p).join("models.yaml"));
111        }
112        let home = dirs::home_dir().ok_or_else(|| anyhow::anyhow!("no home dir"))?;
113        Ok(home.join(".mur/models.yaml"))
114    }
115
116    /// Return the primary model ID for `role`, or the fallback if the primary
117    /// is not in the `models` map, or `None` if the role is not configured.
118    pub fn resolve_role(&self, role: &str) -> Option<&str> {
119        let entry = self.roles.get(role)?;
120        if self.models.contains_key(&entry.primary) {
121            return Some(&entry.primary);
122        }
123        // primary not in registry — try fallback
124        if let Some(fb) = &entry.fallback
125            && self.models.contains_key(fb)
126        {
127            return Some(fb);
128        }
129        // role configured but no available model
130        None
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137
138    #[test]
139    fn parses_full_registry() {
140        let yaml = r#"
141schema_version: 1
142models:
143  anthropic_opus_4_7:
144    provider: anthropic
145    model: claude-opus-4-7
146    secret: env:ANTHROPIC_API_KEY
147    capabilities: [chat, tools]
148  ollama_llama3:
149    provider: ollama
150    model: llama3.2:3b
151    base_url: http://127.0.0.1:11434
152"#;
153        let r: ModelRegistry = serde_yaml_ng::from_str(yaml).unwrap();
154        assert_eq!(r.schema_version, 1);
155        assert_eq!(r.models.len(), 2);
156        let opus = r.models.get("anthropic_opus_4_7").unwrap();
157        assert_eq!(opus.provider, "anthropic");
158        assert_eq!(
159            opus.secret,
160            Some(SecretRef::Env("ANTHROPIC_API_KEY".into()))
161        );
162        assert!(r.models["ollama_llama3"].secret.is_none());
163    }
164
165    #[test]
166    fn round_trip_preserves_shape() {
167        let mut r = ModelRegistry::default();
168        r.models.insert(
169            "foo".into(),
170            ModelEntry {
171                provider: "anthropic".into(),
172                model: "claude-opus-4-7".into(),
173                base_url: None,
174                secret: Some(SecretRef::Keychain {
175                    service: "mur".into(),
176                    account: "anthropic".into(),
177                }),
178                capabilities: vec!["chat".into()],
179                params: serde_json::Value::Null,
180                tier: None,
181                cost_per_1k_tokens: None,
182            },
183        );
184        let s = serde_yaml_ng::to_string(&r).unwrap();
185        let parsed: ModelRegistry = serde_yaml_ng::from_str(&s).unwrap();
186        assert_eq!(r, parsed);
187    }
188
189    #[test]
190    fn rejects_unknown_secret_scheme() {
191        let yaml = r#"
192schema_version: 1
193models:
194  bad:
195    provider: x
196    model: y
197    secret: bogus:value
198"#;
199        let r: Result<ModelRegistry, _> = serde_yaml_ng::from_str(yaml);
200        assert!(r.is_err(), "should reject unknown scheme");
201    }
202
203    #[test]
204    fn test_registry_roundtrip_with_roles() {
205        let yaml = r#"
206schema_version: 1
207models:
208  haiku:
209    provider: anthropic
210    model: claude-haiku-4-5
211roles:
212  reflector:
213    primary: haiku
214    fallback: null
215    cost_budget_per_day_usd: 0.5
216"#;
217        let reg: ModelRegistry = serde_yaml_ng::from_str(yaml).unwrap();
218        assert_eq!(reg.roles["reflector"].primary, "haiku");
219        let back = serde_yaml_ng::to_string(&reg).unwrap();
220        let reg2: ModelRegistry = serde_yaml_ng::from_str(&back).unwrap();
221        assert_eq!(reg, reg2);
222    }
223
224    #[test]
225    fn test_resolve_role_primary() {
226        let mut reg = ModelRegistry::default();
227        reg.models.insert(
228            "haiku".into(),
229            ModelEntry {
230                provider: "anthropic".into(),
231                model: "claude-haiku-4-5".into(),
232                base_url: None,
233                secret: None,
234                capabilities: vec![],
235                params: serde_json::Value::Null,
236                tier: None,
237                cost_per_1k_tokens: None,
238            },
239        );
240        reg.roles.insert(
241            "reflector".into(),
242            RoleEntry {
243                primary: "haiku".into(),
244                fallback: None,
245                ..Default::default()
246            },
247        );
248        assert_eq!(reg.resolve_role("reflector"), Some("haiku"));
249    }
250
251    #[test]
252    fn test_resolve_role_fallback() {
253        let mut reg = ModelRegistry::default();
254        reg.models.insert(
255            "haiku".into(),
256            ModelEntry {
257                provider: "anthropic".into(),
258                model: "claude-haiku-4-5".into(),
259                base_url: None,
260                secret: None,
261                capabilities: vec![],
262                params: serde_json::Value::Null,
263                tier: None,
264                cost_per_1k_tokens: None,
265            },
266        );
267        reg.roles.insert(
268            "reflector".into(),
269            RoleEntry {
270                primary: "nonexistent".into(),
271                fallback: Some("haiku".into()),
272                ..Default::default()
273            },
274        );
275        assert_eq!(reg.resolve_role("reflector"), Some("haiku"));
276    }
277
278    #[test]
279    fn test_resolve_role_none() {
280        let reg = ModelRegistry::default();
281        assert_eq!(reg.resolve_role("reflector"), None);
282    }
283
284    #[test]
285    fn model_entry_parses_tier_field() {
286        let yaml = r#"
287schema_version: 1
288models:
289  haiku:
290    provider: anthropic
291    model: claude-haiku-4-5
292    tier: local
293  opus:
294    provider: anthropic
295    model: claude-opus-4-7
296    tier: frontier
297    cost_per_1k_tokens: 0.015
298"#;
299        let r: ModelRegistry = serde_yaml_ng::from_str(yaml).unwrap();
300        assert_eq!(r.models["haiku"].tier, Some(RouteTier::Local));
301        assert_eq!(r.models["opus"].tier, Some(RouteTier::Frontier));
302        assert_eq!(r.models["opus"].cost_per_1k_tokens, Some(0.015));
303        // Missing tier is None.
304        let mut r2 = ModelRegistry::default();
305        r2.models.insert(
306            "x".into(),
307            ModelEntry {
308                provider: "ollama".into(),
309                model: "llama3".into(),
310                base_url: None,
311                secret: None,
312                capabilities: vec![],
313                params: serde_json::Value::Null,
314                tier: None,
315                cost_per_1k_tokens: None,
316            },
317        );
318        let yaml = serde_yaml_ng::to_string(&r2).unwrap();
319        assert!(
320            !yaml.contains("tier:"),
321            "absent tier should not be serialized: {yaml}"
322        );
323    }
324
325    #[test]
326    fn role_entry_parses_route_policy() {
327        let yaml = r#"
328schema_version: 1
329models:
330  haiku:
331    provider: anthropic
332    model: claude-haiku-4-5
333  opus:
334    provider: anthropic
335    model: claude-opus-4-7
336roles:
337  dev:
338    primary: opus
339    route_policy: !force_frontier
340      model_id: opus
341  reflector:
342    primary: haiku
343    route_policy: prefer_local
344  curator:
345    primary: haiku
346    route_policy: force_local
347  chat:
348    primary: haiku
349"#;
350        let r: ModelRegistry = serde_yaml_ng::from_str(yaml).unwrap();
351        assert_eq!(
352            r.roles["dev"].route_policy,
353            Some(RoutePolicy::ForceFrontier {
354                model_id: "opus".into()
355            })
356        );
357        assert_eq!(
358            r.roles["reflector"].route_policy,
359            Some(RoutePolicy::PreferLocal)
360        );
361        assert_eq!(
362            r.roles["curator"].route_policy,
363            Some(RoutePolicy::ForceLocal)
364        );
365        assert_eq!(r.roles["chat"].route_policy, None);
366    }
367}
368
369#[cfg(test)]
370mod io_tests {
371    use super::*;
372    use tempfile::tempdir;
373
374    #[test]
375    fn load_returns_empty_when_file_missing() {
376        let dir = tempdir().unwrap();
377        let r = ModelRegistry::load_from(&dir.path().join("nope.yaml")).unwrap();
378        assert_eq!(r.models.len(), 0);
379        assert_eq!(r.schema_version, 1);
380    }
381
382    #[test]
383    fn save_then_load_round_trips() {
384        let dir = tempdir().unwrap();
385        let p = dir.path().join("models.yaml");
386        let mut r = ModelRegistry::default();
387        r.models.insert(
388            "x".into(),
389            ModelEntry {
390                provider: "ollama".into(),
391                model: "llama3.2:3b".into(),
392                base_url: None,
393                secret: None,
394                capabilities: vec![],
395                params: serde_json::Value::Null,
396                tier: None,
397                cost_per_1k_tokens: None,
398            },
399        );
400        r.save_to(&p).unwrap();
401        let r2 = ModelRegistry::load_from(&p).unwrap();
402        assert_eq!(r, r2);
403    }
404
405    #[test]
406    fn save_uses_atomic_rename() {
407        let dir = tempdir().unwrap();
408        let p = dir.path().join("models.yaml");
409        ModelRegistry::default().save_to(&p).unwrap();
410        let temp = dir.path().join("models.yaml.tmp");
411        assert!(!temp.exists(), "atomic temp left behind");
412    }
413}