Skip to main content

entrenar/train/tui/
reference.rs

1//! Reference Curve Overlay (ENT-067)
2//!
3//! Compare current training with a "golden" reference run.
4
5use super::sparkline::sparkline_range;
6
7/// Reference curve for comparison with current training run.
8#[derive(Debug, Clone)]
9pub struct ReferenceCurve {
10    /// Reference values (from a "golden" run)
11    values: Vec<f32>,
12    /// Tolerance for deviation detection
13    tolerance: f32,
14}
15
16impl ReferenceCurve {
17    /// Create from a vector of reference values.
18    pub fn new(values: Vec<f32>, tolerance: f32) -> Self {
19        Self { values, tolerance }
20    }
21
22    /// Load from JSON file.
23    pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
24        let values: Vec<f32> = serde_json::from_str(json)?;
25        Ok(Self::new(values, 0.1))
26    }
27
28    /// Get reference value at epoch.
29    pub fn get(&self, epoch: usize) -> Option<f32> {
30        self.values.get(epoch).copied()
31    }
32
33    /// Check if current value deviates from reference.
34    pub fn check_deviation(&self, epoch: usize, current: f32) -> Option<f32> {
35        if let Some(reference) = self.get(epoch) {
36            let deviation = (current - reference).abs() / reference.abs().max(f32::EPSILON);
37            if deviation > self.tolerance {
38                return Some(deviation);
39            }
40        }
41        None
42    }
43
44    /// Generate comparison sparkline.
45    pub fn comparison_sparkline(&self, current: &[f32], width: usize) -> String {
46        let len = current.len().min(self.values.len());
47        if len == 0 {
48            return String::new();
49        }
50
51        // Show deviation from reference
52        let deviations: Vec<f32> = current
53            .iter()
54            .zip(self.values.iter())
55            .map(|(c, r)| (c - r) / r.abs().max(f32::EPSILON))
56            .collect();
57
58        // Use signed sparkline (negative = better, positive = worse for loss)
59        sparkline_range(&deviations, width, -0.5, 0.5)
60    }
61}
62
63#[cfg(test)]
64mod tests {
65    use super::*;
66
67    #[test]
68    fn test_reference_curve_new() {
69        let curve = ReferenceCurve::new(vec![1.0, 0.8, 0.6], 0.1);
70        assert_eq!(curve.get(0), Some(1.0));
71        assert_eq!(curve.get(1), Some(0.8));
72        assert_eq!(curve.get(2), Some(0.6));
73        assert_eq!(curve.get(3), None);
74    }
75
76    #[test]
77    fn test_reference_curve_from_json() {
78        let json = "[1.0, 0.8, 0.6, 0.4]";
79        let curve = ReferenceCurve::from_json(json).expect("operation should succeed");
80        assert_eq!(curve.get(0), Some(1.0));
81        assert_eq!(curve.get(3), Some(0.4));
82    }
83
84    #[test]
85    fn test_reference_curve_from_json_invalid() {
86        let json = "not valid json";
87        assert!(ReferenceCurve::from_json(json).is_err());
88    }
89
90    #[test]
91    fn test_reference_curve_check_deviation_within_tolerance() {
92        let curve = ReferenceCurve::new(vec![1.0], 0.1);
93        // 1.05 is 5% off from 1.0, within 10% tolerance
94        assert!(curve.check_deviation(0, 1.05).is_none());
95    }
96
97    #[test]
98    fn test_reference_curve_check_deviation_exceeds_tolerance() {
99        let curve = ReferenceCurve::new(vec![1.0], 0.1);
100        // 1.15 is 15% off from 1.0, exceeds 10% tolerance
101        let deviation = curve.check_deviation(0, 1.15);
102        assert!(deviation.is_some());
103        assert!((deviation.expect("operation should succeed") - 0.15).abs() < 0.01);
104    }
105
106    #[test]
107    fn test_reference_curve_check_deviation_no_reference() {
108        let curve = ReferenceCurve::new(vec![1.0], 0.1);
109        assert!(curve.check_deviation(5, 1.0).is_none());
110    }
111
112    #[test]
113    fn test_reference_curve_comparison_sparkline() {
114        let curve = ReferenceCurve::new(vec![1.0, 0.8, 0.6, 0.4], 0.1);
115        let current = vec![1.0, 0.9, 0.5, 0.4];
116        let sparkline = curve.comparison_sparkline(&current, 4);
117        assert_eq!(sparkline.chars().count(), 4);
118    }
119
120    #[test]
121    fn test_reference_curve_comparison_sparkline_empty() {
122        let curve = ReferenceCurve::new(vec![], 0.1);
123        let current = vec![1.0];
124        assert_eq!(curve.comparison_sparkline(&current, 4), "");
125    }
126}