Skip to main content

datasynth_eval/statistical/
drift_detection.rs

1//! Drift Detection Evaluation Module.
2//!
3//! Provides tools for evaluating drift detection ground truth labels and
4//! validating that generated drift events are detectable and properly labeled.
5//!
6//! # Overview
7//!
8//! This module evaluates the quality and detectability of drift events in
9//! synthetic data by analyzing:
10//!
11//! - Statistical distribution shifts (mean, variance changes)
12//! - Categorical shifts (proportion changes, new categories)
13//! - Temporal pattern changes (seasonality, trend)
14//! - Regulatory and organizational event impacts
15//!
16//! # Example
17//!
18//! ```ignore
19//! use datasynth_eval::statistical::{DriftDetectionAnalyzer, DriftDetectionEntry};
20//!
21//! let analyzer = DriftDetectionAnalyzer::new(0.05);
22//! let entries = vec![
23//!     DriftDetectionEntry::new(1, 100.0, Some(true)),
24//!     DriftDetectionEntry::new(2, 102.0, Some(false)),
25//!     // ...
26//! ];
27//!
28//! let analysis = analyzer.analyze(&entries)?;
29//! println!("Drift detected: {}", analysis.drift_detected);
30//! ```
31
32use crate::error::{EvalError, EvalResult};
33use serde::{Deserialize, Serialize};
34use std::collections::HashMap;
35
36// =============================================================================
37// Drift Detection Entry
38// =============================================================================
39
40/// A single data point for drift detection analysis.
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct DriftDetectionEntry {
43    /// Period number (e.g., month number from start).
44    pub period: u32,
45    /// Observed value at this period.
46    pub value: f64,
47    /// Ground truth label: true if this period has drift, false otherwise.
48    pub ground_truth_drift: Option<bool>,
49    /// Drift event type if labeled.
50    pub drift_type: Option<String>,
51    /// Magnitude of drift if known.
52    pub drift_magnitude: Option<f64>,
53    /// Detection difficulty (0.0 = easy, 1.0 = hard).
54    pub detection_difficulty: Option<f64>,
55}
56
57impl DriftDetectionEntry {
58    /// Create a new drift detection entry.
59    pub fn new(period: u32, value: f64, ground_truth_drift: Option<bool>) -> Self {
60        Self {
61            period,
62            value,
63            ground_truth_drift,
64            drift_type: None,
65            drift_magnitude: None,
66            detection_difficulty: None,
67        }
68    }
69
70    /// Create entry with full metadata.
71    pub fn with_metadata(
72        period: u32,
73        value: f64,
74        ground_truth_drift: bool,
75        drift_type: impl Into<String>,
76        drift_magnitude: f64,
77        detection_difficulty: f64,
78    ) -> Self {
79        Self {
80            period,
81            value,
82            ground_truth_drift: Some(ground_truth_drift),
83            drift_type: Some(drift_type.into()),
84            drift_magnitude: Some(drift_magnitude),
85            detection_difficulty: Some(detection_difficulty),
86        }
87    }
88}
89
90// =============================================================================
91// Labeled Drift Event
92// =============================================================================
93
94/// A labeled drift event from ground truth data.
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct LabeledDriftEvent {
97    /// Unique event identifier.
98    pub event_id: String,
99    /// Event type classification.
100    pub event_type: DriftEventCategory,
101    /// Start period of the drift.
102    pub start_period: u32,
103    /// End period of the drift (None if ongoing).
104    pub end_period: Option<u32>,
105    /// Affected fields/metrics.
106    pub affected_fields: Vec<String>,
107    /// Magnitude of the drift effect.
108    pub magnitude: f64,
109    /// Detection difficulty level.
110    pub detection_difficulty: DetectionDifficulty,
111}
112
113/// Categories of drift events.
114#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
115pub enum DriftEventCategory {
116    /// Mean shift in distribution.
117    MeanShift,
118    /// Variance change in distribution.
119    VarianceChange,
120    /// Trend change (slope).
121    TrendChange,
122    /// Seasonality pattern change.
123    SeasonalityChange,
124    /// Categorical proportion shift.
125    ProportionShift,
126    /// New category emergence.
127    NewCategory,
128    /// Organizational event (acquisition, merger, etc.).
129    OrganizationalEvent,
130    /// Regulatory change impact.
131    RegulatoryChange,
132    /// Technology transition impact.
133    TechnologyTransition,
134    /// Economic cycle effect.
135    EconomicCycle,
136    /// Process evolution.
137    ProcessEvolution,
138}
139
140impl DriftEventCategory {
141    /// Get human-readable name.
142    pub fn name(&self) -> &'static str {
143        match self {
144            Self::MeanShift => "Mean Shift",
145            Self::VarianceChange => "Variance Change",
146            Self::TrendChange => "Trend Change",
147            Self::SeasonalityChange => "Seasonality Change",
148            Self::ProportionShift => "Proportion Shift",
149            Self::NewCategory => "New Category",
150            Self::OrganizationalEvent => "Organizational Event",
151            Self::RegulatoryChange => "Regulatory Change",
152            Self::TechnologyTransition => "Technology Transition",
153            Self::EconomicCycle => "Economic Cycle",
154            Self::ProcessEvolution => "Process Evolution",
155        }
156    }
157
158    /// Check if this is a statistical drift type.
159    pub fn is_statistical(&self) -> bool {
160        matches!(
161            self,
162            Self::MeanShift | Self::VarianceChange | Self::TrendChange | Self::SeasonalityChange
163        )
164    }
165
166    /// Check if this is a business event drift type.
167    pub fn is_business_event(&self) -> bool {
168        matches!(
169            self,
170            Self::OrganizationalEvent
171                | Self::RegulatoryChange
172                | Self::TechnologyTransition
173                | Self::ProcessEvolution
174        )
175    }
176}
177
178/// Detection difficulty levels.
179#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
180pub enum DetectionDifficulty {
181    /// Easy to detect (large magnitude, clear signal).
182    Easy,
183    /// Medium difficulty.
184    Medium,
185    /// Hard to detect (small magnitude, noisy signal).
186    Hard,
187}
188
189impl DetectionDifficulty {
190    /// Convert to numeric score (0.0 = easy, 1.0 = hard).
191    pub fn to_score(&self) -> f64 {
192        match self {
193            Self::Easy => 0.0,
194            Self::Medium => 0.5,
195            Self::Hard => 1.0,
196        }
197    }
198
199    /// Create from numeric score.
200    pub fn from_score(score: f64) -> Self {
201        if score < 0.33 {
202            Self::Easy
203        } else if score < 0.67 {
204            Self::Medium
205        } else {
206            Self::Hard
207        }
208    }
209}
210
211// =============================================================================
212// Drift Detection Analyzer
213// =============================================================================
214
215/// Analyzer for drift detection evaluation.
216#[derive(Debug, Clone)]
217pub struct DriftDetectionAnalyzer {
218    /// Significance level for statistical tests.
219    significance_level: f64,
220    /// Window size for rolling statistics.
221    window_size: usize,
222    /// Minimum magnitude threshold to consider as drift.
223    min_magnitude_threshold: f64,
224    /// Enable Hellinger distance calculation.
225    use_hellinger: bool,
226    /// Enable Population Stability Index (PSI) calculation.
227    use_psi: bool,
228}
229
230impl DriftDetectionAnalyzer {
231    /// Create a new drift detection analyzer.
232    pub fn new(significance_level: f64) -> Self {
233        Self {
234            significance_level,
235            window_size: 10,
236            min_magnitude_threshold: 0.05,
237            use_hellinger: true,
238            use_psi: true,
239        }
240    }
241
242    /// Set the rolling window size.
243    pub fn with_window_size(mut self, size: usize) -> Self {
244        self.window_size = size;
245        self
246    }
247
248    /// Set the minimum magnitude threshold.
249    pub fn with_min_magnitude(mut self, threshold: f64) -> Self {
250        self.min_magnitude_threshold = threshold;
251        self
252    }
253
254    /// Enable or disable Hellinger distance calculation.
255    pub fn with_hellinger(mut self, enabled: bool) -> Self {
256        self.use_hellinger = enabled;
257        self
258    }
259
260    /// Enable or disable PSI calculation.
261    pub fn with_psi(mut self, enabled: bool) -> Self {
262        self.use_psi = enabled;
263        self
264    }
265
266    /// Analyze drift detection entries.
267    pub fn analyze(&self, entries: &[DriftDetectionEntry]) -> EvalResult<DriftDetectionAnalysis> {
268        if entries.len() < self.window_size * 2 {
269            return Err(EvalError::InsufficientData {
270                required: self.window_size * 2,
271                actual: entries.len(),
272            });
273        }
274
275        // Extract values and labels
276        let values: Vec<f64> = entries.iter().map(|e| e.value).collect();
277        let ground_truth: Vec<Option<bool>> =
278            entries.iter().map(|e| e.ground_truth_drift).collect();
279
280        // Calculate rolling statistics
281        let rolling_means = self.calculate_rolling_means(&values);
282        let rolling_stds = self.calculate_rolling_stds(&values);
283
284        // Detect drift points using CUSUM-like approach
285        let detected_drift = self.detect_drift_points(&rolling_means, &rolling_stds);
286
287        // Calculate detection metrics if ground truth is available
288        let metrics = self.calculate_detection_metrics(&detected_drift, &ground_truth);
289
290        // Calculate statistical measures
291        let hellinger_distance = if self.use_hellinger {
292            Some(self.calculate_hellinger_distance(&values))
293        } else {
294            None
295        };
296
297        let psi = if self.use_psi {
298            Some(self.calculate_psi(&values))
299        } else {
300            None
301        };
302
303        // Determine overall drift status
304        let drift_detected = detected_drift.iter().any(|&d| d);
305        let drift_count = detected_drift.iter().filter(|&&d| d).count();
306
307        // Calculate magnitude of detected drifts
308        let drift_magnitude = self.calculate_drift_magnitude(&rolling_means);
309
310        let passes = self.evaluate_pass_status(&metrics, drift_magnitude);
311        let issues = self.collect_issues(&metrics, drift_magnitude, drift_count);
312
313        Ok(DriftDetectionAnalysis {
314            sample_size: entries.len(),
315            drift_detected,
316            drift_count,
317            drift_magnitude,
318            detection_metrics: metrics,
319            hellinger_distance,
320            psi,
321            rolling_mean_change: self.calculate_mean_change(&rolling_means),
322            rolling_std_change: self.calculate_std_change(&rolling_stds),
323            passes,
324            issues,
325        })
326    }
327
328    /// Analyze labeled drift events for quality.
329    pub fn analyze_labeled_events(
330        &self,
331        events: &[LabeledDriftEvent],
332    ) -> EvalResult<LabeledEventAnalysis> {
333        if events.is_empty() {
334            return Ok(LabeledEventAnalysis::empty());
335        }
336
337        // Count events by category
338        let mut category_counts: HashMap<DriftEventCategory, usize> = HashMap::new();
339        for event in events {
340            *category_counts.entry(event.event_type).or_insert(0) += 1;
341        }
342
343        // Count events by difficulty
344        let mut difficulty_counts: HashMap<DetectionDifficulty, usize> = HashMap::new();
345        for event in events {
346            *difficulty_counts
347                .entry(event.detection_difficulty)
348                .or_insert(0) += 1;
349        }
350
351        // Calculate coverage metrics
352        let total_events = events.len();
353        let statistical_events = events
354            .iter()
355            .filter(|e| e.event_type.is_statistical())
356            .count();
357        let business_events = events
358            .iter()
359            .filter(|e| e.event_type.is_business_event())
360            .count();
361
362        // Calculate average magnitude and difficulty
363        let avg_magnitude = events.iter().map(|e| e.magnitude).sum::<f64>() / total_events as f64;
364        let avg_difficulty = events
365            .iter()
366            .map(|e| e.detection_difficulty.to_score())
367            .sum::<f64>()
368            / total_events as f64;
369
370        // Calculate period coverage
371        let min_period = events.iter().map(|e| e.start_period).min().unwrap_or(0);
372        let max_period = events
373            .iter()
374            .filter_map(|e| e.end_period)
375            .max()
376            .unwrap_or(min_period);
377
378        let passes = total_events > 0 && avg_magnitude >= self.min_magnitude_threshold;
379        let issues = if !passes {
380            vec!["Insufficient drift events or magnitude too low".to_string()]
381        } else {
382            Vec::new()
383        };
384
385        Ok(LabeledEventAnalysis {
386            total_events,
387            statistical_events,
388            business_events,
389            category_distribution: category_counts,
390            difficulty_distribution: difficulty_counts,
391            avg_magnitude,
392            avg_difficulty,
393            period_coverage: (min_period, max_period),
394            passes,
395            issues,
396        })
397    }
398
399    // Helper methods
400
401    fn calculate_rolling_means(&self, values: &[f64]) -> Vec<f64> {
402        if values.len() < self.window_size {
403            tracing::debug!(
404                "Drift detection: not enough values ({}) for window size ({}), returning empty",
405                values.len(),
406                self.window_size
407            );
408            return Vec::new();
409        }
410        let mut means = Vec::with_capacity(values.len() - self.window_size + 1);
411        for i in 0..=(values.len() - self.window_size) {
412            let window = &values[i..i + self.window_size];
413            let mean = window.iter().sum::<f64>() / self.window_size as f64;
414            means.push(mean);
415        }
416        means
417    }
418
419    fn calculate_rolling_stds(&self, values: &[f64]) -> Vec<f64> {
420        if values.len() < self.window_size {
421            tracing::debug!(
422                "Drift detection: not enough values ({}) for window size ({}), returning empty",
423                values.len(),
424                self.window_size
425            );
426            return Vec::new();
427        }
428        let mut stds = Vec::with_capacity(values.len() - self.window_size + 1);
429        for i in 0..=(values.len() - self.window_size) {
430            let window = &values[i..i + self.window_size];
431            let mean = window.iter().sum::<f64>() / self.window_size as f64;
432            let variance =
433                window.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / self.window_size as f64;
434            stds.push(variance.sqrt());
435        }
436        stds
437    }
438
439    fn detect_drift_points(&self, means: &[f64], stds: &[f64]) -> Vec<bool> {
440        if means.len() < 2 {
441            return vec![false; means.len()];
442        }
443
444        let mut detected = vec![false; means.len()];
445
446        // Calculate baseline statistics from first half
447        let baseline_end = means.len() / 2;
448        let baseline_mean = means[..baseline_end].iter().sum::<f64>() / baseline_end as f64;
449        let baseline_std = if baseline_end > 1 {
450            let variance = means[..baseline_end]
451                .iter()
452                .map(|x| (x - baseline_mean).powi(2))
453                .sum::<f64>()
454                / baseline_end as f64;
455            variance.sqrt().max(0.001) // Avoid division by zero
456        } else {
457            0.001
458        };
459
460        // Detect drift using z-score approach
461        for i in baseline_end..means.len() {
462            let z_score = (means[i] - baseline_mean).abs() / baseline_std;
463            let threshold = 1.96 / self.significance_level.sqrt(); // Adjust for significance
464
465            if z_score > threshold {
466                detected[i] = true;
467            }
468
469            // Also check for variance change
470            if i < stds.len() && baseline_end > 0 {
471                let baseline_var_mean =
472                    stds[..baseline_end].iter().sum::<f64>() / baseline_end as f64;
473                if baseline_var_mean > 0.001 {
474                    let var_ratio = stds[i] / baseline_var_mean;
475                    if !(0.5..=2.0).contains(&var_ratio) {
476                        detected[i] = true;
477                    }
478                }
479            }
480        }
481
482        detected
483    }
484
485    fn calculate_detection_metrics(
486        &self,
487        detected: &[bool],
488        ground_truth: &[Option<bool>],
489    ) -> DriftDetectionMetrics {
490        let mut true_positives = 0;
491        let mut false_positives = 0;
492        let mut true_negatives = 0;
493        let mut false_negatives = 0;
494        let mut detection_delays = Vec::new();
495
496        // Adjust for window offset
497        let offset = detected.len().saturating_sub(ground_truth.len());
498
499        for (i, &gt) in ground_truth.iter().enumerate() {
500            let detected_idx = i + offset;
501            if detected_idx >= detected.len() {
502                break;
503            }
504
505            let pred = detected[detected_idx];
506            match gt {
507                Some(true) => {
508                    if pred {
509                        true_positives += 1;
510                    } else {
511                        false_negatives += 1;
512                    }
513                }
514                Some(false) => {
515                    if pred {
516                        false_positives += 1;
517                    } else {
518                        true_negatives += 1;
519                    }
520                }
521                None => {}
522            }
523        }
524
525        // Calculate detection delay for true positives
526        let mut last_drift_start: Option<usize> = None;
527        for (i, &gt) in ground_truth.iter().enumerate() {
528            if gt == Some(true) && last_drift_start.is_none() {
529                last_drift_start = Some(i);
530            } else if gt == Some(false) {
531                last_drift_start = None;
532            }
533
534            let detected_idx = i + offset;
535            if detected_idx < detected.len() && detected[detected_idx] {
536                if let Some(start) = last_drift_start {
537                    detection_delays.push((i - start) as f64);
538                    last_drift_start = None;
539                }
540            }
541        }
542
543        let precision = if true_positives + false_positives > 0 {
544            true_positives as f64 / (true_positives + false_positives) as f64
545        } else {
546            0.0
547        };
548
549        let recall = if true_positives + false_negatives > 0 {
550            true_positives as f64 / (true_positives + false_negatives) as f64
551        } else {
552            0.0
553        };
554
555        let f1_score = if precision + recall > 0.0 {
556            2.0 * precision * recall / (precision + recall)
557        } else {
558            0.0
559        };
560
561        let mean_detection_delay = if detection_delays.is_empty() {
562            None
563        } else {
564            Some(detection_delays.iter().sum::<f64>() / detection_delays.len() as f64)
565        };
566
567        DriftDetectionMetrics {
568            true_positives,
569            false_positives,
570            true_negatives,
571            false_negatives,
572            precision,
573            recall,
574            f1_score,
575            mean_detection_delay,
576        }
577    }
578
579    fn calculate_hellinger_distance(&self, values: &[f64]) -> f64 {
580        if values.len() < 20 {
581            return 0.0;
582        }
583
584        let mid = values.len() / 2;
585        let first_half = &values[..mid];
586        let second_half = &values[mid..];
587
588        // Create histograms with 10 bins
589        let (min_val, max_val) = values.iter().fold((f64::MAX, f64::MIN), |(min, max), &v| {
590            (min.min(v), max.max(v))
591        });
592
593        if (max_val - min_val).abs() < f64::EPSILON {
594            return 0.0;
595        }
596
597        let num_bins = 10;
598        let bin_width = (max_val - min_val) / num_bins as f64;
599
600        let mut hist1 = vec![0.0; num_bins];
601        let mut hist2 = vec![0.0; num_bins];
602
603        for &v in first_half {
604            let bin = ((v - min_val) / bin_width).floor() as usize;
605            let bin = bin.min(num_bins - 1);
606            hist1[bin] += 1.0;
607        }
608
609        for &v in second_half {
610            let bin = ((v - min_val) / bin_width).floor() as usize;
611            let bin = bin.min(num_bins - 1);
612            hist2[bin] += 1.0;
613        }
614
615        // Normalize
616        let sum1: f64 = hist1.iter().sum();
617        let sum2: f64 = hist2.iter().sum();
618
619        if sum1 == 0.0 || sum2 == 0.0 {
620            return 0.0;
621        }
622
623        for h in &mut hist1 {
624            *h /= sum1;
625        }
626        for h in &mut hist2 {
627            *h /= sum2;
628        }
629
630        // Calculate Hellinger distance
631        let mut sum_sq_diff = 0.0;
632        for i in 0..num_bins {
633            let diff = hist1[i].sqrt() - hist2[i].sqrt();
634            sum_sq_diff += diff * diff;
635        }
636
637        (sum_sq_diff / 2.0).sqrt()
638    }
639
640    fn calculate_psi(&self, values: &[f64]) -> f64 {
641        if values.len() < 20 {
642            return 0.0;
643        }
644
645        let mid = values.len() / 2;
646        let baseline = &values[..mid];
647        let current = &values[mid..];
648
649        // Create histograms with 10 bins
650        let (min_val, max_val) = values.iter().fold((f64::MAX, f64::MIN), |(min, max), &v| {
651            (min.min(v), max.max(v))
652        });
653
654        if (max_val - min_val).abs() < f64::EPSILON {
655            return 0.0;
656        }
657
658        let num_bins = 10;
659        let bin_width = (max_val - min_val) / num_bins as f64;
660
661        let mut hist_baseline = vec![0.0; num_bins];
662        let mut hist_current = vec![0.0; num_bins];
663
664        for &v in baseline {
665            let bin = ((v - min_val) / bin_width).floor() as usize;
666            let bin = bin.min(num_bins - 1);
667            hist_baseline[bin] += 1.0;
668        }
669
670        for &v in current {
671            let bin = ((v - min_val) / bin_width).floor() as usize;
672            let bin = bin.min(num_bins - 1);
673            hist_current[bin] += 1.0;
674        }
675
676        // Normalize and add small constant to avoid log(0)
677        let epsilon = 0.0001;
678        let sum_baseline: f64 = hist_baseline.iter().sum();
679        let sum_current: f64 = hist_current.iter().sum();
680
681        if sum_baseline == 0.0 || sum_current == 0.0 {
682            return 0.0;
683        }
684
685        for h in &mut hist_baseline {
686            *h = (*h / sum_baseline).max(epsilon);
687        }
688        for h in &mut hist_current {
689            *h = (*h / sum_current).max(epsilon);
690        }
691
692        // Calculate PSI
693        let mut psi = 0.0;
694        for i in 0..num_bins {
695            let diff = hist_current[i] - hist_baseline[i];
696            let ratio = hist_current[i] / hist_baseline[i];
697            psi += diff * ratio.ln();
698        }
699
700        psi
701    }
702
703    fn calculate_drift_magnitude(&self, means: &[f64]) -> f64 {
704        if means.len() < 2 {
705            return 0.0;
706        }
707
708        let mid = means.len() / 2;
709        let first_mean = means[..mid].iter().sum::<f64>() / mid as f64;
710        let second_mean = means[mid..].iter().sum::<f64>() / (means.len() - mid) as f64;
711
712        if first_mean.abs() < f64::EPSILON {
713            return (second_mean - first_mean).abs();
714        }
715
716        ((second_mean - first_mean) / first_mean).abs()
717    }
718
719    fn calculate_mean_change(&self, means: &[f64]) -> f64 {
720        if means.len() < 2 {
721            return 0.0;
722        }
723        let first = means.first().unwrap_or(&0.0);
724        let last = means.last().unwrap_or(&0.0);
725        if first.abs() < f64::EPSILON {
726            return 0.0;
727        }
728        (last - first) / first
729    }
730
731    fn calculate_std_change(&self, stds: &[f64]) -> f64 {
732        if stds.len() < 2 {
733            return 0.0;
734        }
735        let first = stds.first().unwrap_or(&0.0);
736        let last = stds.last().unwrap_or(&0.0);
737        if first.abs() < f64::EPSILON {
738            return 0.0;
739        }
740        (last - first) / first
741    }
742
743    fn evaluate_pass_status(&self, metrics: &DriftDetectionMetrics, drift_magnitude: f64) -> bool {
744        // Pass if we have reasonable detection metrics or magnitude is below threshold
745        if drift_magnitude < self.min_magnitude_threshold {
746            return true; // No significant drift to detect
747        }
748
749        // If there's significant drift, we need decent detection
750        metrics.f1_score >= 0.5 || metrics.precision >= 0.6 || metrics.recall >= 0.6
751    }
752
753    fn collect_issues(
754        &self,
755        metrics: &DriftDetectionMetrics,
756        drift_magnitude: f64,
757        drift_count: usize,
758    ) -> Vec<String> {
759        let mut issues = Vec::new();
760
761        if drift_magnitude >= self.min_magnitude_threshold {
762            if metrics.precision < 0.5 {
763                issues.push(format!(
764                    "Low precision ({:.2}): many false positives",
765                    metrics.precision
766                ));
767            }
768            if metrics.recall < 0.5 {
769                issues.push(format!(
770                    "Low recall ({:.2}): many drift events missed",
771                    metrics.recall
772                ));
773            }
774            if let Some(delay) = metrics.mean_detection_delay {
775                if delay > 3.0 {
776                    issues.push(format!("High detection delay ({:.1} periods)", delay));
777                }
778            }
779        }
780
781        if drift_count == 0 && drift_magnitude >= self.min_magnitude_threshold {
782            issues.push("No drift detected despite significant magnitude change".to_string());
783        }
784
785        issues
786    }
787}
788
789impl Default for DriftDetectionAnalyzer {
790    fn default() -> Self {
791        Self::new(0.05)
792    }
793}
794
795// =============================================================================
796// Analysis Results
797// =============================================================================
798
799/// Results from drift detection analysis.
800#[derive(Debug, Clone, Serialize, Deserialize)]
801pub struct DriftDetectionAnalysis {
802    /// Number of data points analyzed.
803    pub sample_size: usize,
804    /// Whether any drift was detected.
805    pub drift_detected: bool,
806    /// Number of drift points detected.
807    pub drift_count: usize,
808    /// Overall magnitude of detected drift.
809    pub drift_magnitude: f64,
810    /// Detection metrics (precision, recall, F1).
811    pub detection_metrics: DriftDetectionMetrics,
812    /// Hellinger distance between first and second half.
813    pub hellinger_distance: Option<f64>,
814    /// Population Stability Index.
815    pub psi: Option<f64>,
816    /// Relative change in rolling mean.
817    pub rolling_mean_change: f64,
818    /// Relative change in rolling standard deviation.
819    pub rolling_std_change: f64,
820    /// Whether the analysis passes quality thresholds.
821    pub passes: bool,
822    /// Issues identified during analysis.
823    pub issues: Vec<String>,
824}
825
826/// Drift detection performance metrics.
827#[derive(Debug, Clone, Default, Serialize, Deserialize)]
828pub struct DriftDetectionMetrics {
829    /// True positive count.
830    pub true_positives: usize,
831    /// False positive count.
832    pub false_positives: usize,
833    /// True negative count.
834    pub true_negatives: usize,
835    /// False negative count.
836    pub false_negatives: usize,
837    /// Precision (TP / (TP + FP)).
838    pub precision: f64,
839    /// Recall (TP / (TP + FN)).
840    pub recall: f64,
841    /// F1 score (harmonic mean of precision and recall).
842    pub f1_score: f64,
843    /// Mean delay in detecting drift (in periods).
844    pub mean_detection_delay: Option<f64>,
845}
846
847/// Analysis of labeled drift events.
848#[derive(Debug, Clone, Serialize, Deserialize)]
849pub struct LabeledEventAnalysis {
850    /// Total number of labeled events.
851    pub total_events: usize,
852    /// Number of statistical drift events.
853    pub statistical_events: usize,
854    /// Number of business event drifts.
855    pub business_events: usize,
856    /// Distribution by event category.
857    pub category_distribution: HashMap<DriftEventCategory, usize>,
858    /// Distribution by detection difficulty.
859    pub difficulty_distribution: HashMap<DetectionDifficulty, usize>,
860    /// Average drift magnitude.
861    pub avg_magnitude: f64,
862    /// Average detection difficulty score.
863    pub avg_difficulty: f64,
864    /// Period coverage (min_period, max_period).
865    pub period_coverage: (u32, u32),
866    /// Whether the analysis passes quality thresholds.
867    pub passes: bool,
868    /// Issues identified.
869    pub issues: Vec<String>,
870}
871
872impl LabeledEventAnalysis {
873    /// Create an empty analysis result.
874    pub fn empty() -> Self {
875        Self {
876            total_events: 0,
877            statistical_events: 0,
878            business_events: 0,
879            category_distribution: HashMap::new(),
880            difficulty_distribution: HashMap::new(),
881            avg_magnitude: 0.0,
882            avg_difficulty: 0.0,
883            period_coverage: (0, 0),
884            passes: true,
885            issues: Vec::new(),
886        }
887    }
888}
889
890// =============================================================================
891// Tests
892// =============================================================================
893
894#[cfg(test)]
895#[allow(clippy::unwrap_used)]
896mod tests {
897    use super::*;
898
899    #[test]
900    fn test_drift_detection_entry_creation() {
901        let entry = DriftDetectionEntry::new(1, 100.0, Some(true));
902        assert_eq!(entry.period, 1);
903        assert_eq!(entry.value, 100.0);
904        assert_eq!(entry.ground_truth_drift, Some(true));
905    }
906
907    #[test]
908    fn test_drift_detection_entry_with_metadata() {
909        let entry = DriftDetectionEntry::with_metadata(5, 150.0, true, "MeanShift", 0.15, 0.3);
910        assert_eq!(entry.period, 5);
911        assert_eq!(entry.drift_type, Some("MeanShift".to_string()));
912        assert_eq!(entry.drift_magnitude, Some(0.15));
913        assert_eq!(entry.detection_difficulty, Some(0.3));
914    }
915
916    #[test]
917    fn test_drift_event_category_names() {
918        assert_eq!(DriftEventCategory::MeanShift.name(), "Mean Shift");
919        assert_eq!(
920            DriftEventCategory::OrganizationalEvent.name(),
921            "Organizational Event"
922        );
923    }
924
925    #[test]
926    fn test_drift_event_category_classification() {
927        assert!(DriftEventCategory::MeanShift.is_statistical());
928        assert!(!DriftEventCategory::MeanShift.is_business_event());
929        assert!(DriftEventCategory::OrganizationalEvent.is_business_event());
930        assert!(!DriftEventCategory::OrganizationalEvent.is_statistical());
931    }
932
933    #[test]
934    fn test_detection_difficulty_conversion() {
935        assert_eq!(DetectionDifficulty::Easy.to_score(), 0.0);
936        assert_eq!(DetectionDifficulty::Medium.to_score(), 0.5);
937        assert_eq!(DetectionDifficulty::Hard.to_score(), 1.0);
938
939        assert_eq!(
940            DetectionDifficulty::from_score(0.1),
941            DetectionDifficulty::Easy
942        );
943        assert_eq!(
944            DetectionDifficulty::from_score(0.5),
945            DetectionDifficulty::Medium
946        );
947        assert_eq!(
948            DetectionDifficulty::from_score(0.8),
949            DetectionDifficulty::Hard
950        );
951    }
952
953    #[test]
954    fn test_analyzer_creation() {
955        let analyzer = DriftDetectionAnalyzer::new(0.05)
956            .with_window_size(15)
957            .with_min_magnitude(0.1)
958            .with_hellinger(true)
959            .with_psi(true);
960
961        assert_eq!(analyzer.significance_level, 0.05);
962        assert_eq!(analyzer.window_size, 15);
963        assert_eq!(analyzer.min_magnitude_threshold, 0.1);
964    }
965
966    #[test]
967    fn test_analyze_no_drift() {
968        let analyzer = DriftDetectionAnalyzer::new(0.05).with_window_size(5);
969
970        // Create stable data with no drift
971        let entries: Vec<DriftDetectionEntry> = (0..30)
972            .map(|i| DriftDetectionEntry::new(i, 100.0 + (i as f64 * 0.01), Some(false)))
973            .collect();
974
975        let result = analyzer.analyze(&entries).unwrap();
976        assert!(!result.drift_detected || result.drift_count < 5);
977        assert!(result.drift_magnitude < 0.1);
978    }
979
980    #[test]
981    fn test_analyze_with_drift() {
982        let analyzer = DriftDetectionAnalyzer::new(0.05).with_window_size(5);
983
984        // Create data with clear drift in the middle
985        let mut entries: Vec<DriftDetectionEntry> = (0..15)
986            .map(|i| DriftDetectionEntry::new(i, 100.0, Some(false)))
987            .collect();
988
989        // Add drift after period 15
990        for i in 15..30 {
991            entries.push(DriftDetectionEntry::new(i, 150.0, Some(true)));
992        }
993
994        let result = analyzer.analyze(&entries).unwrap();
995        assert!(result.drift_detected);
996        assert!(result.drift_magnitude > 0.3);
997    }
998
999    #[test]
1000    fn test_analyze_insufficient_data() {
1001        let analyzer = DriftDetectionAnalyzer::new(0.05).with_window_size(10);
1002
1003        let entries: Vec<DriftDetectionEntry> = (0..5)
1004            .map(|i| DriftDetectionEntry::new(i, 100.0, None))
1005            .collect();
1006
1007        let result = analyzer.analyze(&entries);
1008        assert!(result.is_err());
1009    }
1010
1011    #[test]
1012    fn test_analyze_labeled_events() {
1013        let analyzer = DriftDetectionAnalyzer::new(0.05);
1014
1015        let events = vec![
1016            LabeledDriftEvent {
1017                event_id: "E1".to_string(),
1018                event_type: DriftEventCategory::MeanShift,
1019                start_period: 10,
1020                end_period: Some(15),
1021                affected_fields: vec!["amount".to_string()],
1022                magnitude: 0.15,
1023                detection_difficulty: DetectionDifficulty::Easy,
1024            },
1025            LabeledDriftEvent {
1026                event_id: "E2".to_string(),
1027                event_type: DriftEventCategory::OrganizationalEvent,
1028                start_period: 20,
1029                end_period: Some(25),
1030                affected_fields: vec!["volume".to_string()],
1031                magnitude: 0.30,
1032                detection_difficulty: DetectionDifficulty::Medium,
1033            },
1034        ];
1035
1036        let result = analyzer.analyze_labeled_events(&events).unwrap();
1037        assert_eq!(result.total_events, 2);
1038        assert_eq!(result.statistical_events, 1);
1039        assert_eq!(result.business_events, 1);
1040        assert!(result.avg_magnitude > 0.2);
1041        assert!(result.passes);
1042    }
1043
1044    #[test]
1045    fn test_empty_labeled_events() {
1046        let analyzer = DriftDetectionAnalyzer::new(0.05);
1047        let result = analyzer.analyze_labeled_events(&[]).unwrap();
1048        assert_eq!(result.total_events, 0);
1049        assert!(result.passes);
1050    }
1051
1052    #[test]
1053    fn test_hellinger_distance_no_drift() {
1054        let analyzer = DriftDetectionAnalyzer::new(0.05);
1055
1056        // Stable data
1057        let entries: Vec<DriftDetectionEntry> = (0..40)
1058            .map(|i| DriftDetectionEntry::new(i, 100.0 + (i as f64 % 5.0), None))
1059            .collect();
1060
1061        let result = analyzer.analyze(&entries).unwrap();
1062        assert!(result.hellinger_distance.unwrap() < 0.3);
1063    }
1064
1065    #[test]
1066    fn test_psi_calculation() {
1067        let analyzer = DriftDetectionAnalyzer::new(0.05);
1068
1069        // Data with drift
1070        let mut entries: Vec<DriftDetectionEntry> = (0..20)
1071            .map(|i| DriftDetectionEntry::new(i, 100.0, None))
1072            .collect();
1073        for i in 20..40 {
1074            entries.push(DriftDetectionEntry::new(i, 200.0, None));
1075        }
1076
1077        let result = analyzer.analyze(&entries).unwrap();
1078        assert!(result.psi.is_some());
1079        // PSI > 0.1 indicates significant drift
1080        assert!(result.psi.unwrap() > 0.0);
1081    }
1082
1083    #[test]
1084    fn test_detection_metrics_calculation() {
1085        let analyzer = DriftDetectionAnalyzer::new(0.05).with_window_size(3);
1086
1087        // Create data where we know the ground truth
1088        let mut entries = Vec::new();
1089        for i in 0..10 {
1090            entries.push(DriftDetectionEntry::new(i, 100.0, Some(false)));
1091        }
1092        for i in 10..20 {
1093            entries.push(DriftDetectionEntry::new(i, 200.0, Some(true)));
1094        }
1095
1096        let result = analyzer.analyze(&entries).unwrap();
1097
1098        // Should have some detection capability
1099        assert!(result.detection_metrics.precision >= 0.0);
1100        assert!(result.detection_metrics.recall >= 0.0);
1101    }
1102}