use std::collections::HashMap;
use super::statistical::{bin_counts, chi_square_p_value, ks_p_value};
use super::types::{
CategoricalBaseline, DriftCallback, DriftResult, DriftSummary, DriftTest, Severity,
};
pub struct DriftDetector {
tests: Vec<DriftTest>,
baseline: Option<Vec<Vec<f64>>>,
baseline_categorical: Option<CategoricalBaseline>,
warning_multiplier: f64,
callbacks: Vec<DriftCallback>,
}
impl DriftDetector {
pub fn new(tests: Vec<DriftTest>) -> Self {
Self {
tests,
baseline: None,
baseline_categorical: None,
warning_multiplier: 0.8, callbacks: Vec::new(),
}
}
pub fn on_drift<F>(&mut self, callback: F)
where
F: Fn(&[DriftResult]) + Send + Sync + 'static,
{
self.callbacks.push(Box::new(callback));
}
pub fn check_and_trigger(&self, current: &[Vec<f64>]) -> Vec<DriftResult> {
let results = self.check(current);
let has_drift = results.iter().any(|r| r.drifted);
if has_drift {
for callback in &self.callbacks {
callback(&results);
}
}
results
}
pub fn check_categorical_and_trigger(&self, current: &[Vec<usize>]) -> Vec<DriftResult> {
let results = self.check_categorical(current);
let has_drift = results.iter().any(|r| r.drifted);
if has_drift {
for callback in &self.callbacks {
callback(&results);
}
}
results
}
pub fn set_baseline(&mut self, data: &[Vec<f64>]) {
if data.is_empty() {
return;
}
let n_features = data[0].len();
let mut columns = vec![Vec::new(); n_features];
for row in data {
for (i, &val) in row.iter().enumerate() {
if i < n_features {
columns[i].push(val);
}
}
}
self.baseline = Some(columns);
}
pub fn set_baseline_categorical(&mut self, data: &[Vec<usize>]) {
if data.is_empty() {
return;
}
let n_features = data[0].len();
let mut histograms = vec![HashMap::new(); n_features];
for row in data {
for (i, &val) in row.iter().enumerate() {
if i < n_features {
*histograms[i].entry(val).or_insert(0) += 1;
}
}
}
self.baseline_categorical = Some(histograms);
}
pub fn check(&self, current: &[Vec<f64>]) -> Vec<DriftResult> {
let mut results = Vec::new();
let baseline = match &self.baseline {
Some(b) => b,
None => return results,
};
if current.is_empty() {
return results;
}
let n_features = current[0].len().min(baseline.len());
let mut current_columns = vec![Vec::new(); n_features];
for row in current {
for (i, &val) in row.iter().enumerate() {
if i < n_features {
current_columns[i].push(val);
}
}
}
for (feature_idx, (baseline_col, current_col)) in
baseline.iter().zip(current_columns.iter()).enumerate()
{
for test in &self.tests {
let result = match test {
DriftTest::KS { threshold } => {
self.ks_test(feature_idx, baseline_col, current_col, *threshold)
}
DriftTest::PSI { threshold } => {
self.psi_test(feature_idx, baseline_col, current_col, *threshold)
}
DriftTest::ChiSquare { .. } => continue, };
results.push(result);
}
}
results
}
pub fn check_categorical(&self, current: &[Vec<usize>]) -> Vec<DriftResult> {
let mut results = Vec::new();
let baseline = match &self.baseline_categorical {
Some(b) => b,
None => return results,
};
if current.is_empty() {
return results;
}
let n_features = current[0].len().min(baseline.len());
let mut current_histograms = vec![HashMap::new(); n_features];
for row in current {
for (i, &val) in row.iter().enumerate() {
if i < n_features {
*current_histograms[i].entry(val).or_insert(0) += 1;
}
}
}
for (feature_idx, (baseline_hist, current_hist)) in
baseline.iter().zip(current_histograms.iter()).enumerate()
{
for test in &self.tests {
if let DriftTest::ChiSquare { threshold } = test {
let result =
self.chi_square_test(feature_idx, baseline_hist, current_hist, *threshold);
results.push(result);
}
}
}
results
}
fn ks_test(
&self,
feature_idx: usize,
baseline: &[f64],
current: &[f64],
threshold: f64,
) -> DriftResult {
let mut sorted_baseline: Vec<f64> = baseline.to_vec();
let mut sorted_current: Vec<f64> = current.to_vec();
sorted_baseline.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
sorted_current.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let n1 = sorted_baseline.len() as f64;
let n2 = sorted_current.len() as f64;
let mut d_max = 0.0f64;
let mut i = 0usize;
let mut j = 0usize;
while i < sorted_baseline.len() && j < sorted_current.len() {
let cdf1 = (i + 1) as f64 / n1;
let cdf2 = (j + 1) as f64 / n2;
let diff = (cdf1 - cdf2).abs();
d_max = d_max.max(diff);
if sorted_baseline[i] <= sorted_current[j] {
i += 1;
} else {
j += 1;
}
}
let n_eff = (n1 * n2) / (n1 + n2);
let lambda = d_max * n_eff.sqrt();
let p_value = ks_p_value(lambda);
let (drifted, severity) = self.classify_result(p_value, threshold);
DriftResult {
feature: format!("feature_{feature_idx}"),
test: DriftTest::KS { threshold },
statistic: d_max,
p_value,
drifted,
severity,
}
}
fn psi_test(
&self,
feature_idx: usize,
baseline: &[f64],
current: &[f64],
threshold: f64,
) -> DriftResult {
if baseline.is_empty() || current.is_empty() {
return DriftResult {
feature: format!("feature_{feature_idx}"),
test: DriftTest::PSI { threshold },
statistic: 0.0,
p_value: 0.0,
drifted: false,
severity: Severity::None,
};
}
let n_bins = 10;
let mut sorted_baseline: Vec<f64> = baseline.to_vec();
sorted_baseline.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let mut edges = Vec::with_capacity(n_bins + 1);
edges.push(f64::NEG_INFINITY);
for i in 1..n_bins {
let idx = (sorted_baseline.len() * i / n_bins.max(1)).min(sorted_baseline.len() - 1);
edges.push(sorted_baseline[idx]);
}
edges.push(f64::INFINITY);
let baseline_counts = bin_counts(baseline, &edges);
let current_counts = bin_counts(current, &edges);
let total_baseline = baseline.len() as f64;
let total_current = current.len() as f64;
let mut psi = 0.0;
for (b_count, c_count) in baseline_counts.iter().zip(current_counts.iter()) {
let b_pct = (*b_count as f64 + 0.0001) / (total_baseline + 0.001);
let c_pct = (*c_count as f64 + 0.0001) / (total_current + 0.001);
psi += (c_pct - b_pct) * (c_pct / b_pct).max(f64::MIN_POSITIVE).ln();
}
let (drifted, severity) = if psi >= threshold {
(true, Severity::Critical)
} else if psi >= threshold * self.warning_multiplier {
(true, Severity::Warning)
} else {
(false, Severity::None)
};
DriftResult {
feature: format!("feature_{feature_idx}"),
test: DriftTest::PSI { threshold },
statistic: psi,
p_value: psi, drifted,
severity,
}
}
fn chi_square_test(
&self,
feature_idx: usize,
baseline: &HashMap<usize, usize>,
current: &HashMap<usize, usize>,
threshold: f64,
) -> DriftResult {
let mut categories: Vec<usize> = baseline.keys().chain(current.keys()).copied().collect();
categories.sort_unstable();
categories.dedup();
let total_baseline: f64 = baseline.values().sum::<usize>() as f64;
let total_current: f64 = current.values().sum::<usize>() as f64;
if total_baseline == 0.0 || total_current == 0.0 {
return DriftResult {
feature: format!("feature_{feature_idx}"),
test: DriftTest::ChiSquare { threshold },
statistic: 0.0,
p_value: 1.0,
drifted: false,
severity: Severity::None,
};
}
let mut chi_sq = 0.0;
let mut df: usize = 0;
for &cat in &categories {
let observed = *current.get(&cat).unwrap_or(&0) as f64;
let baseline_pct = *baseline.get(&cat).unwrap_or(&0) as f64 / total_baseline;
let expected = baseline_pct * total_current;
if expected > 0.0 {
chi_sq += (observed - expected).powi(2) / expected;
df += 1;
}
}
df = df.saturating_sub(1); let p_value = chi_square_p_value(chi_sq, df);
let (drifted, severity) = self.classify_result(p_value, threshold);
DriftResult {
feature: format!("feature_{feature_idx}"),
test: DriftTest::ChiSquare { threshold },
statistic: chi_sq,
p_value,
drifted,
severity,
}
}
fn classify_result(&self, p_value: f64, threshold: f64) -> (bool, Severity) {
if p_value < threshold {
(true, Severity::Critical)
} else if p_value < threshold / self.warning_multiplier {
(true, Severity::Warning)
} else {
(false, Severity::None)
}
}
pub fn summary(results: &[DriftResult]) -> DriftSummary {
let total = results.len();
let drifted = results.iter().filter(|r| r.drifted).count();
let warnings = results.iter().filter(|r| r.severity == Severity::Warning).count();
let critical = results.iter().filter(|r| r.severity == Severity::Critical).count();
DriftSummary { total_features: total, drifted_features: drifted, warnings, critical }
}
}