Skip to main content

radiate_utils/stats/
statistics.rs

1use core::f32;
2#[cfg(feature = "serde")]
3use serde::{Deserialize, Serialize};
4use std::hash::Hash;
5
6use crate::{Float, Primitive};
7
8#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9#[derive(PartialEq, Clone)]
10pub struct Adder<F: Float = f32> {
11    compensation: F,
12    simple_sum: F,
13    sum: F,
14}
15
16impl<F: Float> Adder<F> {
17    pub fn value(&self) -> F {
18        let result = self.sum + self.compensation;
19        if result.is_nan() {
20            self.simple_sum
21        } else {
22            result
23        }
24    }
25
26    pub fn add(&mut self, value: F) {
27        let y = value - self.compensation;
28        let t = self.sum + y;
29
30        self.compensation = (t - self.sum) - y;
31        self.sum = t;
32        self.simple_sum = self.simple_sum + value;
33    }
34}
35
36impl<F: Float> Default for Adder<F> {
37    fn default() -> Self {
38        Adder {
39            compensation: F::ZERO,
40            simple_sum: F::ZERO,
41            sum: F::ZERO,
42        }
43    }
44}
45
46#[derive(PartialEq, Clone)]
47#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
48pub struct Statistic<F: Float = f32> {
49    m1: Adder<F>,
50    m2: Adder<F>,
51    m3: Adder<F>,
52    m4: Adder<F>,
53    sum: Adder<F>,
54    count: i32,
55    last_value: F,
56    max: F,
57    min: F,
58}
59
60impl<F: Float> Statistic<F> {
61    pub fn new(initial_val: F) -> Self {
62        let mut result = Statistic::default();
63        result.add(initial_val);
64        result
65    }
66
67    pub fn last_value(&self) -> F {
68        self.last_value
69    }
70
71    pub fn count(&self) -> i32 {
72        self.count
73    }
74
75    pub fn min(&self) -> F {
76        self.min
77    }
78
79    pub fn max(&self) -> F {
80        self.max
81    }
82
83    pub fn mean(&self) -> F {
84        if self.count == 0 {
85            F::ZERO
86        } else {
87            self.m1.value()
88        }
89    }
90
91    pub fn sum(&self) -> F {
92        self.sum.value()
93    }
94
95    #[inline(always)]
96    pub fn variance(&self) -> Option<F> {
97        let mut value = F::MIN;
98        if self.count == 1 {
99            value = self.m2.value();
100        } else if self.count > 1 {
101            value = self.m2.value() / (F::from(self.count)? - F::ONE);
102        }
103
104        Some(value)
105    }
106
107    #[inline(always)]
108    pub fn std_dev(&self) -> Option<F> {
109        Some(self.variance()?.sqrt())
110    }
111
112    #[inline(always)]
113    pub fn skewness(&self) -> Option<F> {
114        let mut value = F::NAN;
115        let count = F::from(self.count)?;
116        if self.count >= 3 {
117            let temp = self.m2.value() / count - F::ONE;
118            if temp < F::EPS {
119                value = F::ZERO;
120            } else {
121                value = count * self.m3.value()
122                    / ((count - F::ONE) * (count - F::TWO) * temp.sqrt() * temp)
123            }
124        }
125
126        Some(value)
127    }
128
129    #[inline(always)]
130    pub fn kurtosis(&self) -> Option<F> {
131        let mut value = F::NAN;
132        let count = F::from(self.count)?;
133
134        if self.count >= 4 {
135            let temp = self.m2.value() / count - F::ONE;
136            if temp < F::EPS {
137                value = F::ZERO;
138            } else {
139                value = count * (count + F::ONE) * self.m4.value()
140                    / ((count - F::ONE) * (count - F::TWO) * (count - F::THREE) * temp * temp)
141            }
142        }
143
144        Some(value)
145    }
146
147    #[inline(always)]
148    pub fn add(&mut self, value: F) -> Option<()> {
149        self.count += 1;
150
151        let n = F::from(self.count)?;
152        let d = value - self.m1.value();
153        let dn = d / n;
154        let dn2 = dn * dn;
155        let t1 = d * dn * (n - F::ONE);
156
157        self.m1.add(dn);
158
159        self.m4.add(t1 * dn2 * (n * n - F::THREE * n + F::THREE));
160        self.m4
161            .add(F::SIX * dn2 * self.m2.value() - F::FOUR * dn * self.m3.value());
162
163        self.m3
164            .add(t1 * dn * (n - F::TWO) - F::THREE * dn * self.m2.value());
165        self.m2.add(t1);
166
167        self.last_value = value;
168        self.max = if value > self.max { value } else { self.max };
169        self.min = if value < self.min { value } else { self.min };
170        self.sum.add(value);
171
172        Some(())
173    }
174
175    pub fn clear(&mut self) {
176        self.m1 = Adder::default();
177        self.m2 = Adder::default();
178        self.m3 = Adder::default();
179        self.m4 = Adder::default();
180        self.sum = Adder::default();
181        self.count = 0;
182        self.last_value = F::ZERO;
183        self.max = F::MIN;
184        self.min = F::MAX;
185    }
186
187    pub fn merge(&mut self, other: &Statistic<F>) {
188        if other.count == 0 {
189            return;
190        }
191
192        if self.count == 0 {
193            *self = other.clone();
194            return;
195        }
196
197        if other.count == 1 {
198            self.add(other.last_value);
199            return;
200        }
201
202        if self.count == 1 {
203            let last_value = self.last_value;
204            *self = other.clone();
205            self.add(last_value);
206            return;
207        }
208
209        // Use f64 for more accurate intermediate math
210        let n1 = F::from(self.count).unwrap_or(F::ZERO);
211        let n2 = F::from(other.count).unwrap_or(F::ZERO);
212
213        let mean1 = self.m1.value();
214        let mean2 = other.m1.value();
215
216        let m21 = self.m2.value();
217        let m22 = other.m2.value();
218        let m31 = self.m3.value();
219        let m32 = other.m3.value();
220        let m41 = self.m4.value();
221        let m42 = other.m4.value();
222
223        let n = n1 + n2;
224        let delta = mean2 - mean1;
225        let delta2 = delta * delta;
226        let delta3 = delta2 * delta;
227        let delta4 = delta3 * delta;
228        let n1n2 = n1 * n2;
229
230        // Combined mean and moments (Pébay formulas)
231        let mean = (n1 * mean1 + n2 * mean2) / n;
232
233        let m2 = m21 + m22 + delta2 * n1n2 / n;
234
235        let m3 = m31
236            + m32
237            + delta3 * n1n2 * (n1 - n2) / (n * n)
238            + F::THREE * delta * (n1 * m22 - n2 * m21) / n;
239
240        let m4 = m41
241            + m42
242            + delta4 * n1n2 * (n1 * n1 - n1 * n2 + n2 * n2) / (n * n * n)
243            + F::SIX * delta2 * (n1 * n1 * m22 + n2 * n2 * m21) / (n * n)
244            + F::FOUR * delta * (n1 * m32 - n2 * m31) / n;
245
246        // Write back into Kahan adders.
247        // Using `Adder::default()` + single `add` is fine:
248        self.m1 = Adder::default();
249        self.m1.add(mean);
250
251        self.m2 = Adder::default();
252        self.m2.add(m2);
253
254        self.m3 = Adder::default();
255        self.m3.add(m3);
256
257        self.m4 = Adder::default();
258        self.m4.add(m4);
259
260        // Merge auxiliary stats
261        self.sum.add(other.sum()); // preserves Kahan accuracy for the total sum
262        self.count += other.count;
263        self.max = self.max.max(other.max);
264        self.min = self.min.min(other.min);
265
266        // "last_value" is a bit semantic; assuming `other` is later in time:
267        self.last_value = other.last_value;
268    }
269
270    /// Convenience: return a merged copy instead of mutating in-place
271    pub fn merged(mut self, other: &Statistic<F>) -> Statistic<F> {
272        self.merge(other);
273        self
274    }
275}
276
277impl<T: Primitive, F: Float> FromIterator<T> for Statistic<F> {
278    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
279        let mut statistic = Statistic::<F>::default();
280        for item in iter {
281            if let Some(value) = item.extract::<F>() {
282                statistic.add(value);
283            }
284        }
285        statistic
286    }
287}
288
289impl From<f32> for Statistic {
290    fn from(value: f32) -> Self {
291        Statistic::new(value)
292    }
293}
294
295impl From<i32> for Statistic {
296    fn from(value: i32) -> Self {
297        Statistic::new(value as f32)
298    }
299}
300
301impl From<usize> for Statistic {
302    fn from(value: usize) -> Self {
303        Statistic::new(value as f32)
304    }
305}
306
307impl<F: Float> Default for Statistic<F> {
308    fn default() -> Self {
309        Statistic {
310            m1: Adder::default(),
311            m2: Adder::default(),
312            m3: Adder::default(),
313            m4: Adder::default(),
314            sum: Adder::default(),
315            count: 0,
316            last_value: F::ZERO,
317            max: F::MIN,
318            min: F::MAX,
319        }
320    }
321}
322
323impl<F: Float> Hash for Statistic<F> {
324    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
325        self.count.hash(state);
326        self.last_value.num_hash(state);
327        self.max.num_hash(state);
328        self.min.num_hash(state);
329        self.sum.value().num_hash(state);
330        self.m1.value().num_hash(state);
331        self.m2.value().num_hash(state);
332        self.m3.value().num_hash(state);
333        self.m4.value().num_hash(state);
334    }
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340
341    #[test]
342    fn test_adder() {
343        let mut adder = Adder::default();
344        adder.add(1_f32);
345        adder.add(2_f32);
346        adder.add(3_f32);
347        adder.add(4_f32);
348        adder.add(5_f32);
349
350        assert_eq!(adder.value(), 15_f32);
351    }
352
353    #[test]
354    fn test_statistic() {
355        let mut statistic = Statistic::<f32>::default();
356        statistic.add(1_f32);
357        statistic.add(2_f32);
358        statistic.add(3_f32);
359        statistic.add(4_f32);
360        statistic.add(5_f32);
361
362        assert_eq!(statistic.mean(), 3_f32);
363        assert_eq!(statistic.variance().unwrap(), 2.5_f32);
364        assert_eq!(statistic.std_dev().unwrap(), 1.5811388_f32);
365        assert_eq!(statistic.skewness().unwrap(), 0_f32);
366    }
367
368    #[test]
369    fn test_statistic_merge() {
370        let mut stat1 = Statistic::default();
371        stat1.add(1_f32);
372        stat1.add(2_f32);
373        stat1.add(3_f32);
374
375        let mut stat2 = Statistic::default();
376        stat2.add(4_f32);
377        stat2.add(5_f32);
378        stat2.add(6_f32);
379
380        let merged = stat1.merged(&stat2);
381        assert_eq!(merged.mean(), 3.5_f32);
382        assert_eq!(merged.variance().unwrap(), 3.5_f32);
383        assert_eq!(merged.std_dev().unwrap(), 1.8708287_f32);
384        assert_eq!(merged.skewness().unwrap(), 0_f32);
385        assert_eq!(merged.count(), 6);
386        assert_eq!(merged.min(), 1_f32);
387        assert_eq!(merged.max(), 6_f32);
388        assert_eq!(merged.sum(), 21_f32);
389    }
390}