use std::fmt;
use num_traits::ToPrimitive;
use serde::{Deserialize, Serialize};
use crate::Commute;
#[inline]
pub fn stddev<I, T>(x: I) -> f64
where
I: IntoIterator<Item = T>,
T: ToPrimitive,
{
x.into_iter().collect::<OnlineStats>().stddev()
}
#[inline]
pub fn variance<I, T>(x: I) -> f64
where
I: IntoIterator<Item = T>,
T: ToPrimitive,
{
x.into_iter().collect::<OnlineStats>().variance()
}
#[inline]
pub fn mean<I, T>(x: I) -> f64
where
I: IntoIterator<Item = T>,
T: ToPrimitive,
{
x.into_iter().collect::<OnlineStats>().mean()
}
#[allow(clippy::unsafe_derive_deserialize)]
#[derive(Clone, Copy, Serialize, Deserialize, PartialEq)]
pub struct OnlineStats {
size: u64, mean: f64, q: f64,
hg_sums: bool, harmonic_sum: f64, geometric_sum: f64, n_positive: u64,
n_zero: u64, n_negative: u64, }
impl OnlineStats {
#[must_use]
pub fn new() -> OnlineStats {
Default::default()
}
#[must_use]
pub fn from_slice<T: ToPrimitive>(samples: &[T]) -> OnlineStats {
samples
.iter()
.map(|n| unsafe { n.to_f64().unwrap_unchecked() })
.collect()
}
#[must_use]
pub const fn mean(&self) -> f64 {
if self.is_empty() { f64::NAN } else { self.mean }
}
#[must_use]
pub fn stddev(&self) -> f64 {
self.variance().sqrt()
}
#[must_use]
pub const fn variance(&self) -> f64 {
if self.is_empty() {
f64::NAN
} else {
self.q / (self.size as f64)
}
}
#[must_use]
pub fn harmonic_mean(&self) -> f64 {
if self.is_empty() || self.n_zero > 0 || self.n_negative > 0 {
f64::NAN
} else {
(self.size as f64) / self.harmonic_sum
}
}
#[must_use]
pub fn geometric_mean(&self) -> f64 {
if self.is_empty()
|| self.n_negative > 0
|| self.geometric_sum.is_nan()
|| self.geometric_sum == f64::INFINITY
{
f64::NAN
} else if self.n_zero > 0 || self.geometric_sum == f64::NEG_INFINITY {
0.0
} else {
(self.geometric_sum / (self.size as f64)).exp()
}
}
#[must_use]
pub const fn n_counts(&self) -> (u64, u64, u64) {
(self.n_negative, self.n_zero, self.n_positive)
}
#[inline]
pub fn add<T: ToPrimitive>(&mut self, sample: &T) {
let sample = unsafe { sample.to_f64().unwrap_unchecked() };
if sample.is_nan() {
return;
}
self.size += 1;
let delta = sample - self.mean;
self.mean = delta.mul_add(1.0 / (self.size as f64), self.mean);
self.q = delta.mul_add(sample - self.mean, self.q);
if sample > 0.0 {
if self.hg_sums {
self.harmonic_sum += 1.0 / sample;
self.geometric_sum += sample.ln();
}
self.n_positive += 1;
} else {
if sample.is_sign_negative() {
self.n_negative += 1;
} else {
self.n_zero += 1;
}
self.hg_sums = false;
}
}
#[inline]
pub fn add_f64(&mut self, sample: f64) {
if sample.is_nan() {
return;
}
self.size += 1;
let delta = sample - self.mean;
self.mean = delta.mul_add(1.0 / (self.size as f64), self.mean);
self.q = delta.mul_add(sample - self.mean, self.q);
if sample > 0.0 {
if self.hg_sums {
self.harmonic_sum += 1.0 / sample;
self.geometric_sum += sample.ln();
}
self.n_positive += 1;
} else {
if sample.is_sign_negative() {
self.n_negative += 1;
} else {
self.n_zero += 1;
}
self.hg_sums = false;
}
}
#[inline]
pub fn add_null(&mut self) {
self.add_f64(0.0);
}
#[inline]
#[must_use]
pub const fn len(&self) -> usize {
self.size as usize
}
#[inline]
#[must_use]
pub const fn is_empty(&self) -> bool {
self.size == 0
}
}
impl Commute for OnlineStats {
#[inline]
fn merge(&mut self, v: OnlineStats) {
if v.is_empty() {
return;
}
let (s1, s2) = (self.size as f64, v.size as f64);
let total = s1 + s2;
let delta = self.mean - v.mean;
let meandiffsq = delta * delta;
self.size += v.size;
self.mean = (v.mean - self.mean).mul_add(s2 / total, self.mean);
self.q += meandiffsq.mul_add(s1 * s2 / total, v.q);
self.hg_sums = self.hg_sums && v.hg_sums;
self.harmonic_sum += v.harmonic_sum;
self.geometric_sum += v.geometric_sum;
self.n_zero += v.n_zero;
self.n_negative += v.n_negative;
self.n_positive += v.n_positive;
}
}
impl Default for OnlineStats {
fn default() -> OnlineStats {
OnlineStats {
size: 0,
mean: 0.0,
q: 0.0,
harmonic_sum: 0.0,
geometric_sum: 0.0,
n_zero: 0,
n_negative: 0,
n_positive: 0,
hg_sums: true,
}
}
}
impl fmt::Debug for OnlineStats {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:.10} +/- {:.10}", self.mean(), self.stddev())
}
}
impl<T: ToPrimitive> FromIterator<T> for OnlineStats {
#[inline]
fn from_iter<I: IntoIterator<Item = T>>(it: I) -> OnlineStats {
let mut v = OnlineStats::new();
v.extend(it);
v
}
}
impl<T: ToPrimitive> Extend<T> for OnlineStats {
#[inline]
fn extend<I: IntoIterator<Item = T>>(&mut self, it: I) {
for sample in it {
self.add(&sample);
}
}
}
#[cfg(test)]
mod test {
use super::{OnlineStats, mean, stddev, variance};
use {crate::Commute, crate::merge_all};
#[test]
fn online() {
let expected = OnlineStats::from_slice(&[1usize, 2, 3, 2, 4, 6]);
let var1 = OnlineStats::from_slice(&[1usize, 2, 3]);
let var2 = OnlineStats::from_slice(&[2usize, 4, 6]);
let mut got = var1;
got.merge(var2);
assert_eq!(expected.stddev(), got.stddev());
assert_eq!(expected.mean(), got.mean());
assert_eq!(expected.variance(), got.variance());
}
#[test]
fn online_empty() {
let expected = OnlineStats::new();
assert!(expected.is_empty());
}
#[test]
fn online_many() {
let expected = OnlineStats::from_slice(&[1usize, 2, 3, 2, 4, 6, 3, 6, 9]);
let vars = vec![
OnlineStats::from_slice(&[1usize, 2, 3]),
OnlineStats::from_slice(&[2usize, 4, 6]),
OnlineStats::from_slice(&[3usize, 6, 9]),
];
assert_eq!(
expected.stddev(),
merge_all(vars.clone().into_iter()).unwrap().stddev()
);
assert_eq!(
expected.mean(),
merge_all(vars.clone().into_iter()).unwrap().mean()
);
assert_eq!(
expected.variance(),
merge_all(vars.into_iter()).unwrap().variance()
);
}
#[test]
fn test_means() {
let mut stats = OnlineStats::new();
stats.extend(vec![2.0f64, 4.0, 8.0]);
assert!((stats.mean() - 4.666666666667).abs() < 1e-10);
assert_eq!("3.42857143", format!("{:.8}", stats.harmonic_mean()));
assert!((stats.geometric_mean() - 4.0).abs() < 1e-10);
}
#[test]
fn test_means_with_negative() {
let mut stats = OnlineStats::new();
stats.extend(vec![-2.0f64, 2.0]);
assert!(stats.mean().abs() < 1e-10);
assert!(stats.geometric_mean().is_nan());
assert!(stats.harmonic_mean().is_nan());
}
#[test]
fn test_means_with_zero() {
let mut stats = OnlineStats::new();
stats.extend(vec![0.0f64, 4.0, 8.0]);
assert!((stats.mean() - 4.0).abs() < 1e-10);
assert!(stats.geometric_mean().abs() < 1e-10);
assert!(stats.harmonic_mean().is_nan());
}
#[test]
fn test_means_with_zero_and_negative_values() {
let mut stats = OnlineStats::new();
stats.extend(vec![-10i32, -5, 0, 5, 10]);
assert!(stats.mean().abs() < 1e-10);
assert!(stats.geometric_mean().is_nan());
assert!(stats.harmonic_mean().is_nan());
}
#[test]
fn test_means_single_value() {
let mut stats = OnlineStats::new();
stats.extend(vec![5.0f64]);
assert!((stats.mean() - 5.0).abs() < 1e-10);
assert!((stats.geometric_mean() - 5.0).abs() < 1e-10);
assert!((stats.harmonic_mean() - 5.0).abs() < 1e-10);
}
#[test]
fn test_means_empty() {
let stats = OnlineStats::new();
assert!(stats.mean().is_nan());
assert!(stats.geometric_mean().is_nan());
assert!(stats.harmonic_mean().is_nan());
}
#[test]
fn test_mean_wrapper_basic() {
let result = mean(vec![1.0f64, 2.0, 3.0, 4.0, 5.0]);
assert!((result - 3.0).abs() < 1e-10);
let result = mean(vec![1i32, 2, 3, 4, 5]);
assert!((result - 3.0).abs() < 1e-10);
let result = mean(vec![10u32, 20, 30]);
assert!((result - 20.0).abs() < 1e-10);
}
#[test]
fn test_mean_wrapper_empty() {
let result = mean(Vec::<f64>::new());
assert!(result.is_nan());
}
#[test]
fn test_mean_wrapper_single_element() {
assert!((mean(vec![42.0f64]) - 42.0).abs() < 1e-10);
assert!((mean(vec![100i32]) - 100.0).abs() < 1e-10);
assert!((mean(vec![0u8]) - 0.0).abs() < 1e-10);
}
#[test]
fn test_mean_wrapper_negative_values() {
let result = mean(vec![-5.0f64, 5.0]);
assert!(result.abs() < 1e-10);
let result = mean(vec![-10i32, -20, -30]);
assert!((result - (-20.0)).abs() < 1e-10);
}
#[test]
fn test_mean_wrapper_various_numeric_types() {
assert!((mean(vec![1u8, 2, 3]) - 2.0).abs() < 1e-10);
assert!((mean(vec![1u16, 2, 3]) - 2.0).abs() < 1e-10);
assert!((mean(vec![1u64, 2, 3]) - 2.0).abs() < 1e-10);
assert!((mean(vec![1i8, 2, 3]) - 2.0).abs() < 1e-10);
assert!((mean(vec![1i16, 2, 3]) - 2.0).abs() < 1e-10);
assert!((mean(vec![1i64, 2, 3]) - 2.0).abs() < 1e-10);
assert!((mean(vec![1.0f32, 2.0, 3.0]) - 2.0).abs() < 1e-6);
assert!((mean(vec![1usize, 2, 3]) - 2.0).abs() < 1e-10);
assert!((mean(vec![1isize, 2, 3]) - 2.0).abs() < 1e-10);
}
#[test]
fn test_variance_wrapper_basic() {
let result = variance(vec![1.0f64, 2.0, 3.0, 4.0, 5.0]);
assert!((result - 2.0).abs() < 1e-10);
let result = variance(vec![1i32, 2, 3, 4, 5]);
assert!((result - 2.0).abs() < 1e-10);
}
#[test]
fn test_variance_wrapper_empty() {
let result = variance(Vec::<f64>::new());
assert!(result.is_nan());
}
#[test]
fn test_variance_wrapper_single_element() {
assert!(variance(vec![42.0f64]).abs() < 1e-10);
assert!(variance(vec![100i32]).abs() < 1e-10);
}
#[test]
fn test_variance_wrapper_identical_values() {
let result = variance(vec![5.0f64, 5.0, 5.0, 5.0]);
assert!(result.abs() < 1e-10);
}
#[test]
fn test_variance_wrapper_various_numeric_types() {
let expected = 2.0 / 3.0;
assert!((variance(vec![1u8, 2, 3]) - expected).abs() < 1e-10);
assert!((variance(vec![1u16, 2, 3]) - expected).abs() < 1e-10);
assert!((variance(vec![1i32, 2, 3]) - expected).abs() < 1e-10);
assert!((variance(vec![1i64, 2, 3]) - expected).abs() < 1e-10);
assert!((variance(vec![1usize, 2, 3]) - expected).abs() < 1e-10);
}
#[test]
fn test_stddev_wrapper_basic() {
let result = stddev(vec![1.0f64, 2.0, 3.0, 4.0, 5.0]);
assert!((result - 2.0f64.sqrt()).abs() < 1e-10);
let result = stddev(vec![1i32, 2, 3, 4, 5]);
assert!((result - 2.0f64.sqrt()).abs() < 1e-10);
}
#[test]
fn test_stddev_wrapper_empty() {
let result = stddev(Vec::<f64>::new());
assert!(result.is_nan());
}
#[test]
fn test_stddev_wrapper_single_element() {
assert!(stddev(vec![42.0f64]).abs() < 1e-10);
assert!(stddev(vec![100i32]).abs() < 1e-10);
}
#[test]
fn test_stddev_wrapper_identical_values() {
let result = stddev(vec![5.0f64, 5.0, 5.0, 5.0]);
assert!(result.abs() < 1e-10);
}
#[test]
fn test_stddev_wrapper_various_numeric_types() {
let expected = (2.0f64 / 3.0).sqrt();
assert!((stddev(vec![1u8, 2, 3]) - expected).abs() < 1e-10);
assert!((stddev(vec![1u16, 2, 3]) - expected).abs() < 1e-10);
assert!((stddev(vec![1i32, 2, 3]) - expected).abs() < 1e-10);
assert!((stddev(vec![1i64, 2, 3]) - expected).abs() < 1e-10);
assert!((stddev(vec![1usize, 2, 3]) - expected).abs() < 1e-10);
}
#[test]
fn test_wrapper_functions_consistency() {
let data = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let stats = OnlineStats::from_slice(&data);
assert!((mean(data.clone()) - stats.mean()).abs() < 1e-10);
assert!((variance(data.clone()) - stats.variance()).abs() < 1e-10);
assert!((stddev(data) - stats.stddev()).abs() < 1e-10);
}
#[test]
fn test_wrapper_functions_with_iterators() {
let arr = [1, 2, 3, 4, 5];
assert!((mean(arr) - 3.0).abs() < 1e-10);
assert!((mean(1..=5) - 3.0).abs() < 1e-10);
let result = mean((1..=5).map(|x| x * 2));
assert!((result - 6.0).abs() < 1e-10);
}
#[test]
fn test_n_counts_basic() {
let mut stats = OnlineStats::new();
stats.extend(vec![-5, -3, 0, 0, 2, 4, 6]);
let (neg, zero, pos) = stats.n_counts();
assert_eq!(neg, 2, "Should have 2 negative values");
assert_eq!(zero, 2, "Should have 2 zero values");
assert_eq!(pos, 3, "Should have 3 positive values");
}
#[test]
fn test_n_counts_all_positive() {
let mut stats = OnlineStats::new();
stats.extend(vec![1.0, 2.0, 3.0, 4.0]);
let (neg, zero, pos) = stats.n_counts();
assert_eq!(neg, 0);
assert_eq!(zero, 0);
assert_eq!(pos, 4);
}
#[test]
fn test_n_counts_all_negative() {
let mut stats = OnlineStats::new();
stats.extend(vec![-1.0, -2.0, -3.0]);
let (neg, zero, pos) = stats.n_counts();
assert_eq!(neg, 3);
assert_eq!(zero, 0);
assert_eq!(pos, 0);
}
#[test]
fn test_n_counts_all_zeros() {
let mut stats = OnlineStats::new();
stats.extend(vec![0.0, 0.0, 0.0]);
let (neg, zero, pos) = stats.n_counts();
assert_eq!(neg, 0);
assert_eq!(zero, 3);
assert_eq!(pos, 0);
}
#[test]
fn test_n_counts_with_merge() {
let mut stats1 = OnlineStats::new();
stats1.extend(vec![-2, 0, 3]);
let mut stats2 = OnlineStats::new();
stats2.extend(vec![-1, 5, 7]);
stats1.merge(stats2);
let (neg, zero, pos) = stats1.n_counts();
assert_eq!(neg, 2, "Should have 2 negative values after merge");
assert_eq!(zero, 1, "Should have 1 zero value after merge");
assert_eq!(pos, 3, "Should have 3 positive values after merge");
}
#[test]
fn test_n_counts_empty() {
let stats = OnlineStats::new();
let (neg, zero, pos) = stats.n_counts();
assert_eq!(neg, 0);
assert_eq!(zero, 0);
assert_eq!(pos, 0);
}
#[test]
fn test_n_counts_negative_zero() {
let mut stats = OnlineStats::new();
stats.extend(vec![-0.0f64, 0.0]);
let (neg, zero, pos) = stats.n_counts();
assert_eq!(neg, 1, "-0.0 has negative sign bit");
assert_eq!(zero, 1, "+0.0 is zero");
assert_eq!(pos, 0);
}
#[test]
fn test_n_counts_floats_boundary() {
let mut stats = OnlineStats::new();
stats.extend(vec![-0.0001f64, 0.0, 0.0001]);
let (neg, zero, pos) = stats.n_counts();
assert_eq!(neg, 1);
assert_eq!(zero, 1);
assert_eq!(pos, 1);
}
#[test]
fn test_nan_skipped() {
let mut stats = OnlineStats::new();
stats.add(&1.0f64);
stats.add(&f64::NAN);
stats.add(&3.0f64);
assert_eq!(stats.len(), 2);
assert!((stats.mean() - 2.0).abs() < 1e-10);
}
#[test]
fn test_nan_only() {
let mut stats = OnlineStats::new();
stats.add(&f64::NAN);
stats.add(&f64::NAN);
assert_eq!(stats.len(), 0);
assert!(stats.mean().is_nan());
assert!(stats.variance().is_nan());
assert!(stats.stddev().is_nan());
}
#[test]
fn test_infinity_add() {
let mut stats = OnlineStats::new();
stats.add(&f64::INFINITY);
assert_eq!(stats.len(), 1);
assert!(stats.mean().is_infinite());
}
#[test]
fn test_neg_infinity_add() {
let mut stats = OnlineStats::new();
stats.add(&f64::NEG_INFINITY);
assert_eq!(stats.len(), 1);
assert!(stats.mean().is_infinite());
}
#[test]
fn test_infinity_mixed() {
let mut stats = OnlineStats::new();
stats.add(&f64::INFINITY);
stats.add(&f64::NEG_INFINITY);
assert!(stats.mean().is_nan());
}
#[test]
fn test_add_f64_nan_skipped() {
let mut stats = OnlineStats::new();
stats.add_f64(1.0);
stats.add_f64(f64::NAN);
stats.add_f64(3.0);
assert_eq!(stats.len(), 2);
assert!((stats.mean() - 2.0).abs() < 1e-10);
}
#[test]
fn test_geometric_mean_infinity() {
let mut stats = OnlineStats::new();
stats.add(&f64::INFINITY);
assert!(stats.geometric_mean().is_nan());
}
#[test]
fn test_harmonic_mean_infinity() {
let mut stats = OnlineStats::new();
stats.add(&f64::INFINITY);
stats.add(&1.0f64);
assert!((stats.harmonic_mean() - 2.0).abs() < 1e-10);
}
}