1use crate::route::{RoutePolicy, RouteTier};
16use crate::secret::SecretRef;
17use serde::{Deserialize, Serialize};
18use std::collections::BTreeMap;
19use std::path::{Path, PathBuf};
20
21#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
22pub struct ModelEntry {
23 #[serde(default)]
24 pub provider: String,
25 #[serde(default)]
26 pub model: String,
27 #[serde(default, skip_serializing_if = "Option::is_none")]
28 pub base_url: Option<String>,
29 #[serde(default, skip_serializing_if = "Option::is_none")]
30 pub secret: Option<SecretRef>,
31 #[serde(default, skip_serializing_if = "Vec::is_empty")]
32 pub capabilities: Vec<String>,
33 #[serde(default, skip_serializing_if = "serde_json::Value::is_null")]
34 pub params: serde_json::Value,
35 #[serde(default, skip_serializing_if = "Option::is_none")]
38 pub tier: Option<RouteTier>,
39 #[serde(default, skip_serializing_if = "Option::is_none")]
42 pub cost_per_1k_tokens: Option<f64>,
43 #[serde(default, skip_serializing_if = "Option::is_none")]
46 pub input_cost_per_1k: Option<f64>,
47 #[serde(default, skip_serializing_if = "Option::is_none")]
50 pub output_cost_per_1k: Option<f64>,
51 #[serde(default, skip_serializing_if = "Option::is_none")]
53 pub context_window: Option<u64>,
54}
55
56impl ModelEntry {
57 pub fn effective_costs(&self) -> (Option<f64>, Option<f64>) {
62 let output = self.output_cost_per_1k.or(self.cost_per_1k_tokens);
63 let input = self.input_cost_per_1k.or(self.cost_per_1k_tokens);
64 (input, output)
65 }
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
69pub struct RoleEntry {
70 pub primary: String,
72 #[serde(default, skip_serializing_if = "Option::is_none")]
74 pub fallback: Option<String>,
75 #[serde(default, skip_serializing_if = "Option::is_none")]
77 pub cost_budget_per_day_usd: Option<f64>,
78 #[serde(default)]
80 pub privacy_local_only: bool,
81 #[serde(default, skip_serializing_if = "Option::is_none")]
84 pub route_policy: Option<RoutePolicy>,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
88pub struct ModelRegistry {
89 pub schema_version: u32,
90 #[serde(default)]
91 pub models: BTreeMap<String, ModelEntry>,
92 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
93 pub roles: BTreeMap<String, RoleEntry>,
94}
95
96impl Default for ModelRegistry {
97 fn default() -> Self {
98 Self {
99 schema_version: 1,
100 models: BTreeMap::new(),
101 roles: BTreeMap::new(),
102 }
103 }
104}
105
106impl ModelRegistry {
107 pub fn load_from(path: &Path) -> anyhow::Result<Self> {
108 if !path.exists() {
109 return Ok(Self::default());
110 }
111 let body = std::fs::read_to_string(path)?;
112 if body.trim().is_empty() {
113 return Ok(Self::default());
114 }
115 Ok(serde_yaml_ng::from_str(&body)?)
116 }
117
118 pub fn save_to(&self, path: &Path) -> anyhow::Result<()> {
119 if let Some(parent) = path.parent() {
120 std::fs::create_dir_all(parent)?;
121 }
122 let body = serde_yaml_ng::to_string(self)?;
123 let tmp = path.with_extension("yaml.tmp");
124 std::fs::write(&tmp, body)?;
125 std::fs::rename(&tmp, path)?;
126 Ok(())
127 }
128
129 pub fn default_path() -> anyhow::Result<PathBuf> {
130 if let Ok(p) = std::env::var("MUR_HOME")
133 && !p.is_empty()
134 {
135 return Ok(PathBuf::from(p).join("models.yaml"));
136 }
137 let home = dirs::home_dir().ok_or_else(|| anyhow::anyhow!("no home dir"))?;
138 Ok(home.join(".mur/models.yaml"))
139 }
140
141 pub fn resolve_role(&self, role: &str) -> Option<&str> {
144 let entry = self.roles.get(role)?;
145 if self.models.contains_key(&entry.primary) {
146 return Some(&entry.primary);
147 }
148 if let Some(fb) = &entry.fallback
150 && self.models.contains_key(fb)
151 {
152 return Some(fb);
153 }
154 None
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162
163 #[test]
164 fn parses_full_registry() {
165 let yaml = r#"
166schema_version: 1
167models:
168 anthropic_opus_4_7:
169 provider: anthropic
170 model: claude-opus-4-7
171 secret: env:ANTHROPIC_API_KEY
172 capabilities: [chat, tools]
173 ollama_llama3:
174 provider: ollama
175 model: llama3.2:3b
176 base_url: http://127.0.0.1:11434
177"#;
178 let r: ModelRegistry = serde_yaml_ng::from_str(yaml).unwrap();
179 assert_eq!(r.schema_version, 1);
180 assert_eq!(r.models.len(), 2);
181 let opus = r.models.get("anthropic_opus_4_7").unwrap();
182 assert_eq!(opus.provider, "anthropic");
183 assert_eq!(
184 opus.secret,
185 Some(SecretRef::Env("ANTHROPIC_API_KEY".into()))
186 );
187 assert!(r.models["ollama_llama3"].secret.is_none());
188 }
189
190 #[test]
191 fn round_trip_preserves_shape() {
192 let mut r = ModelRegistry::default();
193 r.models.insert(
194 "foo".into(),
195 ModelEntry {
196 provider: "anthropic".into(),
197 model: "claude-opus-4-7".into(),
198 base_url: None,
199 secret: Some(SecretRef::Keychain {
200 service: "mur".into(),
201 account: "anthropic".into(),
202 }),
203 capabilities: vec!["chat".into()],
204 params: serde_json::Value::Null,
205 tier: None,
206 cost_per_1k_tokens: None,
207 input_cost_per_1k: None,
208 output_cost_per_1k: None,
209 context_window: None,
210 },
211 );
212 let s = serde_yaml_ng::to_string(&r).unwrap();
213 let parsed: ModelRegistry = serde_yaml_ng::from_str(&s).unwrap();
214 assert_eq!(r, parsed);
215 }
216
217 #[test]
218 fn rejects_unknown_secret_scheme() {
219 let yaml = r#"
220schema_version: 1
221models:
222 bad:
223 provider: x
224 model: y
225 secret: bogus:value
226"#;
227 let r: Result<ModelRegistry, _> = serde_yaml_ng::from_str(yaml);
228 assert!(r.is_err(), "should reject unknown scheme");
229 }
230
231 #[test]
232 fn test_registry_roundtrip_with_roles() {
233 let yaml = r#"
234schema_version: 1
235models:
236 haiku:
237 provider: anthropic
238 model: claude-haiku-4-5
239roles:
240 reflector:
241 primary: haiku
242 fallback: null
243 cost_budget_per_day_usd: 0.5
244"#;
245 let reg: ModelRegistry = serde_yaml_ng::from_str(yaml).unwrap();
246 assert_eq!(reg.roles["reflector"].primary, "haiku");
247 let back = serde_yaml_ng::to_string(®).unwrap();
248 let reg2: ModelRegistry = serde_yaml_ng::from_str(&back).unwrap();
249 assert_eq!(reg, reg2);
250 }
251
252 #[test]
253 fn test_resolve_role_primary() {
254 let mut reg = ModelRegistry::default();
255 reg.models.insert(
256 "haiku".into(),
257 ModelEntry {
258 provider: "anthropic".into(),
259 model: "claude-haiku-4-5".into(),
260 base_url: None,
261 secret: None,
262 capabilities: vec![],
263 params: serde_json::Value::Null,
264 tier: None,
265 cost_per_1k_tokens: None,
266 input_cost_per_1k: None,
267 output_cost_per_1k: None,
268 context_window: None,
269 },
270 );
271 reg.roles.insert(
272 "reflector".into(),
273 RoleEntry {
274 primary: "haiku".into(),
275 fallback: None,
276 ..Default::default()
277 },
278 );
279 assert_eq!(reg.resolve_role("reflector"), Some("haiku"));
280 }
281
282 #[test]
283 fn test_resolve_role_fallback() {
284 let mut reg = ModelRegistry::default();
285 reg.models.insert(
286 "haiku".into(),
287 ModelEntry {
288 provider: "anthropic".into(),
289 model: "claude-haiku-4-5".into(),
290 base_url: None,
291 secret: None,
292 capabilities: vec![],
293 params: serde_json::Value::Null,
294 tier: None,
295 cost_per_1k_tokens: None,
296 input_cost_per_1k: None,
297 output_cost_per_1k: None,
298 context_window: None,
299 },
300 );
301 reg.roles.insert(
302 "reflector".into(),
303 RoleEntry {
304 primary: "nonexistent".into(),
305 fallback: Some("haiku".into()),
306 ..Default::default()
307 },
308 );
309 assert_eq!(reg.resolve_role("reflector"), Some("haiku"));
310 }
311
312 #[test]
313 fn test_resolve_role_none() {
314 let reg = ModelRegistry::default();
315 assert_eq!(reg.resolve_role("reflector"), None);
316 }
317
318 #[test]
319 fn model_entry_parses_tier_field() {
320 let yaml = r#"
321schema_version: 1
322models:
323 haiku:
324 provider: anthropic
325 model: claude-haiku-4-5
326 tier: local
327 opus:
328 provider: anthropic
329 model: claude-opus-4-7
330 tier: frontier
331 cost_per_1k_tokens: 0.015
332"#;
333 let r: ModelRegistry = serde_yaml_ng::from_str(yaml).unwrap();
334 assert_eq!(r.models["haiku"].tier, Some(RouteTier::Local));
335 assert_eq!(r.models["opus"].tier, Some(RouteTier::Frontier));
336 assert_eq!(r.models["opus"].cost_per_1k_tokens, Some(0.015));
337 let mut r2 = ModelRegistry::default();
339 r2.models.insert(
340 "x".into(),
341 ModelEntry {
342 provider: "ollama".into(),
343 model: "llama3".into(),
344 base_url: None,
345 secret: None,
346 capabilities: vec![],
347 params: serde_json::Value::Null,
348 tier: None,
349 cost_per_1k_tokens: None,
350 input_cost_per_1k: None,
351 output_cost_per_1k: None,
352 context_window: None,
353 },
354 );
355 let yaml = serde_yaml_ng::to_string(&r2).unwrap();
356 assert!(
357 !yaml.contains("tier:"),
358 "absent tier should not be serialized: {yaml}"
359 );
360 }
361
362 #[test]
363 fn role_entry_parses_route_policy() {
364 let yaml = r#"
365schema_version: 1
366models:
367 haiku:
368 provider: anthropic
369 model: claude-haiku-4-5
370 opus:
371 provider: anthropic
372 model: claude-opus-4-7
373roles:
374 dev:
375 primary: opus
376 route_policy: !force_frontier
377 model_id: opus
378 reflector:
379 primary: haiku
380 route_policy: prefer_local
381 curator:
382 primary: haiku
383 route_policy: force_local
384 chat:
385 primary: haiku
386"#;
387 let r: ModelRegistry = serde_yaml_ng::from_str(yaml).unwrap();
388 assert_eq!(
389 r.roles["dev"].route_policy,
390 Some(RoutePolicy::ForceFrontier {
391 model_id: "opus".into()
392 })
393 );
394 assert_eq!(
395 r.roles["reflector"].route_policy,
396 Some(RoutePolicy::PreferLocal)
397 );
398 assert_eq!(
399 r.roles["curator"].route_policy,
400 Some(RoutePolicy::ForceLocal)
401 );
402 assert_eq!(r.roles["chat"].route_policy, None);
403 }
404
405 #[test]
406 fn parses_split_cost_fields() {
407 let yaml = r#"
408schema_version: 1
409models:
410 opus:
411 provider: anthropic
412 model: claude-opus-4-8
413 input_cost_per_1k: 0.005
414 output_cost_per_1k: 0.025
415 context_window: 200000
416"#;
417 let r: ModelRegistry = serde_yaml_ng::from_str(yaml).unwrap();
418 let e = r.models.get("opus").unwrap();
419 assert_eq!(e.input_cost_per_1k, Some(0.005));
420 assert_eq!(e.output_cost_per_1k, Some(0.025));
421 assert_eq!(e.context_window, Some(200_000));
422 }
423
424 #[test]
425 fn default_model_entry_is_empty() {
426 let e = ModelEntry::default();
427 assert!(e.provider.is_empty());
428 assert_eq!(e.input_cost_per_1k, None);
429 assert_eq!(e.output_cost_per_1k, None);
430 assert_eq!(e.context_window, None);
431 }
432
433 #[test]
434 fn effective_costs_fallback_matrix() {
435 let mut e = ModelEntry {
437 cost_per_1k_tokens: Some(0.01),
438 ..Default::default()
439 };
440 assert_eq!(e.effective_costs(), (Some(0.01), Some(0.01)));
441
442 e = ModelEntry {
444 input_cost_per_1k: Some(0.005),
445 output_cost_per_1k: Some(0.025),
446 ..Default::default()
447 };
448 assert_eq!(e.effective_costs(), (Some(0.005), Some(0.025)));
449
450 e = ModelEntry {
452 cost_per_1k_tokens: Some(0.01),
453 input_cost_per_1k: Some(0.005),
454 output_cost_per_1k: Some(0.025),
455 ..Default::default()
456 };
457 assert_eq!(e.effective_costs(), (Some(0.005), Some(0.025)));
458
459 e = ModelEntry::default();
461 assert_eq!(e.effective_costs(), (None, None));
462 }
463}
464
465#[cfg(test)]
466mod io_tests {
467 use super::*;
468 use tempfile::tempdir;
469
470 #[test]
471 fn load_returns_empty_when_file_missing() {
472 let dir = tempdir().unwrap();
473 let r = ModelRegistry::load_from(&dir.path().join("nope.yaml")).unwrap();
474 assert_eq!(r.models.len(), 0);
475 assert_eq!(r.schema_version, 1);
476 }
477
478 #[test]
479 fn save_then_load_round_trips() {
480 let dir = tempdir().unwrap();
481 let p = dir.path().join("models.yaml");
482 let mut r = ModelRegistry::default();
483 r.models.insert(
484 "x".into(),
485 ModelEntry {
486 provider: "ollama".into(),
487 model: "llama3.2:3b".into(),
488 base_url: None,
489 secret: None,
490 capabilities: vec![],
491 params: serde_json::Value::Null,
492 tier: None,
493 cost_per_1k_tokens: None,
494 input_cost_per_1k: None,
495 output_cost_per_1k: None,
496 context_window: None,
497 },
498 );
499 r.save_to(&p).unwrap();
500 let r2 = ModelRegistry::load_from(&p).unwrap();
501 assert_eq!(r, r2);
502 }
503
504 #[test]
505 fn save_uses_atomic_rename() {
506 let dir = tempdir().unwrap();
507 let p = dir.path().join("models.yaml");
508 ModelRegistry::default().save_to(&p).unwrap();
509 let temp = dir.path().join("models.yaml.tmp");
510 assert!(!temp.exists(), "atomic temp left behind");
511 }
512}