Skip to main content

radiate_utils/stats/
slope.rs

1use crate::{Float, stats::statistics::Adder};
2#[cfg(feature = "serde")]
3use serde::{Deserialize, Serialize};
4
5#[derive(PartialEq, Clone)]
6#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
7pub struct Slope<F: Float> {
8    sum_y: Adder<F>,
9    sum_xy: Adder<F>,
10    count: u32,
11}
12
13impl<F: Float> Slope<F> {
14    pub fn new() -> Self {
15        Self {
16            sum_y: Adder::<F>::default(),
17            sum_xy: Adder::<F>::default(),
18            count: 0,
19        }
20    }
21
22    pub fn count(&self) -> u32 {
23        self.count
24    }
25
26    pub fn add(&mut self, value: F) {
27        let x = F::from(self.count).unwrap_or(F::ZERO);
28        self.sum_y.add(value);
29        self.sum_xy.add(x * value);
30        self.count += 1;
31    }
32
33    pub fn value(&self) -> Option<F> {
34        if self.count < 2 {
35            return None;
36        }
37
38        let n = F::from(self.count)?;
39        let one = F::ONE;
40        let two = F::from(2.0)?;
41        let six = F::from(6.0)?;
42
43        let sum_x = n * (n - one) / two;
44        let sum_x2 = n * (n - one) * (two * n - one) / six;
45
46        let sum_y = self.sum_y.value();
47        let sum_xy = self.sum_xy.value();
48
49        let numerator = n * sum_xy - sum_x * sum_y;
50        let denominator = n * sum_x2 - sum_x * sum_x;
51
52        if denominator.abs() < F::EPS {
53            None
54        } else {
55            Some(numerator / denominator)
56        }
57    }
58
59    pub fn clear(&mut self) {
60        self.sum_y = Adder::default();
61        self.sum_xy = Adder::default();
62        self.count = 0;
63    }
64}
65
66impl<F: Float> Extend<F> for Slope<F> {
67    fn extend<T: IntoIterator<Item = F>>(&mut self, iter: T) {
68        for value in iter {
69            self.add(value);
70        }
71    }
72}
73
74impl<F: Float> FromIterator<F> for Slope<F> {
75    fn from_iter<T: IntoIterator<Item = F>>(iter: T) -> Self {
76        let mut slope = Slope::new();
77        slope.extend(iter);
78        slope
79    }
80}
81
82impl<F: Float> Default for Slope<F> {
83    fn default() -> Self {
84        Self::new()
85    }
86}
87
88impl<F: Float> std::fmt::Debug for Slope<F> {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        f.debug_struct("Slope")
91            .field("sum_y", &self.sum_y.value())
92            .field("sum_xy", &self.sum_xy.value())
93            .field("count", &self.count)
94            .finish()
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101
102    #[test]
103    fn test_slope_increasing_line() {
104        let mut slope = Slope::<f32>::new();
105        slope.add(1.0);
106        slope.add(2.0);
107        slope.add(3.0);
108        slope.add(4.0);
109
110        let value = slope.value().unwrap();
111        assert!((value - 1.0).abs() < 1e-6);
112    }
113
114    #[test]
115    fn test_slope_flat_line() {
116        let mut slope = Slope::<f32>::new();
117        slope.add(5.0);
118        slope.add(5.0);
119        slope.add(5.0);
120        slope.add(5.0);
121
122        let value = slope.value().unwrap();
123        assert!(value.abs() < 1e-6);
124    }
125
126    #[test]
127    fn test_slope_decreasing_line() {
128        let mut slope = Slope::<f32>::new();
129        slope.add(4.0);
130        slope.add(3.0);
131        slope.add(2.0);
132        slope.add(1.0);
133
134        let value = slope.value().unwrap();
135        assert!((value + 1.0).abs() < 1e-6);
136    }
137
138    #[test]
139    fn test_slope_two_points() {
140        let mut slope = Slope::<f32>::new();
141        slope.add(2.0);
142        slope.add(6.0);
143
144        let value = slope.value().unwrap();
145        assert!((value - 4.0).abs() < 1e-6);
146    }
147}