Skip to main content

alimentar/
drift.rs

1//! Data drift detection for ML pipelines
2//!
3//! Detects distribution changes between dataset versions or time periods.
4//! Implements Jidoka—building quality in at the data layer before training.
5//!
6//! # Example
7//!
8//! ```ignore
9//! use alimentar::drift::{DriftDetector, DriftTest};
10//!
11//! let detector = DriftDetector::new(reference_dataset)
12//!     .with_test(DriftTest::KolmogorovSmirnov)
13//!     .with_test(DriftTest::PSI)
14//!     .with_alpha(0.05);
15//!
16//! let report = detector.detect(&current_dataset)?;
17//! if report.drift_detected {
18//!     println!("Drift detected in columns: {:?}", report.drifted_columns());
19//! }
20//! ```
21
22// Statistical computation requires casts, similar variable names, and float literals
23#![allow(clippy::cast_precision_loss)]
24#![allow(clippy::cast_possible_truncation)]
25#![allow(clippy::cast_sign_loss)]
26#![allow(clippy::similar_names)]
27#![allow(clippy::unreadable_literal)]
28#![allow(clippy::suboptimal_flops)]
29
30use std::collections::HashMap;
31
32use crate::{
33    dataset::{ArrowDataset, Dataset},
34    error::{Error, Result},
35};
36
37/// Statistical tests for drift detection
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
39pub enum DriftTest {
40    /// Kolmogorov-Smirnov test for continuous features
41    KolmogorovSmirnov,
42    /// Chi-squared test for categorical features
43    ChiSquared,
44    /// Population Stability Index (PSI)
45    PSI,
46    /// Jensen-Shannon divergence
47    JensenShannon,
48}
49
50impl DriftTest {
51    /// Get human-readable name
52    pub fn name(&self) -> &'static str {
53        match self {
54            Self::KolmogorovSmirnov => "Kolmogorov-Smirnov",
55            Self::ChiSquared => "Chi-Squared",
56            Self::PSI => "Population Stability Index",
57            Self::JensenShannon => "Jensen-Shannon Divergence",
58        }
59    }
60
61    /// Check if test is suitable for continuous data
62    pub fn is_continuous(&self) -> bool {
63        matches!(self, Self::KolmogorovSmirnov | Self::JensenShannon)
64    }
65
66    /// Check if test is suitable for categorical data
67    pub fn is_categorical(&self) -> bool {
68        matches!(self, Self::ChiSquared | Self::PSI)
69    }
70}
71
72/// Severity of detected drift
73#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
74pub enum DriftSeverity {
75    /// No drift detected
76    None,
77    /// Low drift (p > 0.01)
78    Low,
79    /// Medium drift (0.001 < p <= 0.01)
80    Medium,
81    /// High drift (p <= 0.001)
82    High,
83    /// Critical - distribution fundamentally changed
84    Critical,
85}
86
87impl DriftSeverity {
88    /// Create severity from p-value
89    pub fn from_p_value(p_value: f64) -> Self {
90        if p_value > 0.05 {
91            Self::None
92        } else if p_value > 0.01 {
93            Self::Low
94        } else if p_value > 0.001 {
95            Self::Medium
96        } else if p_value > 0.0001 {
97            Self::High
98        } else {
99            Self::Critical
100        }
101    }
102
103    /// Create severity from PSI value
104    pub fn from_psi(psi: f64) -> Self {
105        if psi < 0.1 {
106            Self::None
107        } else if psi < 0.2 {
108            Self::Low
109        } else if psi < 0.25 {
110            Self::Medium
111        } else if psi < 0.5 {
112            Self::High
113        } else {
114            Self::Critical
115        }
116    }
117
118    /// Check if this severity indicates drift
119    pub fn is_drift(&self) -> bool {
120        *self != Self::None
121    }
122}
123
124/// Per-column drift result
125#[derive(Debug, Clone)]
126pub struct ColumnDrift {
127    /// Column name
128    pub column: String,
129    /// Test used
130    pub test: DriftTest,
131    /// Test statistic value
132    pub statistic: f64,
133    /// P-value (if applicable)
134    pub p_value: Option<f64>,
135    /// Whether drift was detected for this column
136    pub drift_detected: bool,
137    /// Severity of drift
138    pub severity: DriftSeverity,
139}
140
141impl ColumnDrift {
142    /// Create a new column drift result
143    pub fn new(
144        column: impl Into<String>,
145        test: DriftTest,
146        statistic: f64,
147        p_value: Option<f64>,
148        severity: DriftSeverity,
149    ) -> Self {
150        Self {
151            column: column.into(),
152            test,
153            statistic,
154            p_value,
155            drift_detected: severity.is_drift(),
156            severity,
157        }
158    }
159}
160
161/// Overall drift detection report
162#[derive(Debug, Clone)]
163pub struct DriftReport {
164    /// Per-column drift scores
165    pub column_scores: HashMap<String, ColumnDrift>,
166    /// Overall drift detected
167    pub drift_detected: bool,
168    /// Timestamp of analysis (Unix epoch seconds)
169    pub timestamp: u64,
170}
171
172impl DriftReport {
173    /// Create a new drift report from column results
174    pub fn from_columns(columns: Vec<ColumnDrift>) -> Self {
175        let drift_detected = columns.iter().any(|c| c.drift_detected);
176        let timestamp = std::time::SystemTime::now()
177            .duration_since(std::time::UNIX_EPOCH)
178            .map(|d| d.as_secs())
179            .unwrap_or(0);
180
181        // Key by column_name:test_name to allow multiple tests per column
182        let column_scores = columns
183            .into_iter()
184            .map(|c| (format!("{}:{:?}", c.column, c.test), c))
185            .collect();
186
187        Self {
188            column_scores,
189            drift_detected,
190            timestamp,
191        }
192    }
193
194    /// Get columns with detected drift
195    pub fn drifted_columns(&self) -> Vec<&str> {
196        self.column_scores
197            .values()
198            .filter(|c| c.drift_detected)
199            .map(|c| c.column.as_str())
200            .collect()
201    }
202
203    /// Get the maximum severity across all columns
204    pub fn max_severity(&self) -> DriftSeverity {
205        self.column_scores
206            .values()
207            .map(|c| c.severity)
208            .max()
209            .unwrap_or(DriftSeverity::None)
210    }
211
212    /// Get number of columns analyzed
213    pub fn num_columns(&self) -> usize {
214        self.column_scores.len()
215    }
216
217    /// Get number of columns with drift
218    pub fn num_drifted(&self) -> usize {
219        self.column_scores
220            .values()
221            .filter(|c| c.drift_detected)
222            .count()
223    }
224}
225
226/// Statistical drift detector
227pub struct DriftDetector {
228    /// Reference dataset (baseline distribution)
229    reference: ArrowDataset,
230    /// Statistical tests to apply
231    tests: Vec<DriftTest>,
232    /// Significance threshold (default: 0.05)
233    alpha: f64,
234}
235
236impl DriftDetector {
237    /// Create a new drift detector with a reference dataset
238    pub fn new(reference: ArrowDataset) -> Self {
239        Self {
240            reference,
241            tests: vec![DriftTest::KolmogorovSmirnov],
242            alpha: 0.05,
243        }
244    }
245
246    /// Add a statistical test
247    #[must_use]
248    pub fn with_test(mut self, test: DriftTest) -> Self {
249        if !self.tests.contains(&test) {
250            self.tests.push(test);
251        }
252        self
253    }
254
255    /// Set significance threshold
256    #[must_use]
257    pub fn with_alpha(mut self, alpha: f64) -> Self {
258        self.alpha = alpha;
259        self
260    }
261
262    /// Set all tests at once
263    #[must_use]
264    pub fn with_tests(mut self, tests: Vec<DriftTest>) -> Self {
265        self.tests = tests;
266        self
267    }
268
269    /// Get the reference dataset
270    pub fn reference(&self) -> &ArrowDataset {
271        &self.reference
272    }
273
274    /// Get configured tests
275    pub fn tests(&self) -> &[DriftTest] {
276        &self.tests
277    }
278
279    /// Get significance threshold
280    pub fn alpha(&self) -> f64 {
281        self.alpha
282    }
283
284    /// Compare current dataset against reference
285    pub fn detect(&self, current: &ArrowDataset) -> Result<DriftReport> {
286        // Verify schemas match
287        if self.reference.schema() != current.schema() {
288            return Err(Error::invalid_config(
289                "Schema mismatch between reference and current dataset",
290            ));
291        }
292
293        let schema = self.reference.schema();
294        let mut results = Vec::new();
295
296        // Extract data from datasets
297        let ref_data = collect_dataset_data(&self.reference);
298        let cur_data = collect_dataset_data(current);
299
300        // Test each column
301        for field in schema.fields() {
302            let column_name = field.name();
303
304            let ref_col = ref_data.get(column_name);
305            let cur_col = cur_data.get(column_name);
306
307            if let (Some(ref_values), Some(cur_values)) = (ref_col, cur_col) {
308                // Run each configured test
309                for test in &self.tests {
310                    let result = run_test(*test, ref_values, cur_values, self.alpha)?;
311                    results.push(ColumnDrift::new(
312                        column_name,
313                        *test,
314                        result.statistic,
315                        result.p_value,
316                        result.severity,
317                    ));
318                }
319            }
320        }
321
322        Ok(DriftReport::from_columns(results))
323    }
324}
325
326/// Internal result from a statistical test
327struct TestResult {
328    statistic: f64,
329    p_value: Option<f64>,
330    severity: DriftSeverity,
331}
332
333/// Run a specific statistical test
334fn run_test(test: DriftTest, reference: &[f64], current: &[f64], alpha: f64) -> Result<TestResult> {
335    match test {
336        DriftTest::KolmogorovSmirnov => ks_test(reference, current, alpha),
337        DriftTest::ChiSquared => chi_squared_test(reference, current, alpha),
338        DriftTest::PSI => psi_test(reference, current),
339        DriftTest::JensenShannon => jensen_shannon_test(reference, current),
340    }
341}
342
343/// Kolmogorov-Smirnov two-sample test
344///
345/// Tests whether two samples come from the same distribution.
346/// The statistic D is the maximum absolute difference between CDFs.
347fn ks_test(reference: &[f64], current: &[f64], alpha: f64) -> Result<TestResult> {
348    if reference.is_empty() || current.is_empty() {
349        return Err(Error::invalid_config(
350            "Cannot perform KS test on empty data",
351        ));
352    }
353
354    // Sort both samples
355    let mut ref_sorted: Vec<f64> = reference
356        .iter()
357        .copied()
358        .filter(|x| x.is_finite())
359        .collect();
360    let mut cur_sorted: Vec<f64> = current.iter().copied().filter(|x| x.is_finite()).collect();
361
362    if ref_sorted.is_empty() || cur_sorted.is_empty() {
363        return Err(Error::invalid_config("No finite values in data"));
364    }
365
366    ref_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
367    cur_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
368
369    let n1 = ref_sorted.len() as f64;
370    let n2 = cur_sorted.len() as f64;
371
372    // Compute empirical CDFs and find maximum difference
373    let d_statistic = compute_ks_statistic(&ref_sorted, &cur_sorted);
374
375    // Compute approximate p-value using asymptotic distribution
376    let en = (n1 * n2 / (n1 + n2)).sqrt();
377    let p_value = ks_p_value(d_statistic * en);
378
379    let severity = if p_value <= alpha {
380        DriftSeverity::from_p_value(p_value)
381    } else {
382        DriftSeverity::None
383    };
384
385    Ok(TestResult {
386        statistic: d_statistic,
387        p_value: Some(p_value),
388        severity,
389    })
390}
391
392/// Compute KS statistic (maximum CDF difference)
393fn compute_ks_statistic(ref_sorted: &[f64], cur_sorted: &[f64]) -> f64 {
394    let n1 = ref_sorted.len();
395    let n2 = cur_sorted.len();
396
397    // Merge and compute CDF differences at each point
398    let mut all_values: Vec<f64> = ref_sorted
399        .iter()
400        .chain(cur_sorted.iter())
401        .copied()
402        .collect();
403    all_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
404    all_values.dedup();
405
406    let mut max_diff = 0.0_f64;
407
408    for &x in &all_values {
409        // CDF of reference at x
410        let cdf1 = ref_sorted.iter().filter(|&&v| v <= x).count() as f64 / n1 as f64;
411        // CDF of current at x
412        let cdf2 = cur_sorted.iter().filter(|&&v| v <= x).count() as f64 / n2 as f64;
413
414        let diff = (cdf1 - cdf2).abs();
415        if diff > max_diff {
416            max_diff = diff;
417        }
418    }
419
420    max_diff
421}
422
423/// Approximate p-value for KS statistic using Kolmogorov distribution
424fn ks_p_value(z: f64) -> f64 {
425    if z <= 0.0 {
426        return 1.0;
427    }
428    if z > 3.0 {
429        return 0.0;
430    }
431
432    // Asymptotic formula: P(D > z) ≈ 2 * sum_{k=1}^inf (-1)^(k-1) * exp(-2*k^2*z^2)
433    let mut p = 0.0;
434    let z_sq = z * z;
435
436    for k in 1..=100 {
437        let k_f = f64::from(k);
438        let term = (-1.0_f64).powi(k - 1) * (-2.0 * k_f * k_f * z_sq).exp();
439        p += term;
440        if term.abs() < 1e-12 {
441            break;
442        }
443    }
444
445    (2.0 * p).clamp(0.0, 1.0)
446}
447
448/// Chi-squared test for categorical data
449///
450/// Bins continuous data and tests for independence.
451fn chi_squared_test(reference: &[f64], current: &[f64], alpha: f64) -> Result<TestResult> {
452    if reference.is_empty() || current.is_empty() {
453        return Err(Error::invalid_config(
454            "Cannot perform chi-squared test on empty data",
455        ));
456    }
457
458    // Bin the data
459    let num_bins = ((reference.len() as f64).sqrt().ceil() as usize).clamp(5, 20);
460    let (ref_bins, cur_bins) = bin_data(reference, current, num_bins)?;
461
462    // Compute chi-squared statistic
463    let n_ref = reference.len() as f64;
464    let n_cur = current.len() as f64;
465    let total = n_ref + n_cur;
466
467    let mut chi_sq = 0.0;
468    let mut df: usize = 0;
469
470    for (r, c) in ref_bins.iter().zip(cur_bins.iter()) {
471        let r = *r as f64;
472        let c = *c as f64;
473        let row_total = r + c;
474
475        if row_total > 0.0 {
476            let expected_r = row_total * n_ref / total;
477            let expected_c = row_total * n_cur / total;
478
479            if expected_r > 0.0 {
480                chi_sq += (r - expected_r).powi(2) / expected_r;
481            }
482            if expected_c > 0.0 {
483                chi_sq += (c - expected_c).powi(2) / expected_c;
484            }
485            df += 1;
486        }
487    }
488
489    df = df.saturating_sub(1); // degrees of freedom = bins - 1
490
491    // Approximate p-value using chi-squared distribution
492    let p_value = chi_squared_p_value(chi_sq, df);
493
494    let severity = if p_value <= alpha {
495        DriftSeverity::from_p_value(p_value)
496    } else {
497        DriftSeverity::None
498    };
499
500    Ok(TestResult {
501        statistic: chi_sq,
502        p_value: Some(p_value),
503        severity,
504    })
505}
506
507/// Bin continuous data into histogram
508fn bin_data(
509    reference: &[f64],
510    current: &[f64],
511    num_bins: usize,
512) -> Result<(Vec<usize>, Vec<usize>)> {
513    // Find global min/max
514    let all_data: Vec<f64> = reference
515        .iter()
516        .chain(current.iter())
517        .copied()
518        .filter(|x| x.is_finite())
519        .collect();
520
521    if all_data.is_empty() {
522        return Err(Error::invalid_config("No finite values in data"));
523    }
524
525    let min_val = all_data.iter().copied().fold(f64::INFINITY, f64::min);
526    let max_val = all_data.iter().copied().fold(f64::NEG_INFINITY, f64::max);
527
528    if (max_val - min_val).abs() < f64::EPSILON {
529        // All values are the same
530        return Ok((vec![reference.len()], vec![current.len()]));
531    }
532
533    let bin_width = (max_val - min_val) / num_bins as f64;
534
535    let bin_value = |v: f64| -> usize {
536        if !v.is_finite() {
537            return 0;
538        }
539        let bin = ((v - min_val) / bin_width).floor() as usize;
540        bin.min(num_bins - 1)
541    };
542
543    let mut ref_bins = vec![0usize; num_bins];
544    let mut cur_bins = vec![0usize; num_bins];
545
546    for &v in reference {
547        ref_bins[bin_value(v)] += 1;
548    }
549    for &v in current {
550        cur_bins[bin_value(v)] += 1;
551    }
552
553    Ok((ref_bins, cur_bins))
554}
555
556/// Approximate chi-squared p-value using Wilson-Hilferty transformation
557fn chi_squared_p_value(chi_sq: f64, df: usize) -> f64 {
558    if df == 0 {
559        return 1.0;
560    }
561
562    let k = df as f64;
563
564    // Wilson-Hilferty approximation: transform to standard normal
565    let z = ((chi_sq / k).cbrt() - (1.0 - 2.0 / (9.0 * k))) / (2.0 / (9.0 * k)).sqrt();
566
567    // Standard normal CDF (approximation)
568    1.0 - standard_normal_cdf(z)
569}
570
571/// Standard normal CDF approximation
572fn standard_normal_cdf(z: f64) -> f64 {
573    // Approximation using error function
574    0.5 * (1.0 + erf(z / std::f64::consts::SQRT_2))
575}
576
577/// Error function approximation
578fn erf(x: f64) -> f64 {
579    // Abramowitz and Stegun approximation
580    let a1 = 0.254829592;
581    let a2 = -0.284496736;
582    let a3 = 1.421413741;
583    let a4 = -1.453152027;
584    let a5 = 1.061405429;
585    let p = 0.3275911;
586
587    let sign = if x < 0.0 { -1.0 } else { 1.0 };
588    let x = x.abs();
589
590    let t = 1.0 / (1.0 + p * x);
591    let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
592
593    sign * y
594}
595
596/// Population Stability Index (PSI)
597///
598/// Measures how much a distribution has shifted.
599/// PSI < 0.1: No significant change
600/// PSI 0.1-0.2: Moderate change
601/// PSI > 0.2: Significant change
602fn psi_test(reference: &[f64], current: &[f64]) -> Result<TestResult> {
603    if reference.is_empty() || current.is_empty() {
604        return Err(Error::invalid_config("Cannot compute PSI on empty data"));
605    }
606
607    // Bin the data (10 bins is standard for PSI)
608    let num_bins = 10;
609    let (ref_bins, cur_bins) = bin_data(reference, current, num_bins)?;
610
611    let n_ref = reference.len() as f64;
612    let n_cur = current.len() as f64;
613
614    let mut psi = 0.0;
615
616    for (r, c) in ref_bins.iter().zip(cur_bins.iter()) {
617        // Convert counts to proportions with smoothing to avoid log(0)
618        let p_ref = (*r as f64 + 0.5) / (n_ref + num_bins as f64 * 0.5);
619        let p_cur = (*c as f64 + 0.5) / (n_cur + num_bins as f64 * 0.5);
620
621        psi += (p_cur - p_ref) * (p_cur / p_ref).ln();
622    }
623
624    let severity = DriftSeverity::from_psi(psi);
625
626    Ok(TestResult {
627        statistic: psi,
628        p_value: None, // PSI doesn't have a p-value
629        severity,
630    })
631}
632
633/// Jensen-Shannon divergence
634///
635/// Symmetric measure of distribution difference.
636/// JSD = 0: Identical distributions
637/// JSD = 1: Completely different distributions
638fn jensen_shannon_test(reference: &[f64], current: &[f64]) -> Result<TestResult> {
639    if reference.is_empty() || current.is_empty() {
640        return Err(Error::invalid_config("Cannot compute JSD on empty data"));
641    }
642
643    // Bin the data
644    let num_bins = 20;
645    let (ref_bins, cur_bins) = bin_data(reference, current, num_bins)?;
646
647    let n_ref = reference.len() as f64;
648    let n_cur = current.len() as f64;
649
650    // Convert to probability distributions with smoothing
651    let p: Vec<f64> = ref_bins
652        .iter()
653        .map(|&c| (c as f64 + 0.5) / (n_ref + num_bins as f64 * 0.5))
654        .collect();
655    let q: Vec<f64> = cur_bins
656        .iter()
657        .map(|&c| (c as f64 + 0.5) / (n_cur + num_bins as f64 * 0.5))
658        .collect();
659
660    // M = (P + Q) / 2
661    let m: Vec<f64> = p
662        .iter()
663        .zip(q.iter())
664        .map(|(pi, qi)| (pi + qi) / 2.0)
665        .collect();
666
667    // JSD = 0.5 * KL(P||M) + 0.5 * KL(Q||M)
668    let kl_pm: f64 = p
669        .iter()
670        .zip(m.iter())
671        .map(|(pi, mi)| if *pi > 0.0 { pi * (pi / mi).ln() } else { 0.0 })
672        .sum();
673
674    let kl_qm: f64 = q
675        .iter()
676        .zip(m.iter())
677        .map(|(qi, mi)| if *qi > 0.0 { qi * (qi / mi).ln() } else { 0.0 })
678        .sum();
679
680    let jsd = 0.5 * kl_pm + 0.5 * kl_qm;
681
682    // JSD is in [0, ln(2)] for base e, normalize to [0, 1]
683    let jsd_normalized = jsd / std::f64::consts::LN_2;
684
685    // Map JSD to severity (similar thresholds to PSI)
686    let severity = if jsd_normalized < 0.05 {
687        DriftSeverity::None
688    } else if jsd_normalized < 0.1 {
689        DriftSeverity::Low
690    } else if jsd_normalized < 0.2 {
691        DriftSeverity::Medium
692    } else if jsd_normalized < 0.4 {
693        DriftSeverity::High
694    } else {
695        DriftSeverity::Critical
696    };
697
698    Ok(TestResult {
699        statistic: jsd_normalized,
700        p_value: None, // JSD doesn't have a traditional p-value
701        severity,
702    })
703}
704
705/// Extract non-null f64 values from a numeric Arrow array into the output
706/// vector.
707fn extract_numeric_values(
708    array: &dyn arrow::array::Array,
709    data_type: &arrow::datatypes::DataType,
710    out: &mut Vec<f64>,
711) {
712    use arrow::{
713        array::{Array, Float64Array, Int32Array, Int64Array},
714        datatypes::DataType,
715    };
716
717    match data_type {
718        DataType::Float64 => {
719            if let Some(arr) = array.as_any().downcast_ref::<Float64Array>() {
720                out.extend(
721                    (0..arr.len())
722                        .filter(|&i| !arr.is_null(i))
723                        .map(|i| arr.value(i)),
724                );
725            }
726        }
727        DataType::Float32 => {
728            if let Some(arr) = array.as_any().downcast_ref::<arrow::array::Float32Array>() {
729                out.extend(
730                    (0..arr.len())
731                        .filter(|&i| !arr.is_null(i))
732                        .map(|i| f64::from(arr.value(i))),
733                );
734            }
735        }
736        DataType::Int32 => {
737            if let Some(arr) = array.as_any().downcast_ref::<Int32Array>() {
738                out.extend(
739                    (0..arr.len())
740                        .filter(|&i| !arr.is_null(i))
741                        .map(|i| f64::from(arr.value(i))),
742                );
743            }
744        }
745        DataType::Int64 => {
746            if let Some(arr) = array.as_any().downcast_ref::<Int64Array>() {
747                out.extend(
748                    (0..arr.len())
749                        .filter(|&i| !arr.is_null(i))
750                        .map(|i| arr.value(i) as f64),
751                );
752            }
753        }
754        _ => {}
755    }
756}
757
758/// Collect all numeric column data from a dataset
759fn collect_dataset_data(dataset: &ArrowDataset) -> HashMap<String, Vec<f64>> {
760    use arrow::datatypes::DataType;
761
762    let mut data: HashMap<String, Vec<f64>> = HashMap::new();
763    let schema = dataset.schema();
764
765    // Initialize vectors for each numeric column
766    for field in schema.fields() {
767        if matches!(
768            field.data_type(),
769            DataType::Int32 | DataType::Int64 | DataType::Float64 | DataType::Float32
770        ) {
771            data.insert(field.name().clone(), Vec::new());
772        }
773    }
774
775    // Collect data from all batches
776    for batch in dataset.iter() {
777        for (col_idx, field) in schema.fields().iter().enumerate() {
778            if let Some(col_data) = data.get_mut(field.name()) {
779                extract_numeric_values(batch.column(col_idx), field.data_type(), col_data);
780            }
781        }
782    }
783
784    data
785}
786
787#[cfg(test)]
788mod tests {
789    use std::sync::Arc;
790
791    use arrow::{
792        array::{Float64Array, Int32Array},
793        datatypes::{DataType, Field, Schema},
794        record_batch::RecordBatch,
795    };
796
797    use super::*;
798
799    // ========== DriftTest enum tests ==========
800
801    #[test]
802    fn test_drift_test_name() {
803        assert_eq!(DriftTest::KolmogorovSmirnov.name(), "Kolmogorov-Smirnov");
804        assert_eq!(DriftTest::ChiSquared.name(), "Chi-Squared");
805        assert_eq!(DriftTest::PSI.name(), "Population Stability Index");
806        assert_eq!(DriftTest::JensenShannon.name(), "Jensen-Shannon Divergence");
807    }
808
809    #[test]
810    fn test_drift_test_is_continuous() {
811        assert!(DriftTest::KolmogorovSmirnov.is_continuous());
812        assert!(DriftTest::JensenShannon.is_continuous());
813        assert!(!DriftTest::ChiSquared.is_continuous());
814        assert!(!DriftTest::PSI.is_continuous());
815    }
816
817    #[test]
818    fn test_drift_test_is_categorical() {
819        assert!(DriftTest::ChiSquared.is_categorical());
820        assert!(DriftTest::PSI.is_categorical());
821        assert!(!DriftTest::KolmogorovSmirnov.is_categorical());
822        assert!(!DriftTest::JensenShannon.is_categorical());
823    }
824
825    // ========== DriftSeverity tests ==========
826
827    #[test]
828    fn test_drift_severity_from_p_value() {
829        assert_eq!(DriftSeverity::from_p_value(0.1), DriftSeverity::None);
830        assert_eq!(DriftSeverity::from_p_value(0.06), DriftSeverity::None);
831        assert_eq!(DriftSeverity::from_p_value(0.04), DriftSeverity::Low);
832        assert_eq!(DriftSeverity::from_p_value(0.005), DriftSeverity::Medium);
833        assert_eq!(DriftSeverity::from_p_value(0.0005), DriftSeverity::High);
834        assert_eq!(
835            DriftSeverity::from_p_value(0.00001),
836            DriftSeverity::Critical
837        );
838    }
839
840    #[test]
841    fn test_drift_severity_from_psi() {
842        assert_eq!(DriftSeverity::from_psi(0.05), DriftSeverity::None);
843        assert_eq!(DriftSeverity::from_psi(0.15), DriftSeverity::Low);
844        assert_eq!(DriftSeverity::from_psi(0.22), DriftSeverity::Medium);
845        assert_eq!(DriftSeverity::from_psi(0.35), DriftSeverity::High);
846        assert_eq!(DriftSeverity::from_psi(0.6), DriftSeverity::Critical);
847    }
848
849    #[test]
850    fn test_drift_severity_is_drift() {
851        assert!(!DriftSeverity::None.is_drift());
852        assert!(DriftSeverity::Low.is_drift());
853        assert!(DriftSeverity::Medium.is_drift());
854        assert!(DriftSeverity::High.is_drift());
855        assert!(DriftSeverity::Critical.is_drift());
856    }
857
858    #[test]
859    fn test_drift_severity_ordering() {
860        assert!(DriftSeverity::None < DriftSeverity::Low);
861        assert!(DriftSeverity::Low < DriftSeverity::Medium);
862        assert!(DriftSeverity::Medium < DriftSeverity::High);
863        assert!(DriftSeverity::High < DriftSeverity::Critical);
864    }
865
866    // ========== ColumnDrift tests ==========
867
868    #[test]
869    fn test_column_drift_new() {
870        let drift = ColumnDrift::new(
871            "age",
872            DriftTest::KolmogorovSmirnov,
873            0.15,
874            Some(0.03),
875            DriftSeverity::Low,
876        );
877
878        assert_eq!(drift.column, "age");
879        assert_eq!(drift.test, DriftTest::KolmogorovSmirnov);
880        assert!((drift.statistic - 0.15).abs() < f64::EPSILON);
881        assert_eq!(drift.p_value, Some(0.03));
882        assert!(drift.drift_detected);
883        assert_eq!(drift.severity, DriftSeverity::Low);
884    }
885
886    #[test]
887    fn test_column_drift_no_drift() {
888        let drift = ColumnDrift::new("income", DriftTest::PSI, 0.05, None, DriftSeverity::None);
889
890        assert!(!drift.drift_detected);
891        assert_eq!(drift.severity, DriftSeverity::None);
892    }
893
894    // ========== DriftReport tests ==========
895
896    #[test]
897    fn test_drift_report_from_columns() {
898        let columns = vec![
899            ColumnDrift::new(
900                "a",
901                DriftTest::KolmogorovSmirnov,
902                0.1,
903                Some(0.5),
904                DriftSeverity::None,
905            ),
906            ColumnDrift::new("b", DriftTest::PSI, 0.25, None, DriftSeverity::Medium),
907        ];
908
909        let report = DriftReport::from_columns(columns);
910
911        assert!(report.drift_detected);
912        assert_eq!(report.num_columns(), 2);
913        assert_eq!(report.num_drifted(), 1);
914        assert_eq!(report.max_severity(), DriftSeverity::Medium);
915    }
916
917    #[test]
918    fn test_drift_report_no_drift() {
919        let columns = vec![
920            ColumnDrift::new(
921                "a",
922                DriftTest::KolmogorovSmirnov,
923                0.05,
924                Some(0.5),
925                DriftSeverity::None,
926            ),
927            ColumnDrift::new("b", DriftTest::PSI, 0.05, None, DriftSeverity::None),
928        ];
929
930        let report = DriftReport::from_columns(columns);
931
932        assert!(!report.drift_detected);
933        assert_eq!(report.num_drifted(), 0);
934        assert_eq!(report.max_severity(), DriftSeverity::None);
935    }
936
937    #[test]
938    fn test_drift_report_drifted_columns() {
939        let columns = vec![
940            ColumnDrift::new(
941                "a",
942                DriftTest::KolmogorovSmirnov,
943                0.1,
944                Some(0.5),
945                DriftSeverity::None,
946            ),
947            ColumnDrift::new("b", DriftTest::PSI, 0.3, None, DriftSeverity::High),
948            ColumnDrift::new(
949                "c",
950                DriftTest::ChiSquared,
951                50.0,
952                Some(0.001),
953                DriftSeverity::Medium,
954            ),
955        ];
956
957        let report = DriftReport::from_columns(columns);
958        let drifted = report.drifted_columns();
959
960        assert_eq!(drifted.len(), 2);
961        assert!(drifted.contains(&"b"));
962        assert!(drifted.contains(&"c"));
963    }
964
965    // ========== KS test implementation tests ==========
966
967    #[test]
968    fn test_ks_identical_distributions() {
969        let data: Vec<f64> = (0..1000).map(|i| i as f64).collect();
970        let result = ks_test(&data, &data, 0.05).expect("ks test");
971
972        assert!(
973            result.statistic < 0.01,
974            "KS statistic should be ~0 for identical data"
975        );
976        assert!(
977            result.p_value.unwrap_or(0.0) > 0.05,
978            "p-value should be high"
979        );
980        assert_eq!(result.severity, DriftSeverity::None);
981    }
982
983    #[test]
984    fn test_ks_different_distributions() {
985        // Uniform [0, 100] vs Uniform [50, 150]
986        let ref_data: Vec<f64> = (0..1000).map(|i| i as f64 / 10.0).collect();
987        let cur_data: Vec<f64> = (0..1000).map(|i| 50.0 + i as f64 / 10.0).collect();
988
989        let result = ks_test(&ref_data, &cur_data, 0.05).expect("ks test");
990
991        assert!(
992            result.statistic > 0.3,
993            "KS statistic should be large for shifted data"
994        );
995        assert!(
996            result.p_value.unwrap_or(1.0) < 0.05,
997            "p-value should be small"
998        );
999        assert!(result.severity.is_drift());
1000    }
1001
1002    #[test]
1003    fn test_ks_empty_data_error() {
1004        let empty: Vec<f64> = vec![];
1005        let data = vec![1.0, 2.0, 3.0];
1006
1007        assert!(ks_test(&empty, &data, 0.05).is_err());
1008        assert!(ks_test(&data, &empty, 0.05).is_err());
1009    }
1010
1011    // ========== Chi-squared test implementation tests ==========
1012
1013    #[test]
1014    fn test_chi_squared_identical_distributions() {
1015        let data: Vec<f64> = (0..1000).map(|i| (i % 10) as f64).collect();
1016        let result = chi_squared_test(&data, &data, 0.05).expect("chi-squared test");
1017
1018        // Identical data should have chi-sq ≈ 0
1019        assert!(
1020            result.statistic < 1.0,
1021            "Chi-squared should be small for identical data"
1022        );
1023        assert!(result.p_value.unwrap_or(0.0) > 0.05);
1024        assert_eq!(result.severity, DriftSeverity::None);
1025    }
1026
1027    #[test]
1028    fn test_chi_squared_different_distributions() {
1029        // Very different distributions
1030        let ref_data: Vec<f64> = (0..1000).map(|_| 0.0).collect();
1031        let cur_data: Vec<f64> = (0..1000).map(|_| 100.0).collect();
1032
1033        let result = chi_squared_test(&ref_data, &cur_data, 0.05).expect("chi-squared test");
1034
1035        assert!(result.statistic > 100.0, "Chi-squared should be large");
1036        assert!(result.p_value.unwrap_or(1.0) < 0.001);
1037        assert!(result.severity.is_drift());
1038    }
1039
1040    // ========== PSI test implementation tests ==========
1041
1042    #[test]
1043    fn test_psi_identical_distributions() {
1044        let data: Vec<f64> = (0..1000).map(|i| i as f64).collect();
1045        let result = psi_test(&data, &data).expect("psi test");
1046
1047        assert!(
1048            result.statistic < 0.05,
1049            "PSI should be ~0 for identical data"
1050        );
1051        assert_eq!(result.severity, DriftSeverity::None);
1052    }
1053
1054    #[test]
1055    fn test_psi_shifted_distribution() {
1056        let ref_data: Vec<f64> = (0..1000).map(|i| i as f64).collect();
1057        let cur_data: Vec<f64> = (0..1000).map(|i| 500.0 + i as f64).collect();
1058
1059        let result = psi_test(&ref_data, &cur_data).expect("psi test");
1060
1061        assert!(
1062            result.statistic > 0.2,
1063            "PSI should indicate drift: {}",
1064            result.statistic
1065        );
1066        assert!(result.severity.is_drift());
1067    }
1068
1069    #[test]
1070    fn test_psi_moderate_shift() {
1071        // Moderate shift - should have PSI between 0.1 and 0.25
1072        let ref_data: Vec<f64> = (0..1000).map(|i| i as f64).collect();
1073        let cur_data: Vec<f64> = (0..1000).map(|i| i as f64 * 1.1 + 50.0).collect();
1074
1075        let result = psi_test(&ref_data, &cur_data).expect("psi test");
1076
1077        // PSI should be in moderate range
1078        assert!(result.statistic > 0.0, "PSI should be positive");
1079    }
1080
1081    // ========== Jensen-Shannon test implementation tests ==========
1082
1083    #[test]
1084    fn test_jsd_identical_distributions() {
1085        let data: Vec<f64> = (0..1000).map(|i| i as f64).collect();
1086        let result = jensen_shannon_test(&data, &data).expect("jsd test");
1087
1088        assert!(
1089            result.statistic < 0.01,
1090            "JSD should be ~0 for identical data"
1091        );
1092        assert_eq!(result.severity, DriftSeverity::None);
1093    }
1094
1095    #[test]
1096    fn test_jsd_different_distributions() {
1097        // Completely non-overlapping distributions
1098        let ref_data: Vec<f64> = (0..1000).map(|i| i as f64).collect();
1099        let cur_data: Vec<f64> = (0..1000).map(|i| 10000.0 + i as f64).collect();
1100
1101        let result = jensen_shannon_test(&ref_data, &cur_data).expect("jsd test");
1102
1103        assert!(
1104            result.statistic > 0.5,
1105            "JSD should be high for non-overlapping: {}",
1106            result.statistic
1107        );
1108        assert!(result.severity.is_drift());
1109    }
1110
1111    // ========== DriftDetector tests ==========
1112
1113    fn make_test_dataset(values: Vec<f64>) -> ArrowDataset {
1114        let schema = Arc::new(Schema::new(vec![Field::new(
1115            "value",
1116            DataType::Float64,
1117            false,
1118        )]));
1119
1120        let batch = RecordBatch::try_new(
1121            Arc::clone(&schema),
1122            vec![Arc::new(Float64Array::from(values))],
1123        )
1124        .expect("batch");
1125
1126        ArrowDataset::from_batch(batch).expect("dataset")
1127    }
1128
1129    fn make_int_dataset(values: Vec<i32>) -> ArrowDataset {
1130        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
1131
1132        let batch = RecordBatch::try_new(
1133            Arc::clone(&schema),
1134            vec![Arc::new(Int32Array::from(values))],
1135        )
1136        .expect("batch");
1137
1138        ArrowDataset::from_batch(batch).expect("dataset")
1139    }
1140
1141    #[test]
1142    fn test_drift_detector_new() {
1143        let dataset = make_test_dataset(vec![1.0, 2.0, 3.0]);
1144        let detector = DriftDetector::new(dataset);
1145
1146        assert_eq!(detector.alpha(), 0.05);
1147        assert_eq!(detector.tests().len(), 1);
1148        assert_eq!(detector.tests()[0], DriftTest::KolmogorovSmirnov);
1149    }
1150
1151    #[test]
1152    fn test_drift_detector_builder() {
1153        let dataset = make_test_dataset(vec![1.0, 2.0, 3.0]);
1154        let detector = DriftDetector::new(dataset)
1155            .with_test(DriftTest::PSI)
1156            .with_test(DriftTest::ChiSquared)
1157            .with_alpha(0.01);
1158
1159        assert_eq!(detector.alpha(), 0.01);
1160        assert_eq!(detector.tests().len(), 3);
1161    }
1162
1163    #[test]
1164    fn test_drift_detector_no_duplicate_tests() {
1165        let dataset = make_test_dataset(vec![1.0, 2.0, 3.0]);
1166        let detector = DriftDetector::new(dataset)
1167            .with_test(DriftTest::KolmogorovSmirnov) // duplicate
1168            .with_test(DriftTest::KolmogorovSmirnov); // duplicate
1169
1170        assert_eq!(detector.tests().len(), 1);
1171    }
1172
1173    #[test]
1174    fn test_drift_detector_detect_no_drift() {
1175        let ref_data: Vec<f64> = (0..500).map(|i| i as f64).collect();
1176        let cur_data: Vec<f64> = (0..500).map(|i| i as f64).collect();
1177
1178        let reference = make_test_dataset(ref_data);
1179        let current = make_test_dataset(cur_data);
1180
1181        let detector = DriftDetector::new(reference);
1182        let report = detector.detect(&current).expect("detect");
1183
1184        assert!(!report.drift_detected);
1185        assert_eq!(report.num_columns(), 1);
1186    }
1187
1188    #[test]
1189    fn test_drift_detector_detect_drift() {
1190        let ref_data: Vec<f64> = (0..500).map(|i| i as f64).collect();
1191        let cur_data: Vec<f64> = (0..500).map(|i| 1000.0 + i as f64).collect();
1192
1193        let reference = make_test_dataset(ref_data);
1194        let current = make_test_dataset(cur_data);
1195
1196        let detector = DriftDetector::new(reference);
1197        let report = detector.detect(&current).expect("detect");
1198
1199        assert!(report.drift_detected);
1200        assert!(report.max_severity().is_drift());
1201    }
1202
1203    #[test]
1204    fn test_drift_detector_schema_mismatch() {
1205        let ref_dataset = make_test_dataset(vec![1.0, 2.0, 3.0]);
1206        let cur_dataset = make_int_dataset(vec![1, 2, 3]);
1207
1208        let detector = DriftDetector::new(ref_dataset);
1209        let result = detector.detect(&cur_dataset);
1210
1211        assert!(result.is_err());
1212    }
1213
1214    #[test]
1215    fn test_drift_detector_multiple_tests() {
1216        let ref_data: Vec<f64> = (0..500).map(|i| i as f64).collect();
1217        let cur_data: Vec<f64> = (0..500).map(|i| 500.0 + i as f64).collect();
1218
1219        let reference = make_test_dataset(ref_data);
1220        let current = make_test_dataset(cur_data);
1221
1222        let detector = DriftDetector::new(reference)
1223            .with_test(DriftTest::PSI)
1224            .with_test(DriftTest::JensenShannon);
1225
1226        let report = detector.detect(&current).expect("detect");
1227
1228        // Should have results for each test
1229        assert_eq!(report.num_columns(), 3); // 1 column × 3 tests
1230    }
1231
1232    // ========== Edge cases ==========
1233
1234    #[test]
1235    fn test_ks_with_nan_values() {
1236        let ref_data = vec![1.0, 2.0, f64::NAN, 4.0, 5.0];
1237        let cur_data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1238
1239        let result = ks_test(&ref_data, &cur_data, 0.05).expect("ks test");
1240        // Should handle NaN gracefully
1241        assert!(result.statistic >= 0.0);
1242    }
1243
1244    #[test]
1245    fn test_psi_with_small_sample() {
1246        let ref_data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1247        let cur_data = vec![1.0, 2.0, 3.0, 4.0, 6.0];
1248
1249        let result = psi_test(&ref_data, &cur_data).expect("psi test");
1250        // Should work with small samples
1251        assert!(result.statistic >= 0.0);
1252    }
1253
1254    #[test]
1255    fn test_bin_data_constant_values() {
1256        let ref_data = vec![5.0; 100];
1257        let cur_data = vec![5.0; 100];
1258
1259        let result = bin_data(&ref_data, &cur_data, 10).expect("bin data");
1260        // Should handle constant data (all in one bin)
1261        assert_eq!(result.0.iter().sum::<usize>(), 100);
1262        assert_eq!(result.1.iter().sum::<usize>(), 100);
1263    }
1264}