entrenar/storage/registry/
policy.rs1use serde::{Deserialize, Serialize};
4
5use super::comparison::{Comparison, MetricRequirement};
6use super::stage::ModelStage;
7use super::version::ModelVersion;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct PromotionPolicy {
12 pub required_metrics: Vec<MetricRequirement>,
14 pub min_test_coverage: Option<f64>,
16 pub required_approvals: u32,
18 pub auto_promote_on_pass: bool,
20 pub target_stage: ModelStage,
22}
23
24impl PromotionPolicy {
25 pub fn new(target_stage: ModelStage) -> Self {
27 Self {
28 required_metrics: Vec::new(),
29 min_test_coverage: None,
30 required_approvals: 0,
31 auto_promote_on_pass: false,
32 target_stage,
33 }
34 }
35
36 pub fn require_metric(mut self, name: &str, comparison: Comparison, threshold: f64) -> Self {
38 self.required_metrics.push(MetricRequirement {
39 name: name.to_string(),
40 comparison,
41 threshold,
42 });
43 self
44 }
45
46 pub fn require_coverage(mut self, coverage: f64) -> Self {
48 self.min_test_coverage = Some(coverage);
49 self
50 }
51
52 pub fn require_approvals(mut self, count: u32) -> Self {
54 self.required_approvals = count;
55 self
56 }
57
58 pub fn auto_promote(mut self) -> Self {
60 self.auto_promote_on_pass = true;
61 self
62 }
63
64 pub fn check(&self, model: &ModelVersion, approvals: u32) -> PolicyCheckResult {
66 let mut failed_requirements = Vec::new();
67
68 for req in &self.required_metrics {
70 if let Some(&value) = model.metrics.get(&req.name) {
71 if !req.comparison.check(value, req.threshold) {
72 failed_requirements.push(format!(
73 "Metric '{}' = {} does not satisfy {} {}",
74 req.name,
75 value,
76 req.comparison.as_str(),
77 req.threshold
78 ));
79 }
80 } else {
81 failed_requirements.push(format!("Missing required metric '{}'", req.name));
82 }
83 }
84
85 if let Some(min_coverage) = self.min_test_coverage {
87 if let Some(&coverage) = model.metrics.get("test_coverage") {
88 if coverage < min_coverage {
89 failed_requirements
90 .push(format!("Test coverage {coverage} < required {min_coverage}"));
91 }
92 } else {
93 failed_requirements.push("Missing test_coverage metric".to_string());
94 }
95 }
96
97 if approvals < self.required_approvals {
99 failed_requirements
100 .push(format!("Approvals {} < required {}", approvals, self.required_approvals));
101 }
102
103 PolicyCheckResult { passed: failed_requirements.is_empty(), failed_requirements }
104 }
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct PolicyCheckResult {
110 pub passed: bool,
112 pub failed_requirements: Vec<String>,
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119
120 #[test]
121 fn test_promotion_policy_new() {
122 let policy = PromotionPolicy::new(ModelStage::Production);
123 assert_eq!(policy.target_stage, ModelStage::Production);
124 assert!(policy.required_metrics.is_empty());
125 }
126
127 #[test]
128 fn test_promotion_policy_require_metric() {
129 let policy = PromotionPolicy::new(ModelStage::Production).require_metric(
130 "accuracy",
131 Comparison::Gte,
132 0.95,
133 );
134
135 assert_eq!(policy.required_metrics.len(), 1);
136 assert_eq!(policy.required_metrics[0].name, "accuracy");
137 }
138
139 #[test]
140 fn test_promotion_policy_check_pass() {
141 let policy = PromotionPolicy::new(ModelStage::Production).require_metric(
142 "accuracy",
143 Comparison::Gte,
144 0.95,
145 );
146
147 let model = ModelVersion::new("test", 1, "/path").with_metric("accuracy", 0.96);
148
149 let result = policy.check(&model, 0);
150 assert!(result.passed);
151 }
152
153 #[test]
154 fn test_promotion_policy_check_fail_metric() {
155 let policy = PromotionPolicy::new(ModelStage::Production).require_metric(
156 "accuracy",
157 Comparison::Gte,
158 0.95,
159 );
160
161 let model = ModelVersion::new("test", 1, "/path").with_metric("accuracy", 0.90);
162
163 let result = policy.check(&model, 0);
164 assert!(!result.passed);
165 assert!(!result.failed_requirements.is_empty());
166 }
167
168 #[test]
169 fn test_promotion_policy_check_fail_missing_metric() {
170 let policy = PromotionPolicy::new(ModelStage::Production).require_metric(
171 "accuracy",
172 Comparison::Gte,
173 0.95,
174 );
175
176 let model = ModelVersion::new("test", 1, "/path");
177
178 let result = policy.check(&model, 0);
179 assert!(!result.passed);
180 assert!(result.failed_requirements[0].contains("Missing"));
181 }
182
183 #[test]
184 fn test_promotion_policy_check_approvals() {
185 let policy = PromotionPolicy::new(ModelStage::Production).require_approvals(2);
186
187 let model = ModelVersion::new("test", 1, "/path");
188
189 let result = policy.check(&model, 1);
191 assert!(!result.passed);
192
193 let result = policy.check(&model, 2);
195 assert!(result.passed);
196 }
197
198 #[test]
199 fn test_promotion_policy_check_coverage() {
200 let policy = PromotionPolicy::new(ModelStage::Production).require_coverage(0.90);
201
202 let model = ModelVersion::new("test", 1, "/path").with_metric("test_coverage", 0.85);
203
204 let result = policy.check(&model, 0);
205 assert!(!result.passed);
206 assert!(result.failed_requirements[0].contains("coverage"));
207 }
208}
209
210#[cfg(test)]
211mod property_tests {
212 use super::*;
213 use proptest::prelude::*;
214
215 proptest! {
216 #![proptest_config(ProptestConfig::with_cases(200))]
217
218 #[test]
219 fn prop_policy_check_deterministic(
220 accuracy in 0.0f64..1.0,
221 threshold in 0.0f64..1.0,
222 approvals in 0u32..10,
223 required_approvals in 0u32..10
224 ) {
225 let policy = PromotionPolicy::new(ModelStage::Production)
226 .require_metric("accuracy", Comparison::Gte, threshold)
227 .require_approvals(required_approvals);
228
229 let model = ModelVersion::new("test", 1, "/path")
230 .with_metric("accuracy", accuracy);
231
232 let result1 = policy.check(&model, approvals);
233 let result2 = policy.check(&model, approvals);
234
235 prop_assert_eq!(result1.passed, result2.passed);
236 }
237 }
238}