radiate_utils/stats/
statistics.rs1use core::f32;
2#[cfg(feature = "serde")]
3use serde::{Deserialize, Serialize};
4use std::{fmt::Debug, hash::Hash};
5
6use crate::{Float, Primitive};
7
8#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9#[derive(PartialEq, Clone)]
10pub struct Adder<F: Float = f32> {
11 compensation: F,
12 simple_sum: F,
13 sum: F,
14}
15
16impl<F: Float> Adder<F> {
17 pub fn value(&self) -> F {
18 let result = self.sum + self.compensation;
19 if result.is_nan() {
20 self.simple_sum
21 } else {
22 result
23 }
24 }
25
26 pub fn add(&mut self, value: F) {
27 let y = value - self.compensation;
28 let t = self.sum + y;
29
30 self.compensation = (t - self.sum) - y;
31 self.sum = t;
32 self.simple_sum = self.simple_sum + value;
33 }
34}
35
36impl<F: Float> Default for Adder<F> {
37 fn default() -> Self {
38 Adder {
39 compensation: F::ZERO,
40 simple_sum: F::ZERO,
41 sum: F::ZERO,
42 }
43 }
44}
45
46#[derive(PartialEq, Clone)]
47#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
48pub struct Statistic<F: Float = f32> {
49 m1: Adder<F>,
50 m2: Adder<F>,
51 m3: Adder<F>,
52 m4: Adder<F>,
53 sum: Adder<F>,
54 count: u32,
55 last_value: F,
56 max: F,
57 min: F,
58}
59
60impl<F: Float> Statistic<F> {
61 pub fn new(initial_val: F) -> Self {
62 let mut result = Statistic::default();
63 result.add(initial_val);
64 result
65 }
66
67 pub fn last_value(&self) -> F {
68 self.last_value
69 }
70
71 pub fn count(&self) -> u32 {
72 self.count
73 }
74
75 pub fn min(&self) -> F {
76 self.min
77 }
78
79 pub fn max(&self) -> F {
80 self.max
81 }
82
83 pub fn mean(&self) -> F {
84 if self.count == 0 {
85 F::ZERO
86 } else {
87 self.m1.value()
88 }
89 }
90
91 pub fn sum(&self) -> F {
92 self.sum.value()
93 }
94
95 #[inline(always)]
96 pub fn variance(&self) -> Option<F> {
97 let mut value = F::MIN;
98 if self.count == 1 {
99 value = self.m2.value();
100 } else if self.count > 1 {
101 value = self.m2.value() / (F::from(self.count)? - F::ONE);
102 } else if self.count == 0 {
103 return None;
104 }
105
106 Some(value)
107 }
108
109 #[inline(always)]
110 pub fn std_dev(&self) -> Option<F> {
111 Some(self.variance()?.sqrt())
112 }
113
114 #[inline(always)]
115 pub fn skewness(&self) -> Option<F> {
116 let mut value = F::NAN;
117 let count = F::from(self.count)?;
118 if self.count >= 3 {
119 let temp = self.m2.value() / count - F::ONE;
120 if temp < F::EPS {
121 value = F::ZERO;
122 } else {
123 value = count * self.m3.value()
124 / ((count - F::ONE) * (count - F::TWO) * temp.sqrt() * temp)
125 }
126 }
127
128 Some(value)
129 }
130
131 #[inline(always)]
132 pub fn kurtosis(&self) -> Option<F> {
133 let mut value = F::NAN;
134 let count = F::from(self.count)?;
135
136 if self.count >= 4 {
137 let temp = self.m2.value() / count - F::ONE;
138 if temp < F::EPS {
139 value = F::ZERO;
140 } else {
141 value = count * (count + F::ONE) * self.m4.value()
142 / ((count - F::ONE) * (count - F::TWO) * (count - F::THREE) * temp * temp)
143 }
144 }
145
146 Some(value)
147 }
148
149 #[inline(always)]
150 pub fn add(&mut self, value: F) -> Option<()> {
151 self.count += 1;
152
153 let n = F::from(self.count)?;
154 let d = value - self.m1.value();
155 let dn = d / n;
156 let dn2 = dn * dn;
157 let t1 = d * dn * (n - F::ONE);
158
159 self.m1.add(dn);
160
161 self.m4.add(t1 * dn2 * (n * n - F::THREE * n + F::THREE));
162 self.m4
163 .add(F::SIX * dn2 * self.m2.value() - F::FOUR * dn * self.m3.value());
164
165 self.m3
166 .add(t1 * dn * (n - F::TWO) - F::THREE * dn * self.m2.value());
167 self.m2.add(t1);
168
169 self.last_value = value;
170 self.max = if value > self.max { value } else { self.max };
171 self.min = if value < self.min { value } else { self.min };
172 self.sum.add(value);
173
174 Some(())
175 }
176
177 pub fn clear(&mut self) {
178 self.m1 = Adder::default();
179 self.m2 = Adder::default();
180 self.m3 = Adder::default();
181 self.m4 = Adder::default();
182 self.sum = Adder::default();
183 self.count = 0;
184 self.last_value = F::ZERO;
185 self.max = F::MIN;
186 self.min = F::MAX;
187 }
188
189 pub fn merge(&mut self, other: &Statistic<F>) {
190 if other.count == 0 {
191 return;
192 }
193
194 if self.count == 0 {
195 *self = other.clone();
196 return;
197 }
198
199 if other.count == 1 {
200 self.add(other.last_value);
201 return;
202 }
203
204 if self.count == 1 {
205 let last_value = self.last_value;
206 *self = other.clone();
207 self.add(last_value);
208 return;
209 }
210
211 let n1 = F::from(self.count).unwrap_or(F::ZERO);
213 let n2 = F::from(other.count).unwrap_or(F::ZERO);
214
215 let mean1 = self.m1.value();
216 let mean2 = other.m1.value();
217
218 let m21 = self.m2.value();
219 let m22 = other.m2.value();
220 let m31 = self.m3.value();
221 let m32 = other.m3.value();
222 let m41 = self.m4.value();
223 let m42 = other.m4.value();
224
225 let n = n1 + n2;
226 let delta = mean2 - mean1;
227 let delta2 = delta * delta;
228 let delta3 = delta2 * delta;
229 let delta4 = delta3 * delta;
230 let n1n2 = n1 * n2;
231
232 let mean = (n1 * mean1 + n2 * mean2) / n;
234
235 let m2 = m21 + m22 + delta2 * n1n2 / n;
236
237 let m3 = m31
238 + m32
239 + delta3 * n1n2 * (n1 - n2) / (n * n)
240 + F::THREE * delta * (n1 * m22 - n2 * m21) / n;
241
242 let m4 = m41
243 + m42
244 + delta4 * n1n2 * (n1 * n1 - n1 * n2 + n2 * n2) / (n * n * n)
245 + F::SIX * delta2 * (n1 * n1 * m22 + n2 * n2 * m21) / (n * n)
246 + F::FOUR * delta * (n1 * m32 - n2 * m31) / n;
247
248 self.m1 = Adder::default();
251 self.m1.add(mean);
252
253 self.m2 = Adder::default();
254 self.m2.add(m2);
255
256 self.m3 = Adder::default();
257 self.m3.add(m3);
258
259 self.m4 = Adder::default();
260 self.m4.add(m4);
261
262 self.sum.add(other.sum()); self.count += other.count;
265 self.max = self.max.max(other.max);
266 self.min = self.min.min(other.min);
267
268 self.last_value = other.last_value;
270 }
271
272 pub fn merged(mut self, other: &Statistic<F>) -> Statistic<F> {
274 self.merge(other);
275 self
276 }
277}
278
279impl<T: Primitive, F: Float> FromIterator<T> for Statistic<F> {
280 fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
281 let mut statistic = Statistic::<F>::default();
282 for item in iter {
283 if let Some(value) = item.extract::<F>() {
284 statistic.add(value);
285 }
286 }
287 statistic
288 }
289}
290
291impl From<f32> for Statistic {
292 fn from(value: f32) -> Self {
293 Statistic::new(value)
294 }
295}
296
297impl From<i32> for Statistic {
298 fn from(value: i32) -> Self {
299 Statistic::new(value as f32)
300 }
301}
302
303impl From<usize> for Statistic {
304 fn from(value: usize) -> Self {
305 Statistic::new(value as f32)
306 }
307}
308
309impl<F: Float> Default for Statistic<F> {
310 fn default() -> Self {
311 Statistic {
312 m1: Adder::default(),
313 m2: Adder::default(),
314 m3: Adder::default(),
315 m4: Adder::default(),
316 sum: Adder::default(),
317 count: 0,
318 last_value: F::ZERO,
319 max: F::MIN,
320 min: F::MAX,
321 }
322 }
323}
324
325impl<F: Float> Hash for Statistic<F> {
326 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
327 self.count.hash(state);
328 self.last_value.num_hash(state);
329 self.max.num_hash(state);
330 self.min.num_hash(state);
331 self.sum.value().num_hash(state);
332 self.m1.value().num_hash(state);
333 self.m2.value().num_hash(state);
334 self.m3.value().num_hash(state);
335 self.m4.value().num_hash(state);
336 }
337}
338
339impl<F: Debug + Float> Debug for Statistic<F> {
340 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
341 f.debug_struct("Statistic")
342 .field("count", &self.count)
343 .field("last_value", &self.last_value)
344 .field("max", &self.max)
345 .field("min", &self.min)
346 .field("sum", &self.sum.value())
347 .field("mean", &self.mean())
348 .field("variance", &self.variance())
349 .field("std_dev", &self.std_dev())
350 .field("skewness", &self.skewness())
351 .field("kurtosis", &self.kurtosis())
352 .finish()
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359
360 #[test]
361 fn test_adder() {
362 let mut adder = Adder::default();
363 adder.add(1_f32);
364 adder.add(2_f32);
365 adder.add(3_f32);
366 adder.add(4_f32);
367 adder.add(5_f32);
368
369 assert_eq!(adder.value(), 15_f32);
370 }
371
372 #[test]
373 fn test_statistic() {
374 let mut statistic = Statistic::<f32>::default();
375 statistic.add(1_f32);
376 statistic.add(2_f32);
377 statistic.add(3_f32);
378 statistic.add(4_f32);
379 statistic.add(5_f32);
380
381 assert_eq!(statistic.mean(), 3_f32);
382 assert_eq!(statistic.variance().unwrap(), 2.5_f32);
383 assert_eq!(statistic.std_dev().unwrap(), 1.5811388_f32);
384 assert_eq!(statistic.skewness().unwrap(), 0_f32);
385 }
386
387 #[test]
388 fn test_statistic_merge() {
389 let mut stat1 = Statistic::default();
390 stat1.add(1_f32);
391 stat1.add(2_f32);
392 stat1.add(3_f32);
393
394 let mut stat2 = Statistic::default();
395 stat2.add(4_f32);
396 stat2.add(5_f32);
397 stat2.add(6_f32);
398
399 let merged = stat1.merged(&stat2);
400 assert_eq!(merged.mean(), 3.5_f32);
401 assert_eq!(merged.variance().unwrap(), 3.5_f32);
402 assert_eq!(merged.std_dev().unwrap(), 1.8708287_f32);
403 assert_eq!(merged.skewness().unwrap(), 0_f32);
404 assert_eq!(merged.count(), 6);
405 assert_eq!(merged.min(), 1_f32);
406 assert_eq!(merged.max(), 6_f32);
407 assert_eq!(merged.sum(), 21_f32);
408 }
409}