use alloc::format;
use crate::error::{RcfError, RcfResult};
use crate::thresholded::EmaStats;
pub const DEFAULT_ALLOWANCE_K: f64 = 0.5;
pub const DEFAULT_THRESHOLD_H: f64 = 5.0;
pub const DEFAULT_MIN_OBSERVATIONS: u64 = 32;
pub const DEFAULT_DECAY: f64 = 0.01;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum DriftKind {
Upward,
Downward,
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct CusumConfig {
pub allowance_k: f64,
pub threshold_h: f64,
pub min_observations: u64,
pub decay: f64,
}
impl Default for CusumConfig {
fn default() -> Self {
Self {
allowance_k: DEFAULT_ALLOWANCE_K,
threshold_h: DEFAULT_THRESHOLD_H,
min_observations: DEFAULT_MIN_OBSERVATIONS,
decay: DEFAULT_DECAY,
}
}
}
impl CusumConfig {
pub fn validate(&self) -> RcfResult<()> {
if !self.allowance_k.is_finite() || self.allowance_k < 0.0 {
return Err(RcfError::InvalidConfig(
format!(
"allowance_k must be finite and >= 0, got {}",
self.allowance_k
)
.into(),
));
}
if !self.threshold_h.is_finite() || self.threshold_h <= 0.0 {
return Err(RcfError::InvalidConfig(
format!(
"threshold_h must be finite and > 0, got {}",
self.threshold_h
)
.into(),
));
}
if !self.decay.is_finite() || self.decay <= 0.0 || self.decay > 1.0 {
return Err(RcfError::InvalidConfig(
format!("decay must be in (0.0, 1.0], got {}", self.decay).into(),
));
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct DriftVerdict {
pub s_high: f64,
pub s_low: f64,
pub threshold: f64,
pub mean: f64,
pub stddev: f64,
pub ready: bool,
pub drift: Option<DriftKind>,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct MetaDriftDetector {
config: CusumConfig,
stats: EmaStats,
s_high: f64,
s_low: f64,
#[cfg(feature = "std")]
#[cfg_attr(
feature = "serde",
serde(skip, default = "crate::metrics::default_sink")
)]
metrics: std::sync::Arc<dyn crate::metrics::MetricsSink>,
}
impl MetaDriftDetector {
pub fn new(config: CusumConfig) -> RcfResult<Self> {
config.validate()?;
let stats = EmaStats::new(config.decay)?;
Ok(Self {
config,
stats,
s_high: 0.0,
s_low: 0.0,
#[cfg(feature = "std")]
metrics: crate::metrics::default_sink(),
})
}
#[cfg(feature = "std")]
#[must_use]
pub fn with_metrics_sink(
mut self,
sink: std::sync::Arc<dyn crate::metrics::MetricsSink>,
) -> Self {
self.metrics = sink;
self
}
#[cfg(feature = "std")]
#[must_use]
pub fn metrics_sink(&self) -> &std::sync::Arc<dyn crate::metrics::MetricsSink> {
&self.metrics
}
pub fn with_defaults() -> RcfResult<Self> {
Self::new(CusumConfig::default())
}
#[must_use]
pub fn config(&self) -> &CusumConfig {
&self.config
}
#[must_use]
pub fn stats(&self) -> &EmaStats {
&self.stats
}
#[must_use]
pub fn s_high(&self) -> f64 {
self.s_high
}
#[must_use]
pub fn s_low(&self) -> f64 {
self.s_low
}
#[must_use = "detector output should be checked — dropping it silently usually indicates a logic bug"]
pub fn observe(&mut self, score: f64) -> DriftVerdict {
if !score.is_finite() {
return DriftVerdict {
s_high: self.s_high,
s_low: self.s_low,
threshold: 0.0,
mean: self.stats.mean(),
stddev: self.stats.stddev(),
ready: false,
drift: None,
};
}
let prev_mean = self.stats.mean();
let prev_stddev = self.stats.stddev();
let prev_observations = self.stats.observations();
self.stats.update(score);
let ready = prev_observations >= self.config.min_observations && prev_stddev > 0.0;
if !ready {
return DriftVerdict {
s_high: self.s_high,
s_low: self.s_low,
threshold: 0.0,
mean: prev_mean,
stddev: prev_stddev,
ready: false,
drift: None,
};
}
let k = self.config.allowance_k * prev_stddev;
let h = self.config.threshold_h * prev_stddev;
let dev = score - prev_mean;
self.s_high = (self.s_high + dev - k).max(0.0);
self.s_low = (self.s_low - dev - k).max(0.0);
let drift = if self.s_high > h {
Some(DriftKind::Upward)
} else if self.s_low > h {
Some(DriftKind::Downward)
} else {
None
};
#[cfg(feature = "std")]
{
use crate::metrics::names;
self.metrics
.observe_histogram(names::DRIFT_S_HIGH, self.s_high);
self.metrics
.observe_histogram(names::DRIFT_S_LOW, self.s_low);
match drift {
Some(DriftKind::Upward) => {
self.metrics.inc_counter(names::DRIFT_FIRES_TOTAL, 1);
self.metrics.inc_counter(names::DRIFT_UP_TOTAL, 1);
}
Some(DriftKind::Downward) => {
self.metrics.inc_counter(names::DRIFT_FIRES_TOTAL, 1);
self.metrics.inc_counter(names::DRIFT_DOWN_TOTAL, 1);
}
None => {}
}
}
DriftVerdict {
s_high: self.s_high,
s_low: self.s_low,
threshold: h,
mean: prev_mean,
stddev: prev_stddev,
ready: true,
drift,
}
}
pub fn reset(&mut self) {
self.s_high = 0.0;
self.s_low = 0.0;
}
pub fn reset_stats(&mut self) {
self.s_high = 0.0;
self.s_low = 0.0;
self.stats.reset();
}
}
#[cfg(test)]
#[allow(clippy::float_cmp)] mod tests {
use super::*;
fn detector(h: f64) -> MetaDriftDetector {
MetaDriftDetector::new(CusumConfig {
allowance_k: 0.5,
threshold_h: h,
min_observations: 8,
decay: 0.1,
})
.unwrap()
}
#[test]
fn default_config_validates() {
CusumConfig::default().validate().unwrap();
}
fn cfg(k: f64, h: f64, min_obs: u64, decay: f64) -> CusumConfig {
CusumConfig {
allowance_k: k,
threshold_h: h,
min_observations: min_obs,
decay,
}
}
#[test]
fn validate_rejects_negative_allowance_k() {
assert!(
cfg(-0.1, DEFAULT_THRESHOLD_H, 8, DEFAULT_DECAY)
.validate()
.is_err()
);
}
#[test]
fn validate_rejects_zero_threshold_h() {
assert!(
cfg(DEFAULT_ALLOWANCE_K, 0.0, 8, DEFAULT_DECAY)
.validate()
.is_err()
);
}
#[test]
fn validate_rejects_decay_outside_range() {
assert!(
cfg(DEFAULT_ALLOWANCE_K, DEFAULT_THRESHOLD_H, 8, 0.0)
.validate()
.is_err()
);
assert!(
cfg(DEFAULT_ALLOWANCE_K, DEFAULT_THRESHOLD_H, 8, 1.5)
.validate()
.is_err()
);
assert!(
cfg(DEFAULT_ALLOWANCE_K, DEFAULT_THRESHOLD_H, 8, f64::NAN)
.validate()
.is_err()
);
}
#[test]
fn warmup_never_fires() {
let mut d = detector(5.0);
for _ in 0..8 {
let v = d.observe(1.0);
assert!(!v.ready);
assert!(v.drift.is_none());
}
}
#[test]
fn constant_stream_does_not_fire() {
let mut d = detector(5.0);
for _ in 0..200 {
let v = d.observe(1.0);
assert!(v.drift.is_none());
}
assert_eq!(d.s_high(), 0.0);
assert_eq!(d.s_low(), 0.0);
}
#[test]
fn upward_shift_fires_upward() {
let mut d = detector(3.0);
for i in 0..64 {
let noise = if i % 2 == 0 { 0.95 } else { 1.05 };
let _ = d.observe(noise);
}
let mut saw_upward = false;
for _ in 0..100 {
let v = d.observe(5.0);
if matches!(v.drift, Some(DriftKind::Upward)) {
saw_upward = true;
break;
}
}
assert!(saw_upward, "CUSUM should fire upward on sustained shift");
}
#[test]
fn downward_shift_fires_downward() {
let mut d = detector(3.0);
for i in 0..64 {
let noise = if i % 2 == 0 { 4.95 } else { 5.05 };
let _ = d.observe(noise);
}
let mut saw_downward = false;
for _ in 0..100 {
let v = d.observe(1.0);
if matches!(v.drift, Some(DriftKind::Downward)) {
saw_downward = true;
break;
}
}
assert!(
saw_downward,
"CUSUM should fire downward on sustained shift"
);
}
#[test]
fn non_finite_input_ignored() {
let mut d = detector(3.0);
for _ in 0..16 {
let _ = d.observe(1.0);
}
let obs_before = d.stats().observations();
let v_nan = d.observe(f64::NAN);
let v_inf = d.observe(f64::INFINITY);
assert!(v_nan.drift.is_none());
assert!(v_inf.drift.is_none());
assert_eq!(d.stats().observations(), obs_before);
}
#[test]
fn reset_clears_accumulators_but_keeps_stats() {
let mut d = detector(3.0);
for i in 0..64 {
let noise = if i % 2 == 0 { 0.95 } else { 1.05 };
let _ = d.observe(noise);
}
for _ in 0..50 {
let _ = d.observe(5.0);
}
assert!(d.s_high() > 0.0);
let stats_obs = d.stats().observations();
d.reset();
assert_eq!(d.s_high(), 0.0);
assert_eq!(d.s_low(), 0.0);
assert_eq!(
d.stats().observations(),
stats_obs,
"reset() must keep the EMA reference"
);
}
#[test]
fn reset_stats_clears_everything() {
let mut d = detector(3.0);
for _ in 0..64 {
let _ = d.observe(1.0);
}
d.reset_stats();
assert_eq!(d.s_high(), 0.0);
assert_eq!(d.s_low(), 0.0);
assert_eq!(d.stats().observations(), 0);
}
#[test]
fn verdict_exposes_reference_mean_and_stddev() {
let mut d = detector(5.0);
for _ in 0..32 {
let _ = d.observe(2.0);
}
let v = d.observe(2.5);
assert!((v.mean - 2.0).abs() < 0.5);
assert!(v.stddev >= 0.0);
}
#[test]
fn with_defaults_builds() {
let d = MetaDriftDetector::with_defaults().unwrap();
assert_eq!(d.config().allowance_k, DEFAULT_ALLOWANCE_K);
assert_eq!(d.config().threshold_h, DEFAULT_THRESHOLD_H);
}
}