use super::DriftSignal;
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct PageHinkleyTest {
delta: f64,
lambda: f64,
lambda_warning: f64,
running_mean: f64,
sum_up: f64,
min_sum_up: f64,
sum_down: f64,
min_sum_down: f64,
count: u64,
}
impl PageHinkleyTest {
pub fn new() -> Self {
Self::with_params(0.005, 50.0)
}
pub fn with_params(delta: f64, lambda: f64) -> Self {
Self {
delta,
lambda,
lambda_warning: lambda * 0.5,
running_mean: 0.0,
sum_up: 0.0,
min_sum_up: f64::MAX,
sum_down: 0.0,
min_sum_down: f64::MAX,
count: 0,
}
}
#[inline]
pub fn delta(&self) -> f64 {
self.delta
}
#[inline]
pub fn lambda(&self) -> f64 {
self.lambda
}
#[inline]
pub fn count(&self) -> u64 {
self.count
}
#[inline]
fn ph_up(&self) -> f64 {
self.sum_up - self.min_sum_up
}
#[inline]
fn ph_down(&self) -> f64 {
self.sum_down - self.min_sum_down
}
fn reset_sums(&mut self) {
self.sum_up = 0.0;
self.min_sum_up = f64::MAX;
self.sum_down = 0.0;
self.min_sum_down = f64::MAX;
}
pub fn update(&mut self, value: f64) -> DriftSignal {
self.count += 1;
self.running_mean += (value - self.running_mean) / self.count as f64;
self.sum_up += value - self.running_mean - self.delta;
self.min_sum_up = crate::math::fmin(self.min_sum_up, self.sum_up);
self.sum_down += self.running_mean - self.delta - value;
self.min_sum_down = crate::math::fmin(self.min_sum_down, self.sum_down);
let ph_up = self.ph_up();
let ph_down = self.ph_down();
if ph_up > self.lambda || ph_down > self.lambda {
self.reset_sums();
return DriftSignal::Drift;
}
if ph_up > self.lambda_warning || ph_down > self.lambda_warning {
return DriftSignal::Warning;
}
DriftSignal::Stable
}
pub fn reset(&mut self) {
self.running_mean = 0.0;
self.sum_up = 0.0;
self.min_sum_up = f64::MAX;
self.sum_down = 0.0;
self.min_sum_down = f64::MAX;
self.count = 0;
}
pub fn estimated_mean(&self) -> f64 {
if self.count == 0 {
0.0
} else {
self.running_mean
}
}
}
impl Default for PageHinkleyTest {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "alloc")]
impl super::DriftDetector for PageHinkleyTest {
fn update(&mut self, value: f64) -> DriftSignal {
PageHinkleyTest::update(self, value)
}
fn reset(&mut self) {
PageHinkleyTest::reset(self);
}
fn clone_fresh(&self) -> alloc::boxed::Box<dyn super::DriftDetector> {
alloc::boxed::Box::new(Self::with_params(self.delta, self.lambda))
}
fn clone_boxed(&self) -> alloc::boxed::Box<dyn super::DriftDetector> {
alloc::boxed::Box::new(self.clone())
}
fn estimated_mean(&self) -> f64 {
PageHinkleyTest::estimated_mean(self)
}
fn serialize_state(&self) -> Option<super::DriftDetectorState> {
Some(super::DriftDetectorState::PageHinkley {
running_mean: self.running_mean,
sum_up: self.sum_up,
min_sum_up: self.min_sum_up,
sum_down: self.sum_down,
min_sum_down: self.min_sum_down,
count: self.count,
})
}
fn restore_state(&mut self, state: &super::DriftDetectorState) -> bool {
#[allow(irrefutable_let_patterns)]
if let super::DriftDetectorState::PageHinkley {
running_mean,
sum_up,
min_sum_up,
sum_down,
min_sum_down,
count,
} = state
{
self.running_mean = *running_mean;
self.sum_up = *sum_up;
self.min_sum_up = *min_sum_up;
self.sum_down = *sum_down;
self.min_sum_down = *min_sum_down;
self.count = *count;
true
} else {
false
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn noisy_stream(center: f64, amplitude: f64, count: usize, seed: u64) -> alloc::vec::Vec<f64> {
let mut values = alloc::vec::Vec::with_capacity(count);
let mut state = seed;
for _ in 0..count {
state = state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1);
let frac = ((state >> 33) as f64) / (u32::MAX as f64 / 2.0) - 1.0;
values.push(center + amplitude * frac);
}
values
}
#[test]
fn stationary_stream_no_drift() {
let mut pht = PageHinkleyTest::new();
let values = noisy_stream(5.0, 0.1, 5000, 42);
for v in &values {
let signal = pht.update(*v);
assert_ne!(
signal,
DriftSignal::Drift,
"stationary stream should never trigger Drift"
);
}
}
#[test]
fn abrupt_upward_shift_detected() {
let mut pht = PageHinkleyTest::new();
let stable = noisy_stream(0.0, 0.01, 1000, 100);
for v in &stable {
pht.update(*v);
}
let shifted = noisy_stream(10.0, 0.01, 1000, 200);
let mut drift_detected = false;
for v in &shifted {
if pht.update(*v) == DriftSignal::Drift {
drift_detected = true;
break;
}
}
assert!(
drift_detected,
"upward shift from 0 to 10 must trigger Drift"
);
}
#[test]
fn abrupt_downward_shift_detected() {
let mut pht = PageHinkleyTest::new();
let stable = noisy_stream(10.0, 0.01, 1000, 300);
for v in &stable {
pht.update(*v);
}
let shifted = noisy_stream(0.0, 0.01, 1000, 400);
let mut drift_detected = false;
for v in &shifted {
if pht.update(*v) == DriftSignal::Drift {
drift_detected = true;
break;
}
}
assert!(
drift_detected,
"downward shift from 10 to 0 must trigger Drift"
);
}
#[test]
fn estimated_mean_tracks_true_mean() {
let mut pht = PageHinkleyTest::new();
let center = 7.5;
let values = noisy_stream(center, 0.01, 2000, 500);
for v in &values {
pht.update(*v);
}
let estimated = pht.estimated_mean();
let error = crate::math::abs(estimated - center);
assert!(
error < 0.1,
"estimated mean {estimated} should be close to true mean {center}, error={error}"
);
}
#[test]
fn estimated_mean_returns_zero_when_empty() {
let pht = PageHinkleyTest::new();
assert_eq!(
pht.estimated_mean(),
0.0,
"empty detector should report mean=0.0"
);
}
#[test]
fn reset_clears_all_state() {
let mut pht = PageHinkleyTest::new();
let values = noisy_stream(5.0, 0.1, 500, 600);
for v in &values {
pht.update(*v);
}
assert!(pht.count() > 0);
assert!(pht.estimated_mean() != 0.0);
pht.reset();
assert_eq!(pht.count(), 0);
assert_eq!(pht.estimated_mean(), 0.0);
assert_eq!(pht.sum_up, 0.0);
assert_eq!(pht.sum_down, 0.0);
assert_eq!(pht.min_sum_up, f64::MAX);
assert_eq!(pht.min_sum_down, f64::MAX);
assert_eq!(pht.running_mean, 0.0);
}
#[cfg(feature = "alloc")]
#[test]
fn clone_fresh_returns_clean_detector_with_same_params() {
use super::super::DriftDetector;
let mut pht = PageHinkleyTest::with_params(0.01, 100.0);
for v in noisy_stream(3.0, 0.5, 200, 700) {
pht.update(v);
}
let fresh = DriftDetector::clone_fresh(&pht);
assert_eq!(fresh.estimated_mean(), 0.0);
let mut fresh = DriftDetector::clone_fresh(&pht);
let signal = fresh.update(1.0);
assert_eq!(
signal,
DriftSignal::Stable,
"single sample on fresh detector should be Stable"
);
}
#[test]
fn custom_params_stored_correctly() {
let pht = PageHinkleyTest::with_params(0.1, 200.0);
assert_eq!(pht.delta(), 0.1);
assert_eq!(pht.lambda(), 200.0);
assert_eq!(pht.lambda_warning, 100.0); }
#[test]
fn default_params_match_new() {
let from_new = PageHinkleyTest::new();
let from_default = PageHinkleyTest::default();
assert_eq!(from_new.delta(), from_default.delta());
assert_eq!(from_new.lambda(), from_default.lambda());
assert_eq!(from_new.count(), from_default.count());
}
#[test]
fn warmup_no_drift_on_extreme_early_values() {
let mut pht = PageHinkleyTest::new();
for i in 0..20 {
let signal = pht.update(1_000_000.0);
assert_ne!(
signal,
DriftSignal::Drift,
"constant extreme stream should not trigger Drift at sample {i}"
);
}
let mut pht2 = PageHinkleyTest::new();
let early_values = [100.0, 102.0, 98.0, 105.0, 95.0, 101.0, 99.0, 103.0];
for (i, &v) in early_values.iter().enumerate() {
let signal = pht2.update(v);
assert_ne!(
signal,
DriftSignal::Drift,
"early noisy samples should not trigger Drift at sample {i}"
);
}
}
#[test]
fn warning_fires_before_drift() {
let mut pht = PageHinkleyTest::with_params(0.005, 50.0);
for v in noisy_stream(0.0, 0.001, 500, 800) {
pht.update(v);
}
let mut saw_warning = false;
let mut saw_drift = false;
for v in noisy_stream(5.0, 0.001, 1000, 900) {
match pht.update(v) {
DriftSignal::Warning => {
if !saw_drift {
saw_warning = true;
}
}
DriftSignal::Drift => {
saw_drift = true;
break;
}
DriftSignal::Stable => {}
}
}
assert!(saw_warning, "Warning should fire before Drift");
assert!(saw_drift, "Drift should eventually fire");
}
#[test]
fn drift_resets_sums_not_mean() {
let mut pht = PageHinkleyTest::new();
for v in noisy_stream(0.0, 0.01, 500, 1000) {
pht.update(v);
}
let count_before = pht.count();
let shifted = noisy_stream(20.0, 0.01, 500, 1100);
let mut drifted = false;
for v in &shifted {
if pht.update(*v) == DriftSignal::Drift {
drifted = true;
break;
}
}
assert!(drifted, "should detect drift");
assert!(pht.count() > count_before);
assert_eq!(pht.sum_up, 0.0);
assert_eq!(pht.sum_down, 0.0);
assert_eq!(pht.min_sum_up, f64::MAX);
assert_eq!(pht.min_sum_down, f64::MAX);
}
#[test]
fn higher_lambda_needs_bigger_shift() {
let mut pht = PageHinkleyTest::with_params(0.005, 5000.0);
for v in noisy_stream(0.0, 0.01, 500, 1200) {
pht.update(v);
}
let mut drifted = false;
for v in noisy_stream(2.0, 0.01, 500, 1300) {
if pht.update(v) == DriftSignal::Drift {
drifted = true;
break;
}
}
assert!(
!drifted,
"high lambda (5000) should not trigger on small shift within 500 samples"
);
}
#[cfg(feature = "alloc")]
#[test]
fn serialize_restore_roundtrip() {
use super::super::DriftDetector;
let mut pht = PageHinkleyTest::new();
for v in noisy_stream(5.0, 0.1, 200, 1400) {
pht.update(v);
}
let state = DriftDetector::serialize_state(&pht);
assert!(state.is_some(), "PHT should support state serialization");
let mut pht2 = PageHinkleyTest::new();
let restored = DriftDetector::restore_state(&mut pht2, state.as_ref().unwrap());
assert!(restored, "restore should succeed for PageHinkley state");
assert_eq!(pht.count(), pht2.count());
assert_eq!(pht.estimated_mean(), pht2.estimated_mean());
}
}