#![allow(clippy::float_cmp)]
extern crate alloc;
use alloc::boxed::Box;
use alloc::vec;
macro_rules! impl_autocorrelation_float {
($name:ident, $builder:ident, $ty:ty) => {
#[doc = concat!("use nexus_stats_detection::signal::", stringify!($name), ";")]
#[doc = concat!("let mut ac = ", stringify!($name), "::builder().lag(1).build().unwrap();")]
#[doc = concat!(" ac.update(if i % 2 == 0 { 1.0 as ", stringify!($ty), " } else { -1.0 as ", stringify!($ty), " }).unwrap();")]
#[derive(Debug, Clone)]
pub struct $name {
buffer: Box<[$ty]>,
lag: usize,
head: usize,
count: u64,
mean: $ty,
m2: $ty,
cross_m: $ty,
}
#[doc = stringify!($name)]
#[derive(Debug, Clone)]
pub struct $builder {
lag: Option<usize>,
}
impl $name {
#[inline]
#[must_use]
pub fn builder() -> $builder {
$builder {
lag: Option::None,
}
}
#[inline]
pub fn update(&mut self, sample: $ty) -> Result<(), nexus_stats_core::DataError> {
check_finite!(sample);
self.count += 1;
let delta = sample - self.mean;
self.mean += delta / self.count as $ty;
let delta2 = sample - self.mean;
self.m2 += delta * delta2;
if self.count > self.lag as u64 {
let x_lagged = self.buffer[self.head];
self.cross_m +=
(sample - self.mean) * (x_lagged - self.mean);
}
self.buffer[self.head] = sample;
self.head = (self.head + 1) % self.lag;
Ok(())
}
#[inline]
#[must_use]
pub fn correlation(&self) -> Option<$ty> {
if self.count < (self.lag as u64 + 2) {
return Option::None;
}
if self.m2 == 0.0 as $ty {
return Option::None;
}
let n_pairs = (self.count - self.lag as u64) as $ty;
let n_samples = (self.count - 1) as $ty;
Option::Some(self.cross_m * n_samples / (self.m2 * n_pairs))
}
#[inline]
#[must_use]
pub fn covariance(&self) -> Option<$ty> {
if self.count < (self.lag as u64 + 2) {
return Option::None;
}
let n_pairs = (self.count - self.lag as u64) as $ty;
Option::Some(self.cross_m / n_pairs)
}
#[inline]
#[must_use]
pub fn lag(&self) -> usize {
self.lag
}
#[inline]
#[must_use]
pub fn count(&self) -> u64 {
self.count
}
#[inline]
#[must_use]
pub fn is_primed(&self) -> bool {
self.count >= self.lag as u64 + 2
}
#[inline]
pub fn reset(&mut self) {
self.buffer.fill(0.0 as $ty);
self.head = 0;
self.count = 0;
self.mean = 0.0 as $ty;
self.m2 = 0.0 as $ty;
self.cross_m = 0.0 as $ty;
}
}
impl $builder {
#[inline]
#[must_use]
pub fn lag(mut self, lag: usize) -> Self {
self.lag = Option::Some(lag);
self
}
#[inline]
pub fn build(self) -> Result<$name, nexus_stats_core::ConfigError> {
let lag = self.lag.ok_or(nexus_stats_core::ConfigError::Missing("lag"))?;
if lag < 1 {
return Err(nexus_stats_core::ConfigError::Invalid("lag must be >= 1"));
}
Ok($name {
buffer: vec![0.0 as $ty; lag].into_boxed_slice(),
lag,
head: 0,
count: 0,
mean: 0.0 as $ty,
m2: 0.0 as $ty,
cross_m: 0.0 as $ty,
})
}
}
};
}
macro_rules! impl_autocorrelation_int {
($name:ident, $builder:ident, $input:ty) => {
#[doc = concat!("use nexus_stats_detection::signal::", stringify!($name), ";")]
#[doc = concat!("let mut ac = ", stringify!($name), "::builder().lag(1).build().unwrap();")]
#[doc = concat!("for i in 0..200 { ac.update(if i % 2 == 0 { 1 as ", stringify!($input), " } else { -1 as ", stringify!($input), " }); }")]
#[derive(Debug, Clone)]
pub struct $name {
buffer: Box<[f64]>,
lag: usize,
head: usize,
count: u64,
mean: f64,
m2: f64,
cross_m: f64,
}
#[doc = stringify!($name)]
#[derive(Debug, Clone)]
pub struct $builder {
lag: Option<usize>,
}
impl $name {
#[inline]
#[must_use]
pub fn builder() -> $builder {
$builder {
lag: Option::None,
}
}
#[inline]
pub fn update(&mut self, sample: $input) {
#[allow(clippy::cast_lossless, clippy::cast_possible_truncation)]
let x = sample as f64;
self.count += 1;
let delta = x - self.mean;
self.mean += delta / self.count as f64;
let delta2 = x - self.mean;
self.m2 += delta * delta2;
if self.count > self.lag as u64 {
let x_lagged = self.buffer[self.head];
self.cross_m += (x - self.mean) * (x_lagged - self.mean);
}
self.buffer[self.head] = x;
self.head = (self.head + 1) % self.lag;
}
#[inline]
#[must_use]
pub fn correlation(&self) -> Option<f64> {
if self.count < (self.lag as u64 + 2) {
return Option::None;
}
if self.m2 == 0.0 {
return Option::None;
}
let n_pairs = (self.count - self.lag as u64) as f64;
let n_samples = (self.count - 1) as f64;
Option::Some(self.cross_m * n_samples / (self.m2 * n_pairs))
}
#[inline]
#[must_use]
pub fn covariance(&self) -> Option<f64> {
if self.count < (self.lag as u64 + 2) {
return Option::None;
}
let n_pairs = (self.count - self.lag as u64) as f64;
Option::Some(self.cross_m / n_pairs)
}
#[inline]
#[must_use]
pub fn lag(&self) -> usize {
self.lag
}
#[inline]
#[must_use]
pub fn count(&self) -> u64 {
self.count
}
#[inline]
#[must_use]
pub fn is_primed(&self) -> bool {
self.count >= self.lag as u64 + 2
}
#[inline]
pub fn reset(&mut self) {
self.buffer.fill(0.0);
self.head = 0;
self.count = 0;
self.mean = 0.0;
self.m2 = 0.0;
self.cross_m = 0.0;
}
}
impl $builder {
#[inline]
#[must_use]
pub fn lag(mut self, lag: usize) -> Self {
self.lag = Option::Some(lag);
self
}
#[inline]
pub fn build(self) -> Result<$name, nexus_stats_core::ConfigError> {
let lag = self.lag.ok_or(nexus_stats_core::ConfigError::Missing("lag"))?;
if lag < 1 {
return Err(nexus_stats_core::ConfigError::Invalid("lag must be >= 1"));
}
Ok($name {
buffer: vec![0.0; lag].into_boxed_slice(),
lag,
head: 0,
count: 0,
mean: 0.0,
m2: 0.0,
cross_m: 0.0,
})
}
}
};
}
impl_autocorrelation_float!(AutocorrelationF64, AutocorrelationF64Builder, f64);
impl_autocorrelation_float!(AutocorrelationF32, AutocorrelationF32Builder, f32);
impl_autocorrelation_int!(AutocorrelationI64, AutocorrelationI64Builder, i64);
impl_autocorrelation_int!(AutocorrelationI32, AutocorrelationI32Builder, i32);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn alternating_negative_lag1() {
let mut ac = AutocorrelationF64::builder().lag(1).build().unwrap();
for i in 0..1000u64 {
ac.update(if i % 2 == 0 { 1.0 } else { -1.0 }).unwrap();
}
let r = ac.correlation().unwrap();
assert!(r < -0.9, "alternating should be strongly negative, got {r}");
}
#[test]
fn trending_positive_lag1() {
let mut ac = AutocorrelationF64::builder().lag(1).build().unwrap();
for i in 0..1000u64 {
ac.update(i as f64).unwrap();
}
let r = ac.correlation().unwrap();
assert!(
r > 0.9,
"monotone trend should have positive lag-1, got {r}"
);
}
#[test]
fn lag10_periodic() {
let mut ac = AutocorrelationF64::builder().lag(10).build().unwrap();
for i in 0..2000u64 {
ac.update((i % 10) as f64).unwrap();
}
let r = ac.correlation().unwrap();
assert!(
r > 0.8,
"period-10 signal should correlate at lag 10, got {r}"
);
}
#[test]
fn constant_input_zero_variance() {
let mut ac = AutocorrelationF64::builder().lag(1).build().unwrap();
for _ in 0..100 {
ac.update(42.0).unwrap();
}
assert!(ac.correlation().is_none());
}
#[test]
fn not_primed_until_lag_plus_2() {
let mut ac = AutocorrelationF64::builder().lag(5).build().unwrap();
for i in 0..6 {
ac.update(i as f64).unwrap();
assert!(!ac.is_primed(), "should not be primed at count {}", i + 1);
}
ac.update(6.0).unwrap();
assert!(ac.is_primed(), "should be primed at count 7 (lag+2)");
}
#[test]
fn covariance_sign_matches_correlation() {
let mut ac = AutocorrelationF64::builder().lag(1).build().unwrap();
for i in 0..500u64 {
ac.update(i as f64).unwrap();
}
let corr = ac.correlation().unwrap();
let cov = ac.covariance().unwrap();
assert!(
corr.signum() == cov.signum(),
"corr={corr}, cov={cov} — signs should match"
);
}
#[test]
fn reset_clears_state() {
let mut ac = AutocorrelationF64::builder().lag(1).build().unwrap();
for i in 0..100 {
ac.update(i as f64).unwrap();
}
ac.reset();
assert_eq!(ac.count(), 0);
assert!(!ac.is_primed());
assert!(ac.correlation().is_none());
}
#[test]
fn lag_accessor() {
let ac = AutocorrelationF64::builder().lag(7).build().unwrap();
assert_eq!(ac.lag(), 7);
}
#[test]
fn i64_alternating() {
let mut ac = AutocorrelationI64::builder().lag(1).build().unwrap();
for i in 0..1000i64 {
ac.update(if i % 2 == 0 { 100 } else { -100 });
}
let r = ac.correlation().unwrap();
assert!(r < -0.9, "i64 alternating got {r}");
}
#[test]
fn i32_trending() {
let mut ac = AutocorrelationI32::builder().lag(1).build().unwrap();
for i in 0..500i32 {
ac.update(i);
}
let r = ac.correlation().unwrap();
assert!(r > 0.9, "i32 trending got {r}");
}
#[test]
fn f32_basic() {
let mut ac = AutocorrelationF32::builder().lag(1).build().unwrap();
for i in 0..200u32 {
ac.update(i as f32).unwrap();
}
assert!(ac.correlation().is_some());
}
#[test]
fn rejects_nan_and_inf() {
let mut ac = AutocorrelationF64::builder().lag(1).build().unwrap();
assert_eq!(
ac.update(f64::NAN),
Err(nexus_stats_core::DataError::NotANumber)
);
assert_eq!(
ac.update(f64::INFINITY),
Err(nexus_stats_core::DataError::Infinite)
);
assert_eq!(ac.count(), 0);
}
#[test]
fn builder_requires_lag() {
let result = AutocorrelationF64::builder().build();
assert!(matches!(
result,
Err(nexus_stats_core::ConfigError::Missing("lag"))
));
}
#[test]
fn builder_rejects_zero_lag() {
let result = AutocorrelationF64::builder().lag(0).build();
assert!(result.is_err());
}
}