use alloc::format;
use alloc::vec;
use alloc::vec::Vec;
#[cfg(not(feature = "std"))]
#[allow(unused_imports)]
use num_traits::Float;
use crate::error::{RcfError, RcfResult};
#[cfg(feature = "std")]
use std::sync::Arc;
pub const DEFAULT_NUM_BINS: usize = 10;
pub const DEFAULT_SMOOTHING: f64 = 1.0e-4;
pub const PSI_WATCH_THRESHOLD: f64 = 0.10;
pub const PSI_ALERT_THRESHOLD: f64 = 0.25;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum DriftLevel {
Stable,
Watch,
Alert,
}
impl DriftLevel {
#[must_use]
pub fn classify(psi: f64) -> Self {
if !psi.is_finite() || psi < PSI_WATCH_THRESHOLD {
Self::Stable
} else if psi < PSI_ALERT_THRESHOLD {
Self::Watch
} else {
Self::Alert
}
}
}
pub struct FeatureDriftDetector<const D: usize> {
num_bins: usize,
smoothing: f64,
baseline: Option<Vec<Vec<u64>>>,
production: Vec<Vec<u64>>,
bin_edges: Option<[(f64, f64); D]>,
cold_samples: Vec<[f64; D]>,
observations_total: u64,
#[cfg(feature = "std")]
metrics: Arc<dyn crate::metrics::MetricsSink>,
}
impl<const D: usize> core::fmt::Debug for FeatureDriftDetector<D> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let mut s = f.debug_struct("FeatureDriftDetector");
s.field("D", &D)
.field("num_bins", &self.num_bins)
.field("smoothing", &self.smoothing)
.field("baseline_frozen", &self.baseline.is_some())
.field("bin_edges", &self.bin_edges)
.field("production_buckets", &self.production.len())
.field("cold_samples", &self.cold_samples.len())
.field("observations_total", &self.observations_total);
#[cfg(feature = "std")]
s.field("metrics", &self.metrics);
s.finish()
}
}
impl<const D: usize> FeatureDriftDetector<D> {
pub fn new(num_bins: usize) -> RcfResult<Self> {
Self::with_smoothing(num_bins, DEFAULT_SMOOTHING)
}
pub fn with_smoothing(num_bins: usize, smoothing: f64) -> RcfResult<Self> {
if D == 0 {
return Err(RcfError::InvalidConfig(
"FeatureDriftDetector: D must be > 0".into(),
));
}
if num_bins < 2 {
return Err(RcfError::InvalidConfig(
format!("FeatureDriftDetector: num_bins must be >= 2, got {num_bins}").into(),
));
}
if !smoothing.is_finite() || smoothing <= 0.0 || smoothing > 1.0 {
return Err(RcfError::InvalidConfig(
format!("FeatureDriftDetector: smoothing must be in (0, 1], got {smoothing}")
.into(),
));
}
Ok(Self {
num_bins,
smoothing,
baseline: None,
production: vec![vec![0; num_bins]; D],
bin_edges: None,
cold_samples: Vec::new(),
observations_total: 0,
#[cfg(feature = "std")]
metrics: crate::metrics::default_sink(),
})
}
#[cfg(feature = "std")]
#[must_use]
pub fn with_metrics_sink(mut self, sink: Arc<dyn crate::metrics::MetricsSink>) -> Self {
self.metrics = sink;
self
}
#[cfg(feature = "std")]
#[must_use]
pub fn metrics_sink(&self) -> &Arc<dyn crate::metrics::MetricsSink> {
&self.metrics
}
#[must_use]
pub fn is_baseline_frozen(&self) -> bool {
self.baseline.is_some()
}
#[must_use]
pub fn observations_total(&self) -> u64 {
self.observations_total
}
#[must_use]
pub fn num_bins(&self) -> usize {
self.num_bins
}
#[must_use]
pub fn bin_edges(&self) -> Option<&[(f64, f64); D]> {
self.bin_edges.as_ref()
}
#[must_use = "detector output should be checked — dropping it silently usually indicates a logic bug"]
pub fn observe(&mut self, point: &[f64; D]) -> RcfResult<()> {
if !point.iter().all(|v| v.is_finite()) {
return Err(RcfError::NaNValue);
}
self.observations_total = self.observations_total.saturating_add(1);
#[cfg(feature = "std")]
self.metrics
.inc_counter(crate::metrics::names::FEATURE_DRIFT_OBSERVED_TOTAL, 1);
if let Some(edges) = self.bin_edges {
for (d, (min, max)) in edges.iter().enumerate() {
let bin = map_to_bin(point[d], *min, *max, self.num_bins);
self.production[d][bin] = self.production[d][bin].saturating_add(1);
}
} else {
self.cold_samples.push(*point);
}
Ok(())
}
pub fn freeze_baseline(&mut self) -> RcfResult<()> {
if self.cold_samples.is_empty() {
return Err(RcfError::EmptyForest);
}
let mut edges = [(f64::INFINITY, f64::NEG_INFINITY); D];
for p in &self.cold_samples {
for d in 0..D {
if p[d] < edges[d].0 {
edges[d].0 = p[d];
}
if p[d] > edges[d].1 {
edges[d].1 = p[d];
}
}
}
for pair in &mut edges {
#[allow(clippy::float_cmp)]
let collapsed = pair.0 == pair.1;
if collapsed {
pair.0 -= 0.5;
pair.1 += 0.5;
}
}
let mut baseline = vec![vec![0_u64; self.num_bins]; D];
for p in &self.cold_samples {
for d in 0..D {
let bin = map_to_bin(p[d], edges[d].0, edges[d].1, self.num_bins);
baseline[d][bin] = baseline[d][bin].saturating_add(1);
}
}
self.baseline = Some(baseline);
self.bin_edges = Some(edges);
self.production = vec![vec![0_u64; self.num_bins]; D];
self.cold_samples.clear();
Ok(())
}
pub fn reset_production(&mut self) {
self.production = vec![vec![0_u64; self.num_bins]; D];
}
#[must_use = "detector output should be checked — dropping it silently usually indicates a logic bug"]
pub fn psi(&self) -> RcfResult<Vec<f64>> {
let baseline = self.baseline.as_ref().ok_or(RcfError::EmptyForest)?;
let mut out = Vec::with_capacity(D);
for (base, prod) in baseline.iter().zip(self.production.iter()) {
out.push(psi_one_dim(base, prod, self.smoothing));
}
#[cfg(feature = "std")]
{
let max_psi = out
.iter()
.copied()
.fold(0.0_f64, |a, b| if b > a { b } else { a });
self.metrics
.set_gauge(crate::metrics::names::FEATURE_DRIFT_MAX_PSI, max_psi);
}
Ok(out)
}
pub fn kl_divergence(&self) -> RcfResult<Vec<f64>> {
let baseline = self.baseline.as_ref().ok_or(RcfError::EmptyForest)?;
let mut out = Vec::with_capacity(D);
for (base, prod) in baseline.iter().zip(self.production.iter()) {
out.push(kl_one_dim(base, prod, self.smoothing));
}
Ok(out)
}
pub fn max_psi(&self) -> RcfResult<f64> {
let all = self.psi()?;
Ok(all
.iter()
.copied()
.fold(0.0_f64, |a, b| if b > a { b } else { a }))
}
pub fn argmax_psi(&self) -> RcfResult<Option<usize>> {
let all = self.psi()?;
let mut best = 0_usize;
let mut best_val = 0.0_f64;
for (d, v) in all.iter().enumerate() {
if *v > best_val {
best_val = *v;
best = d;
}
}
if best_val == 0.0 {
Ok(None)
} else {
Ok(Some(best))
}
}
}
fn map_to_bin(v: f64, min: f64, max: f64, num_bins: usize) -> usize {
if !v.is_finite() || v <= min {
return 0;
}
if v >= max {
return num_bins - 1;
}
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
let idx = (((v - min) / (max - min)) * num_bins as f64) as usize;
idx.min(num_bins - 1)
}
fn psi_one_dim(baseline: &[u64], production: &[u64], smoothing: f64) -> f64 {
if baseline.len() != production.len() || baseline.is_empty() {
return 0.0;
}
#[allow(clippy::cast_precision_loss)]
let base_total: f64 = baseline.iter().copied().map(|x| x as f64).sum::<f64>();
#[allow(clippy::cast_precision_loss)]
let prod_total: f64 = production.iter().copied().map(|x| x as f64).sum::<f64>();
if base_total <= 0.0 || prod_total <= 0.0 {
return 0.0;
}
let mut acc = 0.0_f64;
for (b, p) in baseline.iter().zip(production.iter()) {
#[allow(clippy::cast_precision_loss)]
let p_ratio = (*b as f64 / base_total).max(smoothing);
#[allow(clippy::cast_precision_loss)]
let q_ratio = (*p as f64 / prod_total).max(smoothing);
acc += (q_ratio - p_ratio) * (q_ratio / p_ratio).ln();
}
acc
}
fn kl_one_dim(baseline: &[u64], production: &[u64], smoothing: f64) -> f64 {
if baseline.len() != production.len() || baseline.is_empty() {
return 0.0;
}
#[allow(clippy::cast_precision_loss)]
let base_total: f64 = baseline.iter().copied().map(|x| x as f64).sum::<f64>();
#[allow(clippy::cast_precision_loss)]
let prod_total: f64 = production.iter().copied().map(|x| x as f64).sum::<f64>();
if base_total <= 0.0 || prod_total <= 0.0 {
return 0.0;
}
let mut acc = 0.0_f64;
for (b, p) in baseline.iter().zip(production.iter()) {
#[allow(clippy::cast_precision_loss)]
let p_ratio = (*b as f64 / base_total).max(smoothing);
#[allow(clippy::cast_precision_loss)]
let q_ratio = (*p as f64 / prod_total).max(smoothing);
acc += q_ratio * (q_ratio / p_ratio).ln();
}
acc
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::panic,
clippy::float_cmp,
clippy::cast_precision_loss,
clippy::cast_lossless
)]
mod tests {
use super::*;
#[test]
fn new_rejects_bad_bins() {
assert!(FeatureDriftDetector::<4>::new(0).is_err());
assert!(FeatureDriftDetector::<4>::new(1).is_err());
}
#[test]
fn new_rejects_bad_smoothing() {
assert!(FeatureDriftDetector::<4>::with_smoothing(10, 0.0).is_err());
assert!(FeatureDriftDetector::<4>::with_smoothing(10, f64::NAN).is_err());
assert!(FeatureDriftDetector::<4>::with_smoothing(10, 2.0).is_err());
}
#[test]
fn psi_before_freeze_errors() {
let d = FeatureDriftDetector::<2>::new(10).unwrap();
assert!(d.psi().is_err());
assert!(d.kl_divergence().is_err());
}
#[test]
fn identical_distribution_has_zero_psi() {
let mut d = FeatureDriftDetector::<2>::new(10).unwrap();
for i in 0..200 {
let v = (i as f64 % 10.0) * 0.1;
d.observe(&[v, v + 0.5]).unwrap();
}
d.freeze_baseline().unwrap();
for i in 0..200 {
let v = (i as f64 % 10.0) * 0.1;
d.observe(&[v, v + 0.5]).unwrap();
}
let psi = d.psi().unwrap();
for p in &psi {
assert!(*p < 1.0e-6, "expected near-zero PSI, got {p}");
}
}
#[test]
fn shifted_distribution_raises_psi() {
let mut d = FeatureDriftDetector::<1>::new(10).unwrap();
for i in 0..1000 {
let v = (i as f64 % 10.0) * 0.1;
d.observe(&[v]).unwrap();
}
d.freeze_baseline().unwrap();
for _ in 0..1000 {
d.observe(&[0.95]).unwrap();
}
let psi = d.psi().unwrap();
assert!(
psi[0] > PSI_ALERT_THRESHOLD,
"expected alert-level PSI, got {}",
psi[0]
);
assert_eq!(DriftLevel::classify(psi[0]), DriftLevel::Alert);
}
#[test]
fn drift_level_thresholds() {
assert_eq!(DriftLevel::classify(0.0), DriftLevel::Stable);
assert_eq!(DriftLevel::classify(0.09), DriftLevel::Stable);
assert_eq!(DriftLevel::classify(0.10), DriftLevel::Watch);
assert_eq!(DriftLevel::classify(0.24), DriftLevel::Watch);
assert_eq!(DriftLevel::classify(0.25), DriftLevel::Alert);
assert_eq!(DriftLevel::classify(f64::NAN), DriftLevel::Stable);
}
#[test]
fn argmax_psi_none_on_zero() {
let mut d = FeatureDriftDetector::<3>::new(10).unwrap();
for i in 0..100 {
let v = (i as f64 % 10.0) * 0.1;
d.observe(&[v, v + 0.1, v + 0.2]).unwrap();
}
d.freeze_baseline().unwrap();
let ap = d.argmax_psi().unwrap();
assert!(ap.is_none());
}
#[test]
fn argmax_psi_picks_drifting_dim() {
let mut d = FeatureDriftDetector::<3>::new(10).unwrap();
for i in 0..500 {
let v = (i as f64 % 10.0) * 0.1;
d.observe(&[v, v, v]).unwrap();
}
d.freeze_baseline().unwrap();
for i in 0..500 {
let v = (i as f64 % 10.0) * 0.1;
d.observe(&[v, 0.95, v]).unwrap();
}
let ap = d.argmax_psi().unwrap();
assert_eq!(ap, Some(1));
}
#[test]
fn observe_rejects_nan() {
let mut d = FeatureDriftDetector::<2>::new(10).unwrap();
assert!(d.observe(&[f64::NAN, 0.0]).is_err());
assert!(d.observe(&[0.0, f64::INFINITY]).is_err());
}
#[test]
fn reset_production_leaves_baseline_intact() {
let mut d = FeatureDriftDetector::<1>::new(10).unwrap();
for i in 0..100 {
d.observe(&[(i as f64) * 0.01]).unwrap();
}
d.freeze_baseline().unwrap();
for i in 0..100 {
d.observe(&[(i as f64) * 0.01]).unwrap();
}
d.reset_production();
assert!(d.is_baseline_frozen());
let psi = d.psi().unwrap();
assert!(psi[0].is_finite());
}
#[test]
fn kl_matches_psi_components_on_simple_drift() {
let mut d = FeatureDriftDetector::<1>::new(10).unwrap();
for i in 0..500 {
d.observe(&[(i as f64 % 10.0) * 0.1]).unwrap();
}
d.freeze_baseline().unwrap();
for _ in 0..500 {
d.observe(&[0.95]).unwrap();
}
let kl = d.kl_divergence().unwrap();
assert!(kl[0] > 0.0);
assert!(kl[0].is_finite());
}
}