radiate_utils/stats/
slope.rs1use crate::{Float, stats::statistics::Adder};
2#[cfg(feature = "serde")]
3use serde::{Deserialize, Serialize};
4
5#[derive(PartialEq, Clone)]
6#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
7pub struct Slope<F: Float> {
8 sum_y: Adder<F>,
9 sum_xy: Adder<F>,
10 count: u32,
11}
12
13impl<F: Float> Slope<F> {
14 pub fn new() -> Self {
15 Self {
16 sum_y: Adder::<F>::default(),
17 sum_xy: Adder::<F>::default(),
18 count: 0,
19 }
20 }
21
22 pub fn count(&self) -> u32 {
23 self.count
24 }
25
26 pub fn add(&mut self, value: F) {
27 let x = F::from(self.count).unwrap_or(F::ZERO);
28 self.sum_y.add(value);
29 self.sum_xy.add(x * value);
30 self.count += 1;
31 }
32
33 pub fn value(&self) -> Option<F> {
34 if self.count < 2 {
35 return None;
36 }
37
38 let n = F::from(self.count)?;
39 let one = F::ONE;
40 let two = F::from(2.0)?;
41 let six = F::from(6.0)?;
42
43 let sum_x = n * (n - one) / two;
44 let sum_x2 = n * (n - one) * (two * n - one) / six;
45
46 let sum_y = self.sum_y.value();
47 let sum_xy = self.sum_xy.value();
48
49 let numerator = n * sum_xy - sum_x * sum_y;
50 let denominator = n * sum_x2 - sum_x * sum_x;
51
52 if denominator.abs() < F::EPS {
53 None
54 } else {
55 Some(numerator / denominator)
56 }
57 }
58
59 pub fn clear(&mut self) {
60 self.sum_y = Adder::default();
61 self.sum_xy = Adder::default();
62 self.count = 0;
63 }
64}
65
66impl<F: Float> Extend<F> for Slope<F> {
67 fn extend<T: IntoIterator<Item = F>>(&mut self, iter: T) {
68 for value in iter {
69 self.add(value);
70 }
71 }
72}
73
74impl<F: Float> FromIterator<F> for Slope<F> {
75 fn from_iter<T: IntoIterator<Item = F>>(iter: T) -> Self {
76 let mut slope = Slope::new();
77 slope.extend(iter);
78 slope
79 }
80}
81
82impl<F: Float> Default for Slope<F> {
83 fn default() -> Self {
84 Self::new()
85 }
86}
87
88impl<F: Float> std::fmt::Debug for Slope<F> {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 f.debug_struct("Slope")
91 .field("sum_y", &self.sum_y.value())
92 .field("sum_xy", &self.sum_xy.value())
93 .field("count", &self.count)
94 .finish()
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101
102 #[test]
103 fn test_slope_increasing_line() {
104 let mut slope = Slope::<f32>::new();
105 slope.add(1.0);
106 slope.add(2.0);
107 slope.add(3.0);
108 slope.add(4.0);
109
110 let value = slope.value().unwrap();
111 assert!((value - 1.0).abs() < 1e-6);
112 }
113
114 #[test]
115 fn test_slope_flat_line() {
116 let mut slope = Slope::<f32>::new();
117 slope.add(5.0);
118 slope.add(5.0);
119 slope.add(5.0);
120 slope.add(5.0);
121
122 let value = slope.value().unwrap();
123 assert!(value.abs() < 1e-6);
124 }
125
126 #[test]
127 fn test_slope_decreasing_line() {
128 let mut slope = Slope::<f32>::new();
129 slope.add(4.0);
130 slope.add(3.0);
131 slope.add(2.0);
132 slope.add(1.0);
133
134 let value = slope.value().unwrap();
135 assert!((value + 1.0).abs() < 1e-6);
136 }
137
138 #[test]
139 fn test_slope_two_points() {
140 let mut slope = Slope::<f32>::new();
141 slope.add(2.0);
142 slope.add(6.0);
143
144 let value = slope.value().unwrap();
145 assert!((value - 4.0).abs() < 1e-6);
146 }
147}