1use std::fmt;
2
3use num_traits::ToPrimitive;
4use serde::{Deserialize, Serialize};
5
6use crate::Commute;
7
8#[inline]
10pub fn stddev<I, T>(x: I) -> f64
11where
12 I: IntoIterator<Item = T>,
13 T: ToPrimitive,
14{
15 x.into_iter().collect::<OnlineStats>().stddev()
16}
17
18#[inline]
20pub fn variance<I, T>(x: I) -> f64
21where
22 I: IntoIterator<Item = T>,
23 T: ToPrimitive,
24{
25 x.into_iter().collect::<OnlineStats>().variance()
26}
27
28#[inline]
30pub fn mean<I, T>(x: I) -> f64
31where
32 I: IntoIterator<Item = T>,
33 T: ToPrimitive,
34{
35 x.into_iter().collect::<OnlineStats>().mean()
36}
37
38#[allow(clippy::unsafe_derive_deserialize)]
43#[derive(Clone, Copy, Serialize, Deserialize, PartialEq)]
44pub struct OnlineStats {
45 size: u64, mean: f64, q: f64, hg_sums: bool, harmonic_sum: f64, geometric_sum: f64, n_positive: u64, n_zero: u64, n_negative: u64, }
60
61impl OnlineStats {
62 #[must_use]
66 pub fn new() -> OnlineStats {
67 Default::default()
68 }
69
70 #[must_use]
72 pub fn from_slice<T: ToPrimitive>(samples: &[T]) -> OnlineStats {
73 samples
75 .iter()
76 .map(|n| unsafe { n.to_f64().unwrap_unchecked() })
77 .collect()
78 }
79
80 #[must_use]
82 pub const fn mean(&self) -> f64 {
83 if self.is_empty() { f64::NAN } else { self.mean }
84 }
85
86 #[must_use]
88 pub fn stddev(&self) -> f64 {
89 self.variance().sqrt()
90 }
91
92 #[must_use]
95 pub const fn variance(&self) -> f64 {
96 if self.is_empty() { f64::NAN } else { self.q / (self.size as f64) }
97 }
98
99 #[must_use]
101 pub fn harmonic_mean(&self) -> f64 {
102 if self.is_empty() || self.n_zero > 0 || self.n_negative > 0 {
103 f64::NAN
104 } else {
105 (self.size as f64) / self.harmonic_sum
106 }
107 }
108
109 #[must_use]
111 pub fn geometric_mean(&self) -> f64 {
112 if self.is_empty()
113 || self.n_negative > 0
114 || self.geometric_sum.is_nan()
115 || self.geometric_sum == f64::INFINITY
116 {
117 f64::NAN
118 } else if self.n_zero > 0 || self.geometric_sum == f64::NEG_INFINITY {
119 0.0
122 } else {
123 (self.geometric_sum / (self.size as f64)).exp()
124 }
125 }
126
127 #[must_use]
150 pub const fn n_counts(&self) -> (u64, u64, u64) {
151 (self.n_negative, self.n_zero, self.n_positive)
152 }
153
154 #[inline]
161 pub fn add<T: ToPrimitive>(&mut self, sample: &T) {
162 let sample = unsafe { sample.to_f64().unwrap_unchecked() };
164
165 if sample.is_nan() {
166 return;
167 }
168
169 self.size += 1;
172 let delta = sample - self.mean;
173
174 self.mean = delta.mul_add(1.0 / (self.size as f64), self.mean);
176
177 self.q = delta.mul_add(sample - self.mean, self.q);
179
180 if sample > 0.0 {
182 if self.hg_sums {
183 self.harmonic_sum += 1.0 / sample;
185 self.geometric_sum += sample.ln();
186 }
187 self.n_positive += 1;
188 } else {
189 if sample.is_sign_negative() {
191 self.n_negative += 1;
192 } else {
193 self.n_zero += 1;
194 }
195 self.hg_sums = false;
196 }
197 }
198
199 #[inline]
204 pub fn add_f64(&mut self, sample: f64) {
205 if sample.is_nan() {
206 return;
207 }
208
209 self.size += 1;
210 let delta = sample - self.mean;
211
212 self.mean = delta.mul_add(1.0 / (self.size as f64), self.mean);
213 self.q = delta.mul_add(sample - self.mean, self.q);
214
215 if sample > 0.0 {
217 if self.hg_sums {
218 self.harmonic_sum += 1.0 / sample;
219 self.geometric_sum += sample.ln();
220 }
221 self.n_positive += 1;
222 } else {
223 if sample.is_sign_negative() {
225 self.n_negative += 1;
226 } else {
227 self.n_zero += 1;
228 }
229 self.hg_sums = false;
230 }
231 }
232
233 #[inline]
237 pub fn add_null(&mut self) {
238 self.add_f64(0.0);
239 }
240
241 #[inline]
243 #[must_use]
244 pub const fn len(&self) -> usize {
245 self.size as usize
246 }
247
248 #[inline]
250 #[must_use]
251 pub const fn is_empty(&self) -> bool {
252 self.size == 0
253 }
254}
255
256impl Commute for OnlineStats {
257 #[inline]
258 fn merge(&mut self, v: OnlineStats) {
259 if v.is_empty() {
260 return;
261 }
262
263 let (s1, s2) = (self.size as f64, v.size as f64);
265 let total = s1 + s2;
266 let meandiffsq = (self.mean - v.mean).powi(2);
267
268 self.size += v.size;
269
270 self.mean = (v.mean - self.mean).mul_add(s2 / total, self.mean);
273
274 self.q += meandiffsq.mul_add(s1 * s2 / total, v.q);
277
278 self.hg_sums = self.hg_sums && v.hg_sums;
279 self.harmonic_sum += v.harmonic_sum;
280 self.geometric_sum += v.geometric_sum;
281
282 self.n_zero += v.n_zero;
283 self.n_negative += v.n_negative;
284 self.n_positive += v.n_positive;
285 }
286}
287
288impl Default for OnlineStats {
289 fn default() -> OnlineStats {
290 OnlineStats {
291 size: 0,
292 mean: 0.0,
293 q: 0.0,
294 harmonic_sum: 0.0,
295 geometric_sum: 0.0,
296 n_zero: 0,
297 n_negative: 0,
298 n_positive: 0,
299 hg_sums: true,
300 }
301 }
302}
303
304impl fmt::Debug for OnlineStats {
305 #[inline]
306 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
307 write!(f, "{:.10} +/- {:.10}", self.mean(), self.stddev())
308 }
309}
310
311impl<T: ToPrimitive> FromIterator<T> for OnlineStats {
312 #[inline]
313 fn from_iter<I: IntoIterator<Item = T>>(it: I) -> OnlineStats {
314 let mut v = OnlineStats::new();
315 v.extend(it);
316 v
317 }
318}
319
320impl<T: ToPrimitive> Extend<T> for OnlineStats {
321 #[inline]
322 fn extend<I: IntoIterator<Item = T>>(&mut self, it: I) {
323 for sample in it {
324 self.add(&sample);
325 }
326 }
327}
328
329#[cfg(test)]
330mod test {
331 use super::{OnlineStats, mean, stddev, variance};
332 use {crate::Commute, crate::merge_all};
333
334 #[test]
335 fn online() {
336 let expected = OnlineStats::from_slice(&[1usize, 2, 3, 2, 4, 6]);
338
339 let var1 = OnlineStats::from_slice(&[1usize, 2, 3]);
340 let var2 = OnlineStats::from_slice(&[2usize, 4, 6]);
341 let mut got = var1;
342 got.merge(var2);
343 assert_eq!(expected.stddev(), got.stddev());
344 assert_eq!(expected.mean(), got.mean());
345 assert_eq!(expected.variance(), got.variance());
346 }
347
348 #[test]
349 fn online_empty() {
350 let expected = OnlineStats::new();
351 assert!(expected.is_empty());
352 }
353
354 #[test]
355 fn online_many() {
356 let expected = OnlineStats::from_slice(&[1usize, 2, 3, 2, 4, 6, 3, 6, 9]);
358
359 let vars = vec![
360 OnlineStats::from_slice(&[1usize, 2, 3]),
361 OnlineStats::from_slice(&[2usize, 4, 6]),
362 OnlineStats::from_slice(&[3usize, 6, 9]),
363 ];
364 assert_eq!(
365 expected.stddev(),
366 merge_all(vars.clone().into_iter()).unwrap().stddev()
367 );
368 assert_eq!(
369 expected.mean(),
370 merge_all(vars.clone().into_iter()).unwrap().mean()
371 );
372 assert_eq!(
373 expected.variance(),
374 merge_all(vars.into_iter()).unwrap().variance()
375 );
376 }
377
378 #[test]
379 fn test_means() {
380 let mut stats = OnlineStats::new();
381 stats.extend(vec![2.0f64, 4.0, 8.0]);
382
383 assert!((stats.mean() - 4.666666666667).abs() < 1e-10);
385
386 assert_eq!("3.42857143", format!("{:.8}", stats.harmonic_mean()));
388
389 assert!((stats.geometric_mean() - 4.0).abs() < 1e-10);
391 }
392
393 #[test]
394 fn test_means_with_negative() {
395 let mut stats = OnlineStats::new();
396 stats.extend(vec![-2.0f64, 2.0]);
397
398 assert!(stats.mean().abs() < 1e-10);
400
401 assert!(stats.geometric_mean().is_nan());
403
404 assert!(stats.harmonic_mean().is_nan());
406 }
407
408 #[test]
409 fn test_means_with_zero() {
410 let mut stats = OnlineStats::new();
411 stats.extend(vec![0.0f64, 4.0, 8.0]);
412
413 assert!((stats.mean() - 4.0).abs() < 1e-10);
415
416 assert!(stats.geometric_mean().abs() < 1e-10);
418
419 assert!(stats.harmonic_mean().is_nan());
421 }
422
423 #[test]
424 fn test_means_with_zero_and_negative_values() {
425 let mut stats = OnlineStats::new();
426 stats.extend(vec![-10i32, -5, 0, 5, 10]);
427
428 assert!(stats.mean().abs() < 1e-10);
430
431 assert!(stats.geometric_mean().is_nan());
433
434 assert!(stats.harmonic_mean().is_nan());
436 }
437
438 #[test]
439 fn test_means_single_value() {
440 let mut stats = OnlineStats::new();
441 stats.extend(vec![5.0f64]);
442
443 assert!((stats.mean() - 5.0).abs() < 1e-10);
445 assert!((stats.geometric_mean() - 5.0).abs() < 1e-10);
446 assert!((stats.harmonic_mean() - 5.0).abs() < 1e-10);
447 }
448
449 #[test]
450 fn test_means_empty() {
451 let stats = OnlineStats::new();
452
453 assert!(stats.mean().is_nan());
455 assert!(stats.geometric_mean().is_nan());
456 assert!(stats.harmonic_mean().is_nan());
457 }
458
459 #[test]
462 fn test_mean_wrapper_basic() {
463 let result = mean(vec![1.0f64, 2.0, 3.0, 4.0, 5.0]);
465 assert!((result - 3.0).abs() < 1e-10);
466
467 let result = mean(vec![1i32, 2, 3, 4, 5]);
469 assert!((result - 3.0).abs() < 1e-10);
470
471 let result = mean(vec![10u32, 20, 30]);
473 assert!((result - 20.0).abs() < 1e-10);
474 }
475
476 #[test]
477 fn test_mean_wrapper_empty() {
478 let result = mean(Vec::<f64>::new());
479 assert!(result.is_nan());
480 }
481
482 #[test]
483 fn test_mean_wrapper_single_element() {
484 assert!((mean(vec![42.0f64]) - 42.0).abs() < 1e-10);
485 assert!((mean(vec![100i32]) - 100.0).abs() < 1e-10);
486 assert!((mean(vec![0u8]) - 0.0).abs() < 1e-10);
487 }
488
489 #[test]
490 fn test_mean_wrapper_negative_values() {
491 let result = mean(vec![-5.0f64, 5.0]);
492 assert!(result.abs() < 1e-10); let result = mean(vec![-10i32, -20, -30]);
495 assert!((result - (-20.0)).abs() < 1e-10);
496 }
497
498 #[test]
499 fn test_mean_wrapper_various_numeric_types() {
500 assert!((mean(vec![1u8, 2, 3]) - 2.0).abs() < 1e-10);
502 assert!((mean(vec![1u16, 2, 3]) - 2.0).abs() < 1e-10);
503 assert!((mean(vec![1u64, 2, 3]) - 2.0).abs() < 1e-10);
504 assert!((mean(vec![1i8, 2, 3]) - 2.0).abs() < 1e-10);
505 assert!((mean(vec![1i16, 2, 3]) - 2.0).abs() < 1e-10);
506 assert!((mean(vec![1i64, 2, 3]) - 2.0).abs() < 1e-10);
507 assert!((mean(vec![1.0f32, 2.0, 3.0]) - 2.0).abs() < 1e-6);
508 assert!((mean(vec![1usize, 2, 3]) - 2.0).abs() < 1e-10);
509 assert!((mean(vec![1isize, 2, 3]) - 2.0).abs() < 1e-10);
510 }
511
512 #[test]
513 fn test_variance_wrapper_basic() {
514 let result = variance(vec![1.0f64, 2.0, 3.0, 4.0, 5.0]);
516 assert!((result - 2.0).abs() < 1e-10);
517
518 let result = variance(vec![1i32, 2, 3, 4, 5]);
520 assert!((result - 2.0).abs() < 1e-10);
521 }
522
523 #[test]
524 fn test_variance_wrapper_empty() {
525 let result = variance(Vec::<f64>::new());
526 assert!(result.is_nan());
527 }
528
529 #[test]
530 fn test_variance_wrapper_single_element() {
531 assert!(variance(vec![42.0f64]).abs() < 1e-10);
533 assert!(variance(vec![100i32]).abs() < 1e-10);
534 }
535
536 #[test]
537 fn test_variance_wrapper_identical_values() {
538 let result = variance(vec![5.0f64, 5.0, 5.0, 5.0]);
540 assert!(result.abs() < 1e-10);
541 }
542
543 #[test]
544 fn test_variance_wrapper_various_numeric_types() {
545 let expected = 2.0 / 3.0;
547 assert!((variance(vec![1u8, 2, 3]) - expected).abs() < 1e-10);
548 assert!((variance(vec![1u16, 2, 3]) - expected).abs() < 1e-10);
549 assert!((variance(vec![1i32, 2, 3]) - expected).abs() < 1e-10);
550 assert!((variance(vec![1i64, 2, 3]) - expected).abs() < 1e-10);
551 assert!((variance(vec![1usize, 2, 3]) - expected).abs() < 1e-10);
552 }
553
554 #[test]
555 fn test_stddev_wrapper_basic() {
556 let result = stddev(vec![1.0f64, 2.0, 3.0, 4.0, 5.0]);
558 assert!((result - 2.0f64.sqrt()).abs() < 1e-10);
559
560 let result = stddev(vec![1i32, 2, 3, 4, 5]);
562 assert!((result - 2.0f64.sqrt()).abs() < 1e-10);
563 }
564
565 #[test]
566 fn test_stddev_wrapper_empty() {
567 let result = stddev(Vec::<f64>::new());
568 assert!(result.is_nan());
569 }
570
571 #[test]
572 fn test_stddev_wrapper_single_element() {
573 assert!(stddev(vec![42.0f64]).abs() < 1e-10);
575 assert!(stddev(vec![100i32]).abs() < 1e-10);
576 }
577
578 #[test]
579 fn test_stddev_wrapper_identical_values() {
580 let result = stddev(vec![5.0f64, 5.0, 5.0, 5.0]);
582 assert!(result.abs() < 1e-10);
583 }
584
585 #[test]
586 fn test_stddev_wrapper_various_numeric_types() {
587 let expected = (2.0f64 / 3.0).sqrt();
589 assert!((stddev(vec![1u8, 2, 3]) - expected).abs() < 1e-10);
590 assert!((stddev(vec![1u16, 2, 3]) - expected).abs() < 1e-10);
591 assert!((stddev(vec![1i32, 2, 3]) - expected).abs() < 1e-10);
592 assert!((stddev(vec![1i64, 2, 3]) - expected).abs() < 1e-10);
593 assert!((stddev(vec![1usize, 2, 3]) - expected).abs() < 1e-10);
594 }
595
596 #[test]
597 fn test_wrapper_functions_consistency() {
598 let data = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
600 let stats = OnlineStats::from_slice(&data);
601
602 assert!((mean(data.clone()) - stats.mean()).abs() < 1e-10);
603 assert!((variance(data.clone()) - stats.variance()).abs() < 1e-10);
604 assert!((stddev(data) - stats.stddev()).abs() < 1e-10);
605 }
606
607 #[test]
608 fn test_wrapper_functions_with_iterators() {
609 let arr = [1, 2, 3, 4, 5];
611
612 assert!((mean(arr) - 3.0).abs() < 1e-10);
614
615 assert!((mean(1..=5) - 3.0).abs() < 1e-10);
617
618 let result = mean((1..=5).map(|x| x * 2));
620 assert!((result - 6.0).abs() < 1e-10);
621 }
622
623 #[test]
626 fn test_n_counts_basic() {
627 let mut stats = OnlineStats::new();
628 stats.extend(vec![-5, -3, 0, 0, 2, 4, 6]);
629
630 let (neg, zero, pos) = stats.n_counts();
631 assert_eq!(neg, 2, "Should have 2 negative values");
632 assert_eq!(zero, 2, "Should have 2 zero values");
633 assert_eq!(pos, 3, "Should have 3 positive values");
634 }
635
636 #[test]
637 fn test_n_counts_all_positive() {
638 let mut stats = OnlineStats::new();
639 stats.extend(vec![1.0, 2.0, 3.0, 4.0]);
640
641 let (neg, zero, pos) = stats.n_counts();
642 assert_eq!(neg, 0);
643 assert_eq!(zero, 0);
644 assert_eq!(pos, 4);
645 }
646
647 #[test]
648 fn test_n_counts_all_negative() {
649 let mut stats = OnlineStats::new();
650 stats.extend(vec![-1.0, -2.0, -3.0]);
651
652 let (neg, zero, pos) = stats.n_counts();
653 assert_eq!(neg, 3);
654 assert_eq!(zero, 0);
655 assert_eq!(pos, 0);
656 }
657
658 #[test]
659 fn test_n_counts_all_zeros() {
660 let mut stats = OnlineStats::new();
661 stats.extend(vec![0.0, 0.0, 0.0]);
662
663 let (neg, zero, pos) = stats.n_counts();
664 assert_eq!(neg, 0);
665 assert_eq!(zero, 3);
666 assert_eq!(pos, 0);
667 }
668
669 #[test]
670 fn test_n_counts_with_merge() {
671 let mut stats1 = OnlineStats::new();
672 stats1.extend(vec![-2, 0, 3]);
673
674 let mut stats2 = OnlineStats::new();
675 stats2.extend(vec![-1, 5, 7]);
676
677 stats1.merge(stats2);
678
679 let (neg, zero, pos) = stats1.n_counts();
680 assert_eq!(neg, 2, "Should have 2 negative values after merge");
681 assert_eq!(zero, 1, "Should have 1 zero value after merge");
682 assert_eq!(pos, 3, "Should have 3 positive values after merge");
683 }
684
685 #[test]
686 fn test_n_counts_empty() {
687 let stats = OnlineStats::new();
688
689 let (neg, zero, pos) = stats.n_counts();
690 assert_eq!(neg, 0);
691 assert_eq!(zero, 0);
692 assert_eq!(pos, 0);
693 }
694
695 #[test]
696 fn test_n_counts_negative_zero() {
697 let mut stats = OnlineStats::new();
698 stats.extend(vec![-0.0f64, 0.0]);
701
702 let (neg, zero, pos) = stats.n_counts();
703 assert_eq!(neg, 1, "-0.0 has negative sign bit");
704 assert_eq!(zero, 1, "+0.0 is zero");
705 assert_eq!(pos, 0);
706 }
707
708 #[test]
709 fn test_n_counts_floats_boundary() {
710 let mut stats = OnlineStats::new();
711 stats.extend(vec![-0.0001f64, 0.0, 0.0001]);
713
714 let (neg, zero, pos) = stats.n_counts();
715 assert_eq!(neg, 1);
716 assert_eq!(zero, 1);
717 assert_eq!(pos, 1);
718 }
719}