Skip to main content

entrenar/storage/registry/memory/
traits_impl.rs

1//! ModelRegistry trait implementation for InMemoryRegistry
2
3use chrono::Utc;
4use std::collections::HashMap;
5
6use super::super::comparison::VersionComparison;
7use super::super::error::{RegistryError, Result};
8use super::super::policy::PolicyCheckResult;
9use super::super::stage::ModelStage;
10use super::super::traits::ModelRegistry;
11use super::super::transition::StageTransition;
12use super::super::version::ModelVersion;
13use super::registry::InMemoryRegistry;
14
15impl ModelRegistry for InMemoryRegistry {
16    fn register_model(&mut self, name: &str, artifact_uri: &str) -> Result<ModelVersion> {
17        let version = self.next_version(name);
18        let model = ModelVersion::new(name, version, artifact_uri);
19
20        self.models.entry(name.to_string()).or_default().insert(version, model.clone());
21
22        Ok(model)
23    }
24
25    fn get_model(&self, name: &str, version: u32) -> Result<ModelVersion> {
26        self.models
27            .get(name)
28            .and_then(|versions| versions.get(&version))
29            .cloned()
30            .ok_or_else(|| RegistryError::VersionNotFound(name.to_string(), version))
31    }
32
33    fn get_latest(&self, name: &str) -> Result<ModelVersion> {
34        self.models
35            .get(name)
36            .and_then(|versions| {
37                let max_version = versions.keys().max()?;
38                versions.get(max_version)
39            })
40            .cloned()
41            .ok_or_else(|| RegistryError::ModelNotFound(name.to_string()))
42    }
43
44    fn get_latest_by_stage(&self, name: &str, stage: ModelStage) -> Option<ModelVersion> {
45        self.models.get(name).and_then(|versions| {
46            versions.values().filter(|m| m.stage == stage).max_by_key(|m| m.version).cloned()
47        })
48    }
49
50    fn list_versions(&self, name: &str) -> Result<Vec<ModelVersion>> {
51        self.models
52            .get(name)
53            .map(|versions| {
54                let mut v: Vec<_> = versions.values().cloned().collect();
55                v.sort_by_key(|m| m.version);
56                v
57            })
58            .ok_or_else(|| RegistryError::ModelNotFound(name.to_string()))
59    }
60
61    fn transition_stage(
62        &mut self,
63        name: &str,
64        version: u32,
65        target_stage: ModelStage,
66        user: Option<&str>,
67    ) -> Result<()> {
68        let model = self
69            .models
70            .get_mut(name)
71            .and_then(|versions| versions.get_mut(&version))
72            .ok_or_else(|| RegistryError::VersionNotFound(name.to_string(), version))?;
73
74        if !model.stage.can_transition_to(target_stage) {
75            return Err(RegistryError::InvalidTransition(model.stage, target_stage));
76        }
77
78        let from_stage = model.stage;
79        model.stage = target_stage;
80        model.promoted_at = Some(Utc::now());
81        model.promoted_by = user.map(ToString::to_string);
82
83        // Record transition
84        self.transitions.push(StageTransition {
85            model_name: name.to_string(),
86            version,
87            from_stage,
88            to_stage: target_stage,
89            timestamp: Utc::now(),
90            user: user.map(ToString::to_string),
91            reason: None,
92        });
93
94        Ok(())
95    }
96
97    fn compare_versions(&self, name: &str, v1: u32, v2: u32) -> Result<VersionComparison> {
98        let m1 = self.get_model(name, v1)?;
99        let m2 = self.get_model(name, v2)?;
100
101        let mut metric_diffs = HashMap::new();
102        let mut v2_better_count = 0;
103        let mut total_comparisons = 0;
104
105        // Compare all metrics from both versions
106        let all_metrics: std::collections::HashSet<_> =
107            m1.metrics.keys().chain(m2.metrics.keys()).collect();
108
109        for metric in all_metrics {
110            let val1 = m1.metrics.get(metric).copied().unwrap_or(0.0);
111            let val2 = m2.metrics.get(metric).copied().unwrap_or(0.0);
112            let diff = val2 - val1;
113            metric_diffs.insert(metric.clone(), diff);
114
115            // Assume higher is better for most metrics
116            if diff > 0.0 {
117                v2_better_count += 1;
118            }
119            total_comparisons += 1;
120        }
121
122        let v2_is_better = total_comparisons > 0 && v2_better_count > total_comparisons / 2;
123
124        let summary = if v2_is_better {
125            format!(
126                "Version {v2} is better than {v1} on {v2_better_count}/{total_comparisons} metrics"
127            )
128        } else {
129            format!("Version {v2} is not definitively better than {v1}")
130        };
131
132        Ok(VersionComparison { v1, v2, metric_diffs, v2_is_better, summary })
133    }
134
135    fn log_metrics(
136        &mut self,
137        name: &str,
138        version: u32,
139        metrics: HashMap<String, f64>,
140    ) -> Result<()> {
141        let model = self
142            .models
143            .get_mut(name)
144            .and_then(|versions| versions.get_mut(&version))
145            .ok_or_else(|| RegistryError::VersionNotFound(name.to_string(), version))?;
146
147        model.metrics.extend(metrics);
148        Ok(())
149    }
150
151    fn get_transition_history(&self, name: &str) -> Result<Vec<StageTransition>> {
152        let history: Vec<_> =
153            self.transitions.iter().filter(|t| t.model_name == name).cloned().collect();
154
155        if history.is_empty() && !self.models.contains_key(name) {
156            return Err(RegistryError::ModelNotFound(name.to_string()));
157        }
158
159        Ok(history)
160    }
161
162    fn set_policy(&mut self, policy: super::super::policy::PromotionPolicy) {
163        self.policies.insert(policy.target_stage, policy);
164    }
165
166    fn get_policy(&self, stage: ModelStage) -> Option<&super::super::policy::PromotionPolicy> {
167        self.policies.get(&stage)
168    }
169
170    fn can_promote(
171        &self,
172        name: &str,
173        version: u32,
174        target_stage: ModelStage,
175        approvals: u32,
176    ) -> Result<PolicyCheckResult> {
177        let model = self.get_model(name, version)?;
178
179        // Check stage transition validity
180        if !model.stage.can_transition_to(target_stage) {
181            return Ok(PolicyCheckResult {
182                passed: false,
183                failed_requirements: vec![format!(
184                    "Cannot transition from {} to {}",
185                    model.stage, target_stage
186                )],
187            });
188        }
189
190        // Check policy if exists
191        if let Some(policy) = self.policies.get(&target_stage) {
192            Ok(policy.check(&model, approvals))
193        } else {
194            // No policy = always allowed
195            Ok(PolicyCheckResult { passed: true, failed_requirements: Vec::new() })
196        }
197    }
198}