entrenar/storage/registry/memory/
registry.rs1use std::collections::HashMap;
4
5use super::super::policy::PromotionPolicy;
6use super::super::stage::ModelStage;
7use super::super::transition::StageTransition;
8use super::super::version::ModelVersion;
9
10#[derive(Debug, Default)]
12pub struct InMemoryRegistry {
13 pub(crate) models: HashMap<String, HashMap<u32, ModelVersion>>,
15 pub(crate) transitions: Vec<StageTransition>,
17 pub(crate) policies: HashMap<ModelStage, PromotionPolicy>,
19 pub(crate) rollback_enabled: HashMap<String, (String, f64)>, }
22
23impl InMemoryRegistry {
24 pub fn new() -> Self {
26 Self::default()
27 }
28
29 pub fn enable_auto_rollback(&mut self, model: &str, metric: &str, threshold: f64) {
31 self.rollback_enabled.insert(model.to_string(), (metric.to_string(), threshold));
32 }
33
34 pub fn check_rollback(&self, model: &str, current_metric: f64) -> bool {
36 if let Some((_, threshold)) = self.rollback_enabled.get(model) {
37 current_metric < *threshold
38 } else {
39 false
40 }
41 }
42
43 pub(crate) fn next_version(&self, name: &str) -> u32 {
45 self.models.get(name).map_or(1, |versions| versions.keys().max().copied().unwrap_or(0) + 1)
46 }
47}