Skip to main content

entrenar/storage/registry/memory/
registry.rs

1//! InMemoryRegistry struct and core methods
2
3use std::collections::HashMap;
4
5use super::super::policy::PromotionPolicy;
6use super::super::stage::ModelStage;
7use super::super::transition::StageTransition;
8use super::super::version::ModelVersion;
9
10/// In-memory model registry for testing
11#[derive(Debug, Default)]
12pub struct InMemoryRegistry {
13    /// Models by name -> version -> ModelVersion
14    pub(crate) models: HashMap<String, HashMap<u32, ModelVersion>>,
15    /// Stage transition history
16    pub(crate) transitions: Vec<StageTransition>,
17    /// Promotion policies by stage
18    pub(crate) policies: HashMap<ModelStage, PromotionPolicy>,
19    /// Auto-rollback configuration
20    pub(crate) rollback_enabled: HashMap<String, (String, f64)>, // model -> (metric, threshold)
21}
22
23impl InMemoryRegistry {
24    /// Create a new in-memory registry
25    pub fn new() -> Self {
26        Self::default()
27    }
28
29    /// Enable auto-rollback for a model
30    pub fn enable_auto_rollback(&mut self, model: &str, metric: &str, threshold: f64) {
31        self.rollback_enabled.insert(model.to_string(), (metric.to_string(), threshold));
32    }
33
34    /// Check if rollback is needed based on metrics
35    pub fn check_rollback(&self, model: &str, current_metric: f64) -> bool {
36        if let Some((_, threshold)) = self.rollback_enabled.get(model) {
37            current_metric < *threshold
38        } else {
39            false
40        }
41    }
42
43    /// Get next version number for a model
44    pub(crate) fn next_version(&self, name: &str) -> u32 {
45        self.models.get(name).map_or(1, |versions| versions.keys().max().copied().unwrap_or(0) + 1)
46    }
47}