Skip to main content

entrenar/storage/registry/
comparison.rs

1//! Version comparison and metric comparison types
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6/// Comparison between two model versions
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct VersionComparison {
9    /// First version
10    pub v1: u32,
11    /// Second version
12    pub v2: u32,
13    /// Metric differences (positive = v2 is better for maximizing metrics)
14    pub metric_diffs: HashMap<String, f64>,
15    /// Whether v2 is better overall
16    pub v2_is_better: bool,
17    /// Summary of changes
18    pub summary: String,
19}
20
21/// Metric requirement for promotion policy
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct MetricRequirement {
24    /// Metric name
25    pub name: String,
26    /// Comparison operator
27    pub comparison: Comparison,
28    /// Threshold value
29    pub threshold: f64,
30}
31
32/// Comparison operators
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
34pub enum Comparison {
35    Gt,
36    Gte,
37    Lt,
38    Lte,
39    Eq,
40}
41
42impl Comparison {
43    /// Check if value satisfies comparison with threshold
44    pub fn check(&self, value: f64, threshold: f64) -> bool {
45        match self {
46            Comparison::Gt => value > threshold,
47            Comparison::Gte => value >= threshold,
48            Comparison::Lt => value < threshold,
49            Comparison::Lte => value <= threshold,
50            Comparison::Eq => (value - threshold).abs() < f64::EPSILON,
51        }
52    }
53
54    /// Get string representation
55    pub fn as_str(&self) -> &'static str {
56        match self {
57            Comparison::Gt => ">",
58            Comparison::Gte => ">=",
59            Comparison::Lt => "<",
60            Comparison::Lte => "<=",
61            Comparison::Eq => "==",
62        }
63    }
64}
65
66#[cfg(test)]
67mod tests {
68    use super::*;
69
70    #[test]
71    fn test_comparison_gt() {
72        assert!(Comparison::Gt.check(0.96, 0.95));
73        assert!(!Comparison::Gt.check(0.95, 0.95));
74    }
75
76    #[test]
77    fn test_comparison_gte() {
78        assert!(Comparison::Gte.check(0.95, 0.95));
79        assert!(Comparison::Gte.check(0.96, 0.95));
80    }
81
82    #[test]
83    fn test_comparison_lt() {
84        assert!(Comparison::Lt.check(0.5, 1.0));
85        assert!(!Comparison::Lt.check(1.0, 1.0));
86    }
87
88    #[test]
89    fn test_comparison_eq() {
90        assert!(Comparison::Eq.check(0.95, 0.95));
91        assert!(!Comparison::Eq.check(0.95, 0.96));
92    }
93}
94
95#[cfg(test)]
96mod property_tests {
97    use super::*;
98    use proptest::prelude::*;
99
100    proptest! {
101        #![proptest_config(ProptestConfig::with_cases(200))]
102
103        #[test]
104        fn prop_comparison_consistent(value in -1000.0f64..1000.0, threshold in -1000.0f64..1000.0) {
105            // Gt and Lte are complementary
106            let gt = Comparison::Gt.check(value, threshold);
107            let lte = Comparison::Lte.check(value, threshold);
108            prop_assert!(gt != lte || value == threshold);
109        }
110    }
111}