Skip to main content

radiate_utils/stats/
statistics.rs

1use core::f32;
2#[cfg(feature = "serde")]
3use serde::{Deserialize, Serialize};
4use std::{fmt::Debug, hash::Hash};
5
6use crate::{Float, Primitive};
7
8#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9#[derive(PartialEq, Clone)]
10pub struct Adder<F: Float = f32> {
11    compensation: F,
12    simple_sum: F,
13    sum: F,
14}
15
16impl<F: Float> Adder<F> {
17    pub fn value(&self) -> F {
18        let result = self.sum + self.compensation;
19        if result.is_nan() {
20            self.simple_sum
21        } else {
22            result
23        }
24    }
25
26    pub fn add(&mut self, value: F) {
27        let y = value - self.compensation;
28        let t = self.sum + y;
29
30        self.compensation = (t - self.sum) - y;
31        self.sum = t;
32        self.simple_sum = self.simple_sum + value;
33    }
34}
35
36impl<F: Float> Default for Adder<F> {
37    fn default() -> Self {
38        Adder {
39            compensation: F::ZERO,
40            simple_sum: F::ZERO,
41            sum: F::ZERO,
42        }
43    }
44}
45
46#[derive(PartialEq, Clone)]
47#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
48pub struct Statistic<F: Float = f32> {
49    m1: Adder<F>,
50    m2: Adder<F>,
51    m3: Adder<F>,
52    m4: Adder<F>,
53    sum: Adder<F>,
54    count: u32,
55    last_value: F,
56    max: F,
57    min: F,
58}
59
60impl<F: Float> Statistic<F> {
61    pub fn new(initial_val: F) -> Self {
62        let mut result = Statistic::default();
63        result.add(initial_val);
64        result
65    }
66
67    pub fn last_value(&self) -> F {
68        self.last_value
69    }
70
71    pub fn count(&self) -> u32 {
72        self.count
73    }
74
75    pub fn min(&self) -> F {
76        self.min
77    }
78
79    pub fn max(&self) -> F {
80        self.max
81    }
82
83    pub fn mean(&self) -> F {
84        if self.count == 0 {
85            F::ZERO
86        } else {
87            self.m1.value()
88        }
89    }
90
91    pub fn sum(&self) -> F {
92        self.sum.value()
93    }
94
95    #[inline(always)]
96    pub fn variance(&self) -> Option<F> {
97        let mut value = F::MIN;
98        if self.count == 1 {
99            value = self.m2.value();
100        } else if self.count > 1 {
101            value = self.m2.value() / (F::from(self.count)? - F::ONE);
102        } else if self.count == 0 {
103            return None;
104        }
105
106        Some(value)
107    }
108
109    #[inline(always)]
110    pub fn std_dev(&self) -> Option<F> {
111        Some(self.variance()?.sqrt())
112    }
113
114    #[inline(always)]
115    pub fn skewness(&self) -> Option<F> {
116        let mut value = F::NAN;
117        let count = F::from(self.count)?;
118        if self.count >= 3 {
119            let temp = self.m2.value() / count - F::ONE;
120            if temp < F::EPS {
121                value = F::ZERO;
122            } else {
123                value = count * self.m3.value()
124                    / ((count - F::ONE) * (count - F::TWO) * temp.sqrt() * temp)
125            }
126        }
127
128        Some(value)
129    }
130
131    #[inline(always)]
132    pub fn kurtosis(&self) -> Option<F> {
133        let mut value = F::NAN;
134        let count = F::from(self.count)?;
135
136        if self.count >= 4 {
137            let temp = self.m2.value() / count - F::ONE;
138            if temp < F::EPS {
139                value = F::ZERO;
140            } else {
141                value = count * (count + F::ONE) * self.m4.value()
142                    / ((count - F::ONE) * (count - F::TWO) * (count - F::THREE) * temp * temp)
143            }
144        }
145
146        Some(value)
147    }
148
149    #[inline(always)]
150    pub fn add(&mut self, value: F) -> Option<()> {
151        self.count += 1;
152
153        let n = F::from(self.count)?;
154        let d = value - self.m1.value();
155        let dn = d / n;
156        let dn2 = dn * dn;
157        let t1 = d * dn * (n - F::ONE);
158
159        self.m1.add(dn);
160
161        self.m4.add(t1 * dn2 * (n * n - F::THREE * n + F::THREE));
162        self.m4
163            .add(F::SIX * dn2 * self.m2.value() - F::FOUR * dn * self.m3.value());
164
165        self.m3
166            .add(t1 * dn * (n - F::TWO) - F::THREE * dn * self.m2.value());
167        self.m2.add(t1);
168
169        self.last_value = value;
170        self.max = if value > self.max { value } else { self.max };
171        self.min = if value < self.min { value } else { self.min };
172        self.sum.add(value);
173
174        Some(())
175    }
176
177    pub fn clear(&mut self) {
178        self.m1 = Adder::default();
179        self.m2 = Adder::default();
180        self.m3 = Adder::default();
181        self.m4 = Adder::default();
182        self.sum = Adder::default();
183        self.count = 0;
184        self.last_value = F::ZERO;
185        self.max = F::MIN;
186        self.min = F::MAX;
187    }
188
189    pub fn merge(&mut self, other: &Statistic<F>) {
190        if other.count == 0 {
191            return;
192        }
193
194        if self.count == 0 {
195            *self = other.clone();
196            return;
197        }
198
199        if other.count == 1 {
200            self.add(other.last_value);
201            return;
202        }
203
204        if self.count == 1 {
205            let last_value = self.last_value;
206            *self = other.clone();
207            self.add(last_value);
208            return;
209        }
210
211        // Use f64 for more accurate intermediate math
212        let n1 = F::from(self.count).unwrap_or(F::ZERO);
213        let n2 = F::from(other.count).unwrap_or(F::ZERO);
214
215        let mean1 = self.m1.value();
216        let mean2 = other.m1.value();
217
218        let m21 = self.m2.value();
219        let m22 = other.m2.value();
220        let m31 = self.m3.value();
221        let m32 = other.m3.value();
222        let m41 = self.m4.value();
223        let m42 = other.m4.value();
224
225        let n = n1 + n2;
226        let delta = mean2 - mean1;
227        let delta2 = delta * delta;
228        let delta3 = delta2 * delta;
229        let delta4 = delta3 * delta;
230        let n1n2 = n1 * n2;
231
232        // Combined mean and moments (Pébay formulas)
233        let mean = (n1 * mean1 + n2 * mean2) / n;
234
235        let m2 = m21 + m22 + delta2 * n1n2 / n;
236
237        let m3 = m31
238            + m32
239            + delta3 * n1n2 * (n1 - n2) / (n * n)
240            + F::THREE * delta * (n1 * m22 - n2 * m21) / n;
241
242        let m4 = m41
243            + m42
244            + delta4 * n1n2 * (n1 * n1 - n1 * n2 + n2 * n2) / (n * n * n)
245            + F::SIX * delta2 * (n1 * n1 * m22 + n2 * n2 * m21) / (n * n)
246            + F::FOUR * delta * (n1 * m32 - n2 * m31) / n;
247
248        // Write back into Kahan adders.
249        // Using `Adder::default()` + single `add` is fine:
250        self.m1 = Adder::default();
251        self.m1.add(mean);
252
253        self.m2 = Adder::default();
254        self.m2.add(m2);
255
256        self.m3 = Adder::default();
257        self.m3.add(m3);
258
259        self.m4 = Adder::default();
260        self.m4.add(m4);
261
262        // Merge auxiliary stats
263        self.sum.add(other.sum()); // preserves Kahan accuracy for the total sum
264        self.count += other.count;
265        self.max = self.max.max(other.max);
266        self.min = self.min.min(other.min);
267
268        // "last_value" is a bit semantic; assuming `other` is later in time:
269        self.last_value = other.last_value;
270    }
271
272    /// Convenience: return a merged copy instead of mutating in-place
273    pub fn merged(mut self, other: &Statistic<F>) -> Statistic<F> {
274        self.merge(other);
275        self
276    }
277}
278
279impl<T: Primitive, F: Float> FromIterator<T> for Statistic<F> {
280    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
281        let mut statistic = Statistic::<F>::default();
282        for item in iter {
283            if let Some(value) = item.extract::<F>() {
284                statistic.add(value);
285            }
286        }
287        statistic
288    }
289}
290
291impl From<f32> for Statistic {
292    fn from(value: f32) -> Self {
293        Statistic::new(value)
294    }
295}
296
297impl From<i32> for Statistic {
298    fn from(value: i32) -> Self {
299        Statistic::new(value as f32)
300    }
301}
302
303impl From<usize> for Statistic {
304    fn from(value: usize) -> Self {
305        Statistic::new(value as f32)
306    }
307}
308
309impl<F: Float> Default for Statistic<F> {
310    fn default() -> Self {
311        Statistic {
312            m1: Adder::default(),
313            m2: Adder::default(),
314            m3: Adder::default(),
315            m4: Adder::default(),
316            sum: Adder::default(),
317            count: 0,
318            last_value: F::ZERO,
319            max: F::MIN,
320            min: F::MAX,
321        }
322    }
323}
324
325impl<F: Float> Hash for Statistic<F> {
326    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
327        self.count.hash(state);
328        self.last_value.num_hash(state);
329        self.max.num_hash(state);
330        self.min.num_hash(state);
331        self.sum.value().num_hash(state);
332        self.m1.value().num_hash(state);
333        self.m2.value().num_hash(state);
334        self.m3.value().num_hash(state);
335        self.m4.value().num_hash(state);
336    }
337}
338
339impl<F: Debug + Float> Debug for Statistic<F> {
340    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
341        f.debug_struct("Statistic")
342            .field("count", &self.count)
343            .field("last_value", &self.last_value)
344            .field("max", &self.max)
345            .field("min", &self.min)
346            .field("sum", &self.sum.value())
347            .field("mean", &self.mean())
348            .field("variance", &self.variance())
349            .field("std_dev", &self.std_dev())
350            .field("skewness", &self.skewness())
351            .field("kurtosis", &self.kurtosis())
352            .finish()
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359
360    #[test]
361    fn test_adder() {
362        let mut adder = Adder::default();
363        adder.add(1_f32);
364        adder.add(2_f32);
365        adder.add(3_f32);
366        adder.add(4_f32);
367        adder.add(5_f32);
368
369        assert_eq!(adder.value(), 15_f32);
370    }
371
372    #[test]
373    fn test_statistic() {
374        let mut statistic = Statistic::<f32>::default();
375        statistic.add(1_f32);
376        statistic.add(2_f32);
377        statistic.add(3_f32);
378        statistic.add(4_f32);
379        statistic.add(5_f32);
380
381        assert_eq!(statistic.mean(), 3_f32);
382        assert_eq!(statistic.variance().unwrap(), 2.5_f32);
383        assert_eq!(statistic.std_dev().unwrap(), 1.5811388_f32);
384        assert_eq!(statistic.skewness().unwrap(), 0_f32);
385    }
386
387    #[test]
388    fn test_statistic_merge() {
389        let mut stat1 = Statistic::default();
390        stat1.add(1_f32);
391        stat1.add(2_f32);
392        stat1.add(3_f32);
393
394        let mut stat2 = Statistic::default();
395        stat2.add(4_f32);
396        stat2.add(5_f32);
397        stat2.add(6_f32);
398
399        let merged = stat1.merged(&stat2);
400        assert_eq!(merged.mean(), 3.5_f32);
401        assert_eq!(merged.variance().unwrap(), 3.5_f32);
402        assert_eq!(merged.std_dev().unwrap(), 1.8708287_f32);
403        assert_eq!(merged.skewness().unwrap(), 0_f32);
404        assert_eq!(merged.count(), 6);
405        assert_eq!(merged.min(), 1_f32);
406        assert_eq!(merged.max(), 6_f32);
407        assert_eq!(merged.sum(), 21_f32);
408    }
409}