1use std::fmt;
2
3use num_traits::ToPrimitive;
4use serde::{Deserialize, Serialize};
5
6use crate::Commute;
7
8#[inline]
10pub fn stddev<I, T>(x: I) -> f64
11where
12 I: IntoIterator<Item = T>,
13 T: ToPrimitive,
14{
15 x.into_iter().collect::<OnlineStats>().stddev()
16}
17
18#[inline]
20pub fn variance<I, T>(x: I) -> f64
21where
22 I: IntoIterator<Item = T>,
23 T: ToPrimitive,
24{
25 x.into_iter().collect::<OnlineStats>().variance()
26}
27
28#[inline]
30pub fn mean<I, T>(x: I) -> f64
31where
32 I: IntoIterator<Item = T>,
33 T: ToPrimitive,
34{
35 x.into_iter().collect::<OnlineStats>().mean()
36}
37
38#[allow(clippy::unsafe_derive_deserialize)]
43#[derive(Clone, Copy, Serialize, Deserialize, PartialEq)]
44pub struct OnlineStats {
45 size: u64, mean: f64, q: f64, hg_sums: bool, harmonic_sum: f64, geometric_sum: f64, n_positive: u64, n_zero: u64, n_negative: u64, }
60
61impl OnlineStats {
62 #[must_use]
66 pub fn new() -> OnlineStats {
67 Default::default()
68 }
69
70 #[must_use]
72 pub fn from_slice<T: ToPrimitive>(samples: &[T]) -> OnlineStats {
73 samples
75 .iter()
76 .map(|n| unsafe { n.to_f64().unwrap_unchecked() })
77 .collect()
78 }
79
80 #[must_use]
82 pub const fn mean(&self) -> f64 {
83 if self.is_empty() { f64::NAN } else { self.mean }
84 }
85
86 #[must_use]
88 pub fn stddev(&self) -> f64 {
89 self.variance().sqrt()
90 }
91
92 #[must_use]
95 pub fn variance(&self) -> f64 {
96 self.q / (self.size as f64)
97 }
98
99 #[must_use]
101 pub fn harmonic_mean(&self) -> f64 {
102 if self.is_empty() || self.n_zero > 0 || self.n_negative > 0 {
103 f64::NAN
104 } else {
105 (self.size as f64) / self.harmonic_sum
106 }
107 }
108
109 #[must_use]
111 pub fn geometric_mean(&self) -> f64 {
112 if self.is_empty()
113 || self.n_negative > 0
114 || self.geometric_sum.is_infinite()
115 || self.geometric_sum.is_nan()
116 {
117 f64::NAN
118 } else if self.n_zero > 0 {
119 0.0
120 } else {
121 (self.geometric_sum / (self.size as f64)).exp()
122 }
123 }
124
125 #[must_use]
148 pub const fn n_counts(&self) -> (u64, u64, u64) {
149 (self.n_negative, self.n_zero, self.n_positive)
150 }
151
152 #[inline]
157 pub fn add<T: ToPrimitive>(&mut self, sample: &T) {
158 let sample = unsafe { sample.to_f64().unwrap_unchecked() };
160
161 self.size += 1;
164 let delta = sample - self.mean;
165
166 self.mean = delta.mul_add(1.0 / (self.size as f64), self.mean);
168
169 self.q = delta.mul_add(sample - self.mean, self.q);
171
172 if sample > 0.0 {
174 if self.hg_sums {
175 self.harmonic_sum = (1.0 / sample).mul_add(1.0, self.harmonic_sum);
178 self.geometric_sum = sample.ln().mul_add(1.0, self.geometric_sum);
180 }
181 self.n_positive += 1;
182 } else {
183 if sample.is_sign_negative() {
185 self.n_negative += 1;
186 } else {
187 self.n_zero += 1;
188 }
189 self.hg_sums = self.n_negative == 0 && self.n_zero == 0;
190 }
191 }
192
193 #[inline]
196 pub fn add_f64(&mut self, sample: f64) {
197 self.size += 1;
198 let delta = sample - self.mean;
199
200 self.mean = delta.mul_add(1.0 / (self.size as f64), self.mean);
201 self.q = delta.mul_add(sample - self.mean, self.q);
202
203 if sample > 0.0 {
205 if self.hg_sums {
206 self.harmonic_sum = (1.0 / sample).mul_add(1.0, self.harmonic_sum);
207 self.geometric_sum = sample.ln().mul_add(1.0, self.geometric_sum);
208 }
209 self.n_positive += 1;
210 } else {
211 if sample.is_sign_negative() {
213 self.n_negative += 1;
214 } else {
215 self.n_zero += 1;
216 }
217 self.hg_sums = self.n_negative == 0 && self.n_zero == 0;
218 }
219 }
220
221 #[inline]
224 pub fn add_null(&mut self) {
225 self.add_f64(0.0);
226 }
227
228 #[inline]
230 #[must_use]
231 pub const fn len(&self) -> usize {
232 self.size as usize
233 }
234
235 #[inline]
237 #[must_use]
238 pub const fn is_empty(&self) -> bool {
239 self.size == 0
240 }
241}
242
243impl Commute for OnlineStats {
244 #[inline]
245 fn merge(&mut self, v: OnlineStats) {
246 if v.is_empty() {
247 return;
248 }
249
250 let (s1, s2) = (self.size as f64, v.size as f64);
252 let total = s1 + s2;
253 let meandiffsq = (self.mean - v.mean).powi(2);
254
255 self.size += v.size;
256
257 self.mean = s1.mul_add(self.mean, s2 * v.mean) / total;
261
262 self.q += v.q + f64::mul_add(meandiffsq, s1 * s2 / total, 0.0);
265
266 self.harmonic_sum += v.harmonic_sum;
267 self.geometric_sum += v.geometric_sum;
268
269 self.n_zero += v.n_zero;
270 self.n_negative += v.n_negative;
271 self.n_positive += v.n_positive;
272 }
273}
274
275impl Default for OnlineStats {
276 fn default() -> OnlineStats {
277 OnlineStats {
278 size: 0,
279 mean: 0.0,
280 q: 0.0,
281 harmonic_sum: 0.0,
282 geometric_sum: 0.0,
283 n_zero: 0,
284 n_negative: 0,
285 n_positive: 0,
286 hg_sums: true,
287 }
288 }
289}
290
291impl fmt::Debug for OnlineStats {
292 #[inline]
293 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
294 write!(f, "{:.10} +/- {:.10}", self.mean(), self.stddev())
295 }
296}
297
298impl<T: ToPrimitive> FromIterator<T> for OnlineStats {
299 #[inline]
300 fn from_iter<I: IntoIterator<Item = T>>(it: I) -> OnlineStats {
301 let mut v = OnlineStats::new();
302 v.extend(it);
303 v
304 }
305}
306
307impl<T: ToPrimitive> Extend<T> for OnlineStats {
308 #[inline]
309 fn extend<I: IntoIterator<Item = T>>(&mut self, it: I) {
310 for sample in it {
311 self.add(&sample);
312 }
313 }
314}
315
316#[cfg(test)]
317mod test {
318 use super::{OnlineStats, mean, stddev, variance};
319 use {crate::Commute, crate::merge_all};
320
321 #[test]
322 fn online() {
323 let expected = OnlineStats::from_slice(&[1usize, 2, 3, 2, 4, 6]);
325
326 let var1 = OnlineStats::from_slice(&[1usize, 2, 3]);
327 let var2 = OnlineStats::from_slice(&[2usize, 4, 6]);
328 let mut got = var1;
329 got.merge(var2);
330 assert_eq!(expected.stddev(), got.stddev());
331 assert_eq!(expected.mean(), got.mean());
332 assert_eq!(expected.variance(), got.variance());
333 }
334
335 #[test]
336 fn online_empty() {
337 let expected = OnlineStats::new();
338 assert!(expected.is_empty());
339 }
340
341 #[test]
342 fn online_many() {
343 let expected = OnlineStats::from_slice(&[1usize, 2, 3, 2, 4, 6, 3, 6, 9]);
345
346 let vars = vec![
347 OnlineStats::from_slice(&[1usize, 2, 3]),
348 OnlineStats::from_slice(&[2usize, 4, 6]),
349 OnlineStats::from_slice(&[3usize, 6, 9]),
350 ];
351 assert_eq!(
352 expected.stddev(),
353 merge_all(vars.clone().into_iter()).unwrap().stddev()
354 );
355 assert_eq!(
356 expected.mean(),
357 merge_all(vars.clone().into_iter()).unwrap().mean()
358 );
359 assert_eq!(
360 expected.variance(),
361 merge_all(vars.into_iter()).unwrap().variance()
362 );
363 }
364
365 #[test]
366 fn test_means() {
367 let mut stats = OnlineStats::new();
368 stats.extend(vec![2.0f64, 4.0, 8.0]);
369
370 assert!((stats.mean() - 4.666666666667).abs() < 1e-10);
372
373 assert_eq!("3.42857143", format!("{:.8}", stats.harmonic_mean()));
375
376 assert!((stats.geometric_mean() - 4.0).abs() < 1e-10);
378 }
379
380 #[test]
381 fn test_means_with_negative() {
382 let mut stats = OnlineStats::new();
383 stats.extend(vec![-2.0f64, 2.0]);
384
385 assert!(stats.mean().abs() < 1e-10);
387
388 assert!(stats.geometric_mean().is_nan());
390
391 assert!(stats.harmonic_mean().is_nan());
393 }
394
395 #[test]
396 fn test_means_with_zero() {
397 let mut stats = OnlineStats::new();
398 stats.extend(vec![0.0f64, 4.0, 8.0]);
399
400 assert!((stats.mean() - 4.0).abs() < 1e-10);
402
403 assert!(stats.geometric_mean().abs() < 1e-10);
405
406 assert!(stats.harmonic_mean().is_nan());
408 }
409
410 #[test]
411 fn test_means_with_zero_and_negative_values() {
412 let mut stats = OnlineStats::new();
413 stats.extend(vec![-10i32, -5, 0, 5, 10]);
414
415 assert!(stats.mean().abs() < 1e-10);
417
418 assert!(stats.geometric_mean().is_nan());
420
421 assert!(stats.harmonic_mean().is_nan());
423 }
424
425 #[test]
426 fn test_means_single_value() {
427 let mut stats = OnlineStats::new();
428 stats.extend(vec![5.0f64]);
429
430 assert!((stats.mean() - 5.0).abs() < 1e-10);
432 assert!((stats.geometric_mean() - 5.0).abs() < 1e-10);
433 assert!((stats.harmonic_mean() - 5.0).abs() < 1e-10);
434 }
435
436 #[test]
437 fn test_means_empty() {
438 let stats = OnlineStats::new();
439
440 assert!(stats.mean().is_nan());
442 assert!(stats.geometric_mean().is_nan());
443 assert!(stats.harmonic_mean().is_nan());
444 }
445
446 #[test]
449 fn test_mean_wrapper_basic() {
450 let result = mean(vec![1.0f64, 2.0, 3.0, 4.0, 5.0]);
452 assert!((result - 3.0).abs() < 1e-10);
453
454 let result = mean(vec![1i32, 2, 3, 4, 5]);
456 assert!((result - 3.0).abs() < 1e-10);
457
458 let result = mean(vec![10u32, 20, 30]);
460 assert!((result - 20.0).abs() < 1e-10);
461 }
462
463 #[test]
464 fn test_mean_wrapper_empty() {
465 let result = mean(Vec::<f64>::new());
466 assert!(result.is_nan());
467 }
468
469 #[test]
470 fn test_mean_wrapper_single_element() {
471 assert!((mean(vec![42.0f64]) - 42.0).abs() < 1e-10);
472 assert!((mean(vec![100i32]) - 100.0).abs() < 1e-10);
473 assert!((mean(vec![0u8]) - 0.0).abs() < 1e-10);
474 }
475
476 #[test]
477 fn test_mean_wrapper_negative_values() {
478 let result = mean(vec![-5.0f64, 5.0]);
479 assert!(result.abs() < 1e-10); let result = mean(vec![-10i32, -20, -30]);
482 assert!((result - (-20.0)).abs() < 1e-10);
483 }
484
485 #[test]
486 fn test_mean_wrapper_various_numeric_types() {
487 assert!((mean(vec![1u8, 2, 3]) - 2.0).abs() < 1e-10);
489 assert!((mean(vec![1u16, 2, 3]) - 2.0).abs() < 1e-10);
490 assert!((mean(vec![1u64, 2, 3]) - 2.0).abs() < 1e-10);
491 assert!((mean(vec![1i8, 2, 3]) - 2.0).abs() < 1e-10);
492 assert!((mean(vec![1i16, 2, 3]) - 2.0).abs() < 1e-10);
493 assert!((mean(vec![1i64, 2, 3]) - 2.0).abs() < 1e-10);
494 assert!((mean(vec![1.0f32, 2.0, 3.0]) - 2.0).abs() < 1e-6);
495 assert!((mean(vec![1usize, 2, 3]) - 2.0).abs() < 1e-10);
496 assert!((mean(vec![1isize, 2, 3]) - 2.0).abs() < 1e-10);
497 }
498
499 #[test]
500 fn test_variance_wrapper_basic() {
501 let result = variance(vec![1.0f64, 2.0, 3.0, 4.0, 5.0]);
503 assert!((result - 2.0).abs() < 1e-10);
504
505 let result = variance(vec![1i32, 2, 3, 4, 5]);
507 assert!((result - 2.0).abs() < 1e-10);
508 }
509
510 #[test]
511 fn test_variance_wrapper_empty() {
512 let result = variance(Vec::<f64>::new());
513 assert!(result.is_nan());
514 }
515
516 #[test]
517 fn test_variance_wrapper_single_element() {
518 assert!(variance(vec![42.0f64]).abs() < 1e-10);
520 assert!(variance(vec![100i32]).abs() < 1e-10);
521 }
522
523 #[test]
524 fn test_variance_wrapper_identical_values() {
525 let result = variance(vec![5.0f64, 5.0, 5.0, 5.0]);
527 assert!(result.abs() < 1e-10);
528 }
529
530 #[test]
531 fn test_variance_wrapper_various_numeric_types() {
532 let expected = 2.0 / 3.0;
534 assert!((variance(vec![1u8, 2, 3]) - expected).abs() < 1e-10);
535 assert!((variance(vec![1u16, 2, 3]) - expected).abs() < 1e-10);
536 assert!((variance(vec![1i32, 2, 3]) - expected).abs() < 1e-10);
537 assert!((variance(vec![1i64, 2, 3]) - expected).abs() < 1e-10);
538 assert!((variance(vec![1usize, 2, 3]) - expected).abs() < 1e-10);
539 }
540
541 #[test]
542 fn test_stddev_wrapper_basic() {
543 let result = stddev(vec![1.0f64, 2.0, 3.0, 4.0, 5.0]);
545 assert!((result - 2.0f64.sqrt()).abs() < 1e-10);
546
547 let result = stddev(vec![1i32, 2, 3, 4, 5]);
549 assert!((result - 2.0f64.sqrt()).abs() < 1e-10);
550 }
551
552 #[test]
553 fn test_stddev_wrapper_empty() {
554 let result = stddev(Vec::<f64>::new());
555 assert!(result.is_nan());
556 }
557
558 #[test]
559 fn test_stddev_wrapper_single_element() {
560 assert!(stddev(vec![42.0f64]).abs() < 1e-10);
562 assert!(stddev(vec![100i32]).abs() < 1e-10);
563 }
564
565 #[test]
566 fn test_stddev_wrapper_identical_values() {
567 let result = stddev(vec![5.0f64, 5.0, 5.0, 5.0]);
569 assert!(result.abs() < 1e-10);
570 }
571
572 #[test]
573 fn test_stddev_wrapper_various_numeric_types() {
574 let expected = (2.0f64 / 3.0).sqrt();
576 assert!((stddev(vec![1u8, 2, 3]) - expected).abs() < 1e-10);
577 assert!((stddev(vec![1u16, 2, 3]) - expected).abs() < 1e-10);
578 assert!((stddev(vec![1i32, 2, 3]) - expected).abs() < 1e-10);
579 assert!((stddev(vec![1i64, 2, 3]) - expected).abs() < 1e-10);
580 assert!((stddev(vec![1usize, 2, 3]) - expected).abs() < 1e-10);
581 }
582
583 #[test]
584 fn test_wrapper_functions_consistency() {
585 let data = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
587 let stats = OnlineStats::from_slice(&data);
588
589 assert!((mean(data.clone()) - stats.mean()).abs() < 1e-10);
590 assert!((variance(data.clone()) - stats.variance()).abs() < 1e-10);
591 assert!((stddev(data) - stats.stddev()).abs() < 1e-10);
592 }
593
594 #[test]
595 fn test_wrapper_functions_with_iterators() {
596 let arr = [1, 2, 3, 4, 5];
598
599 assert!((mean(arr) - 3.0).abs() < 1e-10);
601
602 assert!((mean(1..=5) - 3.0).abs() < 1e-10);
604
605 let result = mean((1..=5).map(|x| x * 2));
607 assert!((result - 6.0).abs() < 1e-10);
608 }
609
610 #[test]
613 fn test_n_counts_basic() {
614 let mut stats = OnlineStats::new();
615 stats.extend(vec![-5, -3, 0, 0, 2, 4, 6]);
616
617 let (neg, zero, pos) = stats.n_counts();
618 assert_eq!(neg, 2, "Should have 2 negative values");
619 assert_eq!(zero, 2, "Should have 2 zero values");
620 assert_eq!(pos, 3, "Should have 3 positive values");
621 }
622
623 #[test]
624 fn test_n_counts_all_positive() {
625 let mut stats = OnlineStats::new();
626 stats.extend(vec![1.0, 2.0, 3.0, 4.0]);
627
628 let (neg, zero, pos) = stats.n_counts();
629 assert_eq!(neg, 0);
630 assert_eq!(zero, 0);
631 assert_eq!(pos, 4);
632 }
633
634 #[test]
635 fn test_n_counts_all_negative() {
636 let mut stats = OnlineStats::new();
637 stats.extend(vec![-1.0, -2.0, -3.0]);
638
639 let (neg, zero, pos) = stats.n_counts();
640 assert_eq!(neg, 3);
641 assert_eq!(zero, 0);
642 assert_eq!(pos, 0);
643 }
644
645 #[test]
646 fn test_n_counts_all_zeros() {
647 let mut stats = OnlineStats::new();
648 stats.extend(vec![0.0, 0.0, 0.0]);
649
650 let (neg, zero, pos) = stats.n_counts();
651 assert_eq!(neg, 0);
652 assert_eq!(zero, 3);
653 assert_eq!(pos, 0);
654 }
655
656 #[test]
657 fn test_n_counts_with_merge() {
658 let mut stats1 = OnlineStats::new();
659 stats1.extend(vec![-2, 0, 3]);
660
661 let mut stats2 = OnlineStats::new();
662 stats2.extend(vec![-1, 5, 7]);
663
664 stats1.merge(stats2);
665
666 let (neg, zero, pos) = stats1.n_counts();
667 assert_eq!(neg, 2, "Should have 2 negative values after merge");
668 assert_eq!(zero, 1, "Should have 1 zero value after merge");
669 assert_eq!(pos, 3, "Should have 3 positive values after merge");
670 }
671
672 #[test]
673 fn test_n_counts_empty() {
674 let stats = OnlineStats::new();
675
676 let (neg, zero, pos) = stats.n_counts();
677 assert_eq!(neg, 0);
678 assert_eq!(zero, 0);
679 assert_eq!(pos, 0);
680 }
681
682 #[test]
683 fn test_n_counts_negative_zero() {
684 let mut stats = OnlineStats::new();
685 stats.extend(vec![-0.0f64, 0.0]);
688
689 let (neg, zero, pos) = stats.n_counts();
690 assert_eq!(neg, 1, "-0.0 has negative sign bit");
691 assert_eq!(zero, 1, "+0.0 is zero");
692 assert_eq!(pos, 0);
693 }
694
695 #[test]
696 fn test_n_counts_floats_boundary() {
697 let mut stats = OnlineStats::new();
698 stats.extend(vec![-0.0001f64, 0.0, 0.0001]);
700
701 let (neg, zero, pos) = stats.n_counts();
702 assert_eq!(neg, 1);
703 assert_eq!(zero, 1);
704 assert_eq!(pos, 1);
705 }
706}