1use crate::secret::SecretRef;
16use serde::{Deserialize, Serialize};
17use std::collections::BTreeMap;
18use std::path::{Path, PathBuf};
19
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
21pub struct ModelEntry {
22 pub provider: String,
23 pub model: String,
24 #[serde(default, skip_serializing_if = "Option::is_none")]
25 pub base_url: Option<String>,
26 #[serde(default, skip_serializing_if = "Option::is_none")]
27 pub secret: Option<SecretRef>,
28 #[serde(default, skip_serializing_if = "Vec::is_empty")]
29 pub capabilities: Vec<String>,
30 #[serde(default, skip_serializing_if = "serde_json::Value::is_null")]
31 pub params: serde_json::Value,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
35pub struct RoleEntry {
36 pub primary: String,
38 #[serde(default, skip_serializing_if = "Option::is_none")]
40 pub fallback: Option<String>,
41 #[serde(default, skip_serializing_if = "Option::is_none")]
43 pub cost_budget_per_day_usd: Option<f64>,
44 #[serde(default)]
46 pub privacy_local_only: bool,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
50pub struct ModelRegistry {
51 pub schema_version: u32,
52 #[serde(default)]
53 pub models: BTreeMap<String, ModelEntry>,
54 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
55 pub roles: BTreeMap<String, RoleEntry>,
56}
57
58impl Default for ModelRegistry {
59 fn default() -> Self {
60 Self {
61 schema_version: 1,
62 models: BTreeMap::new(),
63 roles: BTreeMap::new(),
64 }
65 }
66}
67
68impl ModelRegistry {
69 pub fn load_from(path: &Path) -> anyhow::Result<Self> {
70 if !path.exists() {
71 return Ok(Self::default());
72 }
73 let body = std::fs::read_to_string(path)?;
74 if body.trim().is_empty() {
75 return Ok(Self::default());
76 }
77 Ok(serde_yaml_ng::from_str(&body)?)
78 }
79
80 pub fn save_to(&self, path: &Path) -> anyhow::Result<()> {
81 if let Some(parent) = path.parent() {
82 std::fs::create_dir_all(parent)?;
83 }
84 let body = serde_yaml_ng::to_string(self)?;
85 let tmp = path.with_extension("yaml.tmp");
86 std::fs::write(&tmp, body)?;
87 std::fs::rename(&tmp, path)?;
88 Ok(())
89 }
90
91 pub fn default_path() -> anyhow::Result<PathBuf> {
92 if let Ok(p) = std::env::var("MUR_HOME")
95 && !p.is_empty()
96 {
97 return Ok(PathBuf::from(p).join("models.yaml"));
98 }
99 let home = dirs::home_dir().ok_or_else(|| anyhow::anyhow!("no home dir"))?;
100 Ok(home.join(".mur/models.yaml"))
101 }
102
103 pub fn resolve_role(&self, role: &str) -> Option<&str> {
106 let entry = self.roles.get(role)?;
107 if self.models.contains_key(&entry.primary) {
108 return Some(&entry.primary);
109 }
110 if let Some(fb) = &entry.fallback
112 && self.models.contains_key(fb)
113 {
114 return Some(fb);
115 }
116 None
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124
125 #[test]
126 fn parses_full_registry() {
127 let yaml = r#"
128schema_version: 1
129models:
130 anthropic_opus_4_7:
131 provider: anthropic
132 model: claude-opus-4-7
133 secret: env:ANTHROPIC_API_KEY
134 capabilities: [chat, tools]
135 ollama_llama3:
136 provider: ollama
137 model: llama3.2:3b
138 base_url: http://127.0.0.1:11434
139"#;
140 let r: ModelRegistry = serde_yaml_ng::from_str(yaml).unwrap();
141 assert_eq!(r.schema_version, 1);
142 assert_eq!(r.models.len(), 2);
143 let opus = r.models.get("anthropic_opus_4_7").unwrap();
144 assert_eq!(opus.provider, "anthropic");
145 assert_eq!(
146 opus.secret,
147 Some(SecretRef::Env("ANTHROPIC_API_KEY".into()))
148 );
149 assert!(r.models["ollama_llama3"].secret.is_none());
150 }
151
152 #[test]
153 fn round_trip_preserves_shape() {
154 let mut r = ModelRegistry::default();
155 r.models.insert(
156 "foo".into(),
157 ModelEntry {
158 provider: "anthropic".into(),
159 model: "claude-opus-4-7".into(),
160 base_url: None,
161 secret: Some(SecretRef::Keychain {
162 service: "mur".into(),
163 account: "anthropic".into(),
164 }),
165 capabilities: vec!["chat".into()],
166 params: serde_json::Value::Null,
167 },
168 );
169 let s = serde_yaml_ng::to_string(&r).unwrap();
170 let parsed: ModelRegistry = serde_yaml_ng::from_str(&s).unwrap();
171 assert_eq!(r, parsed);
172 }
173
174 #[test]
175 fn rejects_unknown_secret_scheme() {
176 let yaml = r#"
177schema_version: 1
178models:
179 bad:
180 provider: x
181 model: y
182 secret: bogus:value
183"#;
184 let r: Result<ModelRegistry, _> = serde_yaml_ng::from_str(yaml);
185 assert!(r.is_err(), "should reject unknown scheme");
186 }
187
188 #[test]
189 fn test_registry_roundtrip_with_roles() {
190 let yaml = r#"
191schema_version: 1
192models:
193 haiku:
194 provider: anthropic
195 model: claude-haiku-4-5
196roles:
197 reflector:
198 primary: haiku
199 fallback: null
200 cost_budget_per_day_usd: 0.5
201"#;
202 let reg: ModelRegistry = serde_yaml_ng::from_str(yaml).unwrap();
203 assert_eq!(reg.roles["reflector"].primary, "haiku");
204 let back = serde_yaml_ng::to_string(®).unwrap();
205 let reg2: ModelRegistry = serde_yaml_ng::from_str(&back).unwrap();
206 assert_eq!(reg, reg2);
207 }
208
209 #[test]
210 fn test_resolve_role_primary() {
211 let mut reg = ModelRegistry::default();
212 reg.models.insert(
213 "haiku".into(),
214 ModelEntry {
215 provider: "anthropic".into(),
216 model: "claude-haiku-4-5".into(),
217 base_url: None,
218 secret: None,
219 capabilities: vec![],
220 params: serde_json::Value::Null,
221 },
222 );
223 reg.roles.insert(
224 "reflector".into(),
225 RoleEntry {
226 primary: "haiku".into(),
227 fallback: None,
228 ..Default::default()
229 },
230 );
231 assert_eq!(reg.resolve_role("reflector"), Some("haiku"));
232 }
233
234 #[test]
235 fn test_resolve_role_fallback() {
236 let mut reg = ModelRegistry::default();
237 reg.models.insert(
238 "haiku".into(),
239 ModelEntry {
240 provider: "anthropic".into(),
241 model: "claude-haiku-4-5".into(),
242 base_url: None,
243 secret: None,
244 capabilities: vec![],
245 params: serde_json::Value::Null,
246 },
247 );
248 reg.roles.insert(
249 "reflector".into(),
250 RoleEntry {
251 primary: "nonexistent".into(),
252 fallback: Some("haiku".into()),
253 ..Default::default()
254 },
255 );
256 assert_eq!(reg.resolve_role("reflector"), Some("haiku"));
257 }
258
259 #[test]
260 fn test_resolve_role_none() {
261 let reg = ModelRegistry::default();
262 assert_eq!(reg.resolve_role("reflector"), None);
263 }
264}
265
266#[cfg(test)]
267mod io_tests {
268 use super::*;
269 use tempfile::tempdir;
270
271 #[test]
272 fn load_returns_empty_when_file_missing() {
273 let dir = tempdir().unwrap();
274 let r = ModelRegistry::load_from(&dir.path().join("nope.yaml")).unwrap();
275 assert_eq!(r.models.len(), 0);
276 assert_eq!(r.schema_version, 1);
277 }
278
279 #[test]
280 fn save_then_load_round_trips() {
281 let dir = tempdir().unwrap();
282 let p = dir.path().join("models.yaml");
283 let mut r = ModelRegistry::default();
284 r.models.insert(
285 "x".into(),
286 ModelEntry {
287 provider: "ollama".into(),
288 model: "llama3.2:3b".into(),
289 base_url: None,
290 secret: None,
291 capabilities: vec![],
292 params: serde_json::Value::Null,
293 },
294 );
295 r.save_to(&p).unwrap();
296 let r2 = ModelRegistry::load_from(&p).unwrap();
297 assert_eq!(r, r2);
298 }
299
300 #[test]
301 fn save_uses_atomic_rename() {
302 let dir = tempdir().unwrap();
303 let p = dir.path().join("models.yaml");
304 ModelRegistry::default().save_to(&p).unwrap();
305 let temp = dir.path().join("models.yaml.tmp");
306 assert!(!temp.exists(), "atomic temp left behind");
307 }
308}