Skip to main content

radiate_core/stats/
statistics.rs

1use core::f32;
2#[cfg(feature = "serde")]
3use serde::{Deserialize, Serialize};
4
5#[derive(PartialEq, Clone)]
6#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
7pub struct Adder {
8    compensation: f32,
9    simple_sum: f32,
10    sum: f32,
11}
12
13impl Adder {
14    pub fn value(&self) -> f32 {
15        let result = self.sum + self.compensation;
16        if result.is_nan() {
17            self.simple_sum
18        } else {
19            result
20        }
21    }
22
23    pub fn add(&mut self, value: f32) {
24        let y = value - self.compensation;
25        let t = self.sum + y;
26
27        self.compensation = (t - self.sum) - y;
28        self.sum = t;
29        self.simple_sum += value;
30    }
31}
32
33impl Default for Adder {
34    fn default() -> Self {
35        Adder {
36            compensation: 0_f32,
37            simple_sum: 0_f32,
38            sum: 0_f32,
39        }
40    }
41}
42
43#[derive(PartialEq, Clone)]
44#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
45pub struct Statistic {
46    m1: Adder,
47    m2: Adder,
48    m3: Adder,
49    m4: Adder,
50    sum: Adder,
51    count: i32,
52    last_value: f32,
53    max: f32,
54    min: f32,
55}
56
57impl Statistic {
58    pub fn new(initial_val: f32) -> Self {
59        let mut result = Statistic::default();
60        result.add(initial_val);
61        result
62    }
63
64    pub fn last_value(&self) -> f32 {
65        self.last_value
66    }
67
68    pub fn count(&self) -> i32 {
69        self.count
70    }
71
72    pub fn min(&self) -> f32 {
73        self.min
74    }
75
76    pub fn max(&self) -> f32 {
77        self.max
78    }
79
80    pub fn mean(&self) -> f32 {
81        if self.count == 0 {
82            0_f32
83        } else {
84            self.m1.value()
85        }
86    }
87
88    pub fn sum(&self) -> f32 {
89        self.sum.value()
90    }
91
92    #[inline(always)]
93    pub fn variance(&self) -> f32 {
94        let mut value = f32::NAN;
95        if self.count == 1 {
96            value = self.m2.value();
97        } else if self.count > 1 {
98            value = self.m2.value() / (self.count - 1) as f32;
99        }
100
101        value
102    }
103
104    #[inline(always)]
105    pub fn std_dev(&self) -> f32 {
106        self.variance().sqrt()
107    }
108
109    #[inline(always)]
110    pub fn skewness(&self) -> f32 {
111        let mut value = f32::NAN;
112        if self.count >= 3 {
113            let temp = self.m2.value() / self.count as f32 - 1_f32;
114            if temp < 10e-10_f32 {
115                value = 0_f32;
116            } else {
117                value = self.count as f32 * self.m3.value()
118                    / ((self.count as f32 - 1_f32)
119                        * (self.count as f32 - 2_f32)
120                        * temp.sqrt()
121                        * temp)
122            }
123        }
124
125        value
126    }
127
128    #[inline(always)]
129    pub fn kurtosis(&self) -> f32 {
130        let mut value = f32::NAN;
131        if self.count >= 4 {
132            let temp = self.m2.value() / self.count as f32 - 1_f32;
133            if temp < 10e-10_f32 {
134                value = 0_f32;
135            } else {
136                value = self.count as f32 * (self.count as f32 + 1_f32) * self.m4.value()
137                    / ((self.count as f32 - 1_f32)
138                        * (self.count as f32 - 2_f32)
139                        * (self.count as f32 - 3_f32)
140                        * temp
141                        * temp)
142            }
143        }
144
145        value
146    }
147
148    #[inline(always)]
149    pub fn add(&mut self, value: f32) {
150        self.count += 1;
151
152        let n = self.count as f32;
153        let d = value - self.m1.value();
154        let dn = d / n;
155        let dn2 = dn * dn;
156        let t1 = d * dn * (n - 1_f32);
157
158        self.m1.add(dn);
159
160        self.m4.add(t1 * dn2 * (n * n - 3_f32 * n + 3_f32));
161        self.m4
162            .add(6_f32 * dn2 * self.m2.value() - 4_f32 * dn * self.m3.value());
163
164        self.m3
165            .add(t1 * dn * (n - 2_f32) - 3_f32 * dn * self.m2.value());
166        self.m2.add(t1);
167
168        self.last_value = value;
169        self.max = if value > self.max { value } else { self.max };
170        self.min = if value < self.min { value } else { self.min };
171        self.sum.add(value);
172    }
173
174    pub fn clear(&mut self) {
175        self.m1 = Adder::default();
176        self.m2 = Adder::default();
177        self.m3 = Adder::default();
178        self.m4 = Adder::default();
179        self.sum = Adder::default();
180        self.count = 0;
181        self.last_value = 0_f32;
182        self.max = f32::MIN;
183        self.min = f32::MAX;
184    }
185
186    pub fn merge(&mut self, other: &Statistic) {
187        if other.count == 0 {
188            return;
189        }
190
191        if self.count == 0 {
192            *self = other.clone();
193            return;
194        }
195
196        if other.count == 1 {
197            self.add(other.last_value);
198            return;
199        }
200
201        if self.count == 1 {
202            let last_value = self.last_value;
203            *self = other.clone();
204            self.add(last_value);
205            return;
206        }
207
208        // Use f64 for more accurate intermediate math
209        let n1 = self.count as f64;
210        let n2 = other.count as f64;
211
212        let mean1 = self.m1.value() as f64;
213        let mean2 = other.m1.value() as f64;
214
215        let m21 = self.m2.value() as f64;
216        let m22 = other.m2.value() as f64;
217        let m31 = self.m3.value() as f64;
218        let m32 = other.m3.value() as f64;
219        let m41 = self.m4.value() as f64;
220        let m42 = other.m4.value() as f64;
221
222        let n = n1 + n2;
223        let delta = mean2 - mean1;
224        let delta2 = delta * delta;
225        let delta3 = delta2 * delta;
226        let delta4 = delta3 * delta;
227        let n1n2 = n1 * n2;
228
229        // Combined mean and moments (Pébay formulas)
230        let mean = (n1 * mean1 + n2 * mean2) / n;
231
232        let m2 = m21 + m22 + delta2 * n1n2 / n;
233
234        let m3 = m31
235            + m32
236            + delta3 * n1n2 * (n1 - n2) / (n * n)
237            + 3.0 * delta * (n1 * m22 - n2 * m21) / n;
238
239        let m4 = m41
240            + m42
241            + delta4 * n1n2 * (n1 * n1 - n1 * n2 + n2 * n2) / (n * n * n)
242            + 6.0 * delta2 * (n1 * n1 * m22 + n2 * n2 * m21) / (n * n)
243            + 4.0 * delta * (n1 * m32 - n2 * m31) / n;
244
245        // Write back into Kahan adders.
246        // Using `Adder::default()` + single `add` is fine:
247        self.m1 = Adder::default();
248        self.m1.add(mean as f32);
249
250        self.m2 = Adder::default();
251        self.m2.add(m2 as f32);
252
253        self.m3 = Adder::default();
254        self.m3.add(m3 as f32);
255
256        self.m4 = Adder::default();
257        self.m4.add(m4 as f32);
258
259        // Merge auxiliary stats
260        self.sum.add(other.sum()); // preserves Kahan accuracy for the total sum
261        self.count += other.count;
262        self.max = self.max.max(other.max);
263        self.min = self.min.min(other.min);
264
265        // "last_value" is a bit semantic; assuming `other` is later in time:
266        self.last_value = other.last_value;
267    }
268
269    /// Convenience: return a merged copy instead of mutating in-place
270    pub fn merged(mut self, other: &Statistic) -> Statistic {
271        self.merge(other);
272        self
273    }
274}
275
276impl Default for Statistic {
277    fn default() -> Self {
278        Statistic {
279            m1: Adder::default(),
280            m2: Adder::default(),
281            m3: Adder::default(),
282            m4: Adder::default(),
283            sum: Adder::default(),
284            count: 0,
285            last_value: 0_f32,
286            max: f32::MIN,
287            min: f32::MAX,
288        }
289    }
290}
291
292impl From<f32> for Statistic {
293    fn from(value: f32) -> Self {
294        Statistic::new(value)
295    }
296}
297
298impl From<i32> for Statistic {
299    fn from(value: i32) -> Self {
300        Statistic::new(value as f32)
301    }
302}
303
304impl From<usize> for Statistic {
305    fn from(value: usize) -> Self {
306        Statistic::new(value as f32)
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313
314    #[test]
315    fn test_adder() {
316        let mut adder = Adder::default();
317        adder.add(1_f32);
318        adder.add(2_f32);
319        adder.add(3_f32);
320        adder.add(4_f32);
321        adder.add(5_f32);
322
323        assert_eq!(adder.value(), 15_f32);
324    }
325
326    #[test]
327    fn test_statistic() {
328        let mut statistic = Statistic::default();
329        statistic.add(1_f32);
330        statistic.add(2_f32);
331        statistic.add(3_f32);
332        statistic.add(4_f32);
333        statistic.add(5_f32);
334
335        assert_eq!(statistic.mean(), 3_f32);
336        assert_eq!(statistic.variance(), 2.5_f32);
337        assert_eq!(statistic.std_dev(), 1.5811388_f32);
338        assert_eq!(statistic.skewness(), 0_f32);
339    }
340
341    #[test]
342    fn test_statistic_merge() {
343        let mut stat1 = Statistic::default();
344        stat1.add(1_f32);
345        stat1.add(2_f32);
346        stat1.add(3_f32);
347
348        let mut stat2 = Statistic::default();
349        stat2.add(4_f32);
350        stat2.add(5_f32);
351        stat2.add(6_f32);
352
353        let merged = stat1.merged(&stat2);
354        assert_eq!(merged.mean(), 3.5_f32);
355        assert_eq!(merged.variance(), 3.5_f32);
356        assert_eq!(merged.std_dev(), 1.8708287_f32);
357        assert_eq!(merged.skewness(), 0_f32);
358        assert_eq!(merged.count(), 6);
359        assert_eq!(merged.min(), 1_f32);
360        assert_eq!(merged.max(), 6_f32);
361        assert_eq!(merged.sum(), 21_f32);
362    }
363}