1use alloc::format;
44use alloc::vec;
45use alloc::vec::Vec;
46
47#[cfg(not(feature = "std"))]
48#[allow(unused_imports)]
49use num_traits::Float;
50
51use crate::error::{RcfError, RcfResult};
52
53#[cfg(feature = "std")]
54use std::sync::Arc;
55
56pub const DEFAULT_NUM_BINS: usize = 10;
59
60pub const DEFAULT_SMOOTHING: f64 = 1.0e-4;
64
65pub const PSI_WATCH_THRESHOLD: f64 = 0.10;
67pub const PSI_ALERT_THRESHOLD: f64 = 0.25;
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd)]
72#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
73#[non_exhaustive]
74pub enum DriftLevel {
75 Stable,
77 Watch,
79 Alert,
81}
82
83impl DriftLevel {
84 #[must_use]
86 pub fn classify(psi: f64) -> Self {
87 if !psi.is_finite() || psi < PSI_WATCH_THRESHOLD {
88 Self::Stable
89 } else if psi < PSI_ALERT_THRESHOLD {
90 Self::Watch
91 } else {
92 Self::Alert
93 }
94 }
95}
96
97pub struct FeatureDriftDetector<const D: usize> {
103 num_bins: usize,
105 smoothing: f64,
108 baseline: Option<Vec<Vec<u64>>>,
111 production: Vec<Vec<u64>>,
114 bin_edges: Option<[(f64, f64); D]>,
118 cold_samples: Vec<[f64; D]>,
122 observations_total: u64,
126 #[cfg(feature = "std")]
128 metrics: Arc<dyn crate::metrics::MetricsSink>,
129}
130
131impl<const D: usize> core::fmt::Debug for FeatureDriftDetector<D> {
132 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
133 let mut s = f.debug_struct("FeatureDriftDetector");
134 s.field("D", &D)
135 .field("num_bins", &self.num_bins)
136 .field("smoothing", &self.smoothing)
137 .field("baseline_frozen", &self.baseline.is_some())
138 .field("bin_edges", &self.bin_edges)
139 .field("production_buckets", &self.production.len())
140 .field("cold_samples", &self.cold_samples.len())
141 .field("observations_total", &self.observations_total);
142 #[cfg(feature = "std")]
143 s.field("metrics", &self.metrics);
144 s.finish()
145 }
146}
147
148impl<const D: usize> FeatureDriftDetector<D> {
149 pub fn new(num_bins: usize) -> RcfResult<Self> {
157 Self::with_smoothing(num_bins, DEFAULT_SMOOTHING)
158 }
159
160 pub fn with_smoothing(num_bins: usize, smoothing: f64) -> RcfResult<Self> {
168 if D == 0 {
169 return Err(RcfError::InvalidConfig(
170 "FeatureDriftDetector: D must be > 0".into(),
171 ));
172 }
173 if num_bins < 2 {
174 return Err(RcfError::InvalidConfig(
175 format!("FeatureDriftDetector: num_bins must be >= 2, got {num_bins}").into(),
176 ));
177 }
178 if !smoothing.is_finite() || smoothing <= 0.0 || smoothing > 1.0 {
179 return Err(RcfError::InvalidConfig(
180 format!("FeatureDriftDetector: smoothing must be in (0, 1], got {smoothing}")
181 .into(),
182 ));
183 }
184 Ok(Self {
185 num_bins,
186 smoothing,
187 baseline: None,
188 production: vec![vec![0; num_bins]; D],
189 bin_edges: None,
190 cold_samples: Vec::new(),
191 observations_total: 0,
192 #[cfg(feature = "std")]
193 metrics: crate::metrics::default_sink(),
194 })
195 }
196
197 #[cfg(feature = "std")]
200 #[must_use]
201 pub fn with_metrics_sink(mut self, sink: Arc<dyn crate::metrics::MetricsSink>) -> Self {
202 self.metrics = sink;
203 self
204 }
205
206 #[cfg(feature = "std")]
208 #[must_use]
209 pub fn metrics_sink(&self) -> &Arc<dyn crate::metrics::MetricsSink> {
210 &self.metrics
211 }
212
213 #[must_use]
215 pub fn is_baseline_frozen(&self) -> bool {
216 self.baseline.is_some()
217 }
218
219 #[must_use]
221 pub fn observations_total(&self) -> u64 {
222 self.observations_total
223 }
224
225 #[must_use]
227 pub fn num_bins(&self) -> usize {
228 self.num_bins
229 }
230
231 #[must_use]
234 pub fn bin_edges(&self) -> Option<&[(f64, f64); D]> {
235 self.bin_edges.as_ref()
236 }
237
238 #[must_use = "detector output should be checked — dropping it silently usually indicates a logic bug"]
246 pub fn observe(&mut self, point: &[f64; D]) -> RcfResult<()> {
247 if !point.iter().all(|v| v.is_finite()) {
248 return Err(RcfError::NaNValue);
249 }
250 self.observations_total = self.observations_total.saturating_add(1);
251 #[cfg(feature = "std")]
252 self.metrics
253 .inc_counter(crate::metrics::names::FEATURE_DRIFT_OBSERVED_TOTAL, 1);
254
255 if let Some(edges) = self.bin_edges {
259 for (d, (min, max)) in edges.iter().enumerate() {
260 let bin = map_to_bin(point[d], *min, *max, self.num_bins);
261 self.production[d][bin] = self.production[d][bin].saturating_add(1);
262 }
263 } else {
264 self.cold_samples.push(*point);
270 }
271 Ok(())
272 }
273
274 pub fn freeze_baseline(&mut self) -> RcfResult<()> {
285 if self.cold_samples.is_empty() {
286 return Err(RcfError::EmptyForest);
287 }
288 let mut edges = [(f64::INFINITY, f64::NEG_INFINITY); D];
290 for p in &self.cold_samples {
291 for d in 0..D {
292 if p[d] < edges[d].0 {
293 edges[d].0 = p[d];
294 }
295 if p[d] > edges[d].1 {
296 edges[d].1 = p[d];
297 }
298 }
299 }
300 for pair in &mut edges {
303 #[allow(clippy::float_cmp)]
304 let collapsed = pair.0 == pair.1;
305 if collapsed {
306 pair.0 -= 0.5;
307 pair.1 += 0.5;
308 }
309 }
310
311 let mut baseline = vec![vec![0_u64; self.num_bins]; D];
313 for p in &self.cold_samples {
314 for d in 0..D {
315 let bin = map_to_bin(p[d], edges[d].0, edges[d].1, self.num_bins);
316 baseline[d][bin] = baseline[d][bin].saturating_add(1);
317 }
318 }
319
320 self.baseline = Some(baseline);
321 self.bin_edges = Some(edges);
322 self.production = vec![vec![0_u64; self.num_bins]; D];
324 self.cold_samples.clear();
325 Ok(())
326 }
327
328 pub fn reset_production(&mut self) {
331 self.production = vec![vec![0_u64; self.num_bins]; D];
332 }
333
334 #[must_use = "detector output should be checked — dropping it silently usually indicates a logic bug"]
343 pub fn psi(&self) -> RcfResult<Vec<f64>> {
344 let baseline = self.baseline.as_ref().ok_or(RcfError::EmptyForest)?;
345 let mut out = Vec::with_capacity(D);
346 for (base, prod) in baseline.iter().zip(self.production.iter()) {
347 out.push(psi_one_dim(base, prod, self.smoothing));
348 }
349 #[cfg(feature = "std")]
350 {
351 let max_psi = out
352 .iter()
353 .copied()
354 .fold(0.0_f64, |a, b| if b > a { b } else { a });
355 self.metrics
356 .set_gauge(crate::metrics::names::FEATURE_DRIFT_MAX_PSI, max_psi);
357 }
358 Ok(out)
359 }
360
361 pub fn kl_divergence(&self) -> RcfResult<Vec<f64>> {
369 let baseline = self.baseline.as_ref().ok_or(RcfError::EmptyForest)?;
370 let mut out = Vec::with_capacity(D);
371 for (base, prod) in baseline.iter().zip(self.production.iter()) {
372 out.push(kl_one_dim(base, prod, self.smoothing));
373 }
374 Ok(out)
375 }
376
377 pub fn max_psi(&self) -> RcfResult<f64> {
384 let all = self.psi()?;
385 Ok(all
386 .iter()
387 .copied()
388 .fold(0.0_f64, |a, b| if b > a { b } else { a }))
389 }
390
391 pub fn argmax_psi(&self) -> RcfResult<Option<usize>> {
399 let all = self.psi()?;
400 let mut best = 0_usize;
401 let mut best_val = 0.0_f64;
402 for (d, v) in all.iter().enumerate() {
403 if *v > best_val {
404 best_val = *v;
405 best = d;
406 }
407 }
408 if best_val == 0.0 {
409 Ok(None)
410 } else {
411 Ok(Some(best))
412 }
413 }
414}
415
416fn map_to_bin(v: f64, min: f64, max: f64, num_bins: usize) -> usize {
420 if !v.is_finite() || v <= min {
421 return 0;
422 }
423 if v >= max {
424 return num_bins - 1;
425 }
426 #[allow(
427 clippy::cast_precision_loss,
428 clippy::cast_possible_truncation,
429 clippy::cast_sign_loss
430 )]
431 let idx = (((v - min) / (max - min)) * num_bins as f64) as usize;
432 idx.min(num_bins - 1)
433}
434
435fn psi_one_dim(baseline: &[u64], production: &[u64], smoothing: f64) -> f64 {
438 if baseline.len() != production.len() || baseline.is_empty() {
439 return 0.0;
440 }
441 #[allow(clippy::cast_precision_loss)]
442 let base_total: f64 = baseline.iter().copied().map(|x| x as f64).sum::<f64>();
443 #[allow(clippy::cast_precision_loss)]
444 let prod_total: f64 = production.iter().copied().map(|x| x as f64).sum::<f64>();
445 if base_total <= 0.0 || prod_total <= 0.0 {
446 return 0.0;
447 }
448 let mut acc = 0.0_f64;
449 for (b, p) in baseline.iter().zip(production.iter()) {
450 #[allow(clippy::cast_precision_loss)]
451 let p_ratio = (*b as f64 / base_total).max(smoothing);
452 #[allow(clippy::cast_precision_loss)]
453 let q_ratio = (*p as f64 / prod_total).max(smoothing);
454 acc += (q_ratio - p_ratio) * (q_ratio / p_ratio).ln();
455 }
456 acc
457}
458
459fn kl_one_dim(baseline: &[u64], production: &[u64], smoothing: f64) -> f64 {
461 if baseline.len() != production.len() || baseline.is_empty() {
462 return 0.0;
463 }
464 #[allow(clippy::cast_precision_loss)]
465 let base_total: f64 = baseline.iter().copied().map(|x| x as f64).sum::<f64>();
466 #[allow(clippy::cast_precision_loss)]
467 let prod_total: f64 = production.iter().copied().map(|x| x as f64).sum::<f64>();
468 if base_total <= 0.0 || prod_total <= 0.0 {
469 return 0.0;
470 }
471 let mut acc = 0.0_f64;
472 for (b, p) in baseline.iter().zip(production.iter()) {
473 #[allow(clippy::cast_precision_loss)]
474 let p_ratio = (*b as f64 / base_total).max(smoothing);
475 #[allow(clippy::cast_precision_loss)]
476 let q_ratio = (*p as f64 / prod_total).max(smoothing);
477 acc += q_ratio * (q_ratio / p_ratio).ln();
478 }
479 acc
480}
481
482#[cfg(test)]
483#[allow(
484 clippy::unwrap_used,
485 clippy::panic,
486 clippy::float_cmp,
487 clippy::cast_precision_loss,
488 clippy::cast_lossless
489)]
490mod tests {
491 use super::*;
492
493 #[test]
494 fn new_rejects_bad_bins() {
495 assert!(FeatureDriftDetector::<4>::new(0).is_err());
496 assert!(FeatureDriftDetector::<4>::new(1).is_err());
497 }
498
499 #[test]
500 fn new_rejects_bad_smoothing() {
501 assert!(FeatureDriftDetector::<4>::with_smoothing(10, 0.0).is_err());
502 assert!(FeatureDriftDetector::<4>::with_smoothing(10, f64::NAN).is_err());
503 assert!(FeatureDriftDetector::<4>::with_smoothing(10, 2.0).is_err());
504 }
505
506 #[test]
507 fn psi_before_freeze_errors() {
508 let d = FeatureDriftDetector::<2>::new(10).unwrap();
509 assert!(d.psi().is_err());
510 assert!(d.kl_divergence().is_err());
511 }
512
513 #[test]
514 fn identical_distribution_has_zero_psi() {
515 let mut d = FeatureDriftDetector::<2>::new(10).unwrap();
516 for i in 0..200 {
517 let v = (i as f64 % 10.0) * 0.1;
518 d.observe(&[v, v + 0.5]).unwrap();
519 }
520 d.freeze_baseline().unwrap();
521 for i in 0..200 {
523 let v = (i as f64 % 10.0) * 0.1;
524 d.observe(&[v, v + 0.5]).unwrap();
525 }
526 let psi = d.psi().unwrap();
527 for p in &psi {
528 assert!(*p < 1.0e-6, "expected near-zero PSI, got {p}");
529 }
530 }
531
532 #[test]
533 fn shifted_distribution_raises_psi() {
534 let mut d = FeatureDriftDetector::<1>::new(10).unwrap();
535 for i in 0..1000 {
537 let v = (i as f64 % 10.0) * 0.1;
538 d.observe(&[v]).unwrap();
539 }
540 d.freeze_baseline().unwrap();
541 for _ in 0..1000 {
543 d.observe(&[0.95]).unwrap();
544 }
545 let psi = d.psi().unwrap();
546 assert!(
547 psi[0] > PSI_ALERT_THRESHOLD,
548 "expected alert-level PSI, got {}",
549 psi[0]
550 );
551 assert_eq!(DriftLevel::classify(psi[0]), DriftLevel::Alert);
552 }
553
554 #[test]
555 fn drift_level_thresholds() {
556 assert_eq!(DriftLevel::classify(0.0), DriftLevel::Stable);
557 assert_eq!(DriftLevel::classify(0.09), DriftLevel::Stable);
558 assert_eq!(DriftLevel::classify(0.10), DriftLevel::Watch);
559 assert_eq!(DriftLevel::classify(0.24), DriftLevel::Watch);
560 assert_eq!(DriftLevel::classify(0.25), DriftLevel::Alert);
561 assert_eq!(DriftLevel::classify(f64::NAN), DriftLevel::Stable);
562 }
563
564 #[test]
565 fn argmax_psi_none_on_zero() {
566 let mut d = FeatureDriftDetector::<3>::new(10).unwrap();
567 for i in 0..100 {
568 let v = (i as f64 % 10.0) * 0.1;
569 d.observe(&[v, v + 0.1, v + 0.2]).unwrap();
570 }
571 d.freeze_baseline().unwrap();
572 let ap = d.argmax_psi().unwrap();
574 assert!(ap.is_none());
575 }
576
577 #[test]
578 fn argmax_psi_picks_drifting_dim() {
579 let mut d = FeatureDriftDetector::<3>::new(10).unwrap();
580 for i in 0..500 {
581 let v = (i as f64 % 10.0) * 0.1;
582 d.observe(&[v, v, v]).unwrap();
583 }
584 d.freeze_baseline().unwrap();
585 for i in 0..500 {
587 let v = (i as f64 % 10.0) * 0.1;
588 d.observe(&[v, 0.95, v]).unwrap();
589 }
590 let ap = d.argmax_psi().unwrap();
591 assert_eq!(ap, Some(1));
592 }
593
594 #[test]
595 fn observe_rejects_nan() {
596 let mut d = FeatureDriftDetector::<2>::new(10).unwrap();
597 assert!(d.observe(&[f64::NAN, 0.0]).is_err());
598 assert!(d.observe(&[0.0, f64::INFINITY]).is_err());
599 }
600
601 #[test]
602 fn reset_production_leaves_baseline_intact() {
603 let mut d = FeatureDriftDetector::<1>::new(10).unwrap();
604 for i in 0..100 {
605 d.observe(&[(i as f64) * 0.01]).unwrap();
606 }
607 d.freeze_baseline().unwrap();
608 for i in 0..100 {
609 d.observe(&[(i as f64) * 0.01]).unwrap();
610 }
611 d.reset_production();
612 assert!(d.is_baseline_frozen());
613 let psi = d.psi().unwrap();
616 assert!(psi[0].is_finite());
617 }
618
619 #[test]
620 fn kl_matches_psi_components_on_simple_drift() {
621 let mut d = FeatureDriftDetector::<1>::new(10).unwrap();
622 for i in 0..500 {
623 d.observe(&[(i as f64 % 10.0) * 0.1]).unwrap();
624 }
625 d.freeze_baseline().unwrap();
626 for _ in 0..500 {
627 d.observe(&[0.95]).unwrap();
628 }
629 let kl = d.kl_divergence().unwrap();
630 assert!(kl[0] > 0.0);
631 assert!(kl[0].is_finite());
632 }
633}