Skip to main content

datasynth_eval/statistical/
drift_detection.rs

1//! Drift Detection Evaluation Module.
2//!
3//! Provides tools for evaluating drift detection ground truth labels and
4//! validating that generated drift events are detectable and properly labeled.
5//!
6//! # Overview
7//!
8//! This module evaluates the quality and detectability of drift events in
9//! synthetic data by analyzing:
10//!
11//! - Statistical distribution shifts (mean, variance changes)
12//! - Categorical shifts (proportion changes, new categories)
13//! - Temporal pattern changes (seasonality, trend)
14//! - Regulatory and organizational event impacts
15//!
16//! # Example
17//!
18//! ```ignore
19//! use datasynth_eval::statistical::{DriftDetectionAnalyzer, DriftDetectionEntry};
20//!
21//! let analyzer = DriftDetectionAnalyzer::new(0.05);
22//! let entries = vec![
23//!     DriftDetectionEntry::new(1, 100.0, Some(true)),
24//!     DriftDetectionEntry::new(2, 102.0, Some(false)),
25//!     // ...
26//! ];
27//!
28//! let analysis = analyzer.analyze(&entries)?;
29//! println!("Drift detected: {}", analysis.drift_detected);
30//! ```
31
32use crate::error::{EvalError, EvalResult};
33use serde::{Deserialize, Serialize};
34use std::collections::HashMap;
35
36// =============================================================================
37// Drift Detection Entry
38// =============================================================================
39
40/// A single data point for drift detection analysis.
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct DriftDetectionEntry {
43    /// Period number (e.g., month number from start).
44    pub period: u32,
45    /// Observed value at this period.
46    pub value: f64,
47    /// Ground truth label: true if this period has drift, false otherwise.
48    pub ground_truth_drift: Option<bool>,
49    /// Drift event type if labeled.
50    pub drift_type: Option<String>,
51    /// Magnitude of drift if known.
52    pub drift_magnitude: Option<f64>,
53    /// Detection difficulty (0.0 = easy, 1.0 = hard).
54    pub detection_difficulty: Option<f64>,
55}
56
57impl DriftDetectionEntry {
58    /// Create a new drift detection entry.
59    pub fn new(period: u32, value: f64, ground_truth_drift: Option<bool>) -> Self {
60        Self {
61            period,
62            value,
63            ground_truth_drift,
64            drift_type: None,
65            drift_magnitude: None,
66            detection_difficulty: None,
67        }
68    }
69
70    /// Create entry with full metadata.
71    pub fn with_metadata(
72        period: u32,
73        value: f64,
74        ground_truth_drift: bool,
75        drift_type: impl Into<String>,
76        drift_magnitude: f64,
77        detection_difficulty: f64,
78    ) -> Self {
79        Self {
80            period,
81            value,
82            ground_truth_drift: Some(ground_truth_drift),
83            drift_type: Some(drift_type.into()),
84            drift_magnitude: Some(drift_magnitude),
85            detection_difficulty: Some(detection_difficulty),
86        }
87    }
88}
89
90// =============================================================================
91// Labeled Drift Event
92// =============================================================================
93
94/// A labeled drift event from ground truth data.
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct LabeledDriftEvent {
97    /// Unique event identifier.
98    pub event_id: String,
99    /// Event type classification.
100    pub event_type: DriftEventCategory,
101    /// Start period of the drift.
102    pub start_period: u32,
103    /// End period of the drift (None if ongoing).
104    pub end_period: Option<u32>,
105    /// Affected fields/metrics.
106    pub affected_fields: Vec<String>,
107    /// Magnitude of the drift effect.
108    pub magnitude: f64,
109    /// Detection difficulty level.
110    pub detection_difficulty: DetectionDifficulty,
111}
112
113/// Categories of drift events.
114#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
115pub enum DriftEventCategory {
116    /// Mean shift in distribution.
117    MeanShift,
118    /// Variance change in distribution.
119    VarianceChange,
120    /// Trend change (slope).
121    TrendChange,
122    /// Seasonality pattern change.
123    SeasonalityChange,
124    /// Categorical proportion shift.
125    ProportionShift,
126    /// New category emergence.
127    NewCategory,
128    /// Organizational event (acquisition, merger, etc.).
129    OrganizationalEvent,
130    /// Regulatory change impact.
131    RegulatoryChange,
132    /// Technology transition impact.
133    TechnologyTransition,
134    /// Economic cycle effect.
135    EconomicCycle,
136    /// Process evolution.
137    ProcessEvolution,
138}
139
140impl DriftEventCategory {
141    /// Get human-readable name.
142    pub fn name(&self) -> &'static str {
143        match self {
144            Self::MeanShift => "Mean Shift",
145            Self::VarianceChange => "Variance Change",
146            Self::TrendChange => "Trend Change",
147            Self::SeasonalityChange => "Seasonality Change",
148            Self::ProportionShift => "Proportion Shift",
149            Self::NewCategory => "New Category",
150            Self::OrganizationalEvent => "Organizational Event",
151            Self::RegulatoryChange => "Regulatory Change",
152            Self::TechnologyTransition => "Technology Transition",
153            Self::EconomicCycle => "Economic Cycle",
154            Self::ProcessEvolution => "Process Evolution",
155        }
156    }
157
158    /// Check if this is a statistical drift type.
159    pub fn is_statistical(&self) -> bool {
160        matches!(
161            self,
162            Self::MeanShift | Self::VarianceChange | Self::TrendChange | Self::SeasonalityChange
163        )
164    }
165
166    /// Check if this is a business event drift type.
167    pub fn is_business_event(&self) -> bool {
168        matches!(
169            self,
170            Self::OrganizationalEvent
171                | Self::RegulatoryChange
172                | Self::TechnologyTransition
173                | Self::ProcessEvolution
174        )
175    }
176}
177
178/// Detection difficulty levels.
179#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
180pub enum DetectionDifficulty {
181    /// Easy to detect (large magnitude, clear signal).
182    Easy,
183    /// Medium difficulty.
184    Medium,
185    /// Hard to detect (small magnitude, noisy signal).
186    Hard,
187}
188
189impl DetectionDifficulty {
190    /// Convert to numeric score (0.0 = easy, 1.0 = hard).
191    pub fn to_score(&self) -> f64 {
192        match self {
193            Self::Easy => 0.0,
194            Self::Medium => 0.5,
195            Self::Hard => 1.0,
196        }
197    }
198
199    /// Create from numeric score.
200    pub fn from_score(score: f64) -> Self {
201        if score < 0.33 {
202            Self::Easy
203        } else if score < 0.67 {
204            Self::Medium
205        } else {
206            Self::Hard
207        }
208    }
209}
210
211// =============================================================================
212// Drift Detection Analyzer
213// =============================================================================
214
215/// Analyzer for drift detection evaluation.
216#[derive(Debug, Clone)]
217pub struct DriftDetectionAnalyzer {
218    /// Significance level for statistical tests.
219    significance_level: f64,
220    /// Window size for rolling statistics.
221    window_size: usize,
222    /// Minimum magnitude threshold to consider as drift.
223    min_magnitude_threshold: f64,
224    /// Enable Hellinger distance calculation.
225    use_hellinger: bool,
226    /// Enable Population Stability Index (PSI) calculation.
227    use_psi: bool,
228}
229
230impl DriftDetectionAnalyzer {
231    /// Create a new drift detection analyzer.
232    pub fn new(significance_level: f64) -> Self {
233        Self {
234            significance_level,
235            window_size: 10,
236            min_magnitude_threshold: 0.05,
237            use_hellinger: true,
238            use_psi: true,
239        }
240    }
241
242    /// Set the rolling window size.
243    pub fn with_window_size(mut self, size: usize) -> Self {
244        self.window_size = size;
245        self
246    }
247
248    /// Set the minimum magnitude threshold.
249    pub fn with_min_magnitude(mut self, threshold: f64) -> Self {
250        self.min_magnitude_threshold = threshold;
251        self
252    }
253
254    /// Enable or disable Hellinger distance calculation.
255    pub fn with_hellinger(mut self, enabled: bool) -> Self {
256        self.use_hellinger = enabled;
257        self
258    }
259
260    /// Enable or disable PSI calculation.
261    pub fn with_psi(mut self, enabled: bool) -> Self {
262        self.use_psi = enabled;
263        self
264    }
265
266    /// Analyze drift detection entries.
267    pub fn analyze(&self, entries: &[DriftDetectionEntry]) -> EvalResult<DriftDetectionAnalysis> {
268        if entries.len() < self.window_size * 2 {
269            return Err(EvalError::InsufficientData {
270                required: self.window_size * 2,
271                actual: entries.len(),
272            });
273        }
274
275        // Extract values and labels
276        let values: Vec<f64> = entries.iter().map(|e| e.value).collect();
277        let ground_truth: Vec<Option<bool>> =
278            entries.iter().map(|e| e.ground_truth_drift).collect();
279
280        // Calculate rolling statistics
281        let rolling_means = self.calculate_rolling_means(&values);
282        let rolling_stds = self.calculate_rolling_stds(&values);
283
284        // Detect drift points using CUSUM-like approach
285        let detected_drift = self.detect_drift_points(&rolling_means, &rolling_stds);
286
287        // Calculate detection metrics if ground truth is available
288        let metrics = self.calculate_detection_metrics(&detected_drift, &ground_truth);
289
290        // Calculate statistical measures
291        let hellinger_distance = if self.use_hellinger {
292            Some(self.calculate_hellinger_distance(&values))
293        } else {
294            None
295        };
296
297        let psi = if self.use_psi {
298            Some(self.calculate_psi(&values))
299        } else {
300            None
301        };
302
303        // Determine overall drift status
304        let drift_detected = detected_drift.iter().any(|&d| d);
305        let drift_count = detected_drift.iter().filter(|&&d| d).count();
306
307        // Calculate magnitude of detected drifts
308        let drift_magnitude = self.calculate_drift_magnitude(&rolling_means);
309
310        let passes = self.evaluate_pass_status(&metrics, drift_magnitude);
311        let issues = self.collect_issues(&metrics, drift_magnitude, drift_count);
312
313        Ok(DriftDetectionAnalysis {
314            sample_size: entries.len(),
315            drift_detected,
316            drift_count,
317            drift_magnitude,
318            detection_metrics: metrics,
319            hellinger_distance,
320            psi,
321            rolling_mean_change: self.calculate_mean_change(&rolling_means),
322            rolling_std_change: self.calculate_std_change(&rolling_stds),
323            passes,
324            issues,
325        })
326    }
327
328    /// Analyze labeled drift events for quality.
329    pub fn analyze_labeled_events(
330        &self,
331        events: &[LabeledDriftEvent],
332    ) -> EvalResult<LabeledEventAnalysis> {
333        if events.is_empty() {
334            return Ok(LabeledEventAnalysis::empty());
335        }
336
337        // Count events by category
338        let mut category_counts: HashMap<DriftEventCategory, usize> = HashMap::new();
339        for event in events {
340            *category_counts.entry(event.event_type).or_insert(0) += 1;
341        }
342
343        // Count events by difficulty
344        let mut difficulty_counts: HashMap<DetectionDifficulty, usize> = HashMap::new();
345        for event in events {
346            *difficulty_counts
347                .entry(event.detection_difficulty)
348                .or_insert(0) += 1;
349        }
350
351        // Calculate coverage metrics
352        let total_events = events.len();
353        let statistical_events = events
354            .iter()
355            .filter(|e| e.event_type.is_statistical())
356            .count();
357        let business_events = events
358            .iter()
359            .filter(|e| e.event_type.is_business_event())
360            .count();
361
362        // Calculate average magnitude and difficulty
363        let avg_magnitude = events.iter().map(|e| e.magnitude).sum::<f64>() / total_events as f64;
364        let avg_difficulty = events
365            .iter()
366            .map(|e| e.detection_difficulty.to_score())
367            .sum::<f64>()
368            / total_events as f64;
369
370        // Calculate period coverage
371        let min_period = events.iter().map(|e| e.start_period).min().unwrap_or(0);
372        let max_period = events
373            .iter()
374            .filter_map(|e| e.end_period)
375            .max()
376            .unwrap_or(min_period);
377
378        let passes = total_events > 0 && avg_magnitude >= self.min_magnitude_threshold;
379        let issues = if !passes {
380            vec!["Insufficient drift events or magnitude too low".to_string()]
381        } else {
382            Vec::new()
383        };
384
385        Ok(LabeledEventAnalysis {
386            total_events,
387            statistical_events,
388            business_events,
389            category_distribution: category_counts,
390            difficulty_distribution: difficulty_counts,
391            avg_magnitude,
392            avg_difficulty,
393            period_coverage: (min_period, max_period),
394            passes,
395            issues,
396        })
397    }
398
399    // Helper methods
400
401    fn calculate_rolling_means(&self, values: &[f64]) -> Vec<f64> {
402        let mut means = Vec::with_capacity(values.len() - self.window_size + 1);
403        for i in 0..=(values.len() - self.window_size) {
404            let window = &values[i..i + self.window_size];
405            let mean = window.iter().sum::<f64>() / self.window_size as f64;
406            means.push(mean);
407        }
408        means
409    }
410
411    fn calculate_rolling_stds(&self, values: &[f64]) -> Vec<f64> {
412        let mut stds = Vec::with_capacity(values.len() - self.window_size + 1);
413        for i in 0..=(values.len() - self.window_size) {
414            let window = &values[i..i + self.window_size];
415            let mean = window.iter().sum::<f64>() / self.window_size as f64;
416            let variance =
417                window.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / self.window_size as f64;
418            stds.push(variance.sqrt());
419        }
420        stds
421    }
422
423    fn detect_drift_points(&self, means: &[f64], stds: &[f64]) -> Vec<bool> {
424        if means.len() < 2 {
425            return vec![false; means.len()];
426        }
427
428        let mut detected = vec![false; means.len()];
429
430        // Calculate baseline statistics from first half
431        let baseline_end = means.len() / 2;
432        let baseline_mean = means[..baseline_end].iter().sum::<f64>() / baseline_end as f64;
433        let baseline_std = if baseline_end > 1 {
434            let variance = means[..baseline_end]
435                .iter()
436                .map(|x| (x - baseline_mean).powi(2))
437                .sum::<f64>()
438                / baseline_end as f64;
439            variance.sqrt().max(0.001) // Avoid division by zero
440        } else {
441            0.001
442        };
443
444        // Detect drift using z-score approach
445        for i in baseline_end..means.len() {
446            let z_score = (means[i] - baseline_mean).abs() / baseline_std;
447            let threshold = 1.96 / self.significance_level.sqrt(); // Adjust for significance
448
449            if z_score > threshold {
450                detected[i] = true;
451            }
452
453            // Also check for variance change
454            if i < stds.len() && baseline_end > 0 {
455                let baseline_var_mean =
456                    stds[..baseline_end].iter().sum::<f64>() / baseline_end as f64;
457                if baseline_var_mean > 0.001 {
458                    let var_ratio = stds[i] / baseline_var_mean;
459                    if !(0.5..=2.0).contains(&var_ratio) {
460                        detected[i] = true;
461                    }
462                }
463            }
464        }
465
466        detected
467    }
468
469    fn calculate_detection_metrics(
470        &self,
471        detected: &[bool],
472        ground_truth: &[Option<bool>],
473    ) -> DriftDetectionMetrics {
474        let mut true_positives = 0;
475        let mut false_positives = 0;
476        let mut true_negatives = 0;
477        let mut false_negatives = 0;
478        let mut detection_delays = Vec::new();
479
480        // Adjust for window offset
481        let offset = detected.len().saturating_sub(ground_truth.len());
482
483        for (i, &gt) in ground_truth.iter().enumerate() {
484            let detected_idx = i + offset;
485            if detected_idx >= detected.len() {
486                break;
487            }
488
489            let pred = detected[detected_idx];
490            match gt {
491                Some(true) => {
492                    if pred {
493                        true_positives += 1;
494                    } else {
495                        false_negatives += 1;
496                    }
497                }
498                Some(false) => {
499                    if pred {
500                        false_positives += 1;
501                    } else {
502                        true_negatives += 1;
503                    }
504                }
505                None => {}
506            }
507        }
508
509        // Calculate detection delay for true positives
510        let mut last_drift_start: Option<usize> = None;
511        for (i, &gt) in ground_truth.iter().enumerate() {
512            if gt == Some(true) && last_drift_start.is_none() {
513                last_drift_start = Some(i);
514            } else if gt == Some(false) {
515                last_drift_start = None;
516            }
517
518            let detected_idx = i + offset;
519            if detected_idx < detected.len() && detected[detected_idx] {
520                if let Some(start) = last_drift_start {
521                    detection_delays.push((i - start) as f64);
522                    last_drift_start = None;
523                }
524            }
525        }
526
527        let precision = if true_positives + false_positives > 0 {
528            true_positives as f64 / (true_positives + false_positives) as f64
529        } else {
530            0.0
531        };
532
533        let recall = if true_positives + false_negatives > 0 {
534            true_positives as f64 / (true_positives + false_negatives) as f64
535        } else {
536            0.0
537        };
538
539        let f1_score = if precision + recall > 0.0 {
540            2.0 * precision * recall / (precision + recall)
541        } else {
542            0.0
543        };
544
545        let mean_detection_delay = if detection_delays.is_empty() {
546            None
547        } else {
548            Some(detection_delays.iter().sum::<f64>() / detection_delays.len() as f64)
549        };
550
551        DriftDetectionMetrics {
552            true_positives,
553            false_positives,
554            true_negatives,
555            false_negatives,
556            precision,
557            recall,
558            f1_score,
559            mean_detection_delay,
560        }
561    }
562
563    fn calculate_hellinger_distance(&self, values: &[f64]) -> f64 {
564        if values.len() < 20 {
565            return 0.0;
566        }
567
568        let mid = values.len() / 2;
569        let first_half = &values[..mid];
570        let second_half = &values[mid..];
571
572        // Create histograms with 10 bins
573        let (min_val, max_val) = values.iter().fold((f64::MAX, f64::MIN), |(min, max), &v| {
574            (min.min(v), max.max(v))
575        });
576
577        if (max_val - min_val).abs() < f64::EPSILON {
578            return 0.0;
579        }
580
581        let num_bins = 10;
582        let bin_width = (max_val - min_val) / num_bins as f64;
583
584        let mut hist1 = vec![0.0; num_bins];
585        let mut hist2 = vec![0.0; num_bins];
586
587        for &v in first_half {
588            let bin = ((v - min_val) / bin_width).floor() as usize;
589            let bin = bin.min(num_bins - 1);
590            hist1[bin] += 1.0;
591        }
592
593        for &v in second_half {
594            let bin = ((v - min_val) / bin_width).floor() as usize;
595            let bin = bin.min(num_bins - 1);
596            hist2[bin] += 1.0;
597        }
598
599        // Normalize
600        let sum1: f64 = hist1.iter().sum();
601        let sum2: f64 = hist2.iter().sum();
602
603        if sum1 == 0.0 || sum2 == 0.0 {
604            return 0.0;
605        }
606
607        for h in &mut hist1 {
608            *h /= sum1;
609        }
610        for h in &mut hist2 {
611            *h /= sum2;
612        }
613
614        // Calculate Hellinger distance
615        let mut sum_sq_diff = 0.0;
616        for i in 0..num_bins {
617            let diff = hist1[i].sqrt() - hist2[i].sqrt();
618            sum_sq_diff += diff * diff;
619        }
620
621        (sum_sq_diff / 2.0).sqrt()
622    }
623
624    fn calculate_psi(&self, values: &[f64]) -> f64 {
625        if values.len() < 20 {
626            return 0.0;
627        }
628
629        let mid = values.len() / 2;
630        let baseline = &values[..mid];
631        let current = &values[mid..];
632
633        // Create histograms with 10 bins
634        let (min_val, max_val) = values.iter().fold((f64::MAX, f64::MIN), |(min, max), &v| {
635            (min.min(v), max.max(v))
636        });
637
638        if (max_val - min_val).abs() < f64::EPSILON {
639            return 0.0;
640        }
641
642        let num_bins = 10;
643        let bin_width = (max_val - min_val) / num_bins as f64;
644
645        let mut hist_baseline = vec![0.0; num_bins];
646        let mut hist_current = vec![0.0; num_bins];
647
648        for &v in baseline {
649            let bin = ((v - min_val) / bin_width).floor() as usize;
650            let bin = bin.min(num_bins - 1);
651            hist_baseline[bin] += 1.0;
652        }
653
654        for &v in current {
655            let bin = ((v - min_val) / bin_width).floor() as usize;
656            let bin = bin.min(num_bins - 1);
657            hist_current[bin] += 1.0;
658        }
659
660        // Normalize and add small constant to avoid log(0)
661        let epsilon = 0.0001;
662        let sum_baseline: f64 = hist_baseline.iter().sum();
663        let sum_current: f64 = hist_current.iter().sum();
664
665        if sum_baseline == 0.0 || sum_current == 0.0 {
666            return 0.0;
667        }
668
669        for h in &mut hist_baseline {
670            *h = (*h / sum_baseline).max(epsilon);
671        }
672        for h in &mut hist_current {
673            *h = (*h / sum_current).max(epsilon);
674        }
675
676        // Calculate PSI
677        let mut psi = 0.0;
678        for i in 0..num_bins {
679            let diff = hist_current[i] - hist_baseline[i];
680            let ratio = hist_current[i] / hist_baseline[i];
681            psi += diff * ratio.ln();
682        }
683
684        psi
685    }
686
687    fn calculate_drift_magnitude(&self, means: &[f64]) -> f64 {
688        if means.len() < 2 {
689            return 0.0;
690        }
691
692        let mid = means.len() / 2;
693        let first_mean = means[..mid].iter().sum::<f64>() / mid as f64;
694        let second_mean = means[mid..].iter().sum::<f64>() / (means.len() - mid) as f64;
695
696        if first_mean.abs() < f64::EPSILON {
697            return (second_mean - first_mean).abs();
698        }
699
700        ((second_mean - first_mean) / first_mean).abs()
701    }
702
703    fn calculate_mean_change(&self, means: &[f64]) -> f64 {
704        if means.len() < 2 {
705            return 0.0;
706        }
707        let first = means.first().unwrap_or(&0.0);
708        let last = means.last().unwrap_or(&0.0);
709        if first.abs() < f64::EPSILON {
710            return 0.0;
711        }
712        (last - first) / first
713    }
714
715    fn calculate_std_change(&self, stds: &[f64]) -> f64 {
716        if stds.len() < 2 {
717            return 0.0;
718        }
719        let first = stds.first().unwrap_or(&0.0);
720        let last = stds.last().unwrap_or(&0.0);
721        if first.abs() < f64::EPSILON {
722            return 0.0;
723        }
724        (last - first) / first
725    }
726
727    fn evaluate_pass_status(&self, metrics: &DriftDetectionMetrics, drift_magnitude: f64) -> bool {
728        // Pass if we have reasonable detection metrics or magnitude is below threshold
729        if drift_magnitude < self.min_magnitude_threshold {
730            return true; // No significant drift to detect
731        }
732
733        // If there's significant drift, we need decent detection
734        metrics.f1_score >= 0.5 || metrics.precision >= 0.6 || metrics.recall >= 0.6
735    }
736
737    fn collect_issues(
738        &self,
739        metrics: &DriftDetectionMetrics,
740        drift_magnitude: f64,
741        drift_count: usize,
742    ) -> Vec<String> {
743        let mut issues = Vec::new();
744
745        if drift_magnitude >= self.min_magnitude_threshold {
746            if metrics.precision < 0.5 {
747                issues.push(format!(
748                    "Low precision ({:.2}): many false positives",
749                    metrics.precision
750                ));
751            }
752            if metrics.recall < 0.5 {
753                issues.push(format!(
754                    "Low recall ({:.2}): many drift events missed",
755                    metrics.recall
756                ));
757            }
758            if let Some(delay) = metrics.mean_detection_delay {
759                if delay > 3.0 {
760                    issues.push(format!("High detection delay ({:.1} periods)", delay));
761                }
762            }
763        }
764
765        if drift_count == 0 && drift_magnitude >= self.min_magnitude_threshold {
766            issues.push("No drift detected despite significant magnitude change".to_string());
767        }
768
769        issues
770    }
771}
772
773impl Default for DriftDetectionAnalyzer {
774    fn default() -> Self {
775        Self::new(0.05)
776    }
777}
778
779// =============================================================================
780// Analysis Results
781// =============================================================================
782
783/// Results from drift detection analysis.
784#[derive(Debug, Clone, Serialize, Deserialize)]
785pub struct DriftDetectionAnalysis {
786    /// Number of data points analyzed.
787    pub sample_size: usize,
788    /// Whether any drift was detected.
789    pub drift_detected: bool,
790    /// Number of drift points detected.
791    pub drift_count: usize,
792    /// Overall magnitude of detected drift.
793    pub drift_magnitude: f64,
794    /// Detection metrics (precision, recall, F1).
795    pub detection_metrics: DriftDetectionMetrics,
796    /// Hellinger distance between first and second half.
797    pub hellinger_distance: Option<f64>,
798    /// Population Stability Index.
799    pub psi: Option<f64>,
800    /// Relative change in rolling mean.
801    pub rolling_mean_change: f64,
802    /// Relative change in rolling standard deviation.
803    pub rolling_std_change: f64,
804    /// Whether the analysis passes quality thresholds.
805    pub passes: bool,
806    /// Issues identified during analysis.
807    pub issues: Vec<String>,
808}
809
810/// Drift detection performance metrics.
811#[derive(Debug, Clone, Default, Serialize, Deserialize)]
812pub struct DriftDetectionMetrics {
813    /// True positive count.
814    pub true_positives: usize,
815    /// False positive count.
816    pub false_positives: usize,
817    /// True negative count.
818    pub true_negatives: usize,
819    /// False negative count.
820    pub false_negatives: usize,
821    /// Precision (TP / (TP + FP)).
822    pub precision: f64,
823    /// Recall (TP / (TP + FN)).
824    pub recall: f64,
825    /// F1 score (harmonic mean of precision and recall).
826    pub f1_score: f64,
827    /// Mean delay in detecting drift (in periods).
828    pub mean_detection_delay: Option<f64>,
829}
830
831/// Analysis of labeled drift events.
832#[derive(Debug, Clone, Serialize, Deserialize)]
833pub struct LabeledEventAnalysis {
834    /// Total number of labeled events.
835    pub total_events: usize,
836    /// Number of statistical drift events.
837    pub statistical_events: usize,
838    /// Number of business event drifts.
839    pub business_events: usize,
840    /// Distribution by event category.
841    pub category_distribution: HashMap<DriftEventCategory, usize>,
842    /// Distribution by detection difficulty.
843    pub difficulty_distribution: HashMap<DetectionDifficulty, usize>,
844    /// Average drift magnitude.
845    pub avg_magnitude: f64,
846    /// Average detection difficulty score.
847    pub avg_difficulty: f64,
848    /// Period coverage (min_period, max_period).
849    pub period_coverage: (u32, u32),
850    /// Whether the analysis passes quality thresholds.
851    pub passes: bool,
852    /// Issues identified.
853    pub issues: Vec<String>,
854}
855
856impl LabeledEventAnalysis {
857    /// Create an empty analysis result.
858    pub fn empty() -> Self {
859        Self {
860            total_events: 0,
861            statistical_events: 0,
862            business_events: 0,
863            category_distribution: HashMap::new(),
864            difficulty_distribution: HashMap::new(),
865            avg_magnitude: 0.0,
866            avg_difficulty: 0.0,
867            period_coverage: (0, 0),
868            passes: true,
869            issues: Vec::new(),
870        }
871    }
872}
873
874// =============================================================================
875// Tests
876// =============================================================================
877
878#[cfg(test)]
879mod tests {
880    use super::*;
881
882    #[test]
883    fn test_drift_detection_entry_creation() {
884        let entry = DriftDetectionEntry::new(1, 100.0, Some(true));
885        assert_eq!(entry.period, 1);
886        assert_eq!(entry.value, 100.0);
887        assert_eq!(entry.ground_truth_drift, Some(true));
888    }
889
890    #[test]
891    fn test_drift_detection_entry_with_metadata() {
892        let entry = DriftDetectionEntry::with_metadata(5, 150.0, true, "MeanShift", 0.15, 0.3);
893        assert_eq!(entry.period, 5);
894        assert_eq!(entry.drift_type, Some("MeanShift".to_string()));
895        assert_eq!(entry.drift_magnitude, Some(0.15));
896        assert_eq!(entry.detection_difficulty, Some(0.3));
897    }
898
899    #[test]
900    fn test_drift_event_category_names() {
901        assert_eq!(DriftEventCategory::MeanShift.name(), "Mean Shift");
902        assert_eq!(
903            DriftEventCategory::OrganizationalEvent.name(),
904            "Organizational Event"
905        );
906    }
907
908    #[test]
909    fn test_drift_event_category_classification() {
910        assert!(DriftEventCategory::MeanShift.is_statistical());
911        assert!(!DriftEventCategory::MeanShift.is_business_event());
912        assert!(DriftEventCategory::OrganizationalEvent.is_business_event());
913        assert!(!DriftEventCategory::OrganizationalEvent.is_statistical());
914    }
915
916    #[test]
917    fn test_detection_difficulty_conversion() {
918        assert_eq!(DetectionDifficulty::Easy.to_score(), 0.0);
919        assert_eq!(DetectionDifficulty::Medium.to_score(), 0.5);
920        assert_eq!(DetectionDifficulty::Hard.to_score(), 1.0);
921
922        assert_eq!(
923            DetectionDifficulty::from_score(0.1),
924            DetectionDifficulty::Easy
925        );
926        assert_eq!(
927            DetectionDifficulty::from_score(0.5),
928            DetectionDifficulty::Medium
929        );
930        assert_eq!(
931            DetectionDifficulty::from_score(0.8),
932            DetectionDifficulty::Hard
933        );
934    }
935
936    #[test]
937    fn test_analyzer_creation() {
938        let analyzer = DriftDetectionAnalyzer::new(0.05)
939            .with_window_size(15)
940            .with_min_magnitude(0.1)
941            .with_hellinger(true)
942            .with_psi(true);
943
944        assert_eq!(analyzer.significance_level, 0.05);
945        assert_eq!(analyzer.window_size, 15);
946        assert_eq!(analyzer.min_magnitude_threshold, 0.1);
947    }
948
949    #[test]
950    fn test_analyze_no_drift() {
951        let analyzer = DriftDetectionAnalyzer::new(0.05).with_window_size(5);
952
953        // Create stable data with no drift
954        let entries: Vec<DriftDetectionEntry> = (0..30)
955            .map(|i| DriftDetectionEntry::new(i, 100.0 + (i as f64 * 0.01), Some(false)))
956            .collect();
957
958        let result = analyzer.analyze(&entries).unwrap();
959        assert!(!result.drift_detected || result.drift_count < 5);
960        assert!(result.drift_magnitude < 0.1);
961    }
962
963    #[test]
964    fn test_analyze_with_drift() {
965        let analyzer = DriftDetectionAnalyzer::new(0.05).with_window_size(5);
966
967        // Create data with clear drift in the middle
968        let mut entries: Vec<DriftDetectionEntry> = (0..15)
969            .map(|i| DriftDetectionEntry::new(i, 100.0, Some(false)))
970            .collect();
971
972        // Add drift after period 15
973        for i in 15..30 {
974            entries.push(DriftDetectionEntry::new(i, 150.0, Some(true)));
975        }
976
977        let result = analyzer.analyze(&entries).unwrap();
978        assert!(result.drift_detected);
979        assert!(result.drift_magnitude > 0.3);
980    }
981
982    #[test]
983    fn test_analyze_insufficient_data() {
984        let analyzer = DriftDetectionAnalyzer::new(0.05).with_window_size(10);
985
986        let entries: Vec<DriftDetectionEntry> = (0..5)
987            .map(|i| DriftDetectionEntry::new(i, 100.0, None))
988            .collect();
989
990        let result = analyzer.analyze(&entries);
991        assert!(result.is_err());
992    }
993
994    #[test]
995    fn test_analyze_labeled_events() {
996        let analyzer = DriftDetectionAnalyzer::new(0.05);
997
998        let events = vec![
999            LabeledDriftEvent {
1000                event_id: "E1".to_string(),
1001                event_type: DriftEventCategory::MeanShift,
1002                start_period: 10,
1003                end_period: Some(15),
1004                affected_fields: vec!["amount".to_string()],
1005                magnitude: 0.15,
1006                detection_difficulty: DetectionDifficulty::Easy,
1007            },
1008            LabeledDriftEvent {
1009                event_id: "E2".to_string(),
1010                event_type: DriftEventCategory::OrganizationalEvent,
1011                start_period: 20,
1012                end_period: Some(25),
1013                affected_fields: vec!["volume".to_string()],
1014                magnitude: 0.30,
1015                detection_difficulty: DetectionDifficulty::Medium,
1016            },
1017        ];
1018
1019        let result = analyzer.analyze_labeled_events(&events).unwrap();
1020        assert_eq!(result.total_events, 2);
1021        assert_eq!(result.statistical_events, 1);
1022        assert_eq!(result.business_events, 1);
1023        assert!(result.avg_magnitude > 0.2);
1024        assert!(result.passes);
1025    }
1026
1027    #[test]
1028    fn test_empty_labeled_events() {
1029        let analyzer = DriftDetectionAnalyzer::new(0.05);
1030        let result = analyzer.analyze_labeled_events(&[]).unwrap();
1031        assert_eq!(result.total_events, 0);
1032        assert!(result.passes);
1033    }
1034
1035    #[test]
1036    fn test_hellinger_distance_no_drift() {
1037        let analyzer = DriftDetectionAnalyzer::new(0.05);
1038
1039        // Stable data
1040        let entries: Vec<DriftDetectionEntry> = (0..40)
1041            .map(|i| DriftDetectionEntry::new(i, 100.0 + (i as f64 % 5.0), None))
1042            .collect();
1043
1044        let result = analyzer.analyze(&entries).unwrap();
1045        assert!(result.hellinger_distance.unwrap() < 0.3);
1046    }
1047
1048    #[test]
1049    fn test_psi_calculation() {
1050        let analyzer = DriftDetectionAnalyzer::new(0.05);
1051
1052        // Data with drift
1053        let mut entries: Vec<DriftDetectionEntry> = (0..20)
1054            .map(|i| DriftDetectionEntry::new(i, 100.0, None))
1055            .collect();
1056        for i in 20..40 {
1057            entries.push(DriftDetectionEntry::new(i, 200.0, None));
1058        }
1059
1060        let result = analyzer.analyze(&entries).unwrap();
1061        assert!(result.psi.is_some());
1062        // PSI > 0.1 indicates significant drift
1063        assert!(result.psi.unwrap() > 0.0);
1064    }
1065
1066    #[test]
1067    fn test_detection_metrics_calculation() {
1068        let analyzer = DriftDetectionAnalyzer::new(0.05).with_window_size(3);
1069
1070        // Create data where we know the ground truth
1071        let mut entries = Vec::new();
1072        for i in 0..10 {
1073            entries.push(DriftDetectionEntry::new(i, 100.0, Some(false)));
1074        }
1075        for i in 10..20 {
1076            entries.push(DriftDetectionEntry::new(i, 200.0, Some(true)));
1077        }
1078
1079        let result = analyzer.analyze(&entries).unwrap();
1080
1081        // Should have some detection capability
1082        assert!(result.detection_metrics.precision >= 0.0);
1083        assert!(result.detection_metrics.recall >= 0.0);
1084    }
1085}