1use crate::error::{EvalError, EvalResult};
33use serde::{Deserialize, Serialize};
34use std::collections::HashMap;
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct DriftDetectionEntry {
43 pub period: u32,
45 pub value: f64,
47 pub ground_truth_drift: Option<bool>,
49 pub drift_type: Option<String>,
51 pub drift_magnitude: Option<f64>,
53 pub detection_difficulty: Option<f64>,
55}
56
57impl DriftDetectionEntry {
58 pub fn new(period: u32, value: f64, ground_truth_drift: Option<bool>) -> Self {
60 Self {
61 period,
62 value,
63 ground_truth_drift,
64 drift_type: None,
65 drift_magnitude: None,
66 detection_difficulty: None,
67 }
68 }
69
70 pub fn with_metadata(
72 period: u32,
73 value: f64,
74 ground_truth_drift: bool,
75 drift_type: impl Into<String>,
76 drift_magnitude: f64,
77 detection_difficulty: f64,
78 ) -> Self {
79 Self {
80 period,
81 value,
82 ground_truth_drift: Some(ground_truth_drift),
83 drift_type: Some(drift_type.into()),
84 drift_magnitude: Some(drift_magnitude),
85 detection_difficulty: Some(detection_difficulty),
86 }
87 }
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct LabeledDriftEvent {
97 pub event_id: String,
99 pub event_type: DriftEventCategory,
101 pub start_period: u32,
103 pub end_period: Option<u32>,
105 pub affected_fields: Vec<String>,
107 pub magnitude: f64,
109 pub detection_difficulty: DetectionDifficulty,
111}
112
113#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
115pub enum DriftEventCategory {
116 MeanShift,
118 VarianceChange,
120 TrendChange,
122 SeasonalityChange,
124 ProportionShift,
126 NewCategory,
128 OrganizationalEvent,
130 RegulatoryChange,
132 TechnologyTransition,
134 EconomicCycle,
136 ProcessEvolution,
138}
139
140impl DriftEventCategory {
141 pub fn name(&self) -> &'static str {
143 match self {
144 Self::MeanShift => "Mean Shift",
145 Self::VarianceChange => "Variance Change",
146 Self::TrendChange => "Trend Change",
147 Self::SeasonalityChange => "Seasonality Change",
148 Self::ProportionShift => "Proportion Shift",
149 Self::NewCategory => "New Category",
150 Self::OrganizationalEvent => "Organizational Event",
151 Self::RegulatoryChange => "Regulatory Change",
152 Self::TechnologyTransition => "Technology Transition",
153 Self::EconomicCycle => "Economic Cycle",
154 Self::ProcessEvolution => "Process Evolution",
155 }
156 }
157
158 pub fn is_statistical(&self) -> bool {
160 matches!(
161 self,
162 Self::MeanShift | Self::VarianceChange | Self::TrendChange | Self::SeasonalityChange
163 )
164 }
165
166 pub fn is_business_event(&self) -> bool {
168 matches!(
169 self,
170 Self::OrganizationalEvent
171 | Self::RegulatoryChange
172 | Self::TechnologyTransition
173 | Self::ProcessEvolution
174 )
175 }
176}
177
178#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
180pub enum DetectionDifficulty {
181 Easy,
183 Medium,
185 Hard,
187}
188
189impl DetectionDifficulty {
190 pub fn to_score(&self) -> f64 {
192 match self {
193 Self::Easy => 0.0,
194 Self::Medium => 0.5,
195 Self::Hard => 1.0,
196 }
197 }
198
199 pub fn from_score(score: f64) -> Self {
201 if score < 0.33 {
202 Self::Easy
203 } else if score < 0.67 {
204 Self::Medium
205 } else {
206 Self::Hard
207 }
208 }
209}
210
211#[derive(Debug, Clone)]
217pub struct DriftDetectionAnalyzer {
218 significance_level: f64,
220 window_size: usize,
222 min_magnitude_threshold: f64,
224 use_hellinger: bool,
226 use_psi: bool,
228}
229
230impl DriftDetectionAnalyzer {
231 pub fn new(significance_level: f64) -> Self {
233 Self {
234 significance_level,
235 window_size: 10,
236 min_magnitude_threshold: 0.05,
237 use_hellinger: true,
238 use_psi: true,
239 }
240 }
241
242 pub fn with_window_size(mut self, size: usize) -> Self {
244 self.window_size = size;
245 self
246 }
247
248 pub fn with_min_magnitude(mut self, threshold: f64) -> Self {
250 self.min_magnitude_threshold = threshold;
251 self
252 }
253
254 pub fn with_hellinger(mut self, enabled: bool) -> Self {
256 self.use_hellinger = enabled;
257 self
258 }
259
260 pub fn with_psi(mut self, enabled: bool) -> Self {
262 self.use_psi = enabled;
263 self
264 }
265
266 pub fn analyze(&self, entries: &[DriftDetectionEntry]) -> EvalResult<DriftDetectionAnalysis> {
268 if entries.len() < self.window_size * 2 {
269 return Err(EvalError::InsufficientData {
270 required: self.window_size * 2,
271 actual: entries.len(),
272 });
273 }
274
275 let values: Vec<f64> = entries.iter().map(|e| e.value).collect();
277 let ground_truth: Vec<Option<bool>> =
278 entries.iter().map(|e| e.ground_truth_drift).collect();
279
280 let rolling_means = self.calculate_rolling_means(&values);
282 let rolling_stds = self.calculate_rolling_stds(&values);
283
284 let detected_drift = self.detect_drift_points(&rolling_means, &rolling_stds);
286
287 let metrics = self.calculate_detection_metrics(&detected_drift, &ground_truth);
289
290 let hellinger_distance = if self.use_hellinger {
292 Some(self.calculate_hellinger_distance(&values))
293 } else {
294 None
295 };
296
297 let psi = if self.use_psi {
298 Some(self.calculate_psi(&values))
299 } else {
300 None
301 };
302
303 let drift_detected = detected_drift.iter().any(|&d| d);
305 let drift_count = detected_drift.iter().filter(|&&d| d).count();
306
307 let drift_magnitude = self.calculate_drift_magnitude(&rolling_means);
309
310 let passes = self.evaluate_pass_status(&metrics, drift_magnitude);
311 let issues = self.collect_issues(&metrics, drift_magnitude, drift_count);
312
313 Ok(DriftDetectionAnalysis {
314 sample_size: entries.len(),
315 drift_detected,
316 drift_count,
317 drift_magnitude,
318 detection_metrics: metrics,
319 hellinger_distance,
320 psi,
321 rolling_mean_change: self.calculate_mean_change(&rolling_means),
322 rolling_std_change: self.calculate_std_change(&rolling_stds),
323 passes,
324 issues,
325 })
326 }
327
328 pub fn analyze_labeled_events(
330 &self,
331 events: &[LabeledDriftEvent],
332 ) -> EvalResult<LabeledEventAnalysis> {
333 if events.is_empty() {
334 return Ok(LabeledEventAnalysis::empty());
335 }
336
337 let mut category_counts: HashMap<DriftEventCategory, usize> = HashMap::new();
339 for event in events {
340 *category_counts.entry(event.event_type).or_insert(0) += 1;
341 }
342
343 let mut difficulty_counts: HashMap<DetectionDifficulty, usize> = HashMap::new();
345 for event in events {
346 *difficulty_counts
347 .entry(event.detection_difficulty)
348 .or_insert(0) += 1;
349 }
350
351 let total_events = events.len();
353 let statistical_events = events
354 .iter()
355 .filter(|e| e.event_type.is_statistical())
356 .count();
357 let business_events = events
358 .iter()
359 .filter(|e| e.event_type.is_business_event())
360 .count();
361
362 let avg_magnitude = events.iter().map(|e| e.magnitude).sum::<f64>() / total_events as f64;
364 let avg_difficulty = events
365 .iter()
366 .map(|e| e.detection_difficulty.to_score())
367 .sum::<f64>()
368 / total_events as f64;
369
370 let min_period = events.iter().map(|e| e.start_period).min().unwrap_or(0);
372 let max_period = events
373 .iter()
374 .filter_map(|e| e.end_period)
375 .max()
376 .unwrap_or(min_period);
377
378 let passes = total_events > 0 && avg_magnitude >= self.min_magnitude_threshold;
379 let issues = if !passes {
380 vec!["Insufficient drift events or magnitude too low".to_string()]
381 } else {
382 Vec::new()
383 };
384
385 Ok(LabeledEventAnalysis {
386 total_events,
387 statistical_events,
388 business_events,
389 category_distribution: category_counts,
390 difficulty_distribution: difficulty_counts,
391 avg_magnitude,
392 avg_difficulty,
393 period_coverage: (min_period, max_period),
394 passes,
395 issues,
396 })
397 }
398
399 fn calculate_rolling_means(&self, values: &[f64]) -> Vec<f64> {
402 let mut means = Vec::with_capacity(values.len() - self.window_size + 1);
403 for i in 0..=(values.len() - self.window_size) {
404 let window = &values[i..i + self.window_size];
405 let mean = window.iter().sum::<f64>() / self.window_size as f64;
406 means.push(mean);
407 }
408 means
409 }
410
411 fn calculate_rolling_stds(&self, values: &[f64]) -> Vec<f64> {
412 let mut stds = Vec::with_capacity(values.len() - self.window_size + 1);
413 for i in 0..=(values.len() - self.window_size) {
414 let window = &values[i..i + self.window_size];
415 let mean = window.iter().sum::<f64>() / self.window_size as f64;
416 let variance =
417 window.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / self.window_size as f64;
418 stds.push(variance.sqrt());
419 }
420 stds
421 }
422
423 fn detect_drift_points(&self, means: &[f64], stds: &[f64]) -> Vec<bool> {
424 if means.len() < 2 {
425 return vec![false; means.len()];
426 }
427
428 let mut detected = vec![false; means.len()];
429
430 let baseline_end = means.len() / 2;
432 let baseline_mean = means[..baseline_end].iter().sum::<f64>() / baseline_end as f64;
433 let baseline_std = if baseline_end > 1 {
434 let variance = means[..baseline_end]
435 .iter()
436 .map(|x| (x - baseline_mean).powi(2))
437 .sum::<f64>()
438 / baseline_end as f64;
439 variance.sqrt().max(0.001) } else {
441 0.001
442 };
443
444 for i in baseline_end..means.len() {
446 let z_score = (means[i] - baseline_mean).abs() / baseline_std;
447 let threshold = 1.96 / self.significance_level.sqrt(); if z_score > threshold {
450 detected[i] = true;
451 }
452
453 if i < stds.len() && baseline_end > 0 {
455 let baseline_var_mean =
456 stds[..baseline_end].iter().sum::<f64>() / baseline_end as f64;
457 if baseline_var_mean > 0.001 {
458 let var_ratio = stds[i] / baseline_var_mean;
459 if !(0.5..=2.0).contains(&var_ratio) {
460 detected[i] = true;
461 }
462 }
463 }
464 }
465
466 detected
467 }
468
469 fn calculate_detection_metrics(
470 &self,
471 detected: &[bool],
472 ground_truth: &[Option<bool>],
473 ) -> DriftDetectionMetrics {
474 let mut true_positives = 0;
475 let mut false_positives = 0;
476 let mut true_negatives = 0;
477 let mut false_negatives = 0;
478 let mut detection_delays = Vec::new();
479
480 let offset = detected.len().saturating_sub(ground_truth.len());
482
483 for (i, >) in ground_truth.iter().enumerate() {
484 let detected_idx = i + offset;
485 if detected_idx >= detected.len() {
486 break;
487 }
488
489 let pred = detected[detected_idx];
490 match gt {
491 Some(true) => {
492 if pred {
493 true_positives += 1;
494 } else {
495 false_negatives += 1;
496 }
497 }
498 Some(false) => {
499 if pred {
500 false_positives += 1;
501 } else {
502 true_negatives += 1;
503 }
504 }
505 None => {}
506 }
507 }
508
509 let mut last_drift_start: Option<usize> = None;
511 for (i, >) in ground_truth.iter().enumerate() {
512 if gt == Some(true) && last_drift_start.is_none() {
513 last_drift_start = Some(i);
514 } else if gt == Some(false) {
515 last_drift_start = None;
516 }
517
518 let detected_idx = i + offset;
519 if detected_idx < detected.len() && detected[detected_idx] {
520 if let Some(start) = last_drift_start {
521 detection_delays.push((i - start) as f64);
522 last_drift_start = None;
523 }
524 }
525 }
526
527 let precision = if true_positives + false_positives > 0 {
528 true_positives as f64 / (true_positives + false_positives) as f64
529 } else {
530 0.0
531 };
532
533 let recall = if true_positives + false_negatives > 0 {
534 true_positives as f64 / (true_positives + false_negatives) as f64
535 } else {
536 0.0
537 };
538
539 let f1_score = if precision + recall > 0.0 {
540 2.0 * precision * recall / (precision + recall)
541 } else {
542 0.0
543 };
544
545 let mean_detection_delay = if detection_delays.is_empty() {
546 None
547 } else {
548 Some(detection_delays.iter().sum::<f64>() / detection_delays.len() as f64)
549 };
550
551 DriftDetectionMetrics {
552 true_positives,
553 false_positives,
554 true_negatives,
555 false_negatives,
556 precision,
557 recall,
558 f1_score,
559 mean_detection_delay,
560 }
561 }
562
563 fn calculate_hellinger_distance(&self, values: &[f64]) -> f64 {
564 if values.len() < 20 {
565 return 0.0;
566 }
567
568 let mid = values.len() / 2;
569 let first_half = &values[..mid];
570 let second_half = &values[mid..];
571
572 let (min_val, max_val) = values.iter().fold((f64::MAX, f64::MIN), |(min, max), &v| {
574 (min.min(v), max.max(v))
575 });
576
577 if (max_val - min_val).abs() < f64::EPSILON {
578 return 0.0;
579 }
580
581 let num_bins = 10;
582 let bin_width = (max_val - min_val) / num_bins as f64;
583
584 let mut hist1 = vec![0.0; num_bins];
585 let mut hist2 = vec![0.0; num_bins];
586
587 for &v in first_half {
588 let bin = ((v - min_val) / bin_width).floor() as usize;
589 let bin = bin.min(num_bins - 1);
590 hist1[bin] += 1.0;
591 }
592
593 for &v in second_half {
594 let bin = ((v - min_val) / bin_width).floor() as usize;
595 let bin = bin.min(num_bins - 1);
596 hist2[bin] += 1.0;
597 }
598
599 let sum1: f64 = hist1.iter().sum();
601 let sum2: f64 = hist2.iter().sum();
602
603 if sum1 == 0.0 || sum2 == 0.0 {
604 return 0.0;
605 }
606
607 for h in &mut hist1 {
608 *h /= sum1;
609 }
610 for h in &mut hist2 {
611 *h /= sum2;
612 }
613
614 let mut sum_sq_diff = 0.0;
616 for i in 0..num_bins {
617 let diff = hist1[i].sqrt() - hist2[i].sqrt();
618 sum_sq_diff += diff * diff;
619 }
620
621 (sum_sq_diff / 2.0).sqrt()
622 }
623
624 fn calculate_psi(&self, values: &[f64]) -> f64 {
625 if values.len() < 20 {
626 return 0.0;
627 }
628
629 let mid = values.len() / 2;
630 let baseline = &values[..mid];
631 let current = &values[mid..];
632
633 let (min_val, max_val) = values.iter().fold((f64::MAX, f64::MIN), |(min, max), &v| {
635 (min.min(v), max.max(v))
636 });
637
638 if (max_val - min_val).abs() < f64::EPSILON {
639 return 0.0;
640 }
641
642 let num_bins = 10;
643 let bin_width = (max_val - min_val) / num_bins as f64;
644
645 let mut hist_baseline = vec![0.0; num_bins];
646 let mut hist_current = vec![0.0; num_bins];
647
648 for &v in baseline {
649 let bin = ((v - min_val) / bin_width).floor() as usize;
650 let bin = bin.min(num_bins - 1);
651 hist_baseline[bin] += 1.0;
652 }
653
654 for &v in current {
655 let bin = ((v - min_val) / bin_width).floor() as usize;
656 let bin = bin.min(num_bins - 1);
657 hist_current[bin] += 1.0;
658 }
659
660 let epsilon = 0.0001;
662 let sum_baseline: f64 = hist_baseline.iter().sum();
663 let sum_current: f64 = hist_current.iter().sum();
664
665 if sum_baseline == 0.0 || sum_current == 0.0 {
666 return 0.0;
667 }
668
669 for h in &mut hist_baseline {
670 *h = (*h / sum_baseline).max(epsilon);
671 }
672 for h in &mut hist_current {
673 *h = (*h / sum_current).max(epsilon);
674 }
675
676 let mut psi = 0.0;
678 for i in 0..num_bins {
679 let diff = hist_current[i] - hist_baseline[i];
680 let ratio = hist_current[i] / hist_baseline[i];
681 psi += diff * ratio.ln();
682 }
683
684 psi
685 }
686
687 fn calculate_drift_magnitude(&self, means: &[f64]) -> f64 {
688 if means.len() < 2 {
689 return 0.0;
690 }
691
692 let mid = means.len() / 2;
693 let first_mean = means[..mid].iter().sum::<f64>() / mid as f64;
694 let second_mean = means[mid..].iter().sum::<f64>() / (means.len() - mid) as f64;
695
696 if first_mean.abs() < f64::EPSILON {
697 return (second_mean - first_mean).abs();
698 }
699
700 ((second_mean - first_mean) / first_mean).abs()
701 }
702
703 fn calculate_mean_change(&self, means: &[f64]) -> f64 {
704 if means.len() < 2 {
705 return 0.0;
706 }
707 let first = means.first().unwrap_or(&0.0);
708 let last = means.last().unwrap_or(&0.0);
709 if first.abs() < f64::EPSILON {
710 return 0.0;
711 }
712 (last - first) / first
713 }
714
715 fn calculate_std_change(&self, stds: &[f64]) -> f64 {
716 if stds.len() < 2 {
717 return 0.0;
718 }
719 let first = stds.first().unwrap_or(&0.0);
720 let last = stds.last().unwrap_or(&0.0);
721 if first.abs() < f64::EPSILON {
722 return 0.0;
723 }
724 (last - first) / first
725 }
726
727 fn evaluate_pass_status(&self, metrics: &DriftDetectionMetrics, drift_magnitude: f64) -> bool {
728 if drift_magnitude < self.min_magnitude_threshold {
730 return true; }
732
733 metrics.f1_score >= 0.5 || metrics.precision >= 0.6 || metrics.recall >= 0.6
735 }
736
737 fn collect_issues(
738 &self,
739 metrics: &DriftDetectionMetrics,
740 drift_magnitude: f64,
741 drift_count: usize,
742 ) -> Vec<String> {
743 let mut issues = Vec::new();
744
745 if drift_magnitude >= self.min_magnitude_threshold {
746 if metrics.precision < 0.5 {
747 issues.push(format!(
748 "Low precision ({:.2}): many false positives",
749 metrics.precision
750 ));
751 }
752 if metrics.recall < 0.5 {
753 issues.push(format!(
754 "Low recall ({:.2}): many drift events missed",
755 metrics.recall
756 ));
757 }
758 if let Some(delay) = metrics.mean_detection_delay {
759 if delay > 3.0 {
760 issues.push(format!("High detection delay ({:.1} periods)", delay));
761 }
762 }
763 }
764
765 if drift_count == 0 && drift_magnitude >= self.min_magnitude_threshold {
766 issues.push("No drift detected despite significant magnitude change".to_string());
767 }
768
769 issues
770 }
771}
772
773impl Default for DriftDetectionAnalyzer {
774 fn default() -> Self {
775 Self::new(0.05)
776 }
777}
778
779#[derive(Debug, Clone, Serialize, Deserialize)]
785pub struct DriftDetectionAnalysis {
786 pub sample_size: usize,
788 pub drift_detected: bool,
790 pub drift_count: usize,
792 pub drift_magnitude: f64,
794 pub detection_metrics: DriftDetectionMetrics,
796 pub hellinger_distance: Option<f64>,
798 pub psi: Option<f64>,
800 pub rolling_mean_change: f64,
802 pub rolling_std_change: f64,
804 pub passes: bool,
806 pub issues: Vec<String>,
808}
809
810#[derive(Debug, Clone, Default, Serialize, Deserialize)]
812pub struct DriftDetectionMetrics {
813 pub true_positives: usize,
815 pub false_positives: usize,
817 pub true_negatives: usize,
819 pub false_negatives: usize,
821 pub precision: f64,
823 pub recall: f64,
825 pub f1_score: f64,
827 pub mean_detection_delay: Option<f64>,
829}
830
831#[derive(Debug, Clone, Serialize, Deserialize)]
833pub struct LabeledEventAnalysis {
834 pub total_events: usize,
836 pub statistical_events: usize,
838 pub business_events: usize,
840 pub category_distribution: HashMap<DriftEventCategory, usize>,
842 pub difficulty_distribution: HashMap<DetectionDifficulty, usize>,
844 pub avg_magnitude: f64,
846 pub avg_difficulty: f64,
848 pub period_coverage: (u32, u32),
850 pub passes: bool,
852 pub issues: Vec<String>,
854}
855
856impl LabeledEventAnalysis {
857 pub fn empty() -> Self {
859 Self {
860 total_events: 0,
861 statistical_events: 0,
862 business_events: 0,
863 category_distribution: HashMap::new(),
864 difficulty_distribution: HashMap::new(),
865 avg_magnitude: 0.0,
866 avg_difficulty: 0.0,
867 period_coverage: (0, 0),
868 passes: true,
869 issues: Vec::new(),
870 }
871 }
872}
873
874#[cfg(test)]
879mod tests {
880 use super::*;
881
882 #[test]
883 fn test_drift_detection_entry_creation() {
884 let entry = DriftDetectionEntry::new(1, 100.0, Some(true));
885 assert_eq!(entry.period, 1);
886 assert_eq!(entry.value, 100.0);
887 assert_eq!(entry.ground_truth_drift, Some(true));
888 }
889
890 #[test]
891 fn test_drift_detection_entry_with_metadata() {
892 let entry = DriftDetectionEntry::with_metadata(5, 150.0, true, "MeanShift", 0.15, 0.3);
893 assert_eq!(entry.period, 5);
894 assert_eq!(entry.drift_type, Some("MeanShift".to_string()));
895 assert_eq!(entry.drift_magnitude, Some(0.15));
896 assert_eq!(entry.detection_difficulty, Some(0.3));
897 }
898
899 #[test]
900 fn test_drift_event_category_names() {
901 assert_eq!(DriftEventCategory::MeanShift.name(), "Mean Shift");
902 assert_eq!(
903 DriftEventCategory::OrganizationalEvent.name(),
904 "Organizational Event"
905 );
906 }
907
908 #[test]
909 fn test_drift_event_category_classification() {
910 assert!(DriftEventCategory::MeanShift.is_statistical());
911 assert!(!DriftEventCategory::MeanShift.is_business_event());
912 assert!(DriftEventCategory::OrganizationalEvent.is_business_event());
913 assert!(!DriftEventCategory::OrganizationalEvent.is_statistical());
914 }
915
916 #[test]
917 fn test_detection_difficulty_conversion() {
918 assert_eq!(DetectionDifficulty::Easy.to_score(), 0.0);
919 assert_eq!(DetectionDifficulty::Medium.to_score(), 0.5);
920 assert_eq!(DetectionDifficulty::Hard.to_score(), 1.0);
921
922 assert_eq!(
923 DetectionDifficulty::from_score(0.1),
924 DetectionDifficulty::Easy
925 );
926 assert_eq!(
927 DetectionDifficulty::from_score(0.5),
928 DetectionDifficulty::Medium
929 );
930 assert_eq!(
931 DetectionDifficulty::from_score(0.8),
932 DetectionDifficulty::Hard
933 );
934 }
935
936 #[test]
937 fn test_analyzer_creation() {
938 let analyzer = DriftDetectionAnalyzer::new(0.05)
939 .with_window_size(15)
940 .with_min_magnitude(0.1)
941 .with_hellinger(true)
942 .with_psi(true);
943
944 assert_eq!(analyzer.significance_level, 0.05);
945 assert_eq!(analyzer.window_size, 15);
946 assert_eq!(analyzer.min_magnitude_threshold, 0.1);
947 }
948
949 #[test]
950 fn test_analyze_no_drift() {
951 let analyzer = DriftDetectionAnalyzer::new(0.05).with_window_size(5);
952
953 let entries: Vec<DriftDetectionEntry> = (0..30)
955 .map(|i| DriftDetectionEntry::new(i, 100.0 + (i as f64 * 0.01), Some(false)))
956 .collect();
957
958 let result = analyzer.analyze(&entries).unwrap();
959 assert!(!result.drift_detected || result.drift_count < 5);
960 assert!(result.drift_magnitude < 0.1);
961 }
962
963 #[test]
964 fn test_analyze_with_drift() {
965 let analyzer = DriftDetectionAnalyzer::new(0.05).with_window_size(5);
966
967 let mut entries: Vec<DriftDetectionEntry> = (0..15)
969 .map(|i| DriftDetectionEntry::new(i, 100.0, Some(false)))
970 .collect();
971
972 for i in 15..30 {
974 entries.push(DriftDetectionEntry::new(i, 150.0, Some(true)));
975 }
976
977 let result = analyzer.analyze(&entries).unwrap();
978 assert!(result.drift_detected);
979 assert!(result.drift_magnitude > 0.3);
980 }
981
982 #[test]
983 fn test_analyze_insufficient_data() {
984 let analyzer = DriftDetectionAnalyzer::new(0.05).with_window_size(10);
985
986 let entries: Vec<DriftDetectionEntry> = (0..5)
987 .map(|i| DriftDetectionEntry::new(i, 100.0, None))
988 .collect();
989
990 let result = analyzer.analyze(&entries);
991 assert!(result.is_err());
992 }
993
994 #[test]
995 fn test_analyze_labeled_events() {
996 let analyzer = DriftDetectionAnalyzer::new(0.05);
997
998 let events = vec![
999 LabeledDriftEvent {
1000 event_id: "E1".to_string(),
1001 event_type: DriftEventCategory::MeanShift,
1002 start_period: 10,
1003 end_period: Some(15),
1004 affected_fields: vec!["amount".to_string()],
1005 magnitude: 0.15,
1006 detection_difficulty: DetectionDifficulty::Easy,
1007 },
1008 LabeledDriftEvent {
1009 event_id: "E2".to_string(),
1010 event_type: DriftEventCategory::OrganizationalEvent,
1011 start_period: 20,
1012 end_period: Some(25),
1013 affected_fields: vec!["volume".to_string()],
1014 magnitude: 0.30,
1015 detection_difficulty: DetectionDifficulty::Medium,
1016 },
1017 ];
1018
1019 let result = analyzer.analyze_labeled_events(&events).unwrap();
1020 assert_eq!(result.total_events, 2);
1021 assert_eq!(result.statistical_events, 1);
1022 assert_eq!(result.business_events, 1);
1023 assert!(result.avg_magnitude > 0.2);
1024 assert!(result.passes);
1025 }
1026
1027 #[test]
1028 fn test_empty_labeled_events() {
1029 let analyzer = DriftDetectionAnalyzer::new(0.05);
1030 let result = analyzer.analyze_labeled_events(&[]).unwrap();
1031 assert_eq!(result.total_events, 0);
1032 assert!(result.passes);
1033 }
1034
1035 #[test]
1036 fn test_hellinger_distance_no_drift() {
1037 let analyzer = DriftDetectionAnalyzer::new(0.05);
1038
1039 let entries: Vec<DriftDetectionEntry> = (0..40)
1041 .map(|i| DriftDetectionEntry::new(i, 100.0 + (i as f64 % 5.0), None))
1042 .collect();
1043
1044 let result = analyzer.analyze(&entries).unwrap();
1045 assert!(result.hellinger_distance.unwrap() < 0.3);
1046 }
1047
1048 #[test]
1049 fn test_psi_calculation() {
1050 let analyzer = DriftDetectionAnalyzer::new(0.05);
1051
1052 let mut entries: Vec<DriftDetectionEntry> = (0..20)
1054 .map(|i| DriftDetectionEntry::new(i, 100.0, None))
1055 .collect();
1056 for i in 20..40 {
1057 entries.push(DriftDetectionEntry::new(i, 200.0, None));
1058 }
1059
1060 let result = analyzer.analyze(&entries).unwrap();
1061 assert!(result.psi.is_some());
1062 assert!(result.psi.unwrap() > 0.0);
1064 }
1065
1066 #[test]
1067 fn test_detection_metrics_calculation() {
1068 let analyzer = DriftDetectionAnalyzer::new(0.05).with_window_size(3);
1069
1070 let mut entries = Vec::new();
1072 for i in 0..10 {
1073 entries.push(DriftDetectionEntry::new(i, 100.0, Some(false)));
1074 }
1075 for i in 10..20 {
1076 entries.push(DriftDetectionEntry::new(i, 200.0, Some(true)));
1077 }
1078
1079 let result = analyzer.analyze(&entries).unwrap();
1080
1081 assert!(result.detection_metrics.precision >= 0.0);
1083 assert!(result.detection_metrics.recall >= 0.0);
1084 }
1085}