1#![allow(clippy::cast_precision_loss)]
24#![allow(clippy::cast_possible_truncation)]
25#![allow(clippy::cast_sign_loss)]
26#![allow(clippy::similar_names)]
27#![allow(clippy::unreadable_literal)]
28#![allow(clippy::suboptimal_flops)]
29
30use std::collections::HashMap;
31
32use crate::{
33 dataset::{ArrowDataset, Dataset},
34 error::{Error, Result},
35};
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
39pub enum DriftTest {
40 KolmogorovSmirnov,
42 ChiSquared,
44 PSI,
46 JensenShannon,
48}
49
50impl DriftTest {
51 pub fn name(&self) -> &'static str {
53 match self {
54 Self::KolmogorovSmirnov => "Kolmogorov-Smirnov",
55 Self::ChiSquared => "Chi-Squared",
56 Self::PSI => "Population Stability Index",
57 Self::JensenShannon => "Jensen-Shannon Divergence",
58 }
59 }
60
61 pub fn is_continuous(&self) -> bool {
63 matches!(self, Self::KolmogorovSmirnov | Self::JensenShannon)
64 }
65
66 pub fn is_categorical(&self) -> bool {
68 matches!(self, Self::ChiSquared | Self::PSI)
69 }
70}
71
72#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
74pub enum DriftSeverity {
75 None,
77 Low,
79 Medium,
81 High,
83 Critical,
85}
86
87impl DriftSeverity {
88 pub fn from_p_value(p_value: f64) -> Self {
90 if p_value > 0.05 {
91 Self::None
92 } else if p_value > 0.01 {
93 Self::Low
94 } else if p_value > 0.001 {
95 Self::Medium
96 } else if p_value > 0.0001 {
97 Self::High
98 } else {
99 Self::Critical
100 }
101 }
102
103 pub fn from_psi(psi: f64) -> Self {
105 if psi < 0.1 {
106 Self::None
107 } else if psi < 0.2 {
108 Self::Low
109 } else if psi < 0.25 {
110 Self::Medium
111 } else if psi < 0.5 {
112 Self::High
113 } else {
114 Self::Critical
115 }
116 }
117
118 pub fn is_drift(&self) -> bool {
120 *self != Self::None
121 }
122}
123
124#[derive(Debug, Clone)]
126pub struct ColumnDrift {
127 pub column: String,
129 pub test: DriftTest,
131 pub statistic: f64,
133 pub p_value: Option<f64>,
135 pub drift_detected: bool,
137 pub severity: DriftSeverity,
139}
140
141impl ColumnDrift {
142 pub fn new(
144 column: impl Into<String>,
145 test: DriftTest,
146 statistic: f64,
147 p_value: Option<f64>,
148 severity: DriftSeverity,
149 ) -> Self {
150 Self {
151 column: column.into(),
152 test,
153 statistic,
154 p_value,
155 drift_detected: severity.is_drift(),
156 severity,
157 }
158 }
159}
160
161#[derive(Debug, Clone)]
163pub struct DriftReport {
164 pub column_scores: HashMap<String, ColumnDrift>,
166 pub drift_detected: bool,
168 pub timestamp: u64,
170}
171
172impl DriftReport {
173 pub fn from_columns(columns: Vec<ColumnDrift>) -> Self {
175 let drift_detected = columns.iter().any(|c| c.drift_detected);
176 let timestamp = std::time::SystemTime::now()
177 .duration_since(std::time::UNIX_EPOCH)
178 .map(|d| d.as_secs())
179 .unwrap_or(0);
180
181 let column_scores = columns
183 .into_iter()
184 .map(|c| (format!("{}:{:?}", c.column, c.test), c))
185 .collect();
186
187 Self {
188 column_scores,
189 drift_detected,
190 timestamp,
191 }
192 }
193
194 pub fn drifted_columns(&self) -> Vec<&str> {
196 self.column_scores
197 .values()
198 .filter(|c| c.drift_detected)
199 .map(|c| c.column.as_str())
200 .collect()
201 }
202
203 pub fn max_severity(&self) -> DriftSeverity {
205 self.column_scores
206 .values()
207 .map(|c| c.severity)
208 .max()
209 .unwrap_or(DriftSeverity::None)
210 }
211
212 pub fn num_columns(&self) -> usize {
214 self.column_scores.len()
215 }
216
217 pub fn num_drifted(&self) -> usize {
219 self.column_scores
220 .values()
221 .filter(|c| c.drift_detected)
222 .count()
223 }
224}
225
226pub struct DriftDetector {
228 reference: ArrowDataset,
230 tests: Vec<DriftTest>,
232 alpha: f64,
234}
235
236impl DriftDetector {
237 pub fn new(reference: ArrowDataset) -> Self {
239 Self {
240 reference,
241 tests: vec![DriftTest::KolmogorovSmirnov],
242 alpha: 0.05,
243 }
244 }
245
246 #[must_use]
248 pub fn with_test(mut self, test: DriftTest) -> Self {
249 if !self.tests.contains(&test) {
250 self.tests.push(test);
251 }
252 self
253 }
254
255 #[must_use]
257 pub fn with_alpha(mut self, alpha: f64) -> Self {
258 self.alpha = alpha;
259 self
260 }
261
262 #[must_use]
264 pub fn with_tests(mut self, tests: Vec<DriftTest>) -> Self {
265 self.tests = tests;
266 self
267 }
268
269 pub fn reference(&self) -> &ArrowDataset {
271 &self.reference
272 }
273
274 pub fn tests(&self) -> &[DriftTest] {
276 &self.tests
277 }
278
279 pub fn alpha(&self) -> f64 {
281 self.alpha
282 }
283
284 pub fn detect(&self, current: &ArrowDataset) -> Result<DriftReport> {
286 if self.reference.schema() != current.schema() {
288 return Err(Error::invalid_config(
289 "Schema mismatch between reference and current dataset",
290 ));
291 }
292
293 let schema = self.reference.schema();
294 let mut results = Vec::new();
295
296 let ref_data = collect_dataset_data(&self.reference);
298 let cur_data = collect_dataset_data(current);
299
300 for field in schema.fields() {
302 let column_name = field.name();
303
304 let ref_col = ref_data.get(column_name);
305 let cur_col = cur_data.get(column_name);
306
307 if let (Some(ref_values), Some(cur_values)) = (ref_col, cur_col) {
308 for test in &self.tests {
310 let result = run_test(*test, ref_values, cur_values, self.alpha)?;
311 results.push(ColumnDrift::new(
312 column_name,
313 *test,
314 result.statistic,
315 result.p_value,
316 result.severity,
317 ));
318 }
319 }
320 }
321
322 Ok(DriftReport::from_columns(results))
323 }
324}
325
326struct TestResult {
328 statistic: f64,
329 p_value: Option<f64>,
330 severity: DriftSeverity,
331}
332
333fn run_test(test: DriftTest, reference: &[f64], current: &[f64], alpha: f64) -> Result<TestResult> {
335 match test {
336 DriftTest::KolmogorovSmirnov => ks_test(reference, current, alpha),
337 DriftTest::ChiSquared => chi_squared_test(reference, current, alpha),
338 DriftTest::PSI => psi_test(reference, current),
339 DriftTest::JensenShannon => jensen_shannon_test(reference, current),
340 }
341}
342
343fn ks_test(reference: &[f64], current: &[f64], alpha: f64) -> Result<TestResult> {
348 if reference.is_empty() || current.is_empty() {
349 return Err(Error::invalid_config(
350 "Cannot perform KS test on empty data",
351 ));
352 }
353
354 let mut ref_sorted: Vec<f64> = reference
356 .iter()
357 .copied()
358 .filter(|x| x.is_finite())
359 .collect();
360 let mut cur_sorted: Vec<f64> = current.iter().copied().filter(|x| x.is_finite()).collect();
361
362 if ref_sorted.is_empty() || cur_sorted.is_empty() {
363 return Err(Error::invalid_config("No finite values in data"));
364 }
365
366 ref_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
367 cur_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
368
369 let n1 = ref_sorted.len() as f64;
370 let n2 = cur_sorted.len() as f64;
371
372 let d_statistic = compute_ks_statistic(&ref_sorted, &cur_sorted);
374
375 let en = (n1 * n2 / (n1 + n2)).sqrt();
377 let p_value = ks_p_value(d_statistic * en);
378
379 let severity = if p_value <= alpha {
380 DriftSeverity::from_p_value(p_value)
381 } else {
382 DriftSeverity::None
383 };
384
385 Ok(TestResult {
386 statistic: d_statistic,
387 p_value: Some(p_value),
388 severity,
389 })
390}
391
392fn compute_ks_statistic(ref_sorted: &[f64], cur_sorted: &[f64]) -> f64 {
394 let n1 = ref_sorted.len();
395 let n2 = cur_sorted.len();
396
397 let mut all_values: Vec<f64> = ref_sorted
399 .iter()
400 .chain(cur_sorted.iter())
401 .copied()
402 .collect();
403 all_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
404 all_values.dedup();
405
406 let mut max_diff = 0.0_f64;
407
408 for &x in &all_values {
409 let cdf1 = ref_sorted.iter().filter(|&&v| v <= x).count() as f64 / n1 as f64;
411 let cdf2 = cur_sorted.iter().filter(|&&v| v <= x).count() as f64 / n2 as f64;
413
414 let diff = (cdf1 - cdf2).abs();
415 if diff > max_diff {
416 max_diff = diff;
417 }
418 }
419
420 max_diff
421}
422
423fn ks_p_value(z: f64) -> f64 {
425 if z <= 0.0 {
426 return 1.0;
427 }
428 if z > 3.0 {
429 return 0.0;
430 }
431
432 let mut p = 0.0;
434 let z_sq = z * z;
435
436 for k in 1..=100 {
437 let k_f = f64::from(k);
438 let term = (-1.0_f64).powi(k - 1) * (-2.0 * k_f * k_f * z_sq).exp();
439 p += term;
440 if term.abs() < 1e-12 {
441 break;
442 }
443 }
444
445 (2.0 * p).clamp(0.0, 1.0)
446}
447
448fn chi_squared_test(reference: &[f64], current: &[f64], alpha: f64) -> Result<TestResult> {
452 if reference.is_empty() || current.is_empty() {
453 return Err(Error::invalid_config(
454 "Cannot perform chi-squared test on empty data",
455 ));
456 }
457
458 let num_bins = ((reference.len() as f64).sqrt().ceil() as usize).clamp(5, 20);
460 let (ref_bins, cur_bins) = bin_data(reference, current, num_bins)?;
461
462 let n_ref = reference.len() as f64;
464 let n_cur = current.len() as f64;
465 let total = n_ref + n_cur;
466
467 let mut chi_sq = 0.0;
468 let mut df: usize = 0;
469
470 for (r, c) in ref_bins.iter().zip(cur_bins.iter()) {
471 let r = *r as f64;
472 let c = *c as f64;
473 let row_total = r + c;
474
475 if row_total > 0.0 {
476 let expected_r = row_total * n_ref / total;
477 let expected_c = row_total * n_cur / total;
478
479 if expected_r > 0.0 {
480 chi_sq += (r - expected_r).powi(2) / expected_r;
481 }
482 if expected_c > 0.0 {
483 chi_sq += (c - expected_c).powi(2) / expected_c;
484 }
485 df += 1;
486 }
487 }
488
489 df = df.saturating_sub(1); let p_value = chi_squared_p_value(chi_sq, df);
493
494 let severity = if p_value <= alpha {
495 DriftSeverity::from_p_value(p_value)
496 } else {
497 DriftSeverity::None
498 };
499
500 Ok(TestResult {
501 statistic: chi_sq,
502 p_value: Some(p_value),
503 severity,
504 })
505}
506
507fn bin_data(
509 reference: &[f64],
510 current: &[f64],
511 num_bins: usize,
512) -> Result<(Vec<usize>, Vec<usize>)> {
513 let all_data: Vec<f64> = reference
515 .iter()
516 .chain(current.iter())
517 .copied()
518 .filter(|x| x.is_finite())
519 .collect();
520
521 if all_data.is_empty() {
522 return Err(Error::invalid_config("No finite values in data"));
523 }
524
525 let min_val = all_data.iter().copied().fold(f64::INFINITY, f64::min);
526 let max_val = all_data.iter().copied().fold(f64::NEG_INFINITY, f64::max);
527
528 if (max_val - min_val).abs() < f64::EPSILON {
529 return Ok((vec![reference.len()], vec![current.len()]));
531 }
532
533 let bin_width = (max_val - min_val) / num_bins as f64;
534
535 let bin_value = |v: f64| -> usize {
536 if !v.is_finite() {
537 return 0;
538 }
539 let bin = ((v - min_val) / bin_width).floor() as usize;
540 bin.min(num_bins - 1)
541 };
542
543 let mut ref_bins = vec![0usize; num_bins];
544 let mut cur_bins = vec![0usize; num_bins];
545
546 for &v in reference {
547 ref_bins[bin_value(v)] += 1;
548 }
549 for &v in current {
550 cur_bins[bin_value(v)] += 1;
551 }
552
553 Ok((ref_bins, cur_bins))
554}
555
556fn chi_squared_p_value(chi_sq: f64, df: usize) -> f64 {
558 if df == 0 {
559 return 1.0;
560 }
561
562 let k = df as f64;
563
564 let z = ((chi_sq / k).cbrt() - (1.0 - 2.0 / (9.0 * k))) / (2.0 / (9.0 * k)).sqrt();
566
567 1.0 - standard_normal_cdf(z)
569}
570
571fn standard_normal_cdf(z: f64) -> f64 {
573 0.5 * (1.0 + erf(z / std::f64::consts::SQRT_2))
575}
576
577fn erf(x: f64) -> f64 {
579 let a1 = 0.254829592;
581 let a2 = -0.284496736;
582 let a3 = 1.421413741;
583 let a4 = -1.453152027;
584 let a5 = 1.061405429;
585 let p = 0.3275911;
586
587 let sign = if x < 0.0 { -1.0 } else { 1.0 };
588 let x = x.abs();
589
590 let t = 1.0 / (1.0 + p * x);
591 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
592
593 sign * y
594}
595
596fn psi_test(reference: &[f64], current: &[f64]) -> Result<TestResult> {
603 if reference.is_empty() || current.is_empty() {
604 return Err(Error::invalid_config("Cannot compute PSI on empty data"));
605 }
606
607 let num_bins = 10;
609 let (ref_bins, cur_bins) = bin_data(reference, current, num_bins)?;
610
611 let n_ref = reference.len() as f64;
612 let n_cur = current.len() as f64;
613
614 let mut psi = 0.0;
615
616 for (r, c) in ref_bins.iter().zip(cur_bins.iter()) {
617 let p_ref = (*r as f64 + 0.5) / (n_ref + num_bins as f64 * 0.5);
619 let p_cur = (*c as f64 + 0.5) / (n_cur + num_bins as f64 * 0.5);
620
621 psi += (p_cur - p_ref) * (p_cur / p_ref).ln();
622 }
623
624 let severity = DriftSeverity::from_psi(psi);
625
626 Ok(TestResult {
627 statistic: psi,
628 p_value: None, severity,
630 })
631}
632
633fn jensen_shannon_test(reference: &[f64], current: &[f64]) -> Result<TestResult> {
639 if reference.is_empty() || current.is_empty() {
640 return Err(Error::invalid_config("Cannot compute JSD on empty data"));
641 }
642
643 let num_bins = 20;
645 let (ref_bins, cur_bins) = bin_data(reference, current, num_bins)?;
646
647 let n_ref = reference.len() as f64;
648 let n_cur = current.len() as f64;
649
650 let p: Vec<f64> = ref_bins
652 .iter()
653 .map(|&c| (c as f64 + 0.5) / (n_ref + num_bins as f64 * 0.5))
654 .collect();
655 let q: Vec<f64> = cur_bins
656 .iter()
657 .map(|&c| (c as f64 + 0.5) / (n_cur + num_bins as f64 * 0.5))
658 .collect();
659
660 let m: Vec<f64> = p
662 .iter()
663 .zip(q.iter())
664 .map(|(pi, qi)| (pi + qi) / 2.0)
665 .collect();
666
667 let kl_pm: f64 = p
669 .iter()
670 .zip(m.iter())
671 .map(|(pi, mi)| if *pi > 0.0 { pi * (pi / mi).ln() } else { 0.0 })
672 .sum();
673
674 let kl_qm: f64 = q
675 .iter()
676 .zip(m.iter())
677 .map(|(qi, mi)| if *qi > 0.0 { qi * (qi / mi).ln() } else { 0.0 })
678 .sum();
679
680 let jsd = 0.5 * kl_pm + 0.5 * kl_qm;
681
682 let jsd_normalized = jsd / std::f64::consts::LN_2;
684
685 let severity = if jsd_normalized < 0.05 {
687 DriftSeverity::None
688 } else if jsd_normalized < 0.1 {
689 DriftSeverity::Low
690 } else if jsd_normalized < 0.2 {
691 DriftSeverity::Medium
692 } else if jsd_normalized < 0.4 {
693 DriftSeverity::High
694 } else {
695 DriftSeverity::Critical
696 };
697
698 Ok(TestResult {
699 statistic: jsd_normalized,
700 p_value: None, severity,
702 })
703}
704
705fn extract_numeric_values(
708 array: &dyn arrow::array::Array,
709 data_type: &arrow::datatypes::DataType,
710 out: &mut Vec<f64>,
711) {
712 use arrow::{
713 array::{Array, Float64Array, Int32Array, Int64Array},
714 datatypes::DataType,
715 };
716
717 match data_type {
718 DataType::Float64 => {
719 if let Some(arr) = array.as_any().downcast_ref::<Float64Array>() {
720 out.extend(
721 (0..arr.len())
722 .filter(|&i| !arr.is_null(i))
723 .map(|i| arr.value(i)),
724 );
725 }
726 }
727 DataType::Float32 => {
728 if let Some(arr) = array.as_any().downcast_ref::<arrow::array::Float32Array>() {
729 out.extend(
730 (0..arr.len())
731 .filter(|&i| !arr.is_null(i))
732 .map(|i| f64::from(arr.value(i))),
733 );
734 }
735 }
736 DataType::Int32 => {
737 if let Some(arr) = array.as_any().downcast_ref::<Int32Array>() {
738 out.extend(
739 (0..arr.len())
740 .filter(|&i| !arr.is_null(i))
741 .map(|i| f64::from(arr.value(i))),
742 );
743 }
744 }
745 DataType::Int64 => {
746 if let Some(arr) = array.as_any().downcast_ref::<Int64Array>() {
747 out.extend(
748 (0..arr.len())
749 .filter(|&i| !arr.is_null(i))
750 .map(|i| arr.value(i) as f64),
751 );
752 }
753 }
754 _ => {}
755 }
756}
757
758fn collect_dataset_data(dataset: &ArrowDataset) -> HashMap<String, Vec<f64>> {
760 use arrow::datatypes::DataType;
761
762 let mut data: HashMap<String, Vec<f64>> = HashMap::new();
763 let schema = dataset.schema();
764
765 for field in schema.fields() {
767 if matches!(
768 field.data_type(),
769 DataType::Int32 | DataType::Int64 | DataType::Float64 | DataType::Float32
770 ) {
771 data.insert(field.name().clone(), Vec::new());
772 }
773 }
774
775 for batch in dataset.iter() {
777 for (col_idx, field) in schema.fields().iter().enumerate() {
778 if let Some(col_data) = data.get_mut(field.name()) {
779 extract_numeric_values(batch.column(col_idx), field.data_type(), col_data);
780 }
781 }
782 }
783
784 data
785}
786
787#[cfg(test)]
788mod tests {
789 use std::sync::Arc;
790
791 use arrow::{
792 array::{Float64Array, Int32Array},
793 datatypes::{DataType, Field, Schema},
794 record_batch::RecordBatch,
795 };
796
797 use super::*;
798
799 #[test]
802 fn test_drift_test_name() {
803 assert_eq!(DriftTest::KolmogorovSmirnov.name(), "Kolmogorov-Smirnov");
804 assert_eq!(DriftTest::ChiSquared.name(), "Chi-Squared");
805 assert_eq!(DriftTest::PSI.name(), "Population Stability Index");
806 assert_eq!(DriftTest::JensenShannon.name(), "Jensen-Shannon Divergence");
807 }
808
809 #[test]
810 fn test_drift_test_is_continuous() {
811 assert!(DriftTest::KolmogorovSmirnov.is_continuous());
812 assert!(DriftTest::JensenShannon.is_continuous());
813 assert!(!DriftTest::ChiSquared.is_continuous());
814 assert!(!DriftTest::PSI.is_continuous());
815 }
816
817 #[test]
818 fn test_drift_test_is_categorical() {
819 assert!(DriftTest::ChiSquared.is_categorical());
820 assert!(DriftTest::PSI.is_categorical());
821 assert!(!DriftTest::KolmogorovSmirnov.is_categorical());
822 assert!(!DriftTest::JensenShannon.is_categorical());
823 }
824
825 #[test]
828 fn test_drift_severity_from_p_value() {
829 assert_eq!(DriftSeverity::from_p_value(0.1), DriftSeverity::None);
830 assert_eq!(DriftSeverity::from_p_value(0.06), DriftSeverity::None);
831 assert_eq!(DriftSeverity::from_p_value(0.04), DriftSeverity::Low);
832 assert_eq!(DriftSeverity::from_p_value(0.005), DriftSeverity::Medium);
833 assert_eq!(DriftSeverity::from_p_value(0.0005), DriftSeverity::High);
834 assert_eq!(
835 DriftSeverity::from_p_value(0.00001),
836 DriftSeverity::Critical
837 );
838 }
839
840 #[test]
841 fn test_drift_severity_from_psi() {
842 assert_eq!(DriftSeverity::from_psi(0.05), DriftSeverity::None);
843 assert_eq!(DriftSeverity::from_psi(0.15), DriftSeverity::Low);
844 assert_eq!(DriftSeverity::from_psi(0.22), DriftSeverity::Medium);
845 assert_eq!(DriftSeverity::from_psi(0.35), DriftSeverity::High);
846 assert_eq!(DriftSeverity::from_psi(0.6), DriftSeverity::Critical);
847 }
848
849 #[test]
850 fn test_drift_severity_is_drift() {
851 assert!(!DriftSeverity::None.is_drift());
852 assert!(DriftSeverity::Low.is_drift());
853 assert!(DriftSeverity::Medium.is_drift());
854 assert!(DriftSeverity::High.is_drift());
855 assert!(DriftSeverity::Critical.is_drift());
856 }
857
858 #[test]
859 fn test_drift_severity_ordering() {
860 assert!(DriftSeverity::None < DriftSeverity::Low);
861 assert!(DriftSeverity::Low < DriftSeverity::Medium);
862 assert!(DriftSeverity::Medium < DriftSeverity::High);
863 assert!(DriftSeverity::High < DriftSeverity::Critical);
864 }
865
866 #[test]
869 fn test_column_drift_new() {
870 let drift = ColumnDrift::new(
871 "age",
872 DriftTest::KolmogorovSmirnov,
873 0.15,
874 Some(0.03),
875 DriftSeverity::Low,
876 );
877
878 assert_eq!(drift.column, "age");
879 assert_eq!(drift.test, DriftTest::KolmogorovSmirnov);
880 assert!((drift.statistic - 0.15).abs() < f64::EPSILON);
881 assert_eq!(drift.p_value, Some(0.03));
882 assert!(drift.drift_detected);
883 assert_eq!(drift.severity, DriftSeverity::Low);
884 }
885
886 #[test]
887 fn test_column_drift_no_drift() {
888 let drift = ColumnDrift::new("income", DriftTest::PSI, 0.05, None, DriftSeverity::None);
889
890 assert!(!drift.drift_detected);
891 assert_eq!(drift.severity, DriftSeverity::None);
892 }
893
894 #[test]
897 fn test_drift_report_from_columns() {
898 let columns = vec![
899 ColumnDrift::new(
900 "a",
901 DriftTest::KolmogorovSmirnov,
902 0.1,
903 Some(0.5),
904 DriftSeverity::None,
905 ),
906 ColumnDrift::new("b", DriftTest::PSI, 0.25, None, DriftSeverity::Medium),
907 ];
908
909 let report = DriftReport::from_columns(columns);
910
911 assert!(report.drift_detected);
912 assert_eq!(report.num_columns(), 2);
913 assert_eq!(report.num_drifted(), 1);
914 assert_eq!(report.max_severity(), DriftSeverity::Medium);
915 }
916
917 #[test]
918 fn test_drift_report_no_drift() {
919 let columns = vec![
920 ColumnDrift::new(
921 "a",
922 DriftTest::KolmogorovSmirnov,
923 0.05,
924 Some(0.5),
925 DriftSeverity::None,
926 ),
927 ColumnDrift::new("b", DriftTest::PSI, 0.05, None, DriftSeverity::None),
928 ];
929
930 let report = DriftReport::from_columns(columns);
931
932 assert!(!report.drift_detected);
933 assert_eq!(report.num_drifted(), 0);
934 assert_eq!(report.max_severity(), DriftSeverity::None);
935 }
936
937 #[test]
938 fn test_drift_report_drifted_columns() {
939 let columns = vec![
940 ColumnDrift::new(
941 "a",
942 DriftTest::KolmogorovSmirnov,
943 0.1,
944 Some(0.5),
945 DriftSeverity::None,
946 ),
947 ColumnDrift::new("b", DriftTest::PSI, 0.3, None, DriftSeverity::High),
948 ColumnDrift::new(
949 "c",
950 DriftTest::ChiSquared,
951 50.0,
952 Some(0.001),
953 DriftSeverity::Medium,
954 ),
955 ];
956
957 let report = DriftReport::from_columns(columns);
958 let drifted = report.drifted_columns();
959
960 assert_eq!(drifted.len(), 2);
961 assert!(drifted.contains(&"b"));
962 assert!(drifted.contains(&"c"));
963 }
964
965 #[test]
968 fn test_ks_identical_distributions() {
969 let data: Vec<f64> = (0..1000).map(|i| i as f64).collect();
970 let result = ks_test(&data, &data, 0.05).expect("ks test");
971
972 assert!(
973 result.statistic < 0.01,
974 "KS statistic should be ~0 for identical data"
975 );
976 assert!(
977 result.p_value.unwrap_or(0.0) > 0.05,
978 "p-value should be high"
979 );
980 assert_eq!(result.severity, DriftSeverity::None);
981 }
982
983 #[test]
984 fn test_ks_different_distributions() {
985 let ref_data: Vec<f64> = (0..1000).map(|i| i as f64 / 10.0).collect();
987 let cur_data: Vec<f64> = (0..1000).map(|i| 50.0 + i as f64 / 10.0).collect();
988
989 let result = ks_test(&ref_data, &cur_data, 0.05).expect("ks test");
990
991 assert!(
992 result.statistic > 0.3,
993 "KS statistic should be large for shifted data"
994 );
995 assert!(
996 result.p_value.unwrap_or(1.0) < 0.05,
997 "p-value should be small"
998 );
999 assert!(result.severity.is_drift());
1000 }
1001
1002 #[test]
1003 fn test_ks_empty_data_error() {
1004 let empty: Vec<f64> = vec![];
1005 let data = vec![1.0, 2.0, 3.0];
1006
1007 assert!(ks_test(&empty, &data, 0.05).is_err());
1008 assert!(ks_test(&data, &empty, 0.05).is_err());
1009 }
1010
1011 #[test]
1014 fn test_chi_squared_identical_distributions() {
1015 let data: Vec<f64> = (0..1000).map(|i| (i % 10) as f64).collect();
1016 let result = chi_squared_test(&data, &data, 0.05).expect("chi-squared test");
1017
1018 assert!(
1020 result.statistic < 1.0,
1021 "Chi-squared should be small for identical data"
1022 );
1023 assert!(result.p_value.unwrap_or(0.0) > 0.05);
1024 assert_eq!(result.severity, DriftSeverity::None);
1025 }
1026
1027 #[test]
1028 fn test_chi_squared_different_distributions() {
1029 let ref_data: Vec<f64> = (0..1000).map(|_| 0.0).collect();
1031 let cur_data: Vec<f64> = (0..1000).map(|_| 100.0).collect();
1032
1033 let result = chi_squared_test(&ref_data, &cur_data, 0.05).expect("chi-squared test");
1034
1035 assert!(result.statistic > 100.0, "Chi-squared should be large");
1036 assert!(result.p_value.unwrap_or(1.0) < 0.001);
1037 assert!(result.severity.is_drift());
1038 }
1039
1040 #[test]
1043 fn test_psi_identical_distributions() {
1044 let data: Vec<f64> = (0..1000).map(|i| i as f64).collect();
1045 let result = psi_test(&data, &data).expect("psi test");
1046
1047 assert!(
1048 result.statistic < 0.05,
1049 "PSI should be ~0 for identical data"
1050 );
1051 assert_eq!(result.severity, DriftSeverity::None);
1052 }
1053
1054 #[test]
1055 fn test_psi_shifted_distribution() {
1056 let ref_data: Vec<f64> = (0..1000).map(|i| i as f64).collect();
1057 let cur_data: Vec<f64> = (0..1000).map(|i| 500.0 + i as f64).collect();
1058
1059 let result = psi_test(&ref_data, &cur_data).expect("psi test");
1060
1061 assert!(
1062 result.statistic > 0.2,
1063 "PSI should indicate drift: {}",
1064 result.statistic
1065 );
1066 assert!(result.severity.is_drift());
1067 }
1068
1069 #[test]
1070 fn test_psi_moderate_shift() {
1071 let ref_data: Vec<f64> = (0..1000).map(|i| i as f64).collect();
1073 let cur_data: Vec<f64> = (0..1000).map(|i| i as f64 * 1.1 + 50.0).collect();
1074
1075 let result = psi_test(&ref_data, &cur_data).expect("psi test");
1076
1077 assert!(result.statistic > 0.0, "PSI should be positive");
1079 }
1080
1081 #[test]
1084 fn test_jsd_identical_distributions() {
1085 let data: Vec<f64> = (0..1000).map(|i| i as f64).collect();
1086 let result = jensen_shannon_test(&data, &data).expect("jsd test");
1087
1088 assert!(
1089 result.statistic < 0.01,
1090 "JSD should be ~0 for identical data"
1091 );
1092 assert_eq!(result.severity, DriftSeverity::None);
1093 }
1094
1095 #[test]
1096 fn test_jsd_different_distributions() {
1097 let ref_data: Vec<f64> = (0..1000).map(|i| i as f64).collect();
1099 let cur_data: Vec<f64> = (0..1000).map(|i| 10000.0 + i as f64).collect();
1100
1101 let result = jensen_shannon_test(&ref_data, &cur_data).expect("jsd test");
1102
1103 assert!(
1104 result.statistic > 0.5,
1105 "JSD should be high for non-overlapping: {}",
1106 result.statistic
1107 );
1108 assert!(result.severity.is_drift());
1109 }
1110
1111 fn make_test_dataset(values: Vec<f64>) -> ArrowDataset {
1114 let schema = Arc::new(Schema::new(vec![Field::new(
1115 "value",
1116 DataType::Float64,
1117 false,
1118 )]));
1119
1120 let batch = RecordBatch::try_new(
1121 Arc::clone(&schema),
1122 vec![Arc::new(Float64Array::from(values))],
1123 )
1124 .expect("batch");
1125
1126 ArrowDataset::from_batch(batch).expect("dataset")
1127 }
1128
1129 fn make_int_dataset(values: Vec<i32>) -> ArrowDataset {
1130 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
1131
1132 let batch = RecordBatch::try_new(
1133 Arc::clone(&schema),
1134 vec![Arc::new(Int32Array::from(values))],
1135 )
1136 .expect("batch");
1137
1138 ArrowDataset::from_batch(batch).expect("dataset")
1139 }
1140
1141 #[test]
1142 fn test_drift_detector_new() {
1143 let dataset = make_test_dataset(vec![1.0, 2.0, 3.0]);
1144 let detector = DriftDetector::new(dataset);
1145
1146 assert_eq!(detector.alpha(), 0.05);
1147 assert_eq!(detector.tests().len(), 1);
1148 assert_eq!(detector.tests()[0], DriftTest::KolmogorovSmirnov);
1149 }
1150
1151 #[test]
1152 fn test_drift_detector_builder() {
1153 let dataset = make_test_dataset(vec![1.0, 2.0, 3.0]);
1154 let detector = DriftDetector::new(dataset)
1155 .with_test(DriftTest::PSI)
1156 .with_test(DriftTest::ChiSquared)
1157 .with_alpha(0.01);
1158
1159 assert_eq!(detector.alpha(), 0.01);
1160 assert_eq!(detector.tests().len(), 3);
1161 }
1162
1163 #[test]
1164 fn test_drift_detector_no_duplicate_tests() {
1165 let dataset = make_test_dataset(vec![1.0, 2.0, 3.0]);
1166 let detector = DriftDetector::new(dataset)
1167 .with_test(DriftTest::KolmogorovSmirnov) .with_test(DriftTest::KolmogorovSmirnov); assert_eq!(detector.tests().len(), 1);
1171 }
1172
1173 #[test]
1174 fn test_drift_detector_detect_no_drift() {
1175 let ref_data: Vec<f64> = (0..500).map(|i| i as f64).collect();
1176 let cur_data: Vec<f64> = (0..500).map(|i| i as f64).collect();
1177
1178 let reference = make_test_dataset(ref_data);
1179 let current = make_test_dataset(cur_data);
1180
1181 let detector = DriftDetector::new(reference);
1182 let report = detector.detect(¤t).expect("detect");
1183
1184 assert!(!report.drift_detected);
1185 assert_eq!(report.num_columns(), 1);
1186 }
1187
1188 #[test]
1189 fn test_drift_detector_detect_drift() {
1190 let ref_data: Vec<f64> = (0..500).map(|i| i as f64).collect();
1191 let cur_data: Vec<f64> = (0..500).map(|i| 1000.0 + i as f64).collect();
1192
1193 let reference = make_test_dataset(ref_data);
1194 let current = make_test_dataset(cur_data);
1195
1196 let detector = DriftDetector::new(reference);
1197 let report = detector.detect(¤t).expect("detect");
1198
1199 assert!(report.drift_detected);
1200 assert!(report.max_severity().is_drift());
1201 }
1202
1203 #[test]
1204 fn test_drift_detector_schema_mismatch() {
1205 let ref_dataset = make_test_dataset(vec![1.0, 2.0, 3.0]);
1206 let cur_dataset = make_int_dataset(vec![1, 2, 3]);
1207
1208 let detector = DriftDetector::new(ref_dataset);
1209 let result = detector.detect(&cur_dataset);
1210
1211 assert!(result.is_err());
1212 }
1213
1214 #[test]
1215 fn test_drift_detector_multiple_tests() {
1216 let ref_data: Vec<f64> = (0..500).map(|i| i as f64).collect();
1217 let cur_data: Vec<f64> = (0..500).map(|i| 500.0 + i as f64).collect();
1218
1219 let reference = make_test_dataset(ref_data);
1220 let current = make_test_dataset(cur_data);
1221
1222 let detector = DriftDetector::new(reference)
1223 .with_test(DriftTest::PSI)
1224 .with_test(DriftTest::JensenShannon);
1225
1226 let report = detector.detect(¤t).expect("detect");
1227
1228 assert_eq!(report.num_columns(), 3); }
1231
1232 #[test]
1235 fn test_ks_with_nan_values() {
1236 let ref_data = vec![1.0, 2.0, f64::NAN, 4.0, 5.0];
1237 let cur_data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1238
1239 let result = ks_test(&ref_data, &cur_data, 0.05).expect("ks test");
1240 assert!(result.statistic >= 0.0);
1242 }
1243
1244 #[test]
1245 fn test_psi_with_small_sample() {
1246 let ref_data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1247 let cur_data = vec![1.0, 2.0, 3.0, 4.0, 6.0];
1248
1249 let result = psi_test(&ref_data, &cur_data).expect("psi test");
1250 assert!(result.statistic >= 0.0);
1252 }
1253
1254 #[test]
1255 fn test_bin_data_constant_values() {
1256 let ref_data = vec![5.0; 100];
1257 let cur_data = vec![5.0; 100];
1258
1259 let result = bin_data(&ref_data, &cur_data, 10).expect("bin data");
1260 assert_eq!(result.0.iter().sum::<usize>(), 100);
1262 assert_eq!(result.1.iter().sum::<usize>(), 100);
1263 }
1264}