Skip to main content

mur_common/
model.rs

1//! Named model registry shared by all agents.
2//!
3//! On disk: `~/.mur/models.yaml`. Schema:
4//!
5//! ```yaml
6//! schema_version: 1
7//! models:
8//!   anthropic_opus_4_7:
9//!     provider: anthropic
10//!     model: claude-opus-4-7
11//!     secret: env:ANTHROPIC_API_KEY
12//!     capabilities: [chat, tools]
13//! ```
14
15use crate::route::{RoutePolicy, RouteTier};
16use crate::secret::SecretRef;
17use serde::{Deserialize, Serialize};
18use std::collections::BTreeMap;
19use std::path::{Path, PathBuf};
20
21#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
22pub struct ModelEntry {
23    #[serde(default)]
24    pub provider: String,
25    #[serde(default)]
26    pub model: String,
27    #[serde(default, skip_serializing_if = "Option::is_none")]
28    pub base_url: Option<String>,
29    #[serde(default, skip_serializing_if = "Option::is_none")]
30    pub secret: Option<SecretRef>,
31    #[serde(default, skip_serializing_if = "Vec::is_empty")]
32    pub capabilities: Vec<String>,
33    #[serde(default, skip_serializing_if = "serde_json::Value::is_null")]
34    pub params: serde_json::Value,
35    /// Routing tier: cheap/local vs frontier/expensive.
36    /// When absent, the router infers based on provider.
37    #[serde(default, skip_serializing_if = "Option::is_none")]
38    pub tier: Option<RouteTier>,
39    /// Estimated USD cost per 1000 output tokens.
40    /// Used for ledger cost estimates.
41    #[serde(default, skip_serializing_if = "Option::is_none")]
42    pub cost_per_1k_tokens: Option<f64>,
43    /// Estimated USD cost per 1000 input tokens.
44    /// New field for split input/output cost tracking.
45    #[serde(default, skip_serializing_if = "Option::is_none")]
46    pub input_cost_per_1k: Option<f64>,
47    /// Estimated USD cost per 1000 output tokens.
48    /// New field for split input/output cost tracking.
49    #[serde(default, skip_serializing_if = "Option::is_none")]
50    pub output_cost_per_1k: Option<f64>,
51    /// Model context window size in tokens.
52    #[serde(default, skip_serializing_if = "Option::is_none")]
53    pub context_window: Option<u64>,
54}
55
56impl ModelEntry {
57    /// Resolve effective per-1k rates as `(input, output)`.
58    ///
59    /// The deprecated `cost_per_1k_tokens` is treated as the output rate and
60    /// also as the input fallback, so legacy single-rate entries keep working.
61    pub fn effective_costs(&self) -> (Option<f64>, Option<f64>) {
62        let output = self.output_cost_per_1k.or(self.cost_per_1k_tokens);
63        let input = self.input_cost_per_1k.or(self.cost_per_1k_tokens);
64        (input, output)
65    }
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
69pub struct RoleEntry {
70    /// Registry model ID (key in `models:`) to use as primary.
71    pub primary: String,
72    /// Fallback model ID if primary is unavailable.
73    #[serde(default, skip_serializing_if = "Option::is_none")]
74    pub fallback: Option<String>,
75    /// Optional daily cost cap in USD.
76    #[serde(default, skip_serializing_if = "Option::is_none")]
77    pub cost_budget_per_day_usd: Option<f64>,
78    /// If true, only use local models when handling sensitive data.
79    #[serde(default)]
80    pub privacy_local_only: bool,
81    /// Per-role routing policy override.
82    /// When absent, the router uses the default heuristic.
83    #[serde(default, skip_serializing_if = "Option::is_none")]
84    pub route_policy: Option<RoutePolicy>,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
88pub struct ModelRegistry {
89    pub schema_version: u32,
90    #[serde(default)]
91    pub models: BTreeMap<String, ModelEntry>,
92    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
93    pub roles: BTreeMap<String, RoleEntry>,
94}
95
96impl Default for ModelRegistry {
97    fn default() -> Self {
98        Self {
99            schema_version: 1,
100            models: BTreeMap::new(),
101            roles: BTreeMap::new(),
102        }
103    }
104}
105
106impl ModelRegistry {
107    pub fn load_from(path: &Path) -> anyhow::Result<Self> {
108        if !path.exists() {
109            return Ok(Self::default());
110        }
111        let body = std::fs::read_to_string(path)?;
112        if body.trim().is_empty() {
113            return Ok(Self::default());
114        }
115        Ok(serde_yaml_ng::from_str(&body)?)
116    }
117
118    pub fn save_to(&self, path: &Path) -> anyhow::Result<()> {
119        if let Some(parent) = path.parent() {
120            std::fs::create_dir_all(parent)?;
121        }
122        let body = serde_yaml_ng::to_string(self)?;
123        let tmp = path.with_extension("yaml.tmp");
124        std::fs::write(&tmp, body)?;
125        std::fs::rename(&tmp, path)?;
126        Ok(())
127    }
128
129    pub fn default_path() -> anyhow::Result<PathBuf> {
130        // Honor MUR_HOME (used by test harnesses and Windows CI, where
131        // `dirs::home_dir()` reads SHGetKnownFolderPath and ignores HOME).
132        if let Ok(p) = std::env::var("MUR_HOME")
133            && !p.is_empty()
134        {
135            return Ok(PathBuf::from(p).join("models.yaml"));
136        }
137        let home = dirs::home_dir().ok_or_else(|| anyhow::anyhow!("no home dir"))?;
138        Ok(home.join(".mur/models.yaml"))
139    }
140
141    /// Return the primary model ID for `role`, or the fallback if the primary
142    /// is not in the `models` map, or `None` if the role is not configured.
143    pub fn resolve_role(&self, role: &str) -> Option<&str> {
144        let entry = self.roles.get(role)?;
145        if self.models.contains_key(&entry.primary) {
146            return Some(&entry.primary);
147        }
148        // primary not in registry — try fallback
149        if let Some(fb) = &entry.fallback
150            && self.models.contains_key(fb)
151        {
152            return Some(fb);
153        }
154        // role configured but no available model
155        None
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    #[test]
164    fn parses_full_registry() {
165        let yaml = r#"
166schema_version: 1
167models:
168  anthropic_opus_4_7:
169    provider: anthropic
170    model: claude-opus-4-7
171    secret: env:ANTHROPIC_API_KEY
172    capabilities: [chat, tools]
173  ollama_llama3:
174    provider: ollama
175    model: llama3.2:3b
176    base_url: http://127.0.0.1:11434
177"#;
178        let r: ModelRegistry = serde_yaml_ng::from_str(yaml).unwrap();
179        assert_eq!(r.schema_version, 1);
180        assert_eq!(r.models.len(), 2);
181        let opus = r.models.get("anthropic_opus_4_7").unwrap();
182        assert_eq!(opus.provider, "anthropic");
183        assert_eq!(
184            opus.secret,
185            Some(SecretRef::Env("ANTHROPIC_API_KEY".into()))
186        );
187        assert!(r.models["ollama_llama3"].secret.is_none());
188    }
189
190    #[test]
191    fn round_trip_preserves_shape() {
192        let mut r = ModelRegistry::default();
193        r.models.insert(
194            "foo".into(),
195            ModelEntry {
196                provider: "anthropic".into(),
197                model: "claude-opus-4-7".into(),
198                base_url: None,
199                secret: Some(SecretRef::Keychain {
200                    service: "mur".into(),
201                    account: "anthropic".into(),
202                }),
203                capabilities: vec!["chat".into()],
204                params: serde_json::Value::Null,
205                tier: None,
206                cost_per_1k_tokens: None,
207                input_cost_per_1k: None,
208                output_cost_per_1k: None,
209                context_window: None,
210            },
211        );
212        let s = serde_yaml_ng::to_string(&r).unwrap();
213        let parsed: ModelRegistry = serde_yaml_ng::from_str(&s).unwrap();
214        assert_eq!(r, parsed);
215    }
216
217    #[test]
218    fn rejects_unknown_secret_scheme() {
219        let yaml = r#"
220schema_version: 1
221models:
222  bad:
223    provider: x
224    model: y
225    secret: bogus:value
226"#;
227        let r: Result<ModelRegistry, _> = serde_yaml_ng::from_str(yaml);
228        assert!(r.is_err(), "should reject unknown scheme");
229    }
230
231    #[test]
232    fn test_registry_roundtrip_with_roles() {
233        let yaml = r#"
234schema_version: 1
235models:
236  haiku:
237    provider: anthropic
238    model: claude-haiku-4-5
239roles:
240  reflector:
241    primary: haiku
242    fallback: null
243    cost_budget_per_day_usd: 0.5
244"#;
245        let reg: ModelRegistry = serde_yaml_ng::from_str(yaml).unwrap();
246        assert_eq!(reg.roles["reflector"].primary, "haiku");
247        let back = serde_yaml_ng::to_string(&reg).unwrap();
248        let reg2: ModelRegistry = serde_yaml_ng::from_str(&back).unwrap();
249        assert_eq!(reg, reg2);
250    }
251
252    #[test]
253    fn test_resolve_role_primary() {
254        let mut reg = ModelRegistry::default();
255        reg.models.insert(
256            "haiku".into(),
257            ModelEntry {
258                provider: "anthropic".into(),
259                model: "claude-haiku-4-5".into(),
260                base_url: None,
261                secret: None,
262                capabilities: vec![],
263                params: serde_json::Value::Null,
264                tier: None,
265                cost_per_1k_tokens: None,
266                input_cost_per_1k: None,
267                output_cost_per_1k: None,
268                context_window: None,
269            },
270        );
271        reg.roles.insert(
272            "reflector".into(),
273            RoleEntry {
274                primary: "haiku".into(),
275                fallback: None,
276                ..Default::default()
277            },
278        );
279        assert_eq!(reg.resolve_role("reflector"), Some("haiku"));
280    }
281
282    #[test]
283    fn test_resolve_role_fallback() {
284        let mut reg = ModelRegistry::default();
285        reg.models.insert(
286            "haiku".into(),
287            ModelEntry {
288                provider: "anthropic".into(),
289                model: "claude-haiku-4-5".into(),
290                base_url: None,
291                secret: None,
292                capabilities: vec![],
293                params: serde_json::Value::Null,
294                tier: None,
295                cost_per_1k_tokens: None,
296                input_cost_per_1k: None,
297                output_cost_per_1k: None,
298                context_window: None,
299            },
300        );
301        reg.roles.insert(
302            "reflector".into(),
303            RoleEntry {
304                primary: "nonexistent".into(),
305                fallback: Some("haiku".into()),
306                ..Default::default()
307            },
308        );
309        assert_eq!(reg.resolve_role("reflector"), Some("haiku"));
310    }
311
312    #[test]
313    fn test_resolve_role_none() {
314        let reg = ModelRegistry::default();
315        assert_eq!(reg.resolve_role("reflector"), None);
316    }
317
318    #[test]
319    fn model_entry_parses_tier_field() {
320        let yaml = r#"
321schema_version: 1
322models:
323  haiku:
324    provider: anthropic
325    model: claude-haiku-4-5
326    tier: local
327  opus:
328    provider: anthropic
329    model: claude-opus-4-7
330    tier: frontier
331    cost_per_1k_tokens: 0.015
332"#;
333        let r: ModelRegistry = serde_yaml_ng::from_str(yaml).unwrap();
334        assert_eq!(r.models["haiku"].tier, Some(RouteTier::Local));
335        assert_eq!(r.models["opus"].tier, Some(RouteTier::Frontier));
336        assert_eq!(r.models["opus"].cost_per_1k_tokens, Some(0.015));
337        // Missing tier is None.
338        let mut r2 = ModelRegistry::default();
339        r2.models.insert(
340            "x".into(),
341            ModelEntry {
342                provider: "ollama".into(),
343                model: "llama3".into(),
344                base_url: None,
345                secret: None,
346                capabilities: vec![],
347                params: serde_json::Value::Null,
348                tier: None,
349                cost_per_1k_tokens: None,
350                input_cost_per_1k: None,
351                output_cost_per_1k: None,
352                context_window: None,
353            },
354        );
355        let yaml = serde_yaml_ng::to_string(&r2).unwrap();
356        assert!(
357            !yaml.contains("tier:"),
358            "absent tier should not be serialized: {yaml}"
359        );
360    }
361
362    #[test]
363    fn role_entry_parses_route_policy() {
364        let yaml = r#"
365schema_version: 1
366models:
367  haiku:
368    provider: anthropic
369    model: claude-haiku-4-5
370  opus:
371    provider: anthropic
372    model: claude-opus-4-7
373roles:
374  dev:
375    primary: opus
376    route_policy: !force_frontier
377      model_id: opus
378  reflector:
379    primary: haiku
380    route_policy: prefer_local
381  curator:
382    primary: haiku
383    route_policy: force_local
384  chat:
385    primary: haiku
386"#;
387        let r: ModelRegistry = serde_yaml_ng::from_str(yaml).unwrap();
388        assert_eq!(
389            r.roles["dev"].route_policy,
390            Some(RoutePolicy::ForceFrontier {
391                model_id: "opus".into()
392            })
393        );
394        assert_eq!(
395            r.roles["reflector"].route_policy,
396            Some(RoutePolicy::PreferLocal)
397        );
398        assert_eq!(
399            r.roles["curator"].route_policy,
400            Some(RoutePolicy::ForceLocal)
401        );
402        assert_eq!(r.roles["chat"].route_policy, None);
403    }
404
405    #[test]
406    fn parses_split_cost_fields() {
407        let yaml = r#"
408schema_version: 1
409models:
410  opus:
411    provider: anthropic
412    model: claude-opus-4-8
413    input_cost_per_1k: 0.005
414    output_cost_per_1k: 0.025
415    context_window: 200000
416"#;
417        let r: ModelRegistry = serde_yaml_ng::from_str(yaml).unwrap();
418        let e = r.models.get("opus").unwrap();
419        assert_eq!(e.input_cost_per_1k, Some(0.005));
420        assert_eq!(e.output_cost_per_1k, Some(0.025));
421        assert_eq!(e.context_window, Some(200_000));
422    }
423
424    #[test]
425    fn default_model_entry_is_empty() {
426        let e = ModelEntry::default();
427        assert!(e.provider.is_empty());
428        assert_eq!(e.input_cost_per_1k, None);
429        assert_eq!(e.output_cost_per_1k, None);
430        assert_eq!(e.context_window, None);
431    }
432
433    #[test]
434    fn effective_costs_fallback_matrix() {
435        // legacy only → both fall back to the blended rate
436        let mut e = ModelEntry {
437            cost_per_1k_tokens: Some(0.01),
438            ..Default::default()
439        };
440        assert_eq!(e.effective_costs(), (Some(0.01), Some(0.01)));
441
442        // split only → split wins, legacy ignored
443        e = ModelEntry {
444            input_cost_per_1k: Some(0.005),
445            output_cost_per_1k: Some(0.025),
446            ..Default::default()
447        };
448        assert_eq!(e.effective_costs(), (Some(0.005), Some(0.025)));
449
450        // both → split wins
451        e = ModelEntry {
452            cost_per_1k_tokens: Some(0.01),
453            input_cost_per_1k: Some(0.005),
454            output_cost_per_1k: Some(0.025),
455            ..Default::default()
456        };
457        assert_eq!(e.effective_costs(), (Some(0.005), Some(0.025)));
458
459        // none → none
460        e = ModelEntry::default();
461        assert_eq!(e.effective_costs(), (None, None));
462    }
463}
464
465#[cfg(test)]
466mod io_tests {
467    use super::*;
468    use tempfile::tempdir;
469
470    #[test]
471    fn load_returns_empty_when_file_missing() {
472        let dir = tempdir().unwrap();
473        let r = ModelRegistry::load_from(&dir.path().join("nope.yaml")).unwrap();
474        assert_eq!(r.models.len(), 0);
475        assert_eq!(r.schema_version, 1);
476    }
477
478    #[test]
479    fn save_then_load_round_trips() {
480        let dir = tempdir().unwrap();
481        let p = dir.path().join("models.yaml");
482        let mut r = ModelRegistry::default();
483        r.models.insert(
484            "x".into(),
485            ModelEntry {
486                provider: "ollama".into(),
487                model: "llama3.2:3b".into(),
488                base_url: None,
489                secret: None,
490                capabilities: vec![],
491                params: serde_json::Value::Null,
492                tier: None,
493                cost_per_1k_tokens: None,
494                input_cost_per_1k: None,
495                output_cost_per_1k: None,
496                context_window: None,
497            },
498        );
499        r.save_to(&p).unwrap();
500        let r2 = ModelRegistry::load_from(&p).unwrap();
501        assert_eq!(r, r2);
502    }
503
504    #[test]
505    fn save_uses_atomic_rename() {
506        let dir = tempdir().unwrap();
507        let p = dir.path().join("models.yaml");
508        ModelRegistry::default().save_to(&p).unwrap();
509        let temp = dir.path().join("models.yaml.tmp");
510        assert!(!temp.exists(), "atomic temp left behind");
511    }
512}