#![allow(clippy::float_cmp)]
extern crate alloc;
use alloc::boxed::Box;
use alloc::vec;
macro_rules! impl_cross_correlation {
($name:ident, $builder:ident, $ty:ty) => {
#[doc = concat!("use nexus_stats_detection::signal::", stringify!($name), ";")]
#[doc = concat!("let mut cc = ", stringify!($name), "::builder().lag(10).build().unwrap();")]
#[doc = concat!(" cc.update(a[i] as ", stringify!($ty), ", b as ", stringify!($ty), ").unwrap();")]
#[derive(Debug, Clone)]
pub struct $name {
buffer_a: Box<[$ty]>,
cross_m: Box<[$ty]>,
lag: usize,
head: usize,
count: u64,
mean_a: $ty,
mean_b: $ty,
m2_a: $ty,
m2_b: $ty,
}
#[doc = stringify!($name)]
#[derive(Debug, Clone)]
pub struct $builder {
lag: Option<usize>,
}
impl $name {
#[inline]
#[must_use]
pub fn builder() -> $builder {
$builder {
lag: Option::None,
}
}
#[inline]
pub fn update(&mut self, a: $ty, b: $ty) -> Result<(), nexus_stats_core::DataError> {
check_finite!(a);
check_finite!(b);
self.count += 1;
let n = self.count as $ty;
let lag = self.lag;
let da_old = a - self.mean_a;
let db_old = b - self.mean_b;
self.mean_a += da_old / n;
let da_new = a - self.mean_a;
self.m2_a += da_old * da_new;
self.mean_b += db_old / n;
let db_new = b - self.mean_b;
self.m2_b += db_old * db_new;
self.cross_m[0] += da_old * db_new;
if self.count > 1 {
let filled = (self.count - 1).min(lag as u64) as usize;
for k in 1..filled.min(lag) {
let idx = (self.head + lag - k) % lag;
let a_lagged = self.buffer_a[idx];
self.cross_m[k] +=
(a_lagged - self.mean_a) * db_new;
}
}
self.buffer_a[self.head] = a;
self.head = (self.head + 1) % lag;
Ok(())
}
#[cfg(any(feature = "std", feature = "libm"))]
#[inline]
#[must_use]
pub fn correlation(&self, lag: usize) -> Option<$ty> {
if lag >= self.lag {
return Option::None;
}
if self.count < (lag as u64 + 2) {
return Option::None;
}
let var_product = self.m2_a * self.m2_b;
if var_product <= 0.0 as $ty {
return Option::None;
}
let scale = if lag == 0 {
1.0 as $ty
} else {
let n_pairs = (self.count - lag as u64) as $ty;
let n_samples = (self.count - 1) as $ty;
n_samples / n_pairs
};
#[allow(clippy::cast_possible_truncation)]
let denom = nexus_stats_core::math::sqrt(var_product as f64) as $ty;
Option::Some(self.cross_m[lag] * scale / denom)
}
#[cfg(any(feature = "std", feature = "libm"))]
#[inline]
#[must_use]
pub fn peak_lag(&self) -> Option<usize> {
if self.count < 2 {
return Option::None;
}
let var_product = self.m2_a * self.m2_b;
if var_product <= 0.0 as $ty {
return Option::None;
}
let mut best_lag = 0;
let mut best_abs = 0.0 as $ty;
let max_lag = (self.count - 1).min(self.lag as u64) as usize;
let n_samples = (self.count - 1) as $ty;
for k in 0..max_lag {
let normalized = if k == 0 {
self.cross_m[k]
} else {
let n_pairs = (self.count - k as u64) as $ty;
self.cross_m[k] * n_samples / n_pairs
};
let abs_cm = if normalized < 0.0 as $ty {
-normalized
} else {
normalized
};
if abs_cm > best_abs {
best_abs = abs_cm;
best_lag = k;
}
}
Option::Some(best_lag)
}
#[inline]
#[must_use]
pub fn covariance(&self, lag: usize) -> Option<$ty> {
if lag >= self.lag {
return Option::None;
}
if self.count < (lag as u64 + 2) {
return Option::None;
}
let n_pairs = (self.count - lag as u64) as $ty;
Option::Some(self.cross_m[lag] / n_pairs)
}
#[inline]
#[must_use]
pub fn lag(&self) -> usize {
self.lag
}
#[inline]
#[must_use]
pub fn count(&self) -> u64 {
self.count
}
#[inline]
#[must_use]
pub fn is_primed(&self) -> bool {
self.count > self.lag as u64
}
#[inline]
pub fn reset(&mut self) {
self.buffer_a.fill(0.0 as $ty);
self.cross_m.fill(0.0 as $ty);
self.head = 0;
self.count = 0;
self.mean_a = 0.0 as $ty;
self.mean_b = 0.0 as $ty;
self.m2_a = 0.0 as $ty;
self.m2_b = 0.0 as $ty;
}
}
impl $builder {
#[inline]
#[must_use]
pub fn lag(mut self, lag: usize) -> Self {
self.lag = Option::Some(lag);
self
}
#[inline]
pub fn build(self) -> Result<$name, nexus_stats_core::ConfigError> {
let lag = self.lag.ok_or(nexus_stats_core::ConfigError::Missing("lag"))?;
if lag < 1 {
return Err(nexus_stats_core::ConfigError::Invalid("lag must be >= 1"));
}
Ok($name {
buffer_a: vec![0.0 as $ty; lag].into_boxed_slice(),
cross_m: vec![0.0 as $ty; lag].into_boxed_slice(),
lag,
head: 0,
count: 0,
mean_a: 0.0 as $ty,
mean_b: 0.0 as $ty,
m2_a: 0.0 as $ty,
m2_b: 0.0 as $ty,
})
}
}
};
}
impl_cross_correlation!(CrossCorrelationF64, CrossCorrelationF64Builder, f64);
impl_cross_correlation!(CrossCorrelationF32, CrossCorrelationF32Builder, f32);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn identical_streams_correlation_one() {
let mut cc = CrossCorrelationF64::builder().lag(1).build().unwrap();
for i in 0..1000u64 {
let x = i as f64;
cc.update(x, x).unwrap();
}
let r = cc.correlation(0).unwrap();
assert!(
(r - 1.0).abs() < 1e-6,
"identical streams should correlate at 1.0, got {r}"
);
}
#[test]
fn opposite_streams_correlation_negative() {
let mut cc = CrossCorrelationF64::builder().lag(1).build().unwrap();
for i in 0..1000u64 {
let x = i as f64;
cc.update(x, -x).unwrap();
}
let r = cc.correlation(0).unwrap();
assert!(
(r - (-1.0)).abs() < 1e-6,
"opposite streams should correlate at -1.0, got {r}"
);
}
#[test]
fn shifted_signal_peak_lag() {
let mut cc = CrossCorrelationF64::builder().lag(10).build().unwrap();
let shift = 3;
let a: Vec<f64> = (0..1000).map(|i| ((i as f64) * 0.1).sin()).collect();
for i in 0..1000 {
let b = if i >= shift { a[i - shift] } else { 0.0 };
cc.update(a[i], b).unwrap();
}
let peak = cc.peak_lag().unwrap();
assert!(
(peak as i32 - shift as i32).unsigned_abs() <= 1,
"peak lag should be near {shift}, got {peak}"
);
}
#[test]
fn lag0_matches_covariance_type() {
let mut cc = CrossCorrelationF64::builder().lag(1).build().unwrap();
let mut cov = nexus_stats_core::statistics::CovarianceF64::new();
for i in 0..500u64 {
let x = i as f64;
let y = x * 2.0 + 1.0;
cc.update(x, y).unwrap();
let _ = cov.update(x, y);
}
let r_cc = cc.correlation(0).unwrap();
let r_cov = cov.correlation().unwrap();
assert!(
(r_cc - r_cov).abs() < 0.01,
"lag-0 cross-correlation ({r_cc}) should match covariance correlation ({r_cov})"
);
}
#[test]
fn not_primed_until_enough_samples() {
let mut cc = CrossCorrelationF64::builder().lag(5).build().unwrap();
for i in 0..5 {
cc.update(i as f64, i as f64).unwrap();
assert!(!cc.is_primed());
}
cc.update(5.0, 5.0).unwrap();
assert!(cc.is_primed());
}
#[test]
fn lag_out_of_range_returns_none() {
let mut cc = CrossCorrelationF64::builder().lag(5).build().unwrap();
for i in 0..20 {
cc.update(i as f64, i as f64).unwrap();
}
assert!(cc.correlation(5).is_none()); assert!(cc.covariance(5).is_none());
}
#[test]
fn zero_variance_returns_none() {
let mut cc = CrossCorrelationF64::builder().lag(1).build().unwrap();
for _ in 0..100 {
cc.update(42.0, 42.0).unwrap();
}
assert!(cc.correlation(0).is_none());
}
#[test]
fn reset_clears_state() {
let mut cc = CrossCorrelationF64::builder().lag(3).build().unwrap();
for i in 0..100 {
cc.update(i as f64, (i * 2) as f64).unwrap();
}
cc.reset();
assert_eq!(cc.count(), 0);
assert!(!cc.is_primed());
}
#[test]
fn lag_accessor() {
let cc = CrossCorrelationF64::builder().lag(7).build().unwrap();
assert_eq!(cc.lag(), 7);
}
#[test]
fn f32_basic() {
let mut cc = CrossCorrelationF32::builder().lag(1).build().unwrap();
for i in 0..200u32 {
cc.update(i as f32, (i * 2) as f32).unwrap();
}
let r = cc.correlation(0).unwrap();
assert!(r > 0.9, "f32 perfect linear should be near 1.0, got {r}");
}
#[test]
fn rejects_nan_and_inf() {
let mut cc = CrossCorrelationF64::builder().lag(1).build().unwrap();
assert_eq!(
cc.update(f64::NAN, 1.0),
Err(nexus_stats_core::DataError::NotANumber)
);
assert_eq!(
cc.update(1.0, f64::INFINITY),
Err(nexus_stats_core::DataError::Infinite)
);
assert_eq!(cc.count(), 0);
}
#[test]
fn builder_requires_lag() {
let result = CrossCorrelationF64::builder().build();
assert!(matches!(
result,
Err(nexus_stats_core::ConfigError::Missing("lag"))
));
}
#[test]
fn builder_rejects_zero_lag() {
let result = CrossCorrelationF64::builder().lag(0).build();
assert!(result.is_err());
}
}