ra_ap_test_utils/
assert_linear.rs

1//! Checks that a set of measurements looks like a linear function rather than
2//! like a quadratic function. Algorithm:
3//!
4//! 1. Linearly scale input to be in [0; 1)
5//! 2. Using linear regression, compute the best linear function approximating
6//!    the input.
7//! 3. Compute RMSE and  maximal absolute error.
8//! 4. Check that errors are within tolerances and that the constant term is not
9//!    too negative.
10//!
11//! Ideally, we should use a proper "model selection" to directly compare
12//! quadratic and linear models, but that sounds rather complicated:
13//!
14//! > https://stats.stackexchange.com/questions/21844/selecting-best-model-based-on-linear-quadratic-and-cubic-fit-of-data
15//!
16//! We might get false positives on a VM, but never false negatives. So, if the
17//! first round fails, we repeat the ordeal three more times and fail only if
18//! every time there's a fault.
19use stdx::format_to;
20
21#[derive(Default)]
22pub struct AssertLinear {
23    rounds: Vec<Round>,
24}
25
26#[derive(Default)]
27struct Round {
28    samples: Vec<(f64, f64)>,
29    plot: String,
30    linear: bool,
31}
32
33impl AssertLinear {
34    pub fn next_round(&mut self) -> bool {
35        if let Some(round) = self.rounds.last_mut() {
36            round.finish();
37        }
38        if self.rounds.iter().any(|it| it.linear) || self.rounds.len() == 4 {
39            return false;
40        }
41        self.rounds.push(Round::default());
42        true
43    }
44
45    pub fn sample(&mut self, x: f64, y: f64) {
46        self.rounds.last_mut().unwrap().samples.push((x, y));
47    }
48}
49
50impl Drop for AssertLinear {
51    fn drop(&mut self) {
52        assert!(!self.rounds.is_empty());
53        if self.rounds.iter().all(|it| !it.linear) {
54            for round in &self.rounds {
55                eprintln!("\n{}", round.plot);
56            }
57            panic!("Doesn't look linear!");
58        }
59    }
60}
61
62impl Round {
63    fn finish(&mut self) {
64        let (mut xs, mut ys): (Vec<_>, Vec<_>) = self.samples.iter().copied().unzip();
65        normalize(&mut xs);
66        normalize(&mut ys);
67        let xy = xs.iter().copied().zip(ys.iter().copied());
68
69        // Linear regression: finding a and b to fit y = a + b*x.
70
71        let mean_x = mean(&xs);
72        let mean_y = mean(&ys);
73
74        let b = {
75            let mut num = 0.0;
76            let mut denom = 0.0;
77            for (x, y) in xy.clone() {
78                num += (x - mean_x) * (y - mean_y);
79                denom += (x - mean_x).powi(2);
80            }
81            num / denom
82        };
83
84        let a = mean_y - b * mean_x;
85
86        self.plot = format!("y_pred = {a:.3} + {b:.3} * x\n\nx     y     y_pred\n");
87
88        let mut se = 0.0;
89        let mut max_error = 0.0f64;
90        for (x, y) in xy {
91            let y_pred = a + b * x;
92            se += (y - y_pred).powi(2);
93            max_error = max_error.max((y_pred - y).abs());
94
95            format_to!(self.plot, "{:.3} {:.3} {:.3}\n", x, y, y_pred);
96        }
97
98        let rmse = (se / xs.len() as f64).sqrt();
99        format_to!(self.plot, "\nrmse = {:.3} max error = {:.3}", rmse, max_error);
100
101        self.linear = rmse < 0.05 && max_error < 0.1 && a > -0.1;
102
103        fn normalize(xs: &mut [f64]) {
104            let max = xs.iter().copied().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
105            xs.iter_mut().for_each(|it| *it /= max);
106        }
107
108        fn mean(xs: &[f64]) -> f64 {
109            xs.iter().copied().sum::<f64>() / (xs.len() as f64)
110        }
111    }
112}