Skip to main content

entrenar/storage/registry/
policy.rs

1//! Promotion policies for stage transitions (Poka-yoke)
2
3use serde::{Deserialize, Serialize};
4
5use super::comparison::{Comparison, MetricRequirement};
6use super::stage::ModelStage;
7use super::version::ModelVersion;
8
9/// Promotion policy for stage transitions (Poka-yoke)
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct PromotionPolicy {
12    /// Required metrics with thresholds
13    pub required_metrics: Vec<MetricRequirement>,
14    /// Minimum test coverage
15    pub min_test_coverage: Option<f64>,
16    /// Required number of approvals
17    pub required_approvals: u32,
18    /// Auto-promote if all requirements pass
19    pub auto_promote_on_pass: bool,
20    /// Target stage this policy applies to
21    pub target_stage: ModelStage,
22}
23
24impl PromotionPolicy {
25    /// Create a new promotion policy for a target stage
26    pub fn new(target_stage: ModelStage) -> Self {
27        Self {
28            required_metrics: Vec::new(),
29            min_test_coverage: None,
30            required_approvals: 0,
31            auto_promote_on_pass: false,
32            target_stage,
33        }
34    }
35
36    /// Add a metric requirement
37    pub fn require_metric(mut self, name: &str, comparison: Comparison, threshold: f64) -> Self {
38        self.required_metrics.push(MetricRequirement {
39            name: name.to_string(),
40            comparison,
41            threshold,
42        });
43        self
44    }
45
46    /// Set minimum test coverage
47    pub fn require_coverage(mut self, coverage: f64) -> Self {
48        self.min_test_coverage = Some(coverage);
49        self
50    }
51
52    /// Set required approvals
53    pub fn require_approvals(mut self, count: u32) -> Self {
54        self.required_approvals = count;
55        self
56    }
57
58    /// Enable auto-promotion
59    pub fn auto_promote(mut self) -> Self {
60        self.auto_promote_on_pass = true;
61        self
62    }
63
64    /// Check if a model version meets the policy requirements
65    pub fn check(&self, model: &ModelVersion, approvals: u32) -> PolicyCheckResult {
66        let mut failed_requirements = Vec::new();
67
68        // Check metrics
69        for req in &self.required_metrics {
70            if let Some(&value) = model.metrics.get(&req.name) {
71                if !req.comparison.check(value, req.threshold) {
72                    failed_requirements.push(format!(
73                        "Metric '{}' = {} does not satisfy {} {}",
74                        req.name,
75                        value,
76                        req.comparison.as_str(),
77                        req.threshold
78                    ));
79                }
80            } else {
81                failed_requirements.push(format!("Missing required metric '{}'", req.name));
82            }
83        }
84
85        // Check test coverage
86        if let Some(min_coverage) = self.min_test_coverage {
87            if let Some(&coverage) = model.metrics.get("test_coverage") {
88                if coverage < min_coverage {
89                    failed_requirements
90                        .push(format!("Test coverage {coverage} < required {min_coverage}"));
91                }
92            } else {
93                failed_requirements.push("Missing test_coverage metric".to_string());
94            }
95        }
96
97        // Check approvals
98        if approvals < self.required_approvals {
99            failed_requirements
100                .push(format!("Approvals {} < required {}", approvals, self.required_approvals));
101        }
102
103        PolicyCheckResult { passed: failed_requirements.is_empty(), failed_requirements }
104    }
105}
106
107/// Result of policy check
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct PolicyCheckResult {
110    /// Whether all requirements passed
111    pub passed: bool,
112    /// List of failed requirements
113    pub failed_requirements: Vec<String>,
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119
120    #[test]
121    fn test_promotion_policy_new() {
122        let policy = PromotionPolicy::new(ModelStage::Production);
123        assert_eq!(policy.target_stage, ModelStage::Production);
124        assert!(policy.required_metrics.is_empty());
125    }
126
127    #[test]
128    fn test_promotion_policy_require_metric() {
129        let policy = PromotionPolicy::new(ModelStage::Production).require_metric(
130            "accuracy",
131            Comparison::Gte,
132            0.95,
133        );
134
135        assert_eq!(policy.required_metrics.len(), 1);
136        assert_eq!(policy.required_metrics[0].name, "accuracy");
137    }
138
139    #[test]
140    fn test_promotion_policy_check_pass() {
141        let policy = PromotionPolicy::new(ModelStage::Production).require_metric(
142            "accuracy",
143            Comparison::Gte,
144            0.95,
145        );
146
147        let model = ModelVersion::new("test", 1, "/path").with_metric("accuracy", 0.96);
148
149        let result = policy.check(&model, 0);
150        assert!(result.passed);
151    }
152
153    #[test]
154    fn test_promotion_policy_check_fail_metric() {
155        let policy = PromotionPolicy::new(ModelStage::Production).require_metric(
156            "accuracy",
157            Comparison::Gte,
158            0.95,
159        );
160
161        let model = ModelVersion::new("test", 1, "/path").with_metric("accuracy", 0.90);
162
163        let result = policy.check(&model, 0);
164        assert!(!result.passed);
165        assert!(!result.failed_requirements.is_empty());
166    }
167
168    #[test]
169    fn test_promotion_policy_check_fail_missing_metric() {
170        let policy = PromotionPolicy::new(ModelStage::Production).require_metric(
171            "accuracy",
172            Comparison::Gte,
173            0.95,
174        );
175
176        let model = ModelVersion::new("test", 1, "/path");
177
178        let result = policy.check(&model, 0);
179        assert!(!result.passed);
180        assert!(result.failed_requirements[0].contains("Missing"));
181    }
182
183    #[test]
184    fn test_promotion_policy_check_approvals() {
185        let policy = PromotionPolicy::new(ModelStage::Production).require_approvals(2);
186
187        let model = ModelVersion::new("test", 1, "/path");
188
189        // Not enough approvals
190        let result = policy.check(&model, 1);
191        assert!(!result.passed);
192
193        // Enough approvals
194        let result = policy.check(&model, 2);
195        assert!(result.passed);
196    }
197
198    #[test]
199    fn test_promotion_policy_check_coverage() {
200        let policy = PromotionPolicy::new(ModelStage::Production).require_coverage(0.90);
201
202        let model = ModelVersion::new("test", 1, "/path").with_metric("test_coverage", 0.85);
203
204        let result = policy.check(&model, 0);
205        assert!(!result.passed);
206        assert!(result.failed_requirements[0].contains("coverage"));
207    }
208}
209
210#[cfg(test)]
211mod property_tests {
212    use super::*;
213    use proptest::prelude::*;
214
215    proptest! {
216        #![proptest_config(ProptestConfig::with_cases(200))]
217
218        #[test]
219        fn prop_policy_check_deterministic(
220            accuracy in 0.0f64..1.0,
221            threshold in 0.0f64..1.0,
222            approvals in 0u32..10,
223            required_approvals in 0u32..10
224        ) {
225            let policy = PromotionPolicy::new(ModelStage::Production)
226                .require_metric("accuracy", Comparison::Gte, threshold)
227                .require_approvals(required_approvals);
228
229            let model = ModelVersion::new("test", 1, "/path")
230                .with_metric("accuracy", accuracy);
231
232            let result1 = policy.check(&model, approvals);
233            let result2 = policy.check(&model, approvals);
234
235            prop_assert_eq!(result1.passed, result2.passed);
236        }
237    }
238}