use crate::math::MulAdd;
macro_rules! impl_covariance {
($name:ident, $ty:ty) => {
#[derive(Debug, Clone)]
pub struct $name {
count: u64,
mean_x: $ty,
mean_y: $ty,
m2_x: $ty,
m2_y: $ty,
co_moment: $ty,
}
impl $name {
#[inline]
#[must_use]
pub const fn new() -> Self {
Self {
count: 0,
mean_x: 0.0 as $ty,
mean_y: 0.0 as $ty,
m2_x: 0.0 as $ty,
m2_y: 0.0 as $ty,
co_moment: 0.0 as $ty,
}
}
#[inline]
pub fn update(&mut self, x: $ty, y: $ty) -> Result<(), crate::DataError> {
check_finite!(x);
check_finite!(y);
self.count += 1;
let n = self.count as $ty;
let dx = x - self.mean_x;
let dy = y - self.mean_y;
self.mean_x += dx / n;
self.mean_y += dy / n;
let dx2 = x - self.mean_x;
self.co_moment += dx * (y - self.mean_y);
self.m2_x += dx * dx2;
let dy2 = y - self.mean_y;
self.m2_y += dy * dy2;
Ok(())
}
#[inline]
#[must_use]
pub fn count(&self) -> u64 {
self.count
}
#[inline]
#[must_use]
pub fn mean_x(&self) -> Option<$ty> {
if self.count == 0 {
Option::None
} else {
Option::Some(self.mean_x)
}
}
#[inline]
#[must_use]
pub fn mean_y(&self) -> Option<$ty> {
if self.count == 0 {
Option::None
} else {
Option::Some(self.mean_y)
}
}
#[inline]
#[must_use]
pub fn covariance(&self) -> Option<$ty> {
if self.count < 2 {
Option::None
} else {
Option::Some(self.co_moment / (self.count - 1) as $ty)
}
}
#[cfg(any(feature = "std", feature = "libm"))]
#[inline]
#[must_use]
pub fn correlation(&self) -> Option<$ty> {
if self.count < 2 {
return Option::None;
}
let var_product = self.m2_x * self.m2_y;
if var_product <= 0.0 as $ty {
return Option::None;
}
let r = self.co_moment / crate::math::sqrt(var_product as f64) as $ty;
Option::Some(r)
}
#[inline]
pub fn merge(&mut self, other: &Self) {
if other.count == 0 {
return;
}
if self.count == 0 {
*self = other.clone();
return;
}
let combined = self.count + other.count;
let dx = other.mean_x - self.mean_x;
let dy = other.mean_y - self.mean_y;
let weight = self.count as $ty * other.count as $ty / combined as $ty;
let new_mean_x =
(dx * other.count as $ty).fma(1.0 as $ty / combined as $ty, self.mean_x);
let new_mean_y =
(dy * other.count as $ty).fma(1.0 as $ty / combined as $ty, self.mean_y);
self.co_moment += (dx * dy).fma(weight, other.co_moment);
self.m2_x += (dx * dx).fma(weight, other.m2_x);
self.m2_y += (dy * dy).fma(weight, other.m2_y);
self.mean_x = new_mean_x;
self.mean_y = new_mean_y;
self.count = combined;
}
#[inline]
pub fn reset(&mut self) {
*self = Self::new();
}
}
impl Default for $name {
#[inline]
fn default() -> Self {
Self::new()
}
}
};
}
impl_covariance!(CovarianceF64, f64);
impl_covariance!(CovarianceF32, f32);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty() {
let c = CovarianceF64::new();
assert_eq!(c.count(), 0);
assert!(c.covariance().is_none());
assert!(c.correlation().is_none());
}
#[test]
fn perfect_positive_correlation() {
let mut c = CovarianceF64::new();
for i in 0..100 {
c.update(i as f64, i as f64 * 2.0).unwrap();
}
let r = c.correlation().unwrap();
assert!(
(r - 1.0).abs() < 1e-10,
"perfect positive should be 1.0, got {r}"
);
}
#[test]
fn perfect_negative_correlation() {
let mut c = CovarianceF64::new();
for i in 0..100 {
c.update(i as f64, -(i as f64)).unwrap();
}
let r = c.correlation().unwrap();
assert!(
(r + 1.0).abs() < 1e-10,
"perfect negative should be -1.0, got {r}"
);
}
#[test]
fn known_covariance() {
let mut c = CovarianceF64::new();
c.update(1.0, 2.0).unwrap();
c.update(2.0, 4.0).unwrap();
c.update(3.0, 6.0).unwrap();
let cov = c.covariance().unwrap();
assert!(
(cov - 2.0).abs() < 1e-10,
"covariance should be 2.0, got {cov}"
);
}
#[test]
fn merge_matches_single() {
let data_x: [f64; 8] = [1.0, 3.0, 5.0, 7.0, 2.0, 4.0, 6.0, 8.0];
let data_y: [f64; 8] = [2.0, 6.0, 10.0, 14.0, 4.0, 8.0, 12.0, 16.0];
let mut single = CovarianceF64::new();
for i in 0..8 {
single.update(data_x[i], data_y[i]).unwrap();
}
let mut a = CovarianceF64::new();
let mut b = CovarianceF64::new();
for i in 0..4 {
a.update(data_x[i], data_y[i]).unwrap();
}
for i in 4..8 {
b.update(data_x[i], data_y[i]).unwrap();
}
a.merge(&b);
assert_eq!(a.count(), single.count());
assert!((a.covariance().unwrap() - single.covariance().unwrap()).abs() < 1e-10);
assert!((a.correlation().unwrap() - single.correlation().unwrap()).abs() < 1e-10);
}
#[test]
fn reset_clears() {
let mut c = CovarianceF64::new();
c.update(1.0, 2.0).unwrap();
c.update(3.0, 4.0).unwrap();
c.reset();
assert_eq!(c.count(), 0);
}
#[test]
fn f32_basic() {
let mut c = CovarianceF32::new();
c.update(1.0, 2.0).unwrap();
c.update(2.0, 4.0).unwrap();
assert!(c.covariance().is_some());
}
#[test]
fn default_is_empty() {
let c = CovarianceF64::default();
assert_eq!(c.count(), 0);
}
#[test]
fn zero_variance_returns_none_correlation() {
let mut c = CovarianceF64::new();
c.update(5.0, 1.0).unwrap();
c.update(5.0, 2.0).unwrap(); assert!(c.correlation().is_none());
}
#[test]
fn rejects_nan_and_inf() {
let mut c = CovarianceF64::new();
assert_eq!(c.update(f64::NAN, 1.0), Err(crate::DataError::NotANumber));
assert_eq!(c.update(1.0, f64::NAN), Err(crate::DataError::NotANumber));
assert_eq!(
c.update(f64::INFINITY, 1.0),
Err(crate::DataError::Infinite)
);
assert_eq!(
c.update(1.0, f64::NEG_INFINITY),
Err(crate::DataError::Infinite)
);
assert_eq!(c.count(), 0);
}
}