use std::collections::HashMap;
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum DriftTest {
KS { threshold: f64 },
ChiSquare { threshold: f64 },
PSI { threshold: f64 },
}
impl DriftTest {
pub fn name(&self) -> &'static str {
match self {
DriftTest::KS { .. } => "Kolmogorov-Smirnov",
DriftTest::ChiSquare { .. } => "Chi-Square",
DriftTest::PSI { .. } => "PSI",
}
}
pub fn threshold(&self) -> f64 {
match self {
DriftTest::KS { threshold }
| DriftTest::ChiSquare { threshold }
| DriftTest::PSI { threshold } => *threshold,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Severity {
None,
Warning,
Critical,
}
#[derive(Clone, Debug)]
pub struct DriftResult {
pub feature: String,
pub test: DriftTest,
pub statistic: f64,
pub p_value: f64,
pub drifted: bool,
pub severity: Severity,
}
#[derive(Debug, Clone)]
pub struct DriftSummary {
pub total_features: usize,
pub drifted_features: usize,
pub warnings: usize,
pub critical: usize,
}
impl DriftSummary {
pub fn has_critical(&self) -> bool {
self.critical > 0
}
pub fn has_drift(&self) -> bool {
self.drifted_features > 0
}
pub fn drift_percentage(&self) -> f64 {
if self.total_features == 0 {
0.0
} else {
100.0 * self.drifted_features as f64 / self.total_features as f64
}
}
}
pub type DriftCallback = Box<dyn Fn(&[DriftResult]) + Send + Sync>;
pub type CategoricalBaseline = Vec<HashMap<usize, usize>>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_drift_test_ks_name() {
let test = DriftTest::KS { threshold: 0.05 };
assert_eq!(test.name(), "Kolmogorov-Smirnov");
}
#[test]
fn test_drift_test_chi_square_name() {
let test = DriftTest::ChiSquare { threshold: 0.05 };
assert_eq!(test.name(), "Chi-Square");
}
#[test]
fn test_drift_test_psi_name() {
let test = DriftTest::PSI { threshold: 0.1 };
assert_eq!(test.name(), "PSI");
}
#[test]
fn test_drift_test_ks_threshold() {
let test = DriftTest::KS { threshold: 0.05 };
assert!((test.threshold() - 0.05).abs() < 1e-9);
}
#[test]
fn test_drift_test_chi_square_threshold() {
let test = DriftTest::ChiSquare { threshold: 0.01 };
assert!((test.threshold() - 0.01).abs() < 1e-9);
}
#[test]
fn test_drift_test_psi_threshold() {
let test = DriftTest::PSI { threshold: 0.25 };
assert!((test.threshold() - 0.25).abs() < 1e-9);
}
#[test]
fn test_drift_test_clone() {
let test = DriftTest::KS { threshold: 0.05 };
let cloned = test;
assert_eq!(test, cloned);
}
#[test]
fn test_drift_test_debug() {
let test = DriftTest::KS { threshold: 0.05 };
let debug_str = format!("{test:?}");
assert!(debug_str.contains("KS"));
assert!(debug_str.contains("threshold"));
}
#[test]
fn test_severity_none() {
let sev = Severity::None;
assert_eq!(sev, Severity::None);
}
#[test]
fn test_severity_warning() {
let sev = Severity::Warning;
assert_eq!(sev, Severity::Warning);
}
#[test]
fn test_severity_critical() {
let sev = Severity::Critical;
assert_eq!(sev, Severity::Critical);
}
#[test]
fn test_severity_clone() {
let sev = Severity::Warning;
let cloned = sev;
assert_eq!(sev, cloned);
}
#[test]
fn test_severity_debug() {
assert_eq!(format!("{:?}", Severity::None), "None");
assert_eq!(format!("{:?}", Severity::Warning), "Warning");
assert_eq!(format!("{:?}", Severity::Critical), "Critical");
}
#[test]
fn test_drift_summary_has_critical() {
let summary =
DriftSummary { total_features: 10, drifted_features: 3, warnings: 2, critical: 1 };
assert!(summary.has_critical());
let no_critical =
DriftSummary { total_features: 10, drifted_features: 2, warnings: 2, critical: 0 };
assert!(!no_critical.has_critical());
}
#[test]
fn test_drift_summary_has_drift() {
let summary =
DriftSummary { total_features: 10, drifted_features: 3, warnings: 3, critical: 0 };
assert!(summary.has_drift());
let no_drift =
DriftSummary { total_features: 10, drifted_features: 0, warnings: 0, critical: 0 };
assert!(!no_drift.has_drift());
}
#[test]
fn test_drift_summary_drift_percentage() {
let summary =
DriftSummary { total_features: 10, drifted_features: 3, warnings: 2, critical: 1 };
assert!((summary.drift_percentage() - 30.0).abs() < 1e-9);
}
#[test]
fn test_drift_summary_drift_percentage_zero_features() {
let summary =
DriftSummary { total_features: 0, drifted_features: 0, warnings: 0, critical: 0 };
assert!((summary.drift_percentage() - 0.0).abs() < 1e-9);
}
#[test]
fn test_drift_result_clone() {
let result = DriftResult {
feature: "age".to_string(),
test: DriftTest::KS { threshold: 0.05 },
statistic: 0.15,
p_value: 0.02,
drifted: true,
severity: Severity::Warning,
};
let cloned = result.clone();
assert_eq!(result.feature, cloned.feature);
assert_eq!(result.drifted, cloned.drifted);
}
#[test]
fn test_drift_summary_clone() {
let summary =
DriftSummary { total_features: 10, drifted_features: 3, warnings: 2, critical: 1 };
let cloned = summary.clone();
assert_eq!(summary.total_features, cloned.total_features);
}
#[test]
fn test_drift_summary_debug() {
let summary =
DriftSummary { total_features: 10, drifted_features: 3, warnings: 2, critical: 1 };
let debug_str = format!("{summary:?}");
assert!(debug_str.contains("DriftSummary"));
assert!(debug_str.contains("total_features"));
}
}