entrenar/train/tui/
reference.rs1use super::sparkline::sparkline_range;
6
7#[derive(Debug, Clone)]
9pub struct ReferenceCurve {
10 values: Vec<f32>,
12 tolerance: f32,
14}
15
16impl ReferenceCurve {
17 pub fn new(values: Vec<f32>, tolerance: f32) -> Self {
19 Self { values, tolerance }
20 }
21
22 pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
24 let values: Vec<f32> = serde_json::from_str(json)?;
25 Ok(Self::new(values, 0.1))
26 }
27
28 pub fn get(&self, epoch: usize) -> Option<f32> {
30 self.values.get(epoch).copied()
31 }
32
33 pub fn check_deviation(&self, epoch: usize, current: f32) -> Option<f32> {
35 if let Some(reference) = self.get(epoch) {
36 let deviation = (current - reference).abs() / reference.abs().max(f32::EPSILON);
37 if deviation > self.tolerance {
38 return Some(deviation);
39 }
40 }
41 None
42 }
43
44 pub fn comparison_sparkline(&self, current: &[f32], width: usize) -> String {
46 let len = current.len().min(self.values.len());
47 if len == 0 {
48 return String::new();
49 }
50
51 let deviations: Vec<f32> = current
53 .iter()
54 .zip(self.values.iter())
55 .map(|(c, r)| (c - r) / r.abs().max(f32::EPSILON))
56 .collect();
57
58 sparkline_range(&deviations, width, -0.5, 0.5)
60 }
61}
62
63#[cfg(test)]
64mod tests {
65 use super::*;
66
67 #[test]
68 fn test_reference_curve_new() {
69 let curve = ReferenceCurve::new(vec![1.0, 0.8, 0.6], 0.1);
70 assert_eq!(curve.get(0), Some(1.0));
71 assert_eq!(curve.get(1), Some(0.8));
72 assert_eq!(curve.get(2), Some(0.6));
73 assert_eq!(curve.get(3), None);
74 }
75
76 #[test]
77 fn test_reference_curve_from_json() {
78 let json = "[1.0, 0.8, 0.6, 0.4]";
79 let curve = ReferenceCurve::from_json(json).expect("operation should succeed");
80 assert_eq!(curve.get(0), Some(1.0));
81 assert_eq!(curve.get(3), Some(0.4));
82 }
83
84 #[test]
85 fn test_reference_curve_from_json_invalid() {
86 let json = "not valid json";
87 assert!(ReferenceCurve::from_json(json).is_err());
88 }
89
90 #[test]
91 fn test_reference_curve_check_deviation_within_tolerance() {
92 let curve = ReferenceCurve::new(vec![1.0], 0.1);
93 assert!(curve.check_deviation(0, 1.05).is_none());
95 }
96
97 #[test]
98 fn test_reference_curve_check_deviation_exceeds_tolerance() {
99 let curve = ReferenceCurve::new(vec![1.0], 0.1);
100 let deviation = curve.check_deviation(0, 1.15);
102 assert!(deviation.is_some());
103 assert!((deviation.expect("operation should succeed") - 0.15).abs() < 0.01);
104 }
105
106 #[test]
107 fn test_reference_curve_check_deviation_no_reference() {
108 let curve = ReferenceCurve::new(vec![1.0], 0.1);
109 assert!(curve.check_deviation(5, 1.0).is_none());
110 }
111
112 #[test]
113 fn test_reference_curve_comparison_sparkline() {
114 let curve = ReferenceCurve::new(vec![1.0, 0.8, 0.6, 0.4], 0.1);
115 let current = vec![1.0, 0.9, 0.5, 0.4];
116 let sparkline = curve.comparison_sparkline(¤t, 4);
117 assert_eq!(sparkline.chars().count(), 4);
118 }
119
120 #[test]
121 fn test_reference_curve_comparison_sparkline_empty() {
122 let curve = ReferenceCurve::new(vec![], 0.1);
123 let current = vec![1.0];
124 assert_eq!(curve.comparison_sparkline(¤t, 4), "");
125 }
126}