radiate_utils/stats/
statistics.rs1use core::f32;
2#[cfg(feature = "serde")]
3use serde::{Deserialize, Serialize};
4use std::hash::Hash;
5
6use crate::{Float, Primitive};
7
8#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9#[derive(PartialEq, Clone)]
10pub struct Adder<F: Float = f32> {
11 compensation: F,
12 simple_sum: F,
13 sum: F,
14}
15
16impl<F: Float> Adder<F> {
17 pub fn value(&self) -> F {
18 let result = self.sum + self.compensation;
19 if result.is_nan() {
20 self.simple_sum
21 } else {
22 result
23 }
24 }
25
26 pub fn add(&mut self, value: F) {
27 let y = value - self.compensation;
28 let t = self.sum + y;
29
30 self.compensation = (t - self.sum) - y;
31 self.sum = t;
32 self.simple_sum = self.simple_sum + value;
33 }
34}
35
36impl<F: Float> Default for Adder<F> {
37 fn default() -> Self {
38 Adder {
39 compensation: F::ZERO,
40 simple_sum: F::ZERO,
41 sum: F::ZERO,
42 }
43 }
44}
45
46#[derive(PartialEq, Clone)]
47#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
48pub struct Statistic<F: Float = f32> {
49 m1: Adder<F>,
50 m2: Adder<F>,
51 m3: Adder<F>,
52 m4: Adder<F>,
53 sum: Adder<F>,
54 count: i32,
55 last_value: F,
56 max: F,
57 min: F,
58}
59
60impl<F: Float> Statistic<F> {
61 pub fn new(initial_val: F) -> Self {
62 let mut result = Statistic::default();
63 result.add(initial_val);
64 result
65 }
66
67 pub fn last_value(&self) -> F {
68 self.last_value
69 }
70
71 pub fn count(&self) -> i32 {
72 self.count
73 }
74
75 pub fn min(&self) -> F {
76 self.min
77 }
78
79 pub fn max(&self) -> F {
80 self.max
81 }
82
83 pub fn mean(&self) -> F {
84 if self.count == 0 {
85 F::ZERO
86 } else {
87 self.m1.value()
88 }
89 }
90
91 pub fn sum(&self) -> F {
92 self.sum.value()
93 }
94
95 #[inline(always)]
96 pub fn variance(&self) -> Option<F> {
97 let mut value = F::MIN;
98 if self.count == 1 {
99 value = self.m2.value();
100 } else if self.count > 1 {
101 value = self.m2.value() / (F::from(self.count)? - F::ONE);
102 }
103
104 Some(value)
105 }
106
107 #[inline(always)]
108 pub fn std_dev(&self) -> Option<F> {
109 Some(self.variance()?.sqrt())
110 }
111
112 #[inline(always)]
113 pub fn skewness(&self) -> Option<F> {
114 let mut value = F::NAN;
115 let count = F::from(self.count)?;
116 if self.count >= 3 {
117 let temp = self.m2.value() / count - F::ONE;
118 if temp < F::EPS {
119 value = F::ZERO;
120 } else {
121 value = count * self.m3.value()
122 / ((count - F::ONE) * (count - F::TWO) * temp.sqrt() * temp)
123 }
124 }
125
126 Some(value)
127 }
128
129 #[inline(always)]
130 pub fn kurtosis(&self) -> Option<F> {
131 let mut value = F::NAN;
132 let count = F::from(self.count)?;
133
134 if self.count >= 4 {
135 let temp = self.m2.value() / count - F::ONE;
136 if temp < F::EPS {
137 value = F::ZERO;
138 } else {
139 value = count * (count + F::ONE) * self.m4.value()
140 / ((count - F::ONE) * (count - F::TWO) * (count - F::THREE) * temp * temp)
141 }
142 }
143
144 Some(value)
145 }
146
147 #[inline(always)]
148 pub fn add(&mut self, value: F) -> Option<()> {
149 self.count += 1;
150
151 let n = F::from(self.count)?;
152 let d = value - self.m1.value();
153 let dn = d / n;
154 let dn2 = dn * dn;
155 let t1 = d * dn * (n - F::ONE);
156
157 self.m1.add(dn);
158
159 self.m4.add(t1 * dn2 * (n * n - F::THREE * n + F::THREE));
160 self.m4
161 .add(F::SIX * dn2 * self.m2.value() - F::FOUR * dn * self.m3.value());
162
163 self.m3
164 .add(t1 * dn * (n - F::TWO) - F::THREE * dn * self.m2.value());
165 self.m2.add(t1);
166
167 self.last_value = value;
168 self.max = if value > self.max { value } else { self.max };
169 self.min = if value < self.min { value } else { self.min };
170 self.sum.add(value);
171
172 Some(())
173 }
174
175 pub fn clear(&mut self) {
176 self.m1 = Adder::default();
177 self.m2 = Adder::default();
178 self.m3 = Adder::default();
179 self.m4 = Adder::default();
180 self.sum = Adder::default();
181 self.count = 0;
182 self.last_value = F::ZERO;
183 self.max = F::MIN;
184 self.min = F::MAX;
185 }
186
187 pub fn merge(&mut self, other: &Statistic<F>) {
188 if other.count == 0 {
189 return;
190 }
191
192 if self.count == 0 {
193 *self = other.clone();
194 return;
195 }
196
197 if other.count == 1 {
198 self.add(other.last_value);
199 return;
200 }
201
202 if self.count == 1 {
203 let last_value = self.last_value;
204 *self = other.clone();
205 self.add(last_value);
206 return;
207 }
208
209 let n1 = F::from(self.count).unwrap_or(F::ZERO);
211 let n2 = F::from(other.count).unwrap_or(F::ZERO);
212
213 let mean1 = self.m1.value();
214 let mean2 = other.m1.value();
215
216 let m21 = self.m2.value();
217 let m22 = other.m2.value();
218 let m31 = self.m3.value();
219 let m32 = other.m3.value();
220 let m41 = self.m4.value();
221 let m42 = other.m4.value();
222
223 let n = n1 + n2;
224 let delta = mean2 - mean1;
225 let delta2 = delta * delta;
226 let delta3 = delta2 * delta;
227 let delta4 = delta3 * delta;
228 let n1n2 = n1 * n2;
229
230 let mean = (n1 * mean1 + n2 * mean2) / n;
232
233 let m2 = m21 + m22 + delta2 * n1n2 / n;
234
235 let m3 = m31
236 + m32
237 + delta3 * n1n2 * (n1 - n2) / (n * n)
238 + F::THREE * delta * (n1 * m22 - n2 * m21) / n;
239
240 let m4 = m41
241 + m42
242 + delta4 * n1n2 * (n1 * n1 - n1 * n2 + n2 * n2) / (n * n * n)
243 + F::SIX * delta2 * (n1 * n1 * m22 + n2 * n2 * m21) / (n * n)
244 + F::FOUR * delta * (n1 * m32 - n2 * m31) / n;
245
246 self.m1 = Adder::default();
249 self.m1.add(mean);
250
251 self.m2 = Adder::default();
252 self.m2.add(m2);
253
254 self.m3 = Adder::default();
255 self.m3.add(m3);
256
257 self.m4 = Adder::default();
258 self.m4.add(m4);
259
260 self.sum.add(other.sum()); self.count += other.count;
263 self.max = self.max.max(other.max);
264 self.min = self.min.min(other.min);
265
266 self.last_value = other.last_value;
268 }
269
270 pub fn merged(mut self, other: &Statistic<F>) -> Statistic<F> {
272 self.merge(other);
273 self
274 }
275}
276
277impl<T: Primitive, F: Float> FromIterator<T> for Statistic<F> {
278 fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
279 let mut statistic = Statistic::<F>::default();
280 for item in iter {
281 if let Some(value) = item.extract::<F>() {
282 statistic.add(value);
283 }
284 }
285 statistic
286 }
287}
288
289impl From<f32> for Statistic {
290 fn from(value: f32) -> Self {
291 Statistic::new(value)
292 }
293}
294
295impl From<i32> for Statistic {
296 fn from(value: i32) -> Self {
297 Statistic::new(value as f32)
298 }
299}
300
301impl From<usize> for Statistic {
302 fn from(value: usize) -> Self {
303 Statistic::new(value as f32)
304 }
305}
306
307impl<F: Float> Default for Statistic<F> {
308 fn default() -> Self {
309 Statistic {
310 m1: Adder::default(),
311 m2: Adder::default(),
312 m3: Adder::default(),
313 m4: Adder::default(),
314 sum: Adder::default(),
315 count: 0,
316 last_value: F::ZERO,
317 max: F::MIN,
318 min: F::MAX,
319 }
320 }
321}
322
323impl<F: Float> Hash for Statistic<F> {
324 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
325 self.count.hash(state);
326 self.last_value.num_hash(state);
327 self.max.num_hash(state);
328 self.min.num_hash(state);
329 self.sum.value().num_hash(state);
330 self.m1.value().num_hash(state);
331 self.m2.value().num_hash(state);
332 self.m3.value().num_hash(state);
333 self.m4.value().num_hash(state);
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340
341 #[test]
342 fn test_adder() {
343 let mut adder = Adder::default();
344 adder.add(1_f32);
345 adder.add(2_f32);
346 adder.add(3_f32);
347 adder.add(4_f32);
348 adder.add(5_f32);
349
350 assert_eq!(adder.value(), 15_f32);
351 }
352
353 #[test]
354 fn test_statistic() {
355 let mut statistic = Statistic::<f32>::default();
356 statistic.add(1_f32);
357 statistic.add(2_f32);
358 statistic.add(3_f32);
359 statistic.add(4_f32);
360 statistic.add(5_f32);
361
362 assert_eq!(statistic.mean(), 3_f32);
363 assert_eq!(statistic.variance().unwrap(), 2.5_f32);
364 assert_eq!(statistic.std_dev().unwrap(), 1.5811388_f32);
365 assert_eq!(statistic.skewness().unwrap(), 0_f32);
366 }
367
368 #[test]
369 fn test_statistic_merge() {
370 let mut stat1 = Statistic::default();
371 stat1.add(1_f32);
372 stat1.add(2_f32);
373 stat1.add(3_f32);
374
375 let mut stat2 = Statistic::default();
376 stat2.add(4_f32);
377 stat2.add(5_f32);
378 stat2.add(6_f32);
379
380 let merged = stat1.merged(&stat2);
381 assert_eq!(merged.mean(), 3.5_f32);
382 assert_eq!(merged.variance().unwrap(), 3.5_f32);
383 assert_eq!(merged.std_dev().unwrap(), 1.8708287_f32);
384 assert_eq!(merged.skewness().unwrap(), 0_f32);
385 assert_eq!(merged.count(), 6);
386 assert_eq!(merged.min(), 1_f32);
387 assert_eq!(merged.max(), 6_f32);
388 assert_eq!(merged.sum(), 21_f32);
389 }
390}