pub(crate) use super::*;
#[test]
fn test_drift_status_needs_retraining() {
assert!(!DriftStatus::NoDrift.needs_retraining());
assert!(!DriftStatus::Warning { score: 0.15 }.needs_retraining());
assert!(DriftStatus::Drift { score: 0.25 }.needs_retraining());
}
#[test]
fn test_drift_status_score() {
assert_eq!(DriftStatus::NoDrift.score(), None);
assert_eq!(DriftStatus::Warning { score: 0.15 }.score(), Some(0.15));
assert_eq!(DriftStatus::Drift { score: 0.25 }.score(), Some(0.25));
}
#[test]
fn test_drift_config_default() {
let config = DriftConfig::default();
assert!((config.warning_threshold - 0.1).abs() < 1e-6);
assert!((config.drift_threshold - 0.2).abs() < 1e-6);
assert_eq!(config.min_samples, 30);
}
#[test]
fn test_drift_config_builder() {
let config = DriftConfig::new(0.15, 0.3)
.with_min_samples(50)
.with_window_size(200);
assert!((config.warning_threshold - 0.15).abs() < 1e-6);
assert!((config.drift_threshold - 0.3).abs() < 1e-6);
assert_eq!(config.min_samples, 50);
assert_eq!(config.window_size, 200);
}
#[test]
fn test_detector_no_drift() {
let reference = Vector::from_slice(&(0..100).map(|i| i as f32).collect::<Vec<_>>());
let current = Vector::from_slice(&(0..100).map(|i| (i as f32) + 0.1).collect::<Vec<_>>());
let detector = DriftDetector::new(DriftConfig::default().with_min_samples(10));
let status = detector.detect_univariate(&reference, ¤t);
assert!(matches!(status, DriftStatus::NoDrift));
}
#[test]
fn test_detector_significant_drift() {
let reference = Vector::from_slice(&(0..100).map(|i| i as f32).collect::<Vec<_>>());
let current = Vector::from_slice(&(0..100).map(|i| (i + 50) as f32).collect::<Vec<_>>());
let detector = DriftDetector::new(DriftConfig::default().with_min_samples(10));
let status = detector.detect_univariate(&reference, ¤t);
assert!(matches!(status, DriftStatus::Drift { .. }));
}
#[test]
fn test_detector_insufficient_samples() {
let reference = Vector::from_slice(&[1.0, 2.0, 3.0]);
let current = Vector::from_slice(&[10.0, 20.0, 30.0]);
let detector = DriftDetector::new(DriftConfig::default().with_min_samples(30));
let status = detector.detect_univariate(&reference, ¤t);
assert!(matches!(status, DriftStatus::NoDrift));
}
#[test]
fn test_detector_multivariate() {
let reference =
Matrix::from_vec(50, 2, (0..100).map(|i| i as f32).collect()).expect("valid dimensions");
let current =
Matrix::from_vec(50, 2, (0..100).map(|i| i as f32).collect()).expect("valid dimensions");
let detector = DriftDetector::new(DriftConfig::default().with_min_samples(10));
let (overall, feature_statuses) = detector.detect_multivariate(&reference, ¤t);
assert_eq!(feature_statuses.len(), 2);
assert!(matches!(overall, DriftStatus::NoDrift));
}
#[test]
fn test_performance_drift_degradation() {
let baseline = vec![0.95, 0.94, 0.96, 0.95, 0.94];
let current = vec![0.75, 0.74, 0.73, 0.74, 0.75];
let detector = DriftDetector::new(DriftConfig::default());
let status = detector.detect_performance_drift(&baseline, ¤t);
assert!(matches!(
status,
DriftStatus::Drift { .. } | DriftStatus::Warning { .. }
));
}
#[test]
fn test_performance_drift_improvement() {
let baseline = vec![0.75, 0.74, 0.73, 0.74, 0.75];
let current = vec![0.95, 0.94, 0.96, 0.95, 0.94];
let detector = DriftDetector::new(DriftConfig::default());
let status = detector.detect_performance_drift(&baseline, ¤t);
assert!(matches!(status, DriftStatus::NoDrift));
}
#[test]
fn test_rolling_monitor() {
let config = DriftConfig::default()
.with_min_samples(5)
.with_window_size(10);
let mut monitor = RollingDriftMonitor::new(config);
monitor.set_reference(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
for i in 0..10 {
let _ = monitor.observe((i + 1) as f32);
}
let status = monitor.check_drift();
assert!(!status.needs_retraining());
}
#[test]
fn test_rolling_monitor_drift() {
let config = DriftConfig::default()
.with_min_samples(5)
.with_window_size(10);
let mut monitor = RollingDriftMonitor::new(config);
monitor.set_reference(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
for _ in 0..10 {
monitor.observe(1000.0);
}
let status = monitor.check_drift();
assert!(status.needs_retraining());
}
#[test]
fn test_rolling_monitor_reset() {
let config = DriftConfig::default().with_min_samples(5);
let mut monitor = RollingDriftMonitor::new(config);
monitor.set_reference(&[1.0, 2.0, 3.0, 4.0, 5.0]);
for _ in 0..5 {
monitor.observe(100.0);
}
monitor.reset_current();
let status = monitor.check_drift();
assert!(matches!(status, DriftStatus::NoDrift));
}
#[test]
fn test_retraining_trigger() {
let config = DriftConfig::default().with_min_samples(3);
let mut trigger = RetrainingTrigger::new(2, config).with_consecutive_required(2);
trigger.set_baseline_performance(&[0.95, 0.94, 0.96, 0.95, 0.94]);
assert!(!trigger.observe_performance(0.94));
assert!(!trigger.observe_performance(0.95));
trigger.reset();
assert!(!trigger.is_triggered());
}
#[test]
fn test_retraining_trigger_activation() {
let config = DriftConfig::new(0.01, 0.02).with_min_samples(3);
let mut trigger = RetrainingTrigger::new(1, config).with_consecutive_required(2);
trigger.set_baseline_performance(&[0.95, 0.94, 0.96]);
trigger.observe_performance(0.50);
trigger.observe_performance(0.51);
trigger.observe_performance(0.49);
assert!(trigger.observe_performance(0.48) || trigger.is_triggered());
}
#[test]
fn test_helper_mean() {
assert!((mean(&[]) - 0.0).abs() < 1e-6);
assert!((mean(&[5.0]) - 5.0).abs() < 1e-6);
assert!((mean(&[1.0, 2.0, 3.0, 4.0, 5.0]) - 3.0).abs() < 1e-6);
}
#[test]
fn test_helper_std_dev() {
assert!((std_dev(&[], 0.0) - 0.0).abs() < 1e-6);
assert!((std_dev(&[5.0], 5.0) - 0.0).abs() < 1e-6);
let data = [1.0, 2.0, 3.0, 4.0, 5.0];
let std = std_dev(&data, mean(&data));
assert!((std - 1.5811).abs() < 0.001);
}
#[test]
fn test_performance_drift_empty_baseline() {
let detector = DriftDetector::new(DriftConfig::default());
let status = detector.detect_performance_drift(&[], &[0.9, 0.8]);
assert!(matches!(status, DriftStatus::NoDrift));
}
#[test]
fn test_performance_drift_empty_current() {
let detector = DriftDetector::new(DriftConfig::default());
let status = detector.detect_performance_drift(&[0.9, 0.8], &[]);
assert!(matches!(status, DriftStatus::NoDrift));
}
#[test]
fn test_performance_drift_zero_std_relative_drop() {
let baseline = vec![0.5, 0.5, 0.5, 0.5, 0.5];
let current = vec![0.3, 0.3, 0.3, 0.3, 0.3];
let detector = DriftDetector::new(DriftConfig::new(0.1, 0.2));
let status = detector.detect_performance_drift(&baseline, ¤t);
assert!(
matches!(status, DriftStatus::Drift { .. }),
"Expected Drift, got {:?}",
status
);
}
#[test]
fn test_performance_drift_zero_std_no_drop() {
let baseline = vec![0.9, 0.9, 0.9, 0.9, 0.9];
let current = vec![0.9, 0.9, 0.9, 0.9, 0.9];
let detector = DriftDetector::new(DriftConfig::new(0.1, 0.2));
let status = detector.detect_performance_drift(&baseline, ¤t);
assert!(matches!(status, DriftStatus::NoDrift));
}
#[test]
fn test_performance_drift_zero_std_improvement() {
let baseline = vec![0.5, 0.5, 0.5, 0.5, 0.5];
let current = vec![0.8, 0.8, 0.8, 0.8, 0.8];
let detector = DriftDetector::new(DriftConfig::new(0.1, 0.2));
let status = detector.detect_performance_drift(&baseline, ¤t);
assert!(matches!(status, DriftStatus::NoDrift));
}
#[test]
fn test_univariate_zero_std_reference() {
let reference = Vector::from_slice(&vec![5.0; 50]);
let current = Vector::from_slice(&vec![100.0; 50]);
let detector = DriftDetector::new(DriftConfig::default().with_min_samples(10));
let status = detector.detect_univariate(&reference, ¤t);
assert!(matches!(status, DriftStatus::NoDrift));
}
#[test]
fn test_multivariate_drift_on_one_feature() {
let mut ref_data = Vec::with_capacity(100);
let mut cur_data = Vec::with_capacity(100);
for i in 0..50 {
ref_data.push(i as f32); ref_data.push(i as f32); cur_data.push(i as f32); cur_data.push((i + 200) as f32); }
let reference = Matrix::from_vec(50, 2, ref_data).expect("valid dimensions");
let current = Matrix::from_vec(50, 2, cur_data).expect("valid dimensions");
let detector = DriftDetector::new(DriftConfig::default().with_min_samples(10));
let (overall, feature_statuses) = detector.detect_multivariate(&reference, ¤t);
assert_eq!(feature_statuses.len(), 2);
assert!(matches!(feature_statuses[0], DriftStatus::NoDrift));
assert!(matches!(feature_statuses[1], DriftStatus::Drift { .. }));
assert!(matches!(overall, DriftStatus::Drift { .. }));
}
#[test]
fn test_rolling_monitor_update_reference() {
let config = DriftConfig::default()
.with_min_samples(3)
.with_window_size(10);
let mut monitor = RollingDriftMonitor::new(config);
monitor.set_reference(&[1.0, 2.0, 3.0, 4.0, 5.0]);
for _ in 0..5 {
monitor.observe(10.0);
}
monitor.update_reference();
let status = monitor.check_drift();
assert!(matches!(status, DriftStatus::NoDrift));
}
#[test]
fn test_rolling_monitor_window_overflow() {
let config = DriftConfig::default()
.with_min_samples(3)
.with_window_size(5);
let mut monitor = RollingDriftMonitor::new(config);
monitor.set_reference(&[1.0, 2.0, 3.0, 4.0, 5.0]);
for i in 0..10 {
monitor.observe(i as f32);
}
let status = monitor.check_drift();
assert!(matches!(
status,
DriftStatus::NoDrift | DriftStatus::Warning { .. } | DriftStatus::Drift { .. }
));
}
#[test]
fn test_rolling_monitor_set_reference_overflow() {
let config = DriftConfig::default().with_window_size(3);
let mut monitor = RollingDriftMonitor::new(config);
monitor.set_reference(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]);
let status = monitor.check_drift();
assert!(matches!(status, DriftStatus::NoDrift));
}
#[test]
fn test_retraining_trigger_set_baseline_features() {
let config = DriftConfig::default().with_min_samples(3);
let mut trigger = RetrainingTrigger::new(2, config);
let features = Matrix::from_vec(
5,
2,
vec![1.0, 10.0, 2.0, 20.0, 3.0, 30.0, 4.0, 40.0, 5.0, 50.0],
)
.expect("valid dimensions");
trigger.set_baseline_features(&features);
}
#[test]
fn test_retraining_trigger_set_baseline_features_more_monitors() {
let config = DriftConfig::default().with_min_samples(3);
let mut trigger = RetrainingTrigger::new(5, config);
let features =
Matrix::from_vec(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("valid dimensions");
trigger.set_baseline_features(&features);
}
#[test]
fn test_retraining_trigger_consecutive_reset() {
let config = DriftConfig::new(0.01, 0.02).with_min_samples(3);
let mut trigger = RetrainingTrigger::new(1, config).with_consecutive_required(3);
trigger.set_baseline_performance(&[0.95, 0.94, 0.96]);
trigger.observe_performance(0.1);
trigger.observe_performance(0.1);
assert!(!trigger.is_triggered());
trigger.observe_performance(0.95);
assert!(!trigger.is_triggered());
}
#[test]
fn test_classify_drift_warning() {
let detector = DriftDetector::new(DriftConfig::new(0.1, 0.5));
let status = detector.classify_drift(0.3);
assert!(matches!(status, DriftStatus::Warning { score } if (score - 0.3).abs() < 1e-6));
}
#[test]
fn test_classify_drift_exact_threshold() {
let detector = DriftDetector::new(DriftConfig::new(0.1, 0.5));
let status = detector.classify_drift(0.5);
assert!(matches!(status, DriftStatus::Drift { .. }));
let status = detector.classify_drift(0.1);
assert!(matches!(status, DriftStatus::Warning { .. }));
}
#[test]
fn test_drift_status_clone() {
let status = DriftStatus::Warning { score: 0.15 };
let cloned = status.clone();
assert_eq!(status, cloned);
}
#[test]
fn test_drift_config_clone() {
let config = DriftConfig::new(0.1, 0.3);
let cloned = config.clone();
assert!((cloned.warning_threshold - 0.1).abs() < 1e-6);
assert!((cloned.drift_threshold - 0.3).abs() < 1e-6);
}