1use std::fmt;
2
3use num_traits::ToPrimitive;
4use serde::{Deserialize, Serialize};
5
6use crate::Commute;
7
8#[inline]
10pub fn stddev<I, T>(x: I) -> f64
11where
12 I: IntoIterator<Item = T>,
13 T: ToPrimitive,
14{
15 x.into_iter().collect::<OnlineStats>().stddev()
16}
17
18#[inline]
20pub fn variance<I, T>(x: I) -> f64
21where
22 I: IntoIterator<Item = T>,
23 T: ToPrimitive,
24{
25 x.into_iter().collect::<OnlineStats>().variance()
26}
27
28#[inline]
30pub fn mean<I, T>(x: I) -> f64
31where
32 I: IntoIterator<Item = T>,
33 T: ToPrimitive,
34{
35 x.into_iter().collect::<OnlineStats>().mean()
36}
37
38#[allow(clippy::unsafe_derive_deserialize)]
43#[derive(Clone, Copy, Serialize, Deserialize, PartialEq)]
44pub struct OnlineStats {
45 size: u64, mean: f64, q: f64, hg_sums: bool, harmonic_sum: f64, geometric_sum: f64, n_positive: u64, n_zero: u64, n_negative: u64, }
60
61impl OnlineStats {
62 #[must_use]
66 pub fn new() -> OnlineStats {
67 Default::default()
68 }
69
70 #[must_use]
72 pub fn from_slice<T: ToPrimitive>(samples: &[T]) -> OnlineStats {
73 samples
75 .iter()
76 .map(|n| unsafe { n.to_f64().unwrap_unchecked() })
77 .collect()
78 }
79
80 #[must_use]
82 pub const fn mean(&self) -> f64 {
83 if self.is_empty() { f64::NAN } else { self.mean }
84 }
85
86 #[must_use]
88 pub fn stddev(&self) -> f64 {
89 self.variance().sqrt()
90 }
91
92 #[must_use]
95 pub const fn variance(&self) -> f64 {
96 self.q / (self.size as f64)
97 }
98
99 #[must_use]
101 pub fn harmonic_mean(&self) -> f64 {
102 if self.is_empty() || self.n_zero > 0 || self.n_negative > 0 {
103 f64::NAN
104 } else {
105 (self.size as f64) / self.harmonic_sum
106 }
107 }
108
109 #[must_use]
111 pub fn geometric_mean(&self) -> f64 {
112 if self.is_empty()
113 || self.n_negative > 0
114 || self.geometric_sum.is_infinite()
115 || self.geometric_sum.is_nan()
116 {
117 f64::NAN
118 } else if self.n_zero > 0 {
119 0.0
120 } else {
121 (self.geometric_sum / (self.size as f64)).exp()
122 }
123 }
124
125 #[must_use]
148 pub const fn n_counts(&self) -> (u64, u64, u64) {
149 (self.n_negative, self.n_zero, self.n_positive)
150 }
151
152 #[inline]
157 pub fn add<T: ToPrimitive>(&mut self, sample: &T) {
158 let sample = unsafe { sample.to_f64().unwrap_unchecked() };
160
161 self.size += 1;
164 let delta = sample - self.mean;
165
166 self.mean = delta.mul_add(1.0 / (self.size as f64), self.mean);
168
169 self.q = delta.mul_add(sample - self.mean, self.q);
171
172 if sample > 0.0 {
174 if self.hg_sums {
175 self.harmonic_sum += 1.0 / sample;
177 self.geometric_sum += sample.ln();
178 }
179 self.n_positive += 1;
180 } else {
181 if sample.is_sign_negative() {
183 self.n_negative += 1;
184 } else {
185 self.n_zero += 1;
186 }
187 self.hg_sums = false;
188 }
189 }
190
191 #[inline]
194 pub fn add_f64(&mut self, sample: f64) {
195 self.size += 1;
196 let delta = sample - self.mean;
197
198 self.mean = delta.mul_add(1.0 / (self.size as f64), self.mean);
199 self.q = delta.mul_add(sample - self.mean, self.q);
200
201 if sample > 0.0 {
203 if self.hg_sums {
204 self.harmonic_sum += 1.0 / sample;
205 self.geometric_sum += sample.ln();
206 }
207 self.n_positive += 1;
208 } else {
209 if sample.is_sign_negative() {
211 self.n_negative += 1;
212 } else {
213 self.n_zero += 1;
214 }
215 self.hg_sums = false;
216 }
217 }
218
219 #[inline]
222 pub fn add_null(&mut self) {
223 self.add_f64(0.0);
224 }
225
226 #[inline]
228 #[must_use]
229 pub const fn len(&self) -> usize {
230 self.size as usize
231 }
232
233 #[inline]
235 #[must_use]
236 pub const fn is_empty(&self) -> bool {
237 self.size == 0
238 }
239}
240
241impl Commute for OnlineStats {
242 #[inline]
243 fn merge(&mut self, v: OnlineStats) {
244 if v.is_empty() {
245 return;
246 }
247
248 let (s1, s2) = (self.size as f64, v.size as f64);
250 let total = s1 + s2;
251 let meandiffsq = (self.mean - v.mean).powi(2);
252
253 self.size += v.size;
254
255 self.mean = s1.mul_add(self.mean, s2 * v.mean) / total;
259
260 self.q += v.q + f64::mul_add(meandiffsq, s1 * s2 / total, 0.0);
263
264 self.harmonic_sum += v.harmonic_sum;
265 self.geometric_sum += v.geometric_sum;
266
267 self.n_zero += v.n_zero;
268 self.n_negative += v.n_negative;
269 self.n_positive += v.n_positive;
270 }
271}
272
273impl Default for OnlineStats {
274 fn default() -> OnlineStats {
275 OnlineStats {
276 size: 0,
277 mean: 0.0,
278 q: 0.0,
279 harmonic_sum: 0.0,
280 geometric_sum: 0.0,
281 n_zero: 0,
282 n_negative: 0,
283 n_positive: 0,
284 hg_sums: true,
285 }
286 }
287}
288
289impl fmt::Debug for OnlineStats {
290 #[inline]
291 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
292 write!(f, "{:.10} +/- {:.10}", self.mean(), self.stddev())
293 }
294}
295
296impl<T: ToPrimitive> FromIterator<T> for OnlineStats {
297 #[inline]
298 fn from_iter<I: IntoIterator<Item = T>>(it: I) -> OnlineStats {
299 let mut v = OnlineStats::new();
300 v.extend(it);
301 v
302 }
303}
304
305impl<T: ToPrimitive> Extend<T> for OnlineStats {
306 #[inline]
307 fn extend<I: IntoIterator<Item = T>>(&mut self, it: I) {
308 for sample in it {
309 self.add(&sample);
310 }
311 }
312}
313
314#[cfg(test)]
315mod test {
316 use super::{OnlineStats, mean, stddev, variance};
317 use {crate::Commute, crate::merge_all};
318
319 #[test]
320 fn online() {
321 let expected = OnlineStats::from_slice(&[1usize, 2, 3, 2, 4, 6]);
323
324 let var1 = OnlineStats::from_slice(&[1usize, 2, 3]);
325 let var2 = OnlineStats::from_slice(&[2usize, 4, 6]);
326 let mut got = var1;
327 got.merge(var2);
328 assert_eq!(expected.stddev(), got.stddev());
329 assert_eq!(expected.mean(), got.mean());
330 assert_eq!(expected.variance(), got.variance());
331 }
332
333 #[test]
334 fn online_empty() {
335 let expected = OnlineStats::new();
336 assert!(expected.is_empty());
337 }
338
339 #[test]
340 fn online_many() {
341 let expected = OnlineStats::from_slice(&[1usize, 2, 3, 2, 4, 6, 3, 6, 9]);
343
344 let vars = vec![
345 OnlineStats::from_slice(&[1usize, 2, 3]),
346 OnlineStats::from_slice(&[2usize, 4, 6]),
347 OnlineStats::from_slice(&[3usize, 6, 9]),
348 ];
349 assert_eq!(
350 expected.stddev(),
351 merge_all(vars.clone().into_iter()).unwrap().stddev()
352 );
353 assert_eq!(
354 expected.mean(),
355 merge_all(vars.clone().into_iter()).unwrap().mean()
356 );
357 assert_eq!(
358 expected.variance(),
359 merge_all(vars.into_iter()).unwrap().variance()
360 );
361 }
362
363 #[test]
364 fn test_means() {
365 let mut stats = OnlineStats::new();
366 stats.extend(vec![2.0f64, 4.0, 8.0]);
367
368 assert!((stats.mean() - 4.666666666667).abs() < 1e-10);
370
371 assert_eq!("3.42857143", format!("{:.8}", stats.harmonic_mean()));
373
374 assert!((stats.geometric_mean() - 4.0).abs() < 1e-10);
376 }
377
378 #[test]
379 fn test_means_with_negative() {
380 let mut stats = OnlineStats::new();
381 stats.extend(vec![-2.0f64, 2.0]);
382
383 assert!(stats.mean().abs() < 1e-10);
385
386 assert!(stats.geometric_mean().is_nan());
388
389 assert!(stats.harmonic_mean().is_nan());
391 }
392
393 #[test]
394 fn test_means_with_zero() {
395 let mut stats = OnlineStats::new();
396 stats.extend(vec![0.0f64, 4.0, 8.0]);
397
398 assert!((stats.mean() - 4.0).abs() < 1e-10);
400
401 assert!(stats.geometric_mean().abs() < 1e-10);
403
404 assert!(stats.harmonic_mean().is_nan());
406 }
407
408 #[test]
409 fn test_means_with_zero_and_negative_values() {
410 let mut stats = OnlineStats::new();
411 stats.extend(vec![-10i32, -5, 0, 5, 10]);
412
413 assert!(stats.mean().abs() < 1e-10);
415
416 assert!(stats.geometric_mean().is_nan());
418
419 assert!(stats.harmonic_mean().is_nan());
421 }
422
423 #[test]
424 fn test_means_single_value() {
425 let mut stats = OnlineStats::new();
426 stats.extend(vec![5.0f64]);
427
428 assert!((stats.mean() - 5.0).abs() < 1e-10);
430 assert!((stats.geometric_mean() - 5.0).abs() < 1e-10);
431 assert!((stats.harmonic_mean() - 5.0).abs() < 1e-10);
432 }
433
434 #[test]
435 fn test_means_empty() {
436 let stats = OnlineStats::new();
437
438 assert!(stats.mean().is_nan());
440 assert!(stats.geometric_mean().is_nan());
441 assert!(stats.harmonic_mean().is_nan());
442 }
443
444 #[test]
447 fn test_mean_wrapper_basic() {
448 let result = mean(vec![1.0f64, 2.0, 3.0, 4.0, 5.0]);
450 assert!((result - 3.0).abs() < 1e-10);
451
452 let result = mean(vec![1i32, 2, 3, 4, 5]);
454 assert!((result - 3.0).abs() < 1e-10);
455
456 let result = mean(vec![10u32, 20, 30]);
458 assert!((result - 20.0).abs() < 1e-10);
459 }
460
461 #[test]
462 fn test_mean_wrapper_empty() {
463 let result = mean(Vec::<f64>::new());
464 assert!(result.is_nan());
465 }
466
467 #[test]
468 fn test_mean_wrapper_single_element() {
469 assert!((mean(vec![42.0f64]) - 42.0).abs() < 1e-10);
470 assert!((mean(vec![100i32]) - 100.0).abs() < 1e-10);
471 assert!((mean(vec![0u8]) - 0.0).abs() < 1e-10);
472 }
473
474 #[test]
475 fn test_mean_wrapper_negative_values() {
476 let result = mean(vec![-5.0f64, 5.0]);
477 assert!(result.abs() < 1e-10); let result = mean(vec![-10i32, -20, -30]);
480 assert!((result - (-20.0)).abs() < 1e-10);
481 }
482
483 #[test]
484 fn test_mean_wrapper_various_numeric_types() {
485 assert!((mean(vec![1u8, 2, 3]) - 2.0).abs() < 1e-10);
487 assert!((mean(vec![1u16, 2, 3]) - 2.0).abs() < 1e-10);
488 assert!((mean(vec![1u64, 2, 3]) - 2.0).abs() < 1e-10);
489 assert!((mean(vec![1i8, 2, 3]) - 2.0).abs() < 1e-10);
490 assert!((mean(vec![1i16, 2, 3]) - 2.0).abs() < 1e-10);
491 assert!((mean(vec![1i64, 2, 3]) - 2.0).abs() < 1e-10);
492 assert!((mean(vec![1.0f32, 2.0, 3.0]) - 2.0).abs() < 1e-6);
493 assert!((mean(vec![1usize, 2, 3]) - 2.0).abs() < 1e-10);
494 assert!((mean(vec![1isize, 2, 3]) - 2.0).abs() < 1e-10);
495 }
496
497 #[test]
498 fn test_variance_wrapper_basic() {
499 let result = variance(vec![1.0f64, 2.0, 3.0, 4.0, 5.0]);
501 assert!((result - 2.0).abs() < 1e-10);
502
503 let result = variance(vec![1i32, 2, 3, 4, 5]);
505 assert!((result - 2.0).abs() < 1e-10);
506 }
507
508 #[test]
509 fn test_variance_wrapper_empty() {
510 let result = variance(Vec::<f64>::new());
511 assert!(result.is_nan());
512 }
513
514 #[test]
515 fn test_variance_wrapper_single_element() {
516 assert!(variance(vec![42.0f64]).abs() < 1e-10);
518 assert!(variance(vec![100i32]).abs() < 1e-10);
519 }
520
521 #[test]
522 fn test_variance_wrapper_identical_values() {
523 let result = variance(vec![5.0f64, 5.0, 5.0, 5.0]);
525 assert!(result.abs() < 1e-10);
526 }
527
528 #[test]
529 fn test_variance_wrapper_various_numeric_types() {
530 let expected = 2.0 / 3.0;
532 assert!((variance(vec![1u8, 2, 3]) - expected).abs() < 1e-10);
533 assert!((variance(vec![1u16, 2, 3]) - expected).abs() < 1e-10);
534 assert!((variance(vec![1i32, 2, 3]) - expected).abs() < 1e-10);
535 assert!((variance(vec![1i64, 2, 3]) - expected).abs() < 1e-10);
536 assert!((variance(vec![1usize, 2, 3]) - expected).abs() < 1e-10);
537 }
538
539 #[test]
540 fn test_stddev_wrapper_basic() {
541 let result = stddev(vec![1.0f64, 2.0, 3.0, 4.0, 5.0]);
543 assert!((result - 2.0f64.sqrt()).abs() < 1e-10);
544
545 let result = stddev(vec![1i32, 2, 3, 4, 5]);
547 assert!((result - 2.0f64.sqrt()).abs() < 1e-10);
548 }
549
550 #[test]
551 fn test_stddev_wrapper_empty() {
552 let result = stddev(Vec::<f64>::new());
553 assert!(result.is_nan());
554 }
555
556 #[test]
557 fn test_stddev_wrapper_single_element() {
558 assert!(stddev(vec![42.0f64]).abs() < 1e-10);
560 assert!(stddev(vec![100i32]).abs() < 1e-10);
561 }
562
563 #[test]
564 fn test_stddev_wrapper_identical_values() {
565 let result = stddev(vec![5.0f64, 5.0, 5.0, 5.0]);
567 assert!(result.abs() < 1e-10);
568 }
569
570 #[test]
571 fn test_stddev_wrapper_various_numeric_types() {
572 let expected = (2.0f64 / 3.0).sqrt();
574 assert!((stddev(vec![1u8, 2, 3]) - expected).abs() < 1e-10);
575 assert!((stddev(vec![1u16, 2, 3]) - expected).abs() < 1e-10);
576 assert!((stddev(vec![1i32, 2, 3]) - expected).abs() < 1e-10);
577 assert!((stddev(vec![1i64, 2, 3]) - expected).abs() < 1e-10);
578 assert!((stddev(vec![1usize, 2, 3]) - expected).abs() < 1e-10);
579 }
580
581 #[test]
582 fn test_wrapper_functions_consistency() {
583 let data = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
585 let stats = OnlineStats::from_slice(&data);
586
587 assert!((mean(data.clone()) - stats.mean()).abs() < 1e-10);
588 assert!((variance(data.clone()) - stats.variance()).abs() < 1e-10);
589 assert!((stddev(data) - stats.stddev()).abs() < 1e-10);
590 }
591
592 #[test]
593 fn test_wrapper_functions_with_iterators() {
594 let arr = [1, 2, 3, 4, 5];
596
597 assert!((mean(arr) - 3.0).abs() < 1e-10);
599
600 assert!((mean(1..=5) - 3.0).abs() < 1e-10);
602
603 let result = mean((1..=5).map(|x| x * 2));
605 assert!((result - 6.0).abs() < 1e-10);
606 }
607
608 #[test]
611 fn test_n_counts_basic() {
612 let mut stats = OnlineStats::new();
613 stats.extend(vec![-5, -3, 0, 0, 2, 4, 6]);
614
615 let (neg, zero, pos) = stats.n_counts();
616 assert_eq!(neg, 2, "Should have 2 negative values");
617 assert_eq!(zero, 2, "Should have 2 zero values");
618 assert_eq!(pos, 3, "Should have 3 positive values");
619 }
620
621 #[test]
622 fn test_n_counts_all_positive() {
623 let mut stats = OnlineStats::new();
624 stats.extend(vec![1.0, 2.0, 3.0, 4.0]);
625
626 let (neg, zero, pos) = stats.n_counts();
627 assert_eq!(neg, 0);
628 assert_eq!(zero, 0);
629 assert_eq!(pos, 4);
630 }
631
632 #[test]
633 fn test_n_counts_all_negative() {
634 let mut stats = OnlineStats::new();
635 stats.extend(vec![-1.0, -2.0, -3.0]);
636
637 let (neg, zero, pos) = stats.n_counts();
638 assert_eq!(neg, 3);
639 assert_eq!(zero, 0);
640 assert_eq!(pos, 0);
641 }
642
643 #[test]
644 fn test_n_counts_all_zeros() {
645 let mut stats = OnlineStats::new();
646 stats.extend(vec![0.0, 0.0, 0.0]);
647
648 let (neg, zero, pos) = stats.n_counts();
649 assert_eq!(neg, 0);
650 assert_eq!(zero, 3);
651 assert_eq!(pos, 0);
652 }
653
654 #[test]
655 fn test_n_counts_with_merge() {
656 let mut stats1 = OnlineStats::new();
657 stats1.extend(vec![-2, 0, 3]);
658
659 let mut stats2 = OnlineStats::new();
660 stats2.extend(vec![-1, 5, 7]);
661
662 stats1.merge(stats2);
663
664 let (neg, zero, pos) = stats1.n_counts();
665 assert_eq!(neg, 2, "Should have 2 negative values after merge");
666 assert_eq!(zero, 1, "Should have 1 zero value after merge");
667 assert_eq!(pos, 3, "Should have 3 positive values after merge");
668 }
669
670 #[test]
671 fn test_n_counts_empty() {
672 let stats = OnlineStats::new();
673
674 let (neg, zero, pos) = stats.n_counts();
675 assert_eq!(neg, 0);
676 assert_eq!(zero, 0);
677 assert_eq!(pos, 0);
678 }
679
680 #[test]
681 fn test_n_counts_negative_zero() {
682 let mut stats = OnlineStats::new();
683 stats.extend(vec![-0.0f64, 0.0]);
686
687 let (neg, zero, pos) = stats.n_counts();
688 assert_eq!(neg, 1, "-0.0 has negative sign bit");
689 assert_eq!(zero, 1, "+0.0 is zero");
690 assert_eq!(pos, 0);
691 }
692
693 #[test]
694 fn test_n_counts_floats_boundary() {
695 let mut stats = OnlineStats::new();
696 stats.extend(vec![-0.0001f64, 0.0, 0.0001]);
698
699 let (neg, zero, pos) = stats.n_counts();
700 assert_eq!(neg, 1);
701 assert_eq!(zero, 1);
702 assert_eq!(pos, 1);
703 }
704}