use nexus_stats_core::math::MulAdd;
macro_rules! impl_robust_z {
($name:ident, $builder:ident, $ty:ty) => {
#[derive(Debug, Clone)]
pub struct $name {
alpha: $ty,
one_minus_alpha: $ty,
ema: $ty,
ema_abs_dev: $ty,
last_z: $ty,
reject_threshold: $ty,
count: u64,
min_samples: u64,
initialized: bool,
}
#[doc = stringify!($name)]
#[derive(Debug, Clone)]
pub struct $builder {
alpha: Option<$ty>,
reject_threshold: Option<$ty>,
min_samples: u64,
}
impl $name {
const MAD_CONSTANT: $ty = 0.6745 as $ty;
#[inline]
#[must_use]
pub fn builder() -> $builder {
$builder {
alpha: Option::None,
reject_threshold: Option::None,
min_samples: 10,
}
}
#[inline]
pub fn update(
&mut self,
sample: $ty,
) -> Result<Option<$ty>, nexus_stats_core::DataError> {
check_finite!(sample);
self.count += 1;
if !self.initialized {
self.ema = sample;
self.ema_abs_dev = 0.0 as $ty;
self.initialized = true;
self.last_z = 0.0 as $ty;
return Ok(if self.count >= self.min_samples {
Option::Some(0.0 as $ty)
} else {
Option::None
});
}
let abs_dev = (sample - self.ema).abs();
self.last_z = if self.ema_abs_dev > 0.0 as $ty {
Self::MAD_CONSTANT * (sample - self.ema) / self.ema_abs_dev
} else {
0.0 as $ty
};
if self.last_z.abs() <= self.reject_threshold {
self.ema = self.alpha.fma(sample, self.one_minus_alpha * self.ema);
self.ema_abs_dev = self
.alpha
.fma(abs_dev, self.one_minus_alpha * self.ema_abs_dev);
}
Ok(if self.count >= self.min_samples {
Option::Some(self.last_z)
} else {
Option::None
})
}
#[inline]
#[must_use]
pub fn z_score(&self) -> Option<$ty> {
if self.count >= self.min_samples {
Option::Some(self.last_z)
} else {
Option::None
}
}
#[inline]
#[must_use]
pub fn baseline(&self) -> Option<$ty> {
if self.count >= self.min_samples {
Option::Some(self.ema)
} else {
Option::None
}
}
#[inline]
#[must_use]
pub fn mad(&self) -> Option<$ty> {
if self.count >= self.min_samples {
Option::Some(self.ema_abs_dev)
} else {
Option::None
}
}
#[inline]
#[must_use]
pub fn count(&self) -> u64 {
self.count
}
#[inline]
#[must_use]
pub fn is_primed(&self) -> bool {
self.count >= self.min_samples
}
#[inline]
pub fn reset(&mut self) {
self.ema = 0.0 as $ty;
self.ema_abs_dev = 0.0 as $ty;
self.last_z = 0.0 as $ty;
self.count = 0;
self.initialized = false;
}
}
impl $builder {
#[inline]
#[must_use]
pub fn alpha(mut self, alpha: $ty) -> Self {
self.alpha = Option::Some(alpha);
self
}
#[inline]
#[must_use]
pub fn span(mut self, n: u64) -> Self {
self.alpha = Option::Some(2.0 as $ty / (n as $ty + 1.0 as $ty));
self
}
#[inline]
#[must_use]
pub fn reject_threshold(mut self, z: $ty) -> Self {
self.reject_threshold = Option::Some(z);
self
}
#[inline]
#[must_use]
pub fn min_samples(mut self, min: u64) -> Self {
self.min_samples = min;
self
}
#[inline]
pub fn build(self) -> Result<$name, nexus_stats_core::ConfigError> {
let alpha = self
.alpha
.ok_or(nexus_stats_core::ConfigError::Missing("alpha"))?;
let reject = self
.reject_threshold
.ok_or(nexus_stats_core::ConfigError::Missing("reject_threshold"))?;
if !(alpha > 0.0 as $ty && alpha < 1.0 as $ty) {
return Err(nexus_stats_core::ConfigError::Invalid(
"alpha must be in (0, 1)",
));
}
if reject <= 0.0 as $ty {
return Err(nexus_stats_core::ConfigError::Invalid(
"reject_threshold must be positive",
));
}
Ok($name {
alpha,
one_minus_alpha: 1.0 as $ty - alpha,
ema: 0.0 as $ty,
ema_abs_dev: 0.0 as $ty,
last_z: 0.0 as $ty,
reject_threshold: reject,
count: 0,
min_samples: self.min_samples,
initialized: false,
})
}
}
};
}
impl_robust_z!(RobustZScoreF64, RobustZScoreF64Builder, f64);
impl_robust_z!(RobustZScoreF32, RobustZScoreF32Builder, f32);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn stable_signal_low_z() {
let mut rz = RobustZScoreF64::builder()
.alpha(0.1)
.reject_threshold(10.0)
.min_samples(5)
.build()
.unwrap();
for _ in 0..20 {
let _ = rz.update(100.0);
}
let z = rz.z_score().unwrap();
assert!(
z.abs() < 0.1,
"stable signal should have ~zero z-score, got {z}"
);
}
#[test]
fn outlier_high_z() {
let mut rz = RobustZScoreF64::builder()
.alpha(0.1)
.reject_threshold(10.0)
.min_samples(5)
.build()
.unwrap();
for i in 0..20 {
let _ = rz.update(100.0 + (i % 3) as f64);
}
let z = rz.update(200.0).unwrap().unwrap();
assert!(z.abs() > 3.0, "outlier should have high z-score, got {z}");
}
#[test]
fn estimator_freeze_on_reject() {
let mut rz = RobustZScoreF64::builder()
.alpha(0.1)
.reject_threshold(3.0)
.min_samples(5)
.build()
.unwrap();
for i in 0..20 {
let _ = rz.update(100.0 + (i % 2) as f64);
}
let baseline_before = rz.baseline().unwrap();
let _ = rz.update(500.0);
let baseline_after = rz.baseline().unwrap();
assert!(
(baseline_before - baseline_after).abs() < 1e-10,
"baseline should not move on rejected sample"
);
}
#[test]
fn recovery_after_freeze() {
let mut rz = RobustZScoreF64::builder()
.alpha(0.1)
.reject_threshold(5.0)
.min_samples(5)
.build()
.unwrap();
for _ in 0..20 {
let _ = rz.update(100.0);
}
let _ = rz.update(500.0);
for _ in 0..10 {
let _ = rz.update(100.0);
}
let z = rz.z_score().unwrap();
assert!(z.abs() < 1.0, "should recover after freeze, got {z}");
}
#[test]
fn priming() {
let mut rz = RobustZScoreF64::builder()
.alpha(0.1)
.reject_threshold(5.0)
.min_samples(10)
.build()
.unwrap();
for _ in 0..9 {
assert!(rz.update(100.0).unwrap().is_none());
}
assert!(rz.update(100.0).unwrap().is_some());
}
#[test]
fn reset() {
let mut rz = RobustZScoreF64::builder()
.alpha(0.1)
.reject_threshold(5.0)
.min_samples(5)
.build()
.unwrap();
for _ in 0..20 {
let _ = rz.update(100.0);
}
rz.reset();
assert_eq!(rz.count(), 0);
}
#[test]
fn f32_basic() {
let mut rz = RobustZScoreF32::builder()
.alpha(0.1)
.reject_threshold(5.0)
.min_samples(5)
.build()
.unwrap();
for _ in 0..10 {
let _ = rz.update(100.0);
}
assert!(rz.is_primed());
}
#[test]
fn errors_without_reject_threshold() {
let result = RobustZScoreF64::builder().alpha(0.1).build();
assert!(matches!(
result,
Err(nexus_stats_core::ConfigError::Missing("reject_threshold"))
));
}
#[test]
fn rejects_nan_and_inf() {
let mut rz = RobustZScoreF64::builder()
.alpha(0.1)
.reject_threshold(5.0)
.min_samples(5)
.build()
.unwrap();
assert_eq!(
rz.update(f64::NAN).unwrap_err(),
nexus_stats_core::DataError::NotANumber
);
assert_eq!(
rz.update(f64::INFINITY).unwrap_err(),
nexus_stats_core::DataError::Infinite
);
assert_eq!(
rz.update(f64::NEG_INFINITY).unwrap_err(),
nexus_stats_core::DataError::Infinite
);
assert_eq!(rz.count(), 0);
}
}