entrenar/storage/registry/memory/
traits_impl.rs1use chrono::Utc;
4use std::collections::HashMap;
5
6use super::super::comparison::VersionComparison;
7use super::super::error::{RegistryError, Result};
8use super::super::policy::PolicyCheckResult;
9use super::super::stage::ModelStage;
10use super::super::traits::ModelRegistry;
11use super::super::transition::StageTransition;
12use super::super::version::ModelVersion;
13use super::registry::InMemoryRegistry;
14
15impl ModelRegistry for InMemoryRegistry {
16 fn register_model(&mut self, name: &str, artifact_uri: &str) -> Result<ModelVersion> {
17 let version = self.next_version(name);
18 let model = ModelVersion::new(name, version, artifact_uri);
19
20 self.models.entry(name.to_string()).or_default().insert(version, model.clone());
21
22 Ok(model)
23 }
24
25 fn get_model(&self, name: &str, version: u32) -> Result<ModelVersion> {
26 self.models
27 .get(name)
28 .and_then(|versions| versions.get(&version))
29 .cloned()
30 .ok_or_else(|| RegistryError::VersionNotFound(name.to_string(), version))
31 }
32
33 fn get_latest(&self, name: &str) -> Result<ModelVersion> {
34 self.models
35 .get(name)
36 .and_then(|versions| {
37 let max_version = versions.keys().max()?;
38 versions.get(max_version)
39 })
40 .cloned()
41 .ok_or_else(|| RegistryError::ModelNotFound(name.to_string()))
42 }
43
44 fn get_latest_by_stage(&self, name: &str, stage: ModelStage) -> Option<ModelVersion> {
45 self.models.get(name).and_then(|versions| {
46 versions.values().filter(|m| m.stage == stage).max_by_key(|m| m.version).cloned()
47 })
48 }
49
50 fn list_versions(&self, name: &str) -> Result<Vec<ModelVersion>> {
51 self.models
52 .get(name)
53 .map(|versions| {
54 let mut v: Vec<_> = versions.values().cloned().collect();
55 v.sort_by_key(|m| m.version);
56 v
57 })
58 .ok_or_else(|| RegistryError::ModelNotFound(name.to_string()))
59 }
60
61 fn transition_stage(
62 &mut self,
63 name: &str,
64 version: u32,
65 target_stage: ModelStage,
66 user: Option<&str>,
67 ) -> Result<()> {
68 let model = self
69 .models
70 .get_mut(name)
71 .and_then(|versions| versions.get_mut(&version))
72 .ok_or_else(|| RegistryError::VersionNotFound(name.to_string(), version))?;
73
74 if !model.stage.can_transition_to(target_stage) {
75 return Err(RegistryError::InvalidTransition(model.stage, target_stage));
76 }
77
78 let from_stage = model.stage;
79 model.stage = target_stage;
80 model.promoted_at = Some(Utc::now());
81 model.promoted_by = user.map(ToString::to_string);
82
83 self.transitions.push(StageTransition {
85 model_name: name.to_string(),
86 version,
87 from_stage,
88 to_stage: target_stage,
89 timestamp: Utc::now(),
90 user: user.map(ToString::to_string),
91 reason: None,
92 });
93
94 Ok(())
95 }
96
97 fn compare_versions(&self, name: &str, v1: u32, v2: u32) -> Result<VersionComparison> {
98 let m1 = self.get_model(name, v1)?;
99 let m2 = self.get_model(name, v2)?;
100
101 let mut metric_diffs = HashMap::new();
102 let mut v2_better_count = 0;
103 let mut total_comparisons = 0;
104
105 let all_metrics: std::collections::HashSet<_> =
107 m1.metrics.keys().chain(m2.metrics.keys()).collect();
108
109 for metric in all_metrics {
110 let val1 = m1.metrics.get(metric).copied().unwrap_or(0.0);
111 let val2 = m2.metrics.get(metric).copied().unwrap_or(0.0);
112 let diff = val2 - val1;
113 metric_diffs.insert(metric.clone(), diff);
114
115 if diff > 0.0 {
117 v2_better_count += 1;
118 }
119 total_comparisons += 1;
120 }
121
122 let v2_is_better = total_comparisons > 0 && v2_better_count > total_comparisons / 2;
123
124 let summary = if v2_is_better {
125 format!(
126 "Version {v2} is better than {v1} on {v2_better_count}/{total_comparisons} metrics"
127 )
128 } else {
129 format!("Version {v2} is not definitively better than {v1}")
130 };
131
132 Ok(VersionComparison { v1, v2, metric_diffs, v2_is_better, summary })
133 }
134
135 fn log_metrics(
136 &mut self,
137 name: &str,
138 version: u32,
139 metrics: HashMap<String, f64>,
140 ) -> Result<()> {
141 let model = self
142 .models
143 .get_mut(name)
144 .and_then(|versions| versions.get_mut(&version))
145 .ok_or_else(|| RegistryError::VersionNotFound(name.to_string(), version))?;
146
147 model.metrics.extend(metrics);
148 Ok(())
149 }
150
151 fn get_transition_history(&self, name: &str) -> Result<Vec<StageTransition>> {
152 let history: Vec<_> =
153 self.transitions.iter().filter(|t| t.model_name == name).cloned().collect();
154
155 if history.is_empty() && !self.models.contains_key(name) {
156 return Err(RegistryError::ModelNotFound(name.to_string()));
157 }
158
159 Ok(history)
160 }
161
162 fn set_policy(&mut self, policy: super::super::policy::PromotionPolicy) {
163 self.policies.insert(policy.target_stage, policy);
164 }
165
166 fn get_policy(&self, stage: ModelStage) -> Option<&super::super::policy::PromotionPolicy> {
167 self.policies.get(&stage)
168 }
169
170 fn can_promote(
171 &self,
172 name: &str,
173 version: u32,
174 target_stage: ModelStage,
175 approvals: u32,
176 ) -> Result<PolicyCheckResult> {
177 let model = self.get_model(name, version)?;
178
179 if !model.stage.can_transition_to(target_stage) {
181 return Ok(PolicyCheckResult {
182 passed: false,
183 failed_requirements: vec![format!(
184 "Cannot transition from {} to {}",
185 model.stage, target_stage
186 )],
187 });
188 }
189
190 if let Some(policy) = self.policies.get(&target_stage) {
192 Ok(policy.check(&model, approvals))
193 } else {
194 Ok(PolicyCheckResult { passed: true, failed_requirements: Vec::new() })
196 }
197 }
198}