use alloc::vec;
use alloc::vec::Vec;
#[derive(Debug, Clone)]
pub struct OnlineCovarianceF64 {
cov: Vec<f64>,
means: Vec<f64>,
dim: usize,
alpha: f64,
count: u64,
}
#[derive(Debug, Clone)]
pub struct OnlineCovarianceF64Builder {
dim: Option<usize>,
alpha: Option<f64>,
}
impl OnlineCovarianceF64 {
#[inline]
#[must_use]
pub fn builder() -> OnlineCovarianceF64Builder {
OnlineCovarianceF64Builder {
dim: None,
alpha: None,
}
}
pub fn update(&mut self, observation: &[f64]) -> Result<(), crate::DataError> {
assert_eq!(
observation.len(),
self.dim,
"observation dimension ({}) != configured dim ({})",
observation.len(),
self.dim,
);
for &v in observation {
check_finite!(v);
}
self.count += 1;
let d = self.dim;
if self.count == 1 {
self.means.copy_from_slice(observation);
return Ok(());
}
let alpha = self.alpha;
let one_minus_alpha = 1.0 - alpha;
for i in 0..d {
self.means[i] = observation[i] - self.means[i]; }
let alpha_times_one_minus = alpha * one_minus_alpha;
for i in 0..d {
for j in i..d {
let idx = i * d + j;
self.cov[idx] = one_minus_alpha * self.cov[idx]
+ alpha_times_one_minus * self.means[i] * self.means[j];
if i != j {
self.cov[j * d + i] = self.cov[idx];
}
}
}
for i in 0..d {
self.means[i] = (-self.means[i]).mul_add(one_minus_alpha, observation[i]);
}
Ok(())
}
#[inline]
#[must_use]
pub fn covariance(&self, i: usize, j: usize) -> Option<f64> {
if !self.is_primed() {
return None;
}
debug_assert!(i < self.dim && j < self.dim);
Some(self.cov[i * self.dim + j])
}
#[cfg(any(feature = "std", feature = "libm"))]
#[inline]
#[must_use]
pub fn correlation(&self, i: usize, j: usize) -> Option<f64> {
let var_i = self.variance(i)?;
let var_j = self.variance(j)?;
if var_i < f64::EPSILON || var_j < f64::EPSILON {
return None;
}
Some(self.covariance(i, j)? / (crate::math::sqrt(var_i) * crate::math::sqrt(var_j)))
}
#[inline]
#[must_use]
pub fn variance(&self, i: usize) -> Option<f64> {
self.covariance(i, i)
}
#[inline]
#[must_use]
pub fn mean(&self, i: usize) -> Option<f64> {
if self.count == 0 {
return None;
}
Some(self.means[i])
}
#[inline]
#[must_use]
pub fn as_matrix(&self) -> &[f64] {
&self.cov
}
#[inline]
#[must_use]
pub fn dim(&self) -> usize {
self.dim
}
#[inline]
#[must_use]
pub fn is_primed(&self) -> bool {
self.count >= 2
}
#[inline]
#[must_use]
pub fn count(&self) -> u64 {
self.count
}
pub fn reset(&mut self) {
self.cov.iter_mut().for_each(|v| *v = 0.0);
self.means.iter_mut().for_each(|v| *v = 0.0);
self.count = 0;
}
}
impl OnlineCovarianceF64Builder {
#[inline]
#[must_use]
pub fn dim(mut self, d: usize) -> Self {
self.dim = Some(d);
self
}
#[cfg(any(feature = "std", feature = "libm"))]
#[inline]
#[must_use]
pub fn halflife(mut self, h: f64) -> Self {
let alpha = 1.0 - crate::math::exp(-core::f64::consts::LN_2 / h);
self.alpha = Some(alpha);
self
}
#[inline]
#[must_use]
pub fn alpha(mut self, alpha: f64) -> Self {
self.alpha = Some(alpha);
self
}
pub fn build(self) -> Result<OnlineCovarianceF64, crate::ConfigError> {
let dim = self
.dim
.ok_or(crate::ConfigError::Missing("dim"))?;
if dim == 0 {
return Err(crate::ConfigError::Invalid("dim must be > 0"));
}
let alpha = self
.alpha
.ok_or(crate::ConfigError::Missing("halflife or alpha"))?;
if !(alpha > 0.0 && alpha < 1.0) {
return Err(crate::ConfigError::Invalid("alpha must be in (0, 1)"));
}
Ok(OnlineCovarianceF64 {
cov: vec![0.0; dim * dim],
means: vec![0.0; dim],
dim,
alpha,
count: 0,
})
}
}
#[cfg(test)]
#[allow(clippy::float_cmp)]
mod tests {
use super::*;
fn basic_cov(dim: usize) -> OnlineCovarianceF64 {
OnlineCovarianceF64::builder()
.dim(dim)
.halflife(50.0)
.build()
.unwrap()
}
#[test]
fn uncorrelated_2d() {
let mut cov = basic_cov(2);
for i in 0..200 {
let x = (i % 10) as f64;
let y = ((i * 7) % 13) as f64;
cov.update(&[x, y]).unwrap();
}
assert!(cov.variance(0).unwrap() > 0.0);
assert!(cov.variance(1).unwrap() > 0.0);
let corr = cov.correlation(0, 1).unwrap().abs();
assert!(corr < 0.5, "expected low correlation, got {corr}");
}
#[test]
fn perfectly_correlated_2d() {
let mut cov = basic_cov(2);
for i in 0..200 {
let x = i as f64;
cov.update(&[x, x * 2.0 + 1.0]).unwrap();
}
let corr = cov.correlation(0, 1).unwrap();
assert!(
corr > 0.95,
"expected high correlation, got {corr}"
);
}
#[test]
fn symmetry() {
let mut cov = basic_cov(3);
for i in 0..100 {
let v = [i as f64, (i * 2) as f64, (i * 3) as f64];
cov.update(&v).unwrap();
}
for i in 0..3 {
for j in 0..3 {
assert!(
(cov.covariance(i, j).unwrap() - cov.covariance(j, i).unwrap()).abs() < 1e-10,
"cov({i},{j}) != cov({j},{i})"
);
}
}
}
#[test]
#[should_panic(expected = "observation dimension")]
fn wrong_dimension_panics() {
let mut cov = basic_cov(3);
let _ = cov.update(&[1.0, 2.0]);
}
#[test]
fn priming() {
let mut cov = basic_cov(2);
assert!(!cov.is_primed());
assert!(cov.covariance(0, 1).is_none());
assert!(cov.correlation(0, 1).is_none());
assert!(cov.variance(0).is_none());
assert!(cov.mean(0).is_none());
cov.update(&[1.0, 2.0]).unwrap();
assert!(!cov.is_primed());
assert!(cov.covariance(0, 1).is_none());
assert!(cov.mean(0).is_some());
cov.update(&[3.0, 4.0]).unwrap();
assert!(cov.is_primed());
assert!(cov.covariance(0, 1).is_some());
}
#[test]
fn reset_clears() {
let mut cov = basic_cov(2);
for i in 0..50 {
cov.update(&[i as f64, i as f64]).unwrap();
}
cov.reset();
assert_eq!(cov.count(), 0);
assert!(!cov.is_primed());
}
#[test]
fn invalid_config() {
assert!(OnlineCovarianceF64::builder().alpha(0.1).build().is_err()); assert!(OnlineCovarianceF64::builder().dim(2).build().is_err()); assert!(OnlineCovarianceF64::builder().dim(0).alpha(0.1).build().is_err()); }
}