1use crate::error::{EvalError, EvalResult};
33use serde::{Deserialize, Serialize};
34use std::collections::HashMap;
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct DriftDetectionEntry {
43 pub period: u32,
45 pub value: f64,
47 pub ground_truth_drift: Option<bool>,
49 pub drift_type: Option<String>,
51 pub drift_magnitude: Option<f64>,
53 pub detection_difficulty: Option<f64>,
55}
56
57impl DriftDetectionEntry {
58 pub fn new(period: u32, value: f64, ground_truth_drift: Option<bool>) -> Self {
60 Self {
61 period,
62 value,
63 ground_truth_drift,
64 drift_type: None,
65 drift_magnitude: None,
66 detection_difficulty: None,
67 }
68 }
69
70 pub fn with_metadata(
72 period: u32,
73 value: f64,
74 ground_truth_drift: bool,
75 drift_type: impl Into<String>,
76 drift_magnitude: f64,
77 detection_difficulty: f64,
78 ) -> Self {
79 Self {
80 period,
81 value,
82 ground_truth_drift: Some(ground_truth_drift),
83 drift_type: Some(drift_type.into()),
84 drift_magnitude: Some(drift_magnitude),
85 detection_difficulty: Some(detection_difficulty),
86 }
87 }
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct LabeledDriftEvent {
97 pub event_id: String,
99 pub event_type: DriftEventCategory,
101 pub start_period: u32,
103 pub end_period: Option<u32>,
105 pub affected_fields: Vec<String>,
107 pub magnitude: f64,
109 pub detection_difficulty: DetectionDifficulty,
111}
112
113#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
115pub enum DriftEventCategory {
116 MeanShift,
118 VarianceChange,
120 TrendChange,
122 SeasonalityChange,
124 ProportionShift,
126 NewCategory,
128 OrganizationalEvent,
130 RegulatoryChange,
132 TechnologyTransition,
134 EconomicCycle,
136 ProcessEvolution,
138}
139
140impl DriftEventCategory {
141 pub fn name(&self) -> &'static str {
143 match self {
144 Self::MeanShift => "Mean Shift",
145 Self::VarianceChange => "Variance Change",
146 Self::TrendChange => "Trend Change",
147 Self::SeasonalityChange => "Seasonality Change",
148 Self::ProportionShift => "Proportion Shift",
149 Self::NewCategory => "New Category",
150 Self::OrganizationalEvent => "Organizational Event",
151 Self::RegulatoryChange => "Regulatory Change",
152 Self::TechnologyTransition => "Technology Transition",
153 Self::EconomicCycle => "Economic Cycle",
154 Self::ProcessEvolution => "Process Evolution",
155 }
156 }
157
158 pub fn is_statistical(&self) -> bool {
160 matches!(
161 self,
162 Self::MeanShift | Self::VarianceChange | Self::TrendChange | Self::SeasonalityChange
163 )
164 }
165
166 pub fn is_business_event(&self) -> bool {
168 matches!(
169 self,
170 Self::OrganizationalEvent
171 | Self::RegulatoryChange
172 | Self::TechnologyTransition
173 | Self::ProcessEvolution
174 )
175 }
176}
177
178#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
180pub enum DetectionDifficulty {
181 Easy,
183 Medium,
185 Hard,
187}
188
189impl DetectionDifficulty {
190 pub fn to_score(&self) -> f64 {
192 match self {
193 Self::Easy => 0.0,
194 Self::Medium => 0.5,
195 Self::Hard => 1.0,
196 }
197 }
198
199 pub fn from_score(score: f64) -> Self {
201 if score < 0.33 {
202 Self::Easy
203 } else if score < 0.67 {
204 Self::Medium
205 } else {
206 Self::Hard
207 }
208 }
209}
210
211#[derive(Debug, Clone)]
217pub struct DriftDetectionAnalyzer {
218 significance_level: f64,
220 window_size: usize,
222 min_magnitude_threshold: f64,
224 use_hellinger: bool,
226 use_psi: bool,
228}
229
230impl DriftDetectionAnalyzer {
231 pub fn new(significance_level: f64) -> Self {
233 Self {
234 significance_level,
235 window_size: 10,
236 min_magnitude_threshold: 0.05,
237 use_hellinger: true,
238 use_psi: true,
239 }
240 }
241
242 pub fn with_window_size(mut self, size: usize) -> Self {
244 self.window_size = size;
245 self
246 }
247
248 pub fn with_min_magnitude(mut self, threshold: f64) -> Self {
250 self.min_magnitude_threshold = threshold;
251 self
252 }
253
254 pub fn with_hellinger(mut self, enabled: bool) -> Self {
256 self.use_hellinger = enabled;
257 self
258 }
259
260 pub fn with_psi(mut self, enabled: bool) -> Self {
262 self.use_psi = enabled;
263 self
264 }
265
266 pub fn analyze(&self, entries: &[DriftDetectionEntry]) -> EvalResult<DriftDetectionAnalysis> {
268 if entries.len() < self.window_size * 2 {
269 return Err(EvalError::InsufficientData {
270 required: self.window_size * 2,
271 actual: entries.len(),
272 });
273 }
274
275 let values: Vec<f64> = entries.iter().map(|e| e.value).collect();
277 let ground_truth: Vec<Option<bool>> =
278 entries.iter().map(|e| e.ground_truth_drift).collect();
279
280 let rolling_means = self.calculate_rolling_means(&values);
282 let rolling_stds = self.calculate_rolling_stds(&values);
283
284 let detected_drift = self.detect_drift_points(&rolling_means, &rolling_stds);
286
287 let metrics = self.calculate_detection_metrics(&detected_drift, &ground_truth);
289
290 let hellinger_distance = if self.use_hellinger {
292 Some(self.calculate_hellinger_distance(&values))
293 } else {
294 None
295 };
296
297 let psi = if self.use_psi {
298 Some(self.calculate_psi(&values))
299 } else {
300 None
301 };
302
303 let drift_detected = detected_drift.iter().any(|&d| d);
305 let drift_count = detected_drift.iter().filter(|&&d| d).count();
306
307 let drift_magnitude = self.calculate_drift_magnitude(&rolling_means);
309
310 let passes = self.evaluate_pass_status(&metrics, drift_magnitude);
311 let issues = self.collect_issues(&metrics, drift_magnitude, drift_count);
312
313 Ok(DriftDetectionAnalysis {
314 sample_size: entries.len(),
315 drift_detected,
316 drift_count,
317 drift_magnitude,
318 detection_metrics: metrics,
319 hellinger_distance,
320 psi,
321 rolling_mean_change: self.calculate_mean_change(&rolling_means),
322 rolling_std_change: self.calculate_std_change(&rolling_stds),
323 passes,
324 issues,
325 })
326 }
327
328 pub fn analyze_labeled_events(
330 &self,
331 events: &[LabeledDriftEvent],
332 ) -> EvalResult<LabeledEventAnalysis> {
333 if events.is_empty() {
334 return Ok(LabeledEventAnalysis::empty());
335 }
336
337 let mut category_counts: HashMap<DriftEventCategory, usize> = HashMap::new();
339 for event in events {
340 *category_counts.entry(event.event_type).or_insert(0) += 1;
341 }
342
343 let mut difficulty_counts: HashMap<DetectionDifficulty, usize> = HashMap::new();
345 for event in events {
346 *difficulty_counts
347 .entry(event.detection_difficulty)
348 .or_insert(0) += 1;
349 }
350
351 let total_events = events.len();
353 let statistical_events = events
354 .iter()
355 .filter(|e| e.event_type.is_statistical())
356 .count();
357 let business_events = events
358 .iter()
359 .filter(|e| e.event_type.is_business_event())
360 .count();
361
362 let avg_magnitude = events.iter().map(|e| e.magnitude).sum::<f64>() / total_events as f64;
364 let avg_difficulty = events
365 .iter()
366 .map(|e| e.detection_difficulty.to_score())
367 .sum::<f64>()
368 / total_events as f64;
369
370 let min_period = events.iter().map(|e| e.start_period).min().unwrap_or(0);
372 let max_period = events
373 .iter()
374 .filter_map(|e| e.end_period)
375 .max()
376 .unwrap_or(min_period);
377
378 let passes = total_events > 0 && avg_magnitude >= self.min_magnitude_threshold;
379 let issues = if !passes {
380 vec!["Insufficient drift events or magnitude too low".to_string()]
381 } else {
382 Vec::new()
383 };
384
385 Ok(LabeledEventAnalysis {
386 total_events,
387 statistical_events,
388 business_events,
389 category_distribution: category_counts,
390 difficulty_distribution: difficulty_counts,
391 avg_magnitude,
392 avg_difficulty,
393 period_coverage: (min_period, max_period),
394 passes,
395 issues,
396 })
397 }
398
399 fn calculate_rolling_means(&self, values: &[f64]) -> Vec<f64> {
402 if values.len() < self.window_size {
403 tracing::debug!(
404 "Drift detection: not enough values ({}) for window size ({}), returning empty",
405 values.len(),
406 self.window_size
407 );
408 return Vec::new();
409 }
410 let mut means = Vec::with_capacity(values.len() - self.window_size + 1);
411 for i in 0..=(values.len() - self.window_size) {
412 let window = &values[i..i + self.window_size];
413 let mean = window.iter().sum::<f64>() / self.window_size as f64;
414 means.push(mean);
415 }
416 means
417 }
418
419 fn calculate_rolling_stds(&self, values: &[f64]) -> Vec<f64> {
420 if values.len() < self.window_size {
421 tracing::debug!(
422 "Drift detection: not enough values ({}) for window size ({}), returning empty",
423 values.len(),
424 self.window_size
425 );
426 return Vec::new();
427 }
428 let mut stds = Vec::with_capacity(values.len() - self.window_size + 1);
429 for i in 0..=(values.len() - self.window_size) {
430 let window = &values[i..i + self.window_size];
431 let mean = window.iter().sum::<f64>() / self.window_size as f64;
432 let variance =
433 window.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / self.window_size as f64;
434 stds.push(variance.sqrt());
435 }
436 stds
437 }
438
439 fn detect_drift_points(&self, means: &[f64], stds: &[f64]) -> Vec<bool> {
440 if means.len() < 2 {
441 return vec![false; means.len()];
442 }
443
444 let mut detected = vec![false; means.len()];
445
446 let baseline_end = means.len() / 2;
448 let baseline_mean = means[..baseline_end].iter().sum::<f64>() / baseline_end as f64;
449 let baseline_std = if baseline_end > 1 {
450 let variance = means[..baseline_end]
451 .iter()
452 .map(|x| (x - baseline_mean).powi(2))
453 .sum::<f64>()
454 / baseline_end as f64;
455 variance.sqrt().max(0.001) } else {
457 0.001
458 };
459
460 for i in baseline_end..means.len() {
462 let z_score = (means[i] - baseline_mean).abs() / baseline_std;
463 let threshold = 1.96 / self.significance_level.sqrt(); if z_score > threshold {
466 detected[i] = true;
467 }
468
469 if i < stds.len() && baseline_end > 0 {
471 let baseline_var_mean =
472 stds[..baseline_end].iter().sum::<f64>() / baseline_end as f64;
473 if baseline_var_mean > 0.001 {
474 let var_ratio = stds[i] / baseline_var_mean;
475 if !(0.5..=2.0).contains(&var_ratio) {
476 detected[i] = true;
477 }
478 }
479 }
480 }
481
482 detected
483 }
484
485 fn calculate_detection_metrics(
486 &self,
487 detected: &[bool],
488 ground_truth: &[Option<bool>],
489 ) -> DriftDetectionMetrics {
490 let mut true_positives = 0;
491 let mut false_positives = 0;
492 let mut true_negatives = 0;
493 let mut false_negatives = 0;
494 let mut detection_delays = Vec::new();
495
496 let offset = detected.len().saturating_sub(ground_truth.len());
498
499 for (i, >) in ground_truth.iter().enumerate() {
500 let detected_idx = i + offset;
501 if detected_idx >= detected.len() {
502 break;
503 }
504
505 let pred = detected[detected_idx];
506 match gt {
507 Some(true) => {
508 if pred {
509 true_positives += 1;
510 } else {
511 false_negatives += 1;
512 }
513 }
514 Some(false) => {
515 if pred {
516 false_positives += 1;
517 } else {
518 true_negatives += 1;
519 }
520 }
521 None => {}
522 }
523 }
524
525 let mut last_drift_start: Option<usize> = None;
527 for (i, >) in ground_truth.iter().enumerate() {
528 if gt == Some(true) && last_drift_start.is_none() {
529 last_drift_start = Some(i);
530 } else if gt == Some(false) {
531 last_drift_start = None;
532 }
533
534 let detected_idx = i + offset;
535 if detected_idx < detected.len() && detected[detected_idx] {
536 if let Some(start) = last_drift_start {
537 detection_delays.push((i - start) as f64);
538 last_drift_start = None;
539 }
540 }
541 }
542
543 let precision = if true_positives + false_positives > 0 {
544 true_positives as f64 / (true_positives + false_positives) as f64
545 } else {
546 0.0
547 };
548
549 let recall = if true_positives + false_negatives > 0 {
550 true_positives as f64 / (true_positives + false_negatives) as f64
551 } else {
552 0.0
553 };
554
555 let f1_score = if precision + recall > 0.0 {
556 2.0 * precision * recall / (precision + recall)
557 } else {
558 0.0
559 };
560
561 let mean_detection_delay = if detection_delays.is_empty() {
562 None
563 } else {
564 Some(detection_delays.iter().sum::<f64>() / detection_delays.len() as f64)
565 };
566
567 DriftDetectionMetrics {
568 true_positives,
569 false_positives,
570 true_negatives,
571 false_negatives,
572 precision,
573 recall,
574 f1_score,
575 mean_detection_delay,
576 }
577 }
578
579 fn calculate_hellinger_distance(&self, values: &[f64]) -> f64 {
580 if values.len() < 20 {
581 return 0.0;
582 }
583
584 let mid = values.len() / 2;
585 let first_half = &values[..mid];
586 let second_half = &values[mid..];
587
588 let (min_val, max_val) = values.iter().fold((f64::MAX, f64::MIN), |(min, max), &v| {
590 (min.min(v), max.max(v))
591 });
592
593 if (max_val - min_val).abs() < f64::EPSILON {
594 return 0.0;
595 }
596
597 let num_bins = 10;
598 let bin_width = (max_val - min_val) / num_bins as f64;
599
600 let mut hist1 = vec![0.0; num_bins];
601 let mut hist2 = vec![0.0; num_bins];
602
603 for &v in first_half {
604 let bin = ((v - min_val) / bin_width).floor() as usize;
605 let bin = bin.min(num_bins - 1);
606 hist1[bin] += 1.0;
607 }
608
609 for &v in second_half {
610 let bin = ((v - min_val) / bin_width).floor() as usize;
611 let bin = bin.min(num_bins - 1);
612 hist2[bin] += 1.0;
613 }
614
615 let sum1: f64 = hist1.iter().sum();
617 let sum2: f64 = hist2.iter().sum();
618
619 if sum1 == 0.0 || sum2 == 0.0 {
620 return 0.0;
621 }
622
623 for h in &mut hist1 {
624 *h /= sum1;
625 }
626 for h in &mut hist2 {
627 *h /= sum2;
628 }
629
630 let mut sum_sq_diff = 0.0;
632 for i in 0..num_bins {
633 let diff = hist1[i].sqrt() - hist2[i].sqrt();
634 sum_sq_diff += diff * diff;
635 }
636
637 (sum_sq_diff / 2.0).sqrt()
638 }
639
640 fn calculate_psi(&self, values: &[f64]) -> f64 {
641 if values.len() < 20 {
642 return 0.0;
643 }
644
645 let mid = values.len() / 2;
646 let baseline = &values[..mid];
647 let current = &values[mid..];
648
649 let (min_val, max_val) = values.iter().fold((f64::MAX, f64::MIN), |(min, max), &v| {
651 (min.min(v), max.max(v))
652 });
653
654 if (max_val - min_val).abs() < f64::EPSILON {
655 return 0.0;
656 }
657
658 let num_bins = 10;
659 let bin_width = (max_val - min_val) / num_bins as f64;
660
661 let mut hist_baseline = vec![0.0; num_bins];
662 let mut hist_current = vec![0.0; num_bins];
663
664 for &v in baseline {
665 let bin = ((v - min_val) / bin_width).floor() as usize;
666 let bin = bin.min(num_bins - 1);
667 hist_baseline[bin] += 1.0;
668 }
669
670 for &v in current {
671 let bin = ((v - min_val) / bin_width).floor() as usize;
672 let bin = bin.min(num_bins - 1);
673 hist_current[bin] += 1.0;
674 }
675
676 let epsilon = 0.0001;
678 let sum_baseline: f64 = hist_baseline.iter().sum();
679 let sum_current: f64 = hist_current.iter().sum();
680
681 if sum_baseline == 0.0 || sum_current == 0.0 {
682 return 0.0;
683 }
684
685 for h in &mut hist_baseline {
686 *h = (*h / sum_baseline).max(epsilon);
687 }
688 for h in &mut hist_current {
689 *h = (*h / sum_current).max(epsilon);
690 }
691
692 let mut psi = 0.0;
694 for i in 0..num_bins {
695 let diff = hist_current[i] - hist_baseline[i];
696 let ratio = hist_current[i] / hist_baseline[i];
697 psi += diff * ratio.ln();
698 }
699
700 psi
701 }
702
703 fn calculate_drift_magnitude(&self, means: &[f64]) -> f64 {
704 if means.len() < 2 {
705 return 0.0;
706 }
707
708 let mid = means.len() / 2;
709 let first_mean = means[..mid].iter().sum::<f64>() / mid as f64;
710 let second_mean = means[mid..].iter().sum::<f64>() / (means.len() - mid) as f64;
711
712 if first_mean.abs() < f64::EPSILON {
713 return (second_mean - first_mean).abs();
714 }
715
716 ((second_mean - first_mean) / first_mean).abs()
717 }
718
719 fn calculate_mean_change(&self, means: &[f64]) -> f64 {
720 if means.len() < 2 {
721 return 0.0;
722 }
723 let first = means.first().unwrap_or(&0.0);
724 let last = means.last().unwrap_or(&0.0);
725 if first.abs() < f64::EPSILON {
726 return 0.0;
727 }
728 (last - first) / first
729 }
730
731 fn calculate_std_change(&self, stds: &[f64]) -> f64 {
732 if stds.len() < 2 {
733 return 0.0;
734 }
735 let first = stds.first().unwrap_or(&0.0);
736 let last = stds.last().unwrap_or(&0.0);
737 if first.abs() < f64::EPSILON {
738 return 0.0;
739 }
740 (last - first) / first
741 }
742
743 fn evaluate_pass_status(&self, metrics: &DriftDetectionMetrics, drift_magnitude: f64) -> bool {
744 if drift_magnitude < self.min_magnitude_threshold {
746 return true; }
748
749 metrics.f1_score >= 0.5 || metrics.precision >= 0.6 || metrics.recall >= 0.6
751 }
752
753 fn collect_issues(
754 &self,
755 metrics: &DriftDetectionMetrics,
756 drift_magnitude: f64,
757 drift_count: usize,
758 ) -> Vec<String> {
759 let mut issues = Vec::new();
760
761 if drift_magnitude >= self.min_magnitude_threshold {
762 if metrics.precision < 0.5 {
763 issues.push(format!(
764 "Low precision ({:.2}): many false positives",
765 metrics.precision
766 ));
767 }
768 if metrics.recall < 0.5 {
769 issues.push(format!(
770 "Low recall ({:.2}): many drift events missed",
771 metrics.recall
772 ));
773 }
774 if let Some(delay) = metrics.mean_detection_delay {
775 if delay > 3.0 {
776 issues.push(format!("High detection delay ({:.1} periods)", delay));
777 }
778 }
779 }
780
781 if drift_count == 0 && drift_magnitude >= self.min_magnitude_threshold {
782 issues.push("No drift detected despite significant magnitude change".to_string());
783 }
784
785 issues
786 }
787}
788
789impl Default for DriftDetectionAnalyzer {
790 fn default() -> Self {
791 Self::new(0.05)
792 }
793}
794
795#[derive(Debug, Clone, Serialize, Deserialize)]
801pub struct DriftDetectionAnalysis {
802 pub sample_size: usize,
804 pub drift_detected: bool,
806 pub drift_count: usize,
808 pub drift_magnitude: f64,
810 pub detection_metrics: DriftDetectionMetrics,
812 pub hellinger_distance: Option<f64>,
814 pub psi: Option<f64>,
816 pub rolling_mean_change: f64,
818 pub rolling_std_change: f64,
820 pub passes: bool,
822 pub issues: Vec<String>,
824}
825
826#[derive(Debug, Clone, Default, Serialize, Deserialize)]
828pub struct DriftDetectionMetrics {
829 pub true_positives: usize,
831 pub false_positives: usize,
833 pub true_negatives: usize,
835 pub false_negatives: usize,
837 pub precision: f64,
839 pub recall: f64,
841 pub f1_score: f64,
843 pub mean_detection_delay: Option<f64>,
845}
846
847#[derive(Debug, Clone, Serialize, Deserialize)]
849pub struct LabeledEventAnalysis {
850 pub total_events: usize,
852 pub statistical_events: usize,
854 pub business_events: usize,
856 pub category_distribution: HashMap<DriftEventCategory, usize>,
858 pub difficulty_distribution: HashMap<DetectionDifficulty, usize>,
860 pub avg_magnitude: f64,
862 pub avg_difficulty: f64,
864 pub period_coverage: (u32, u32),
866 pub passes: bool,
868 pub issues: Vec<String>,
870}
871
872impl LabeledEventAnalysis {
873 pub fn empty() -> Self {
875 Self {
876 total_events: 0,
877 statistical_events: 0,
878 business_events: 0,
879 category_distribution: HashMap::new(),
880 difficulty_distribution: HashMap::new(),
881 avg_magnitude: 0.0,
882 avg_difficulty: 0.0,
883 period_coverage: (0, 0),
884 passes: true,
885 issues: Vec::new(),
886 }
887 }
888}
889
890#[cfg(test)]
895#[allow(clippy::unwrap_used)]
896mod tests {
897 use super::*;
898
899 #[test]
900 fn test_drift_detection_entry_creation() {
901 let entry = DriftDetectionEntry::new(1, 100.0, Some(true));
902 assert_eq!(entry.period, 1);
903 assert_eq!(entry.value, 100.0);
904 assert_eq!(entry.ground_truth_drift, Some(true));
905 }
906
907 #[test]
908 fn test_drift_detection_entry_with_metadata() {
909 let entry = DriftDetectionEntry::with_metadata(5, 150.0, true, "MeanShift", 0.15, 0.3);
910 assert_eq!(entry.period, 5);
911 assert_eq!(entry.drift_type, Some("MeanShift".to_string()));
912 assert_eq!(entry.drift_magnitude, Some(0.15));
913 assert_eq!(entry.detection_difficulty, Some(0.3));
914 }
915
916 #[test]
917 fn test_drift_event_category_names() {
918 assert_eq!(DriftEventCategory::MeanShift.name(), "Mean Shift");
919 assert_eq!(
920 DriftEventCategory::OrganizationalEvent.name(),
921 "Organizational Event"
922 );
923 }
924
925 #[test]
926 fn test_drift_event_category_classification() {
927 assert!(DriftEventCategory::MeanShift.is_statistical());
928 assert!(!DriftEventCategory::MeanShift.is_business_event());
929 assert!(DriftEventCategory::OrganizationalEvent.is_business_event());
930 assert!(!DriftEventCategory::OrganizationalEvent.is_statistical());
931 }
932
933 #[test]
934 fn test_detection_difficulty_conversion() {
935 assert_eq!(DetectionDifficulty::Easy.to_score(), 0.0);
936 assert_eq!(DetectionDifficulty::Medium.to_score(), 0.5);
937 assert_eq!(DetectionDifficulty::Hard.to_score(), 1.0);
938
939 assert_eq!(
940 DetectionDifficulty::from_score(0.1),
941 DetectionDifficulty::Easy
942 );
943 assert_eq!(
944 DetectionDifficulty::from_score(0.5),
945 DetectionDifficulty::Medium
946 );
947 assert_eq!(
948 DetectionDifficulty::from_score(0.8),
949 DetectionDifficulty::Hard
950 );
951 }
952
953 #[test]
954 fn test_analyzer_creation() {
955 let analyzer = DriftDetectionAnalyzer::new(0.05)
956 .with_window_size(15)
957 .with_min_magnitude(0.1)
958 .with_hellinger(true)
959 .with_psi(true);
960
961 assert_eq!(analyzer.significance_level, 0.05);
962 assert_eq!(analyzer.window_size, 15);
963 assert_eq!(analyzer.min_magnitude_threshold, 0.1);
964 }
965
966 #[test]
967 fn test_analyze_no_drift() {
968 let analyzer = DriftDetectionAnalyzer::new(0.05).with_window_size(5);
969
970 let entries: Vec<DriftDetectionEntry> = (0..30)
972 .map(|i| DriftDetectionEntry::new(i, 100.0 + (i as f64 * 0.01), Some(false)))
973 .collect();
974
975 let result = analyzer.analyze(&entries).unwrap();
976 assert!(!result.drift_detected || result.drift_count < 5);
977 assert!(result.drift_magnitude < 0.1);
978 }
979
980 #[test]
981 fn test_analyze_with_drift() {
982 let analyzer = DriftDetectionAnalyzer::new(0.05).with_window_size(5);
983
984 let mut entries: Vec<DriftDetectionEntry> = (0..15)
986 .map(|i| DriftDetectionEntry::new(i, 100.0, Some(false)))
987 .collect();
988
989 for i in 15..30 {
991 entries.push(DriftDetectionEntry::new(i, 150.0, Some(true)));
992 }
993
994 let result = analyzer.analyze(&entries).unwrap();
995 assert!(result.drift_detected);
996 assert!(result.drift_magnitude > 0.3);
997 }
998
999 #[test]
1000 fn test_analyze_insufficient_data() {
1001 let analyzer = DriftDetectionAnalyzer::new(0.05).with_window_size(10);
1002
1003 let entries: Vec<DriftDetectionEntry> = (0..5)
1004 .map(|i| DriftDetectionEntry::new(i, 100.0, None))
1005 .collect();
1006
1007 let result = analyzer.analyze(&entries);
1008 assert!(result.is_err());
1009 }
1010
1011 #[test]
1012 fn test_analyze_labeled_events() {
1013 let analyzer = DriftDetectionAnalyzer::new(0.05);
1014
1015 let events = vec![
1016 LabeledDriftEvent {
1017 event_id: "E1".to_string(),
1018 event_type: DriftEventCategory::MeanShift,
1019 start_period: 10,
1020 end_period: Some(15),
1021 affected_fields: vec!["amount".to_string()],
1022 magnitude: 0.15,
1023 detection_difficulty: DetectionDifficulty::Easy,
1024 },
1025 LabeledDriftEvent {
1026 event_id: "E2".to_string(),
1027 event_type: DriftEventCategory::OrganizationalEvent,
1028 start_period: 20,
1029 end_period: Some(25),
1030 affected_fields: vec!["volume".to_string()],
1031 magnitude: 0.30,
1032 detection_difficulty: DetectionDifficulty::Medium,
1033 },
1034 ];
1035
1036 let result = analyzer.analyze_labeled_events(&events).unwrap();
1037 assert_eq!(result.total_events, 2);
1038 assert_eq!(result.statistical_events, 1);
1039 assert_eq!(result.business_events, 1);
1040 assert!(result.avg_magnitude > 0.2);
1041 assert!(result.passes);
1042 }
1043
1044 #[test]
1045 fn test_empty_labeled_events() {
1046 let analyzer = DriftDetectionAnalyzer::new(0.05);
1047 let result = analyzer.analyze_labeled_events(&[]).unwrap();
1048 assert_eq!(result.total_events, 0);
1049 assert!(result.passes);
1050 }
1051
1052 #[test]
1053 fn test_hellinger_distance_no_drift() {
1054 let analyzer = DriftDetectionAnalyzer::new(0.05);
1055
1056 let entries: Vec<DriftDetectionEntry> = (0..40)
1058 .map(|i| DriftDetectionEntry::new(i, 100.0 + (i as f64 % 5.0), None))
1059 .collect();
1060
1061 let result = analyzer.analyze(&entries).unwrap();
1062 assert!(result.hellinger_distance.unwrap() < 0.3);
1063 }
1064
1065 #[test]
1066 fn test_psi_calculation() {
1067 let analyzer = DriftDetectionAnalyzer::new(0.05);
1068
1069 let mut entries: Vec<DriftDetectionEntry> = (0..20)
1071 .map(|i| DriftDetectionEntry::new(i, 100.0, None))
1072 .collect();
1073 for i in 20..40 {
1074 entries.push(DriftDetectionEntry::new(i, 200.0, None));
1075 }
1076
1077 let result = analyzer.analyze(&entries).unwrap();
1078 assert!(result.psi.is_some());
1079 assert!(result.psi.unwrap() > 0.0);
1081 }
1082
1083 #[test]
1084 fn test_detection_metrics_calculation() {
1085 let analyzer = DriftDetectionAnalyzer::new(0.05).with_window_size(3);
1086
1087 let mut entries = Vec::new();
1089 for i in 0..10 {
1090 entries.push(DriftDetectionEntry::new(i, 100.0, Some(false)));
1091 }
1092 for i in 10..20 {
1093 entries.push(DriftDetectionEntry::new(i, 200.0, Some(true)));
1094 }
1095
1096 let result = analyzer.analyze(&entries).unwrap();
1097
1098 assert!(result.detection_metrics.precision >= 0.0);
1100 assert!(result.detection_metrics.recall >= 0.0);
1101 }
1102}