1use crate::route::{RoutePolicy, RouteTier};
16use crate::secret::SecretRef;
17use serde::{Deserialize, Serialize};
18use std::collections::BTreeMap;
19use std::path::{Path, PathBuf};
20
21#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
22pub struct ModelEntry {
23 pub provider: String,
24 pub model: String,
25 #[serde(default, skip_serializing_if = "Option::is_none")]
26 pub base_url: Option<String>,
27 #[serde(default, skip_serializing_if = "Option::is_none")]
28 pub secret: Option<SecretRef>,
29 #[serde(default, skip_serializing_if = "Vec::is_empty")]
30 pub capabilities: Vec<String>,
31 #[serde(default, skip_serializing_if = "serde_json::Value::is_null")]
32 pub params: serde_json::Value,
33 #[serde(default, skip_serializing_if = "Option::is_none")]
36 pub tier: Option<RouteTier>,
37 #[serde(default, skip_serializing_if = "Option::is_none")]
40 pub cost_per_1k_tokens: Option<f64>,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
44pub struct RoleEntry {
45 pub primary: String,
47 #[serde(default, skip_serializing_if = "Option::is_none")]
49 pub fallback: Option<String>,
50 #[serde(default, skip_serializing_if = "Option::is_none")]
52 pub cost_budget_per_day_usd: Option<f64>,
53 #[serde(default)]
55 pub privacy_local_only: bool,
56 #[serde(default, skip_serializing_if = "Option::is_none")]
59 pub route_policy: Option<RoutePolicy>,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
63pub struct ModelRegistry {
64 pub schema_version: u32,
65 #[serde(default)]
66 pub models: BTreeMap<String, ModelEntry>,
67 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
68 pub roles: BTreeMap<String, RoleEntry>,
69}
70
71impl Default for ModelRegistry {
72 fn default() -> Self {
73 Self {
74 schema_version: 1,
75 models: BTreeMap::new(),
76 roles: BTreeMap::new(),
77 }
78 }
79}
80
81impl ModelRegistry {
82 pub fn load_from(path: &Path) -> anyhow::Result<Self> {
83 if !path.exists() {
84 return Ok(Self::default());
85 }
86 let body = std::fs::read_to_string(path)?;
87 if body.trim().is_empty() {
88 return Ok(Self::default());
89 }
90 Ok(serde_yaml_ng::from_str(&body)?)
91 }
92
93 pub fn save_to(&self, path: &Path) -> anyhow::Result<()> {
94 if let Some(parent) = path.parent() {
95 std::fs::create_dir_all(parent)?;
96 }
97 let body = serde_yaml_ng::to_string(self)?;
98 let tmp = path.with_extension("yaml.tmp");
99 std::fs::write(&tmp, body)?;
100 std::fs::rename(&tmp, path)?;
101 Ok(())
102 }
103
104 pub fn default_path() -> anyhow::Result<PathBuf> {
105 if let Ok(p) = std::env::var("MUR_HOME")
108 && !p.is_empty()
109 {
110 return Ok(PathBuf::from(p).join("models.yaml"));
111 }
112 let home = dirs::home_dir().ok_or_else(|| anyhow::anyhow!("no home dir"))?;
113 Ok(home.join(".mur/models.yaml"))
114 }
115
116 pub fn resolve_role(&self, role: &str) -> Option<&str> {
119 let entry = self.roles.get(role)?;
120 if self.models.contains_key(&entry.primary) {
121 return Some(&entry.primary);
122 }
123 if let Some(fb) = &entry.fallback
125 && self.models.contains_key(fb)
126 {
127 return Some(fb);
128 }
129 None
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137
138 #[test]
139 fn parses_full_registry() {
140 let yaml = r#"
141schema_version: 1
142models:
143 anthropic_opus_4_7:
144 provider: anthropic
145 model: claude-opus-4-7
146 secret: env:ANTHROPIC_API_KEY
147 capabilities: [chat, tools]
148 ollama_llama3:
149 provider: ollama
150 model: llama3.2:3b
151 base_url: http://127.0.0.1:11434
152"#;
153 let r: ModelRegistry = serde_yaml_ng::from_str(yaml).unwrap();
154 assert_eq!(r.schema_version, 1);
155 assert_eq!(r.models.len(), 2);
156 let opus = r.models.get("anthropic_opus_4_7").unwrap();
157 assert_eq!(opus.provider, "anthropic");
158 assert_eq!(
159 opus.secret,
160 Some(SecretRef::Env("ANTHROPIC_API_KEY".into()))
161 );
162 assert!(r.models["ollama_llama3"].secret.is_none());
163 }
164
165 #[test]
166 fn round_trip_preserves_shape() {
167 let mut r = ModelRegistry::default();
168 r.models.insert(
169 "foo".into(),
170 ModelEntry {
171 provider: "anthropic".into(),
172 model: "claude-opus-4-7".into(),
173 base_url: None,
174 secret: Some(SecretRef::Keychain {
175 service: "mur".into(),
176 account: "anthropic".into(),
177 }),
178 capabilities: vec!["chat".into()],
179 params: serde_json::Value::Null,
180 tier: None,
181 cost_per_1k_tokens: None,
182 },
183 );
184 let s = serde_yaml_ng::to_string(&r).unwrap();
185 let parsed: ModelRegistry = serde_yaml_ng::from_str(&s).unwrap();
186 assert_eq!(r, parsed);
187 }
188
189 #[test]
190 fn rejects_unknown_secret_scheme() {
191 let yaml = r#"
192schema_version: 1
193models:
194 bad:
195 provider: x
196 model: y
197 secret: bogus:value
198"#;
199 let r: Result<ModelRegistry, _> = serde_yaml_ng::from_str(yaml);
200 assert!(r.is_err(), "should reject unknown scheme");
201 }
202
203 #[test]
204 fn test_registry_roundtrip_with_roles() {
205 let yaml = r#"
206schema_version: 1
207models:
208 haiku:
209 provider: anthropic
210 model: claude-haiku-4-5
211roles:
212 reflector:
213 primary: haiku
214 fallback: null
215 cost_budget_per_day_usd: 0.5
216"#;
217 let reg: ModelRegistry = serde_yaml_ng::from_str(yaml).unwrap();
218 assert_eq!(reg.roles["reflector"].primary, "haiku");
219 let back = serde_yaml_ng::to_string(®).unwrap();
220 let reg2: ModelRegistry = serde_yaml_ng::from_str(&back).unwrap();
221 assert_eq!(reg, reg2);
222 }
223
224 #[test]
225 fn test_resolve_role_primary() {
226 let mut reg = ModelRegistry::default();
227 reg.models.insert(
228 "haiku".into(),
229 ModelEntry {
230 provider: "anthropic".into(),
231 model: "claude-haiku-4-5".into(),
232 base_url: None,
233 secret: None,
234 capabilities: vec![],
235 params: serde_json::Value::Null,
236 tier: None,
237 cost_per_1k_tokens: None,
238 },
239 );
240 reg.roles.insert(
241 "reflector".into(),
242 RoleEntry {
243 primary: "haiku".into(),
244 fallback: None,
245 ..Default::default()
246 },
247 );
248 assert_eq!(reg.resolve_role("reflector"), Some("haiku"));
249 }
250
251 #[test]
252 fn test_resolve_role_fallback() {
253 let mut reg = ModelRegistry::default();
254 reg.models.insert(
255 "haiku".into(),
256 ModelEntry {
257 provider: "anthropic".into(),
258 model: "claude-haiku-4-5".into(),
259 base_url: None,
260 secret: None,
261 capabilities: vec![],
262 params: serde_json::Value::Null,
263 tier: None,
264 cost_per_1k_tokens: None,
265 },
266 );
267 reg.roles.insert(
268 "reflector".into(),
269 RoleEntry {
270 primary: "nonexistent".into(),
271 fallback: Some("haiku".into()),
272 ..Default::default()
273 },
274 );
275 assert_eq!(reg.resolve_role("reflector"), Some("haiku"));
276 }
277
278 #[test]
279 fn test_resolve_role_none() {
280 let reg = ModelRegistry::default();
281 assert_eq!(reg.resolve_role("reflector"), None);
282 }
283
284 #[test]
285 fn model_entry_parses_tier_field() {
286 let yaml = r#"
287schema_version: 1
288models:
289 haiku:
290 provider: anthropic
291 model: claude-haiku-4-5
292 tier: local
293 opus:
294 provider: anthropic
295 model: claude-opus-4-7
296 tier: frontier
297 cost_per_1k_tokens: 0.015
298"#;
299 let r: ModelRegistry = serde_yaml_ng::from_str(yaml).unwrap();
300 assert_eq!(r.models["haiku"].tier, Some(RouteTier::Local));
301 assert_eq!(r.models["opus"].tier, Some(RouteTier::Frontier));
302 assert_eq!(r.models["opus"].cost_per_1k_tokens, Some(0.015));
303 let mut r2 = ModelRegistry::default();
305 r2.models.insert(
306 "x".into(),
307 ModelEntry {
308 provider: "ollama".into(),
309 model: "llama3".into(),
310 base_url: None,
311 secret: None,
312 capabilities: vec![],
313 params: serde_json::Value::Null,
314 tier: None,
315 cost_per_1k_tokens: None,
316 },
317 );
318 let yaml = serde_yaml_ng::to_string(&r2).unwrap();
319 assert!(
320 !yaml.contains("tier:"),
321 "absent tier should not be serialized: {yaml}"
322 );
323 }
324
325 #[test]
326 fn role_entry_parses_route_policy() {
327 let yaml = r#"
328schema_version: 1
329models:
330 haiku:
331 provider: anthropic
332 model: claude-haiku-4-5
333 opus:
334 provider: anthropic
335 model: claude-opus-4-7
336roles:
337 dev:
338 primary: opus
339 route_policy: !force_frontier
340 model_id: opus
341 reflector:
342 primary: haiku
343 route_policy: prefer_local
344 curator:
345 primary: haiku
346 route_policy: force_local
347 chat:
348 primary: haiku
349"#;
350 let r: ModelRegistry = serde_yaml_ng::from_str(yaml).unwrap();
351 assert_eq!(
352 r.roles["dev"].route_policy,
353 Some(RoutePolicy::ForceFrontier {
354 model_id: "opus".into()
355 })
356 );
357 assert_eq!(
358 r.roles["reflector"].route_policy,
359 Some(RoutePolicy::PreferLocal)
360 );
361 assert_eq!(
362 r.roles["curator"].route_policy,
363 Some(RoutePolicy::ForceLocal)
364 );
365 assert_eq!(r.roles["chat"].route_policy, None);
366 }
367}
368
369#[cfg(test)]
370mod io_tests {
371 use super::*;
372 use tempfile::tempdir;
373
374 #[test]
375 fn load_returns_empty_when_file_missing() {
376 let dir = tempdir().unwrap();
377 let r = ModelRegistry::load_from(&dir.path().join("nope.yaml")).unwrap();
378 assert_eq!(r.models.len(), 0);
379 assert_eq!(r.schema_version, 1);
380 }
381
382 #[test]
383 fn save_then_load_round_trips() {
384 let dir = tempdir().unwrap();
385 let p = dir.path().join("models.yaml");
386 let mut r = ModelRegistry::default();
387 r.models.insert(
388 "x".into(),
389 ModelEntry {
390 provider: "ollama".into(),
391 model: "llama3.2:3b".into(),
392 base_url: None,
393 secret: None,
394 capabilities: vec![],
395 params: serde_json::Value::Null,
396 tier: None,
397 cost_per_1k_tokens: None,
398 },
399 );
400 r.save_to(&p).unwrap();
401 let r2 = ModelRegistry::load_from(&p).unwrap();
402 assert_eq!(r, r2);
403 }
404
405 #[test]
406 fn save_uses_atomic_rename() {
407 let dir = tempdir().unwrap();
408 let p = dir.path().join("models.yaml");
409 ModelRegistry::default().save_to(&p).unwrap();
410 let temp = dir.path().join("models.yaml.tmp");
411 assert!(!temp.exists(), "atomic temp left behind");
412 }
413}