1use std::collections::HashMap;
16use std::sync::{Arc, Mutex, PoisonError};
17
18use super::jobs::now_ms;
19use super::persist::{key, ns, MlPersistence};
20use crate::json::{Map, Value as JsonValue};
21
22#[derive(Debug, Clone, PartialEq, Eq)]
24pub enum ModelRegistryError {
25 UnknownModel(String),
27 UnknownVersion { model: String, version: u32 },
29 VersionArchived { model: String, version: u32 },
31 LockPoisoned,
34 Backend(String),
36}
37
38impl std::fmt::Display for ModelRegistryError {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 match self {
41 ModelRegistryError::UnknownModel(name) => {
42 write!(f, "unknown model '{name}'")
43 }
44 ModelRegistryError::UnknownVersion { model, version } => {
45 write!(f, "unknown version {version} for model '{model}'")
46 }
47 ModelRegistryError::VersionArchived { model, version } => {
48 write!(
49 f,
50 "version {version} of model '{model}' is archived; unarchive before use"
51 )
52 }
53 ModelRegistryError::LockPoisoned => write!(f, "ml registry lock poisoned"),
54 ModelRegistryError::Backend(msg) => write!(f, "ml registry backend: {msg}"),
55 }
56 }
57}
58
59impl std::error::Error for ModelRegistryError {}
60
61impl<T> From<PoisonError<T>> for ModelRegistryError {
62 fn from(_: PoisonError<T>) -> Self {
63 ModelRegistryError::LockPoisoned
64 }
65}
66
67#[derive(Debug, Clone)]
74pub struct ModelVersion {
75 pub model: String,
76 pub version: u32,
77 pub weights_blob: Vec<u8>,
78 pub hyperparams_json: String,
79 pub metrics_json: String,
80 pub training_data_hash: Option<String>,
82 pub training_sql: Option<String>,
84 pub parent_version: Option<u32>,
86 pub created_at_ms: u64,
88 pub created_by: Option<String>,
90 pub archived: bool,
91}
92
93#[derive(Debug, Clone)]
95pub struct ModelSummary {
96 pub name: String,
97 pub active_version: Option<u32>,
98 pub total_versions: usize,
99 pub archived_versions: usize,
100}
101
102#[derive(Debug)]
103struct ModelState {
104 versions: Vec<ModelVersion>,
105 active_version: Option<u32>,
106}
107
108#[derive(Clone)]
115pub struct ModelRegistry {
116 inner: Arc<Mutex<HashMap<String, ModelState>>>,
117 backend: Option<Arc<dyn MlPersistence>>,
118}
119
120impl Default for ModelRegistry {
121 fn default() -> Self {
122 Self {
123 inner: Arc::new(Mutex::new(HashMap::new())),
124 backend: None,
125 }
126 }
127}
128
129impl std::fmt::Debug for ModelRegistry {
130 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131 f.debug_struct("ModelRegistry")
132 .field("has_backend", &self.backend.is_some())
133 .finish()
134 }
135}
136
137impl ModelRegistry {
138 pub fn new() -> Self {
139 Self::default()
140 }
141
142 pub fn with_backend(backend: Arc<dyn MlPersistence>) -> Self {
146 let registry = Self {
147 inner: Arc::new(Mutex::new(HashMap::new())),
148 backend: Some(backend),
149 };
150 let _ = registry.load_from_backend();
154 registry
155 }
156
157 pub fn load_from_backend(&self) -> Result<(), ModelRegistryError> {
161 let Some(backend) = self.backend.as_ref() else {
162 return Ok(());
163 };
164 let model_rows = backend
165 .list(ns::MODELS)
166 .map_err(|e| ModelRegistryError::Backend(e.to_string()))?;
167 let version_rows = backend
168 .list(ns::MODEL_VERSIONS)
169 .map_err(|e| ModelRegistryError::Backend(e.to_string()))?;
170 let mut guard = self.inner.lock()?;
171 guard.clear();
172 for (key, raw) in model_rows {
173 let active = decode_model_active(&raw);
174 guard.insert(
175 key,
176 ModelState {
177 versions: Vec::new(),
178 active_version: active,
179 },
180 );
181 }
182 for (k, raw) in version_rows {
183 let Some((model, _)) = key::parse_model_version(&k) else {
184 continue;
185 };
186 let Some(version) = ModelVersion::from_json(&raw) else {
187 continue;
188 };
189 let state = guard.entry(model).or_insert_with(|| ModelState {
190 versions: Vec::new(),
191 active_version: None,
192 });
193 state.versions.push(version);
194 }
195 Ok(())
196 }
197
198 fn persist_model(&self, name: &str, active: Option<u32>) {
199 if let Some(backend) = self.backend.as_ref() {
200 let raw = encode_model_active(active);
201 let _ = backend.put(ns::MODELS, &key::model(name), &raw);
202 }
203 }
204
205 fn persist_version(&self, version: &ModelVersion) {
206 if let Some(backend) = self.backend.as_ref() {
207 let raw = version.to_json();
208 let _ = backend.put(
209 ns::MODEL_VERSIONS,
210 &key::model_version(&version.model, version.version),
211 &raw,
212 );
213 }
214 }
215
216 pub fn register_version(
224 &self,
225 model: impl Into<String>,
226 mut version: ModelVersion,
227 make_active: bool,
228 ) -> Result<u32, ModelRegistryError> {
229 let name = model.into();
230 let mut guard = self.inner.lock()?;
231 let state = guard.entry(name.clone()).or_insert_with(|| ModelState {
232 versions: Vec::new(),
233 active_version: None,
234 });
235 let next_version = state
236 .versions
237 .iter()
238 .map(|v| v.version)
239 .max()
240 .unwrap_or(0)
241 .saturating_add(1);
242 version.model = name.clone();
243 version.version = next_version;
244 version.archived = false;
245 if version.created_at_ms == 0 {
246 version.created_at_ms = now_ms();
247 }
248 state.versions.push(version.clone());
249 if make_active {
250 state.active_version = Some(next_version);
251 }
252 let active_snapshot = state.active_version;
253 drop(guard);
254 self.persist_version(&version);
255 self.persist_model(&name, active_snapshot);
256 Ok(next_version)
257 }
258
259 pub fn set_active_version(&self, model: &str, version: u32) -> Result<(), ModelRegistryError> {
262 let mut guard = self.inner.lock()?;
263 let state = guard
264 .get_mut(model)
265 .ok_or_else(|| ModelRegistryError::UnknownModel(model.to_string()))?;
266 let found = state.versions.iter().find(|v| v.version == version).ok_or(
267 ModelRegistryError::UnknownVersion {
268 model: model.to_string(),
269 version,
270 },
271 )?;
272 if found.archived {
273 return Err(ModelRegistryError::VersionArchived {
274 model: model.to_string(),
275 version,
276 });
277 }
278 state.active_version = Some(version);
279 drop(guard);
280 self.persist_model(model, Some(version));
281 Ok(())
282 }
283
284 pub fn archive_version(&self, model: &str, version: u32) -> Result<(), ModelRegistryError> {
289 let mut guard = self.inner.lock()?;
290 let state = guard
291 .get_mut(model)
292 .ok_or_else(|| ModelRegistryError::UnknownModel(model.to_string()))?;
293 let entry = state
294 .versions
295 .iter_mut()
296 .find(|v| v.version == version)
297 .ok_or(ModelRegistryError::UnknownVersion {
298 model: model.to_string(),
299 version,
300 })?;
301 entry.archived = true;
302 let entry_snapshot = entry.clone();
303 if state.active_version == Some(version) {
304 state.active_version = None;
305 }
306 let active_snapshot = state.active_version;
307 drop(guard);
308 self.persist_version(&entry_snapshot);
309 self.persist_model(model, active_snapshot);
310 Ok(())
311 }
312
313 pub fn get_version(
316 &self,
317 model: &str,
318 version: u32,
319 ) -> Result<ModelVersion, ModelRegistryError> {
320 let guard = self.inner.lock()?;
321 let state = guard
322 .get(model)
323 .ok_or_else(|| ModelRegistryError::UnknownModel(model.to_string()))?;
324 state
325 .versions
326 .iter()
327 .find(|v| v.version == version)
328 .cloned()
329 .ok_or(ModelRegistryError::UnknownVersion {
330 model: model.to_string(),
331 version,
332 })
333 }
334
335 pub fn get_active(&self, model: &str) -> Result<Option<ModelVersion>, ModelRegistryError> {
337 let guard = self.inner.lock()?;
338 let Some(state) = guard.get(model) else {
339 return Err(ModelRegistryError::UnknownModel(model.to_string()));
340 };
341 let Some(active) = state.active_version else {
342 return Ok(None);
343 };
344 Ok(state.versions.iter().find(|v| v.version == active).cloned())
345 }
346
347 pub fn list_versions(&self, model: &str) -> Result<Vec<ModelVersion>, ModelRegistryError> {
349 let guard = self.inner.lock()?;
350 let state = guard
351 .get(model)
352 .ok_or_else(|| ModelRegistryError::UnknownModel(model.to_string()))?;
353 let mut out = state.versions.clone();
354 out.sort_by_key(|v| v.version);
355 Ok(out)
356 }
357
358 pub fn summaries(&self) -> Result<Vec<ModelSummary>, ModelRegistryError> {
360 let guard = self.inner.lock()?;
361 let mut out: Vec<ModelSummary> = guard
362 .iter()
363 .map(|(name, state)| ModelSummary {
364 name: name.clone(),
365 active_version: state.active_version,
366 total_versions: state.versions.len(),
367 archived_versions: state.versions.iter().filter(|v| v.archived).count(),
368 })
369 .collect();
370 out.sort_by(|a, b| a.name.cmp(&b.name));
371 Ok(out)
372 }
373}
374
375impl ModelVersion {
383 pub fn to_json(&self) -> String {
384 let mut obj = Map::new();
385 obj.insert("model".to_string(), JsonValue::String(self.model.clone()));
386 obj.insert(
387 "version".to_string(),
388 JsonValue::Number(self.version as f64),
389 );
390 obj.insert(
391 "weights_hex".to_string(),
392 JsonValue::String(hex_encode(&self.weights_blob)),
393 );
394 obj.insert(
395 "hyperparams".to_string(),
396 JsonValue::String(self.hyperparams_json.clone()),
397 );
398 obj.insert(
399 "metrics".to_string(),
400 JsonValue::String(self.metrics_json.clone()),
401 );
402 obj.insert(
403 "training_data_hash".to_string(),
404 self.training_data_hash
405 .as_ref()
406 .map(|s| JsonValue::String(s.clone()))
407 .unwrap_or(JsonValue::Null),
408 );
409 obj.insert(
410 "training_sql".to_string(),
411 self.training_sql
412 .as_ref()
413 .map(|s| JsonValue::String(s.clone()))
414 .unwrap_or(JsonValue::Null),
415 );
416 obj.insert(
417 "parent_version".to_string(),
418 self.parent_version
419 .map(|v| JsonValue::Number(v as f64))
420 .unwrap_or(JsonValue::Null),
421 );
422 obj.insert(
423 "created_at".to_string(),
424 JsonValue::Number(self.created_at_ms as f64),
425 );
426 obj.insert(
427 "created_by".to_string(),
428 self.created_by
429 .as_ref()
430 .map(|s| JsonValue::String(s.clone()))
431 .unwrap_or(JsonValue::Null),
432 );
433 obj.insert("archived".to_string(), JsonValue::Bool(self.archived));
434 JsonValue::Object(obj).to_string_compact()
435 }
436
437 pub fn from_json(raw: &str) -> Option<Self> {
438 let parsed = crate::json::parse_json(raw).ok()?;
439 let value = JsonValue::from(parsed);
440 let obj = value.as_object()?;
441 let model = obj.get("model")?.as_str()?.to_string();
442 let version = obj.get("version")?.as_i64()? as u32;
443 let weights_blob = hex_decode(obj.get("weights_hex")?.as_str()?)?;
444 let hyperparams_json = obj.get("hyperparams")?.as_str()?.to_string();
445 let metrics_json = obj.get("metrics")?.as_str()?.to_string();
446 let training_data_hash = match obj.get("training_data_hash") {
447 Some(JsonValue::String(s)) => Some(s.clone()),
448 _ => None,
449 };
450 let training_sql = match obj.get("training_sql") {
451 Some(JsonValue::String(s)) => Some(s.clone()),
452 _ => None,
453 };
454 let parent_version = match obj.get("parent_version") {
455 Some(JsonValue::Number(n)) => Some(*n as u32),
456 _ => None,
457 };
458 let created_at_ms = obj.get("created_at")?.as_i64()? as u64;
459 let created_by = match obj.get("created_by") {
460 Some(JsonValue::String(s)) => Some(s.clone()),
461 _ => None,
462 };
463 let archived = match obj.get("archived") {
464 Some(JsonValue::Bool(b)) => *b,
465 _ => false,
466 };
467 Some(ModelVersion {
468 model,
469 version,
470 weights_blob,
471 hyperparams_json,
472 metrics_json,
473 training_data_hash,
474 training_sql,
475 parent_version,
476 created_at_ms,
477 created_by,
478 archived,
479 })
480 }
481}
482
483fn encode_model_active(active: Option<u32>) -> String {
484 let mut obj = Map::new();
485 obj.insert(
486 "active".to_string(),
487 active
488 .map(|v| JsonValue::Number(v as f64))
489 .unwrap_or(JsonValue::Null),
490 );
491 JsonValue::Object(obj).to_string_compact()
492}
493
494fn decode_model_active(raw: &str) -> Option<u32> {
495 let parsed = crate::json::parse_json(raw).ok()?;
496 let value = JsonValue::from(parsed);
497 match value.as_object()?.get("active") {
498 Some(JsonValue::Number(n)) => Some(*n as u32),
499 _ => None,
500 }
501}
502
503fn hex_encode(bytes: &[u8]) -> String {
504 let mut out = String::with_capacity(bytes.len() * 2);
505 for b in bytes {
506 out.push_str(&format!("{b:02x}"));
507 }
508 out
509}
510
511fn hex_decode(s: &str) -> Option<Vec<u8>> {
512 if !s.len().is_multiple_of(2) {
513 return None;
514 }
515 let mut out = Vec::with_capacity(s.len() / 2);
516 for chunk in s.as_bytes().chunks(2) {
517 let hi = hex_nibble(chunk[0])?;
518 let lo = hex_nibble(chunk[1])?;
519 out.push((hi << 4) | lo);
520 }
521 Some(out)
522}
523
524fn hex_nibble(c: u8) -> Option<u8> {
525 match c {
526 b'0'..=b'9' => Some(c - b'0'),
527 b'a'..=b'f' => Some(c - b'a' + 10),
528 b'A'..=b'F' => Some(c - b'A' + 10),
529 _ => None,
530 }
531}
532
533#[cfg(test)]
534mod tests {
535 use super::*;
536
537 fn fresh_version() -> ModelVersion {
538 ModelVersion {
539 model: String::new(),
540 version: 0,
541 weights_blob: vec![1, 2, 3],
542 hyperparams_json: "{}".into(),
543 metrics_json: "{\"f1\":0.9}".into(),
544 training_data_hash: None,
545 training_sql: None,
546 parent_version: None,
547 created_at_ms: 0,
548 created_by: None,
549 archived: false,
550 }
551 }
552
553 #[test]
554 fn register_assigns_monotonic_versions() {
555 let reg = ModelRegistry::new();
556 let v1 = reg.register_version("m", fresh_version(), true).unwrap();
557 let v2 = reg.register_version("m", fresh_version(), true).unwrap();
558 let v3 = reg.register_version("m", fresh_version(), true).unwrap();
559 assert_eq!((v1, v2, v3), (1, 2, 3));
560 }
561
562 #[test]
563 fn new_version_becomes_active_by_default() {
564 let reg = ModelRegistry::new();
565 reg.register_version("m", fresh_version(), true).unwrap();
566 reg.register_version("m", fresh_version(), true).unwrap();
567 let active = reg.get_active("m").unwrap().unwrap();
568 assert_eq!(active.version, 2);
569 }
570
571 #[test]
572 fn unpublished_training_keeps_old_active_version() {
573 let reg = ModelRegistry::new();
574 reg.register_version("m", fresh_version(), true).unwrap();
575 reg.register_version("m", fresh_version(), false).unwrap();
576 assert_eq!(reg.get_active("m").unwrap().unwrap().version, 1);
577 }
578
579 #[test]
580 fn set_active_version_rolls_back() {
581 let reg = ModelRegistry::new();
582 reg.register_version("m", fresh_version(), true).unwrap();
583 reg.register_version("m", fresh_version(), true).unwrap();
584 reg.set_active_version("m", 1).unwrap();
585 assert_eq!(reg.get_active("m").unwrap().unwrap().version, 1);
586 }
587
588 #[test]
589 fn set_active_rejects_unknown_version() {
590 let reg = ModelRegistry::new();
591 reg.register_version("m", fresh_version(), true).unwrap();
592 let err = reg.set_active_version("m", 99).unwrap_err();
593 assert!(matches!(err, ModelRegistryError::UnknownVersion { .. }));
594 }
595
596 #[test]
597 fn archived_version_cannot_become_active() {
598 let reg = ModelRegistry::new();
599 reg.register_version("m", fresh_version(), true).unwrap();
600 reg.register_version("m", fresh_version(), false).unwrap();
601 reg.archive_version("m", 1).unwrap();
602 let err = reg.set_active_version("m", 1).unwrap_err();
603 assert!(matches!(err, ModelRegistryError::VersionArchived { .. }));
604 }
605
606 #[test]
607 fn archiving_active_version_clears_pointer() {
608 let reg = ModelRegistry::new();
609 reg.register_version("m", fresh_version(), true).unwrap();
610 reg.archive_version("m", 1).unwrap();
611 assert!(reg.get_active("m").unwrap().is_none());
612 }
613
614 #[test]
615 fn list_versions_returns_in_order() {
616 let reg = ModelRegistry::new();
617 for _ in 0..5 {
618 reg.register_version("m", fresh_version(), true).unwrap();
619 }
620 let versions: Vec<u32> = reg
621 .list_versions("m")
622 .unwrap()
623 .into_iter()
624 .map(|v| v.version)
625 .collect();
626 assert_eq!(versions, vec![1, 2, 3, 4, 5]);
627 }
628
629 #[test]
630 fn summaries_count_archived_separately() {
631 let reg = ModelRegistry::new();
632 reg.register_version("m", fresh_version(), true).unwrap();
633 reg.register_version("m", fresh_version(), true).unwrap();
634 reg.register_version("m", fresh_version(), true).unwrap();
635 reg.archive_version("m", 1).unwrap();
636 let s = ®.summaries().unwrap()[0];
637 assert_eq!(s.total_versions, 3);
638 assert_eq!(s.archived_versions, 1);
639 assert_eq!(s.active_version, Some(3));
640 }
641
642 #[test]
643 fn unknown_model_lookups_error_cleanly() {
644 let reg = ModelRegistry::new();
645 assert!(matches!(
646 reg.get_active("nope").unwrap_err(),
647 ModelRegistryError::UnknownModel(_)
648 ));
649 assert!(matches!(
650 reg.list_versions("nope").unwrap_err(),
651 ModelRegistryError::UnknownModel(_)
652 ));
653 }
654
655 #[test]
656 fn model_version_json_round_trips() {
657 let v = ModelVersion {
658 model: "spam".to_string(),
659 version: 7,
660 weights_blob: vec![0xde, 0xad, 0xbe, 0xef],
661 hyperparams_json: "{\"lr\":0.01}".to_string(),
662 metrics_json: "{\"f1\":0.93}".to_string(),
663 training_data_hash: Some("abcd".to_string()),
664 training_sql: Some("SELECT * FROM t".to_string()),
665 parent_version: Some(6),
666 created_at_ms: 1_700_000_000_000,
667 created_by: Some("alice".to_string()),
668 archived: false,
669 };
670 let round = ModelVersion::from_json(&v.to_json()).unwrap();
671 assert_eq!(round.model, v.model);
672 assert_eq!(round.version, v.version);
673 assert_eq!(round.weights_blob, v.weights_blob);
674 assert_eq!(round.hyperparams_json, v.hyperparams_json);
675 assert_eq!(round.metrics_json, v.metrics_json);
676 assert_eq!(round.training_data_hash, v.training_data_hash);
677 assert_eq!(round.training_sql, v.training_sql);
678 assert_eq!(round.parent_version, v.parent_version);
679 assert_eq!(round.created_at_ms, v.created_at_ms);
680 assert_eq!(round.created_by, v.created_by);
681 assert_eq!(round.archived, v.archived);
682 }
683
684 #[test]
685 fn backend_persists_versions_and_active_pointer() {
686 use super::super::persist::InMemoryMlPersistence;
687 let backend = Arc::new(InMemoryMlPersistence::new());
688 let reg = ModelRegistry::with_backend(backend.clone());
689 reg.register_version("m", fresh_version(), true).unwrap();
690 reg.register_version("m", fresh_version(), true).unwrap();
691
692 let reg2 = ModelRegistry::with_backend(backend);
695 let active = reg2.get_active("m").unwrap().unwrap();
696 assert_eq!(active.version, 2);
697 let versions: Vec<u32> = reg2
698 .list_versions("m")
699 .unwrap()
700 .into_iter()
701 .map(|v| v.version)
702 .collect();
703 assert_eq!(versions, vec![1, 2]);
704 }
705
706 #[test]
707 fn backend_rehydrate_survives_archive_then_rollback() {
708 use super::super::persist::InMemoryMlPersistence;
709 let backend = Arc::new(InMemoryMlPersistence::new());
710 let reg = ModelRegistry::with_backend(backend.clone());
711 reg.register_version("m", fresh_version(), true).unwrap();
712 reg.register_version("m", fresh_version(), true).unwrap();
713 reg.archive_version("m", 1).unwrap();
714 reg.set_active_version("m", 2).unwrap();
715
716 let reg2 = ModelRegistry::with_backend(backend);
717 let versions = reg2.list_versions("m").unwrap();
718 assert_eq!(versions.len(), 2);
719 assert!(versions.iter().find(|v| v.version == 1).unwrap().archived);
720 assert_eq!(reg2.get_active("m").unwrap().unwrap().version, 2);
721 }
722
723 #[test]
724 fn hex_helpers_round_trip() {
725 let bytes = vec![0u8, 1, 2, 3, 255, 128, 64];
726 assert_eq!(hex_decode(&hex_encode(&bytes)).unwrap(), bytes);
727 }
728
729 #[test]
730 fn hex_decode_rejects_odd_length_or_non_hex() {
731 assert!(hex_decode("abc").is_none());
732 assert!(hex_decode("zz").is_none());
733 }
734}