use std::collections::HashMap;
use super::super::policy::PromotionPolicy;
use super::super::stage::ModelStage;
use super::super::transition::StageTransition;
use super::super::version::ModelVersion;
#[derive(Debug, Default)]
pub struct InMemoryRegistry {
pub(crate) models: HashMap<String, HashMap<u32, ModelVersion>>,
pub(crate) transitions: Vec<StageTransition>,
pub(crate) policies: HashMap<ModelStage, PromotionPolicy>,
pub(crate) rollback_enabled: HashMap<String, (String, f64)>, }
impl InMemoryRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn enable_auto_rollback(&mut self, model: &str, metric: &str, threshold: f64) {
self.rollback_enabled.insert(model.to_string(), (metric.to_string(), threshold));
}
pub fn check_rollback(&self, model: &str, current_metric: f64) -> bool {
if let Some((_, threshold)) = self.rollback_enabled.get(model) {
current_metric < *threshold
} else {
false
}
}
pub(crate) fn next_version(&self, name: &str) -> u32 {
self.models.get(name).map_or(1, |versions| versions.keys().max().copied().unwrap_or(0) + 1)
}
}