1use super::config::*;
8use super::optimizer::{Adaptation, AdaptationPriority, AdaptationType, StreamingDataPoint};
9
10use scirs2_core::numeric::Float;
11use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, VecDeque};
13use std::time::{Duration, Instant};
14
15pub struct EnhancedDriftDetector<A: Float + Send + Sync> {
17 config: DriftConfig,
19 detection_method: DriftDetectionMethod,
21 statistical_tests: HashMap<StatisticalMethod, Box<dyn StatisticalTest<A>>>,
23 distribution_methods: HashMap<DistributionMethod, Box<dyn DistributionComparator<A>>>,
25 model_detectors: HashMap<ModelType, Box<dyn ModelBasedDetector<A>>>,
27 ensemble_strategy: Option<VotingStrategy>,
29 detection_history: VecDeque<DriftEvent<A>>,
31 false_positive_tracker: FalsePositiveTracker<A>,
33 reference_window: VecDeque<StreamingDataPoint<A>>,
35 drift_state: DriftState,
37 last_detection: Option<Instant>,
39 sensitivity_factor: A,
41}
42
43#[derive(Debug, Clone)]
45pub struct DriftEvent<A: Float + Send + Sync> {
46 pub timestamp: Instant,
48 pub severity: DriftSeverity,
50 pub confidence: A,
52 pub detection_method: String,
54 pub p_value: Option<A>,
56 pub magnitude: A,
58 pub affected_features: Vec<usize>,
60}
61
62#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
64pub enum DriftSeverity {
65 Minor,
67 Moderate,
69 Major,
71 Critical,
73}
74
75#[derive(Debug, Clone, PartialEq, Eq)]
77pub enum DriftState {
78 Stable,
80 Warning,
82 Drift,
84 Recovery,
86}
87
88pub struct FalsePositiveTracker<A: Float + Send + Sync> {
90 false_positives: VecDeque<Instant>,
92 true_positives: VecDeque<Instant>,
94 current_fp_rate: A,
96 target_fp_rate: A,
98}
99
100pub trait StatisticalTest<A: Float + Send + Sync>: Send + Sync {
102 fn test_for_drift(
104 &mut self,
105 reference: &[A],
106 current: &[A],
107 ) -> Result<DriftTestResult<A>, String>;
108
109 fn update_parameters(&mut self, performance_feedback: A) -> Result<(), String>;
111
112 fn reset(&mut self);
114}
115
116#[derive(Debug, Clone)]
118pub struct DriftTestResult<A: Float + Send + Sync> {
119 pub drift_detected: bool,
121 pub p_value: A,
123 pub test_statistic: A,
125 pub confidence: A,
127 pub metadata: HashMap<String, A>,
129}
130
131pub trait DistributionComparator<A: Float + Send + Sync>: Send + Sync {
133 fn compare_distributions(
135 &self,
136 reference: &[A],
137 current: &[A],
138 ) -> Result<DistributionComparison<A>, String>;
139
140 fn get_threshold(&self) -> A;
142
143 fn update_threshold(&mut self, new_threshold: A);
145}
146
147#[derive(Debug, Clone)]
149pub struct DistributionComparison<A: Float + Send + Sync> {
150 pub distance: A,
152 pub threshold: A,
154 pub drift_detected: bool,
156 pub confidence: A,
158}
159
160pub trait ModelBasedDetector<A: Float + Send + Sync>: Send + Sync {
162 fn update_model(&mut self, data: &[StreamingDataPoint<A>]) -> Result<(), String>;
164
165 fn detect_drift(
167 &mut self,
168 data: &[StreamingDataPoint<A>],
169 ) -> Result<ModelDriftResult<A>, String>;
170
171 fn reset_model(&mut self) -> Result<(), String>;
173}
174
175#[derive(Debug, Clone)]
177pub struct ModelDriftResult<A: Float + Send + Sync> {
178 pub drift_detected: bool,
180 pub performance_degradation: A,
182 pub confidence: A,
184 pub feature_importance_changes: Vec<A>,
186}
187
188impl<A: Float + Default + Clone + Send + Sync + std::iter::Sum + 'static> EnhancedDriftDetector<A> {
189 pub fn new(config: &StreamingConfig) -> Result<Self, String> {
191 let drift_config = config.drift_config.clone();
192
193 let mut statistical_tests: HashMap<StatisticalMethod, Box<dyn StatisticalTest<A>>> =
194 HashMap::new();
195 let mut distribution_methods: HashMap<
196 DistributionMethod,
197 Box<dyn DistributionComparator<A>>,
198 > = HashMap::new();
199 let mut model_detectors: HashMap<ModelType, Box<dyn ModelBasedDetector<A>>> =
200 HashMap::new();
201
202 statistical_tests.insert(
204 StatisticalMethod::ADWIN,
205 Box::new(ADWINTest::new(drift_config.sensitivity)?),
206 );
207 statistical_tests.insert(
208 StatisticalMethod::DDM,
209 Box::new(DDMTest::new(drift_config.sensitivity)?),
210 );
211 statistical_tests.insert(
212 StatisticalMethod::PageHinkley,
213 Box::new(PageHinkleyTest::new(drift_config.sensitivity)?),
214 );
215
216 distribution_methods.insert(
218 DistributionMethod::KLDivergence,
219 Box::new(KLDivergenceComparator::new(drift_config.sensitivity)?),
220 );
221 distribution_methods.insert(
222 DistributionMethod::JSDivergence,
223 Box::new(JSDivergenceComparator::new(drift_config.sensitivity)?),
224 );
225
226 model_detectors.insert(ModelType::Linear, Box::new(LinearModelDetector::new()?));
228
229 let ensemble_strategy = match &drift_config.detection_method {
230 DriftDetectionMethod::Ensemble {
231 voting_strategy, ..
232 } => Some(voting_strategy.clone()),
233 _ => None,
234 };
235
236 let false_positive_tracker = FalsePositiveTracker::new();
237
238 Ok(Self {
239 config: drift_config.clone(),
240 detection_method: drift_config.detection_method,
241 statistical_tests,
242 distribution_methods,
243 model_detectors,
244 ensemble_strategy,
245 detection_history: VecDeque::with_capacity(1000),
246 false_positive_tracker,
247 reference_window: VecDeque::with_capacity(drift_config.window_size),
248 drift_state: DriftState::Stable,
249 last_detection: None,
250 sensitivity_factor: A::one(),
251 })
252 }
253
254 pub fn detect_drift(&mut self, batch: &[StreamingDataPoint<A>]) -> Result<bool, String> {
256 if !self.config.enable_detection || batch.len() < self.config.min_samples {
257 return Ok(false);
258 }
259
260 self.update_reference_window(batch)?;
262
263 if self.reference_window.len() < self.config.window_size / 2 {
265 return Ok(false);
266 }
267
268 let current_features = self.extract_features(batch)?;
270 let reference_features = self.extract_reference_features()?;
271
272 let detection_method = self.detection_method.clone();
274 let drift_result = match detection_method {
275 DriftDetectionMethod::Statistical(method) => {
276 self.detect_statistical_drift(&method, &reference_features, ¤t_features)?
277 }
278 DriftDetectionMethod::Distribution(method) => {
279 self.detect_distribution_drift(&method, &reference_features, ¤t_features)?
280 }
281 DriftDetectionMethod::ModelBased(model_type) => {
282 self.detect_model_drift(&model_type, batch)?
283 }
284 DriftDetectionMethod::Ensemble {
285 methods,
286 voting_strategy,
287 } => self.detect_ensemble_drift(
288 &methods,
289 &voting_strategy,
290 &reference_features,
291 ¤t_features,
292 batch,
293 )?,
294 };
295
296 if drift_result.drift_detected {
298 self.handle_drift_detection(drift_result)?;
299 Ok(true)
300 } else {
301 self.update_drift_state(false);
302 Ok(false)
303 }
304 }
305
306 fn update_reference_window(&mut self, batch: &[StreamingDataPoint<A>]) -> Result<(), String> {
308 for data_point in batch {
309 if self.reference_window.len() >= self.config.window_size {
310 self.reference_window.pop_front();
311 }
312 self.reference_window.push_back(data_point.clone());
313 }
314 Ok(())
315 }
316
317 fn extract_features(&self, batch: &[StreamingDataPoint<A>]) -> Result<Vec<A>, String> {
319 let mut features = Vec::new();
320
321 for data_point in batch {
322 features.extend(data_point.features.iter().cloned());
323 }
324
325 Ok(features)
326 }
327
328 fn extract_reference_features(&self) -> Result<Vec<A>, String> {
330 let reference_data: Vec<_> = self
331 .reference_window
332 .iter()
333 .take(self.reference_window.len() / 2)
334 .collect();
335
336 let mut features = Vec::new();
337 for data_point in reference_data {
338 features.extend(data_point.features.iter().cloned());
339 }
340
341 Ok(features)
342 }
343
344 fn detect_statistical_drift(
346 &mut self,
347 method: &StatisticalMethod,
348 reference: &[A],
349 current: &[A],
350 ) -> Result<DriftTestResult<A>, String> {
351 if let Some(test) = self.statistical_tests.get_mut(method) {
352 let mut result = test.test_for_drift(reference, current)?;
353
354 result.confidence = result.confidence * self.sensitivity_factor;
356 result.drift_detected = result.p_value
357 < A::from(self.config.significance_level).expect("unwrap failed")
358 * self.sensitivity_factor;
359
360 Ok(result)
361 } else {
362 Err(format!("Statistical method {:?} not implemented", method))
363 }
364 }
365
366 fn detect_distribution_drift(
368 &mut self,
369 method: &DistributionMethod,
370 reference: &[A],
371 current: &[A],
372 ) -> Result<DriftTestResult<A>, String> {
373 if let Some(comparator) = self.distribution_methods.get(method) {
374 let comparison = comparator.compare_distributions(reference, current)?;
375
376 let result = DriftTestResult {
377 drift_detected: comparison.drift_detected,
378 p_value: A::one() - comparison.confidence, test_statistic: comparison.distance,
380 confidence: comparison.confidence * self.sensitivity_factor,
381 metadata: HashMap::new(),
382 };
383
384 Ok(result)
385 } else {
386 Err(format!("Distribution method {:?} not implemented", method))
387 }
388 }
389
390 fn detect_model_drift(
392 &mut self,
393 model_type: &ModelType,
394 batch: &[StreamingDataPoint<A>],
395 ) -> Result<DriftTestResult<A>, String> {
396 if let Some(detector) = self.model_detectors.get_mut(model_type) {
397 let model_result = detector.detect_drift(batch)?;
398
399 let result = DriftTestResult {
400 drift_detected: model_result.drift_detected,
401 p_value: A::one() - model_result.confidence,
402 test_statistic: model_result.performance_degradation,
403 confidence: model_result.confidence * self.sensitivity_factor,
404 metadata: HashMap::new(),
405 };
406
407 Ok(result)
408 } else {
409 Err(format!("Model type {:?} not implemented", model_type))
410 }
411 }
412
413 fn detect_ensemble_drift(
415 &mut self,
416 methods: &[DriftDetectionMethod],
417 voting_strategy: &VotingStrategy,
418 reference: &[A],
419 current: &[A],
420 batch: &[StreamingDataPoint<A>],
421 ) -> Result<DriftTestResult<A>, String> {
422 let mut results = Vec::new();
423
424 for method in methods {
426 let result = match method {
427 DriftDetectionMethod::Statistical(stat_method) => {
428 self.detect_statistical_drift(stat_method, reference, current)?
429 }
430 DriftDetectionMethod::Distribution(dist_method) => {
431 self.detect_distribution_drift(dist_method, reference, current)?
432 }
433 DriftDetectionMethod::ModelBased(model_type) => {
434 self.detect_model_drift(model_type, batch)?
435 }
436 DriftDetectionMethod::Ensemble { .. } => {
437 continue;
439 }
440 };
441 results.push(result);
442 }
443
444 let ensemble_result = self.apply_voting_strategy(voting_strategy, &results)?;
446 Ok(ensemble_result)
447 }
448
449 fn apply_voting_strategy(
451 &self,
452 strategy: &VotingStrategy,
453 results: &[DriftTestResult<A>],
454 ) -> Result<DriftTestResult<A>, String> {
455 if results.is_empty() {
456 return Err("No results to vote on".to_string());
457 }
458
459 let drift_detected = match strategy {
460 VotingStrategy::Majority => {
461 let positive_votes = results.iter().filter(|r| r.drift_detected).count();
462 positive_votes > results.len() / 2
463 }
464 VotingStrategy::Weighted { weights } => {
465 if weights.len() != results.len() {
466 return Err("Number of weights doesn't match number of results".to_string());
467 }
468
469 let weighted_score: f64 = results
470 .iter()
471 .zip(weights.iter())
472 .map(|(result, &weight)| weight * if result.drift_detected { 1.0 } else { 0.0 })
473 .sum();
474
475 let total_weight: f64 = weights.iter().sum();
476 weighted_score / total_weight > 0.5
477 }
478 VotingStrategy::Unanimous => results.iter().all(|r| r.drift_detected),
479 VotingStrategy::Threshold { min_votes } => {
480 let positive_votes = results.iter().filter(|r| r.drift_detected).count();
481 positive_votes >= *min_votes
482 }
483 };
484
485 let avg_confidence = results.iter().map(|r| r.confidence).sum::<A>()
487 / A::from(results.len()).expect("unwrap failed");
488
489 let avg_p_value = results.iter().map(|r| r.p_value).sum::<A>()
490 / A::from(results.len()).expect("unwrap failed");
491
492 let avg_test_statistic = results.iter().map(|r| r.test_statistic).sum::<A>()
493 / A::from(results.len()).expect("unwrap failed");
494
495 Ok(DriftTestResult {
496 drift_detected,
497 p_value: avg_p_value,
498 test_statistic: avg_test_statistic,
499 confidence: avg_confidence,
500 metadata: HashMap::new(),
501 })
502 }
503
504 fn handle_drift_detection(&mut self, result: DriftTestResult<A>) -> Result<(), String> {
506 let severity = self.classify_drift_severity(&result);
507
508 let drift_event = DriftEvent {
509 timestamp: Instant::now(),
510 severity: severity.clone(),
511 confidence: result.confidence,
512 detection_method: format!("{:?}", self.detection_method),
513 p_value: Some(result.p_value),
514 magnitude: result.test_statistic,
515 affected_features: Vec::new(), };
517
518 if self.detection_history.len() >= 1000 {
520 self.detection_history.pop_front();
521 }
522 self.detection_history.push_back(drift_event);
523
524 self.update_drift_state(true);
526 self.last_detection = Some(Instant::now());
527
528 if self.config.enable_false_positive_tracking {
530 self.false_positive_tracker.record_detection(true)?;
531 }
532
533 Ok(())
534 }
535
536 fn classify_drift_severity(&self, result: &DriftTestResult<A>) -> DriftSeverity {
538 let confidence = result.confidence.to_f64().unwrap_or(0.0);
539 let p_value = result.p_value.to_f64().unwrap_or(1.0);
540
541 if p_value < 0.001 && confidence > 0.95 {
542 DriftSeverity::Critical
543 } else if p_value < 0.01 && confidence > 0.9 {
544 DriftSeverity::Major
545 } else if p_value < 0.05 && confidence > 0.8 {
546 DriftSeverity::Moderate
547 } else {
548 DriftSeverity::Minor
549 }
550 }
551
552 fn update_drift_state(&mut self, drift_detected: bool) {
554 self.drift_state = match (&self.drift_state, drift_detected) {
555 (DriftState::Stable, true) => DriftState::Warning,
556 (DriftState::Warning, true) => DriftState::Drift,
557 (DriftState::Drift, false) => DriftState::Recovery,
558 (DriftState::Recovery, false) => DriftState::Stable,
559 (state, _) => state.clone(),
560 };
561 }
562
563 pub fn compute_sensitivity_adaptation(&mut self) -> Result<Option<Adaptation<A>>, String> {
565 if self.config.enable_false_positive_tracking {
567 let current_fp_rate = self.false_positive_tracker.current_fp_rate;
568 let target_fp_rate = A::from(0.05).expect("unwrap failed"); if (current_fp_rate - target_fp_rate).abs() > A::from(0.02).expect("unwrap failed") {
571 let adjustment = if current_fp_rate > target_fp_rate {
572 -A::from(0.1).expect("unwrap failed")
574 } else {
575 A::from(0.1).expect("unwrap failed")
577 };
578
579 let adaptation = Adaptation {
580 adaptation_type: AdaptationType::DriftSensitivity,
581 magnitude: adjustment,
582 target_component: "drift_detector".to_string(),
583 parameters: HashMap::new(),
584 priority: AdaptationPriority::Normal,
585 timestamp: Instant::now(),
586 };
587
588 return Ok(Some(adaptation));
589 }
590 }
591
592 Ok(None)
593 }
594
595 pub fn apply_sensitivity_adaptation(
597 &mut self,
598 adaptation: &Adaptation<A>,
599 ) -> Result<(), String> {
600 if adaptation.adaptation_type == AdaptationType::DriftSensitivity {
601 self.sensitivity_factor = (self.sensitivity_factor + adaptation.magnitude)
602 .max(A::from(0.1).expect("unwrap failed"))
603 .min(A::from(2.0).expect("unwrap failed"));
604 }
605 Ok(())
606 }
607
608 pub fn is_drift_detected(&self) -> bool {
610 matches!(self.drift_state, DriftState::Drift | DriftState::Warning)
611 }
612
613 pub fn get_drift_state(&self) -> &DriftState {
615 &self.drift_state
616 }
617
618 pub fn get_recent_drift_events(&self, count: usize) -> Vec<&DriftEvent<A>> {
620 self.detection_history.iter().rev().take(count).collect()
621 }
622
623 pub fn reset(&mut self) -> Result<(), String> {
625 self.detection_history.clear();
626 self.reference_window.clear();
627 self.drift_state = DriftState::Stable;
628 self.last_detection = None;
629 self.sensitivity_factor = A::one();
630
631 for test in self.statistical_tests.values_mut() {
633 test.reset();
634 }
635
636 for detector in self.model_detectors.values_mut() {
637 detector.reset_model()?;
638 }
639
640 Ok(())
641 }
642
643 pub fn get_diagnostics(&self) -> DriftDiagnostics {
645 DriftDiagnostics {
646 current_state: self.drift_state.clone(),
647 detection_count: self.detection_history.len(),
648 false_positive_rate: self
649 .false_positive_tracker
650 .current_fp_rate
651 .to_f64()
652 .unwrap_or(0.0),
653 sensitivity_factor: self.sensitivity_factor.to_f64().unwrap_or(1.0),
654 last_detection_time: self.last_detection,
655 reference_window_size: self.reference_window.len(),
656 }
657 }
658}
659
660impl<A: Float + Send + Sync + Send + Sync> FalsePositiveTracker<A> {
661 fn new() -> Self {
662 Self {
663 false_positives: VecDeque::new(),
664 true_positives: VecDeque::new(),
665 current_fp_rate: A::zero(),
666 target_fp_rate: A::from(0.05).expect("unwrap failed"),
667 }
668 }
669
670 fn record_detection(&mut self, is_true_positive: bool) -> Result<(), String> {
671 let now = Instant::now();
672
673 if is_true_positive {
674 self.true_positives.push_back(now);
675 } else {
676 self.false_positives.push_back(now);
677 }
678
679 let cutoff = now - Duration::from_secs(3600);
681 self.false_positives.retain(|&time| time > cutoff);
682 self.true_positives.retain(|&time| time > cutoff);
683
684 let total_detections = self.false_positives.len() + self.true_positives.len();
686 if total_detections > 0 {
687 self.current_fp_rate = A::from(self.false_positives.len()).expect("unwrap failed")
688 / A::from(total_detections).expect("unwrap failed");
689 }
690
691 Ok(())
692 }
693}
694
695#[derive(Debug, Clone)]
697pub struct DriftDiagnostics {
698 pub current_state: DriftState,
699 pub detection_count: usize,
700 pub false_positive_rate: f64,
701 pub sensitivity_factor: f64,
702 pub last_detection_time: Option<Instant>,
703 pub reference_window_size: usize,
704}
705
706struct ADWINTest<A: Float + Send + Sync> {
710 sensitivity: A,
711 window: VecDeque<A>,
712}
713
714impl<A: Float + Default + Clone + Send + Sync + std::iter::Sum> ADWINTest<A> {
715 fn new(sensitivity: f64) -> Result<Self, String> {
716 Ok(Self {
717 sensitivity: A::from(sensitivity).expect("unwrap failed"),
718 window: VecDeque::new(),
719 })
720 }
721}
722
723impl<A: Float + Default + Clone + Send + Sync + std::iter::Sum> StatisticalTest<A>
724 for ADWINTest<A>
725{
726 fn test_for_drift(
727 &mut self,
728 reference: &[A],
729 current: &[A],
730 ) -> Result<DriftTestResult<A>, String> {
731 let ref_mean =
733 reference.iter().cloned().sum::<A>() / A::from(reference.len()).expect("unwrap failed");
734 let cur_mean =
735 current.iter().cloned().sum::<A>() / A::from(current.len()).expect("unwrap failed");
736
737 let difference = (ref_mean - cur_mean).abs();
738 let threshold = self.sensitivity;
739
740 let drift_detected = difference > threshold;
741
742 Ok(DriftTestResult {
743 drift_detected,
744 p_value: if drift_detected {
745 A::from(0.01).expect("unwrap failed")
746 } else {
747 A::from(0.5).expect("unwrap failed")
748 },
749 test_statistic: difference,
750 confidence: if drift_detected {
751 A::from(0.9).expect("unwrap failed")
752 } else {
753 A::from(0.1).expect("unwrap failed")
754 },
755 metadata: HashMap::new(),
756 })
757 }
758
759 fn update_parameters(&mut self, _performance_feedback: A) -> Result<(), String> {
760 Ok(())
761 }
762
763 fn reset(&mut self) {
764 self.window.clear();
765 }
766}
767
768struct DDMTest<A: Float + Send + Sync> {
769 sensitivity: A,
770 error_rate: A,
771 std_dev: A,
772}
773
774impl<A: Float + Default + Send + Sync + std::iter::Sum> DDMTest<A> {
775 fn new(sensitivity: f64) -> Result<Self, String> {
776 Ok(Self {
777 sensitivity: A::from(sensitivity).expect("unwrap failed"),
778 error_rate: A::zero(),
779 std_dev: A::zero(),
780 })
781 }
782}
783
784impl<A: Float + Default + Clone + Send + Sync + std::iter::Sum> StatisticalTest<A> for DDMTest<A> {
785 fn test_for_drift(
786 &mut self,
787 reference: &[A],
788 current: &[A],
789 ) -> Result<DriftTestResult<A>, String> {
790 let ref_mean =
792 reference.iter().cloned().sum::<A>() / A::from(reference.len()).expect("unwrap failed");
793 let cur_mean =
794 current.iter().cloned().sum::<A>() / A::from(current.len()).expect("unwrap failed");
795
796 let difference = (ref_mean - cur_mean).abs();
797 let drift_detected = difference > self.sensitivity;
798
799 Ok(DriftTestResult {
800 drift_detected,
801 p_value: if drift_detected {
802 A::from(0.02).expect("unwrap failed")
803 } else {
804 A::from(0.6).expect("unwrap failed")
805 },
806 test_statistic: difference,
807 confidence: if drift_detected {
808 A::from(0.85).expect("unwrap failed")
809 } else {
810 A::from(0.15).expect("unwrap failed")
811 },
812 metadata: HashMap::new(),
813 })
814 }
815
816 fn update_parameters(&mut self, _performance_feedback: A) -> Result<(), String> {
817 Ok(())
818 }
819
820 fn reset(&mut self) {
821 self.error_rate = A::zero();
822 self.std_dev = A::zero();
823 }
824}
825
826struct PageHinkleyTest<A: Float + Send + Sync> {
827 sensitivity: A,
828 cumulative_sum: A,
829}
830
831impl<A: Float + Default + Send + Sync + std::iter::Sum> PageHinkleyTest<A> {
832 fn new(sensitivity: f64) -> Result<Self, String> {
833 Ok(Self {
834 sensitivity: A::from(sensitivity).expect("unwrap failed"),
835 cumulative_sum: A::zero(),
836 })
837 }
838}
839
840impl<A: Float + Default + Clone + Send + Sync + std::iter::Sum> StatisticalTest<A>
841 for PageHinkleyTest<A>
842{
843 fn test_for_drift(
844 &mut self,
845 reference: &[A],
846 current: &[A],
847 ) -> Result<DriftTestResult<A>, String> {
848 let ref_mean =
850 reference.iter().cloned().sum::<A>() / A::from(reference.len()).expect("unwrap failed");
851 let cur_mean =
852 current.iter().cloned().sum::<A>() / A::from(current.len()).expect("unwrap failed");
853
854 let difference = cur_mean - ref_mean;
855 self.cumulative_sum = self.cumulative_sum + difference;
856
857 let drift_detected = self.cumulative_sum.abs() > self.sensitivity;
858
859 Ok(DriftTestResult {
860 drift_detected,
861 p_value: if drift_detected {
862 A::from(0.015).expect("unwrap failed")
863 } else {
864 A::from(0.7).expect("unwrap failed")
865 },
866 test_statistic: self.cumulative_sum,
867 confidence: if drift_detected {
868 A::from(0.88).expect("unwrap failed")
869 } else {
870 A::from(0.12).expect("unwrap failed")
871 },
872 metadata: HashMap::new(),
873 })
874 }
875
876 fn update_parameters(&mut self, _performance_feedback: A) -> Result<(), String> {
877 Ok(())
878 }
879
880 fn reset(&mut self) {
881 self.cumulative_sum = A::zero();
882 }
883}
884
885struct KLDivergenceComparator<A: Float + Send + Sync> {
886 threshold: A,
887}
888
889impl<A: Float + Send + Sync + Send + Sync> KLDivergenceComparator<A> {
890 fn new(sensitivity: f64) -> Result<Self, String> {
891 Ok(Self {
892 threshold: A::from(sensitivity).expect("unwrap failed"),
893 })
894 }
895}
896
897impl<A: Float + Default + Clone + Send + Sync + std::iter::Sum> DistributionComparator<A>
898 for KLDivergenceComparator<A>
899{
900 fn compare_distributions(
901 &self,
902 reference: &[A],
903 current: &[A],
904 ) -> Result<DistributionComparison<A>, String> {
905 let ref_mean =
907 reference.iter().cloned().sum::<A>() / A::from(reference.len()).expect("unwrap failed");
908 let cur_mean =
909 current.iter().cloned().sum::<A>() / A::from(current.len()).expect("unwrap failed");
910
911 let distance = (ref_mean - cur_mean).abs();
912 let drift_detected = distance > self.threshold;
913
914 Ok(DistributionComparison {
915 distance,
916 threshold: self.threshold,
917 drift_detected,
918 confidence: if drift_detected {
919 A::from(0.8).expect("unwrap failed")
920 } else {
921 A::from(0.2).expect("unwrap failed")
922 },
923 })
924 }
925
926 fn get_threshold(&self) -> A {
927 self.threshold
928 }
929
930 fn update_threshold(&mut self, new_threshold: A) {
931 self.threshold = new_threshold;
932 }
933}
934
935struct JSDivergenceComparator<A: Float + Send + Sync> {
936 threshold: A,
937}
938
939impl<A: Float + Send + Sync + Send + Sync> JSDivergenceComparator<A> {
940 fn new(sensitivity: f64) -> Result<Self, String> {
941 Ok(Self {
942 threshold: A::from(sensitivity).expect("unwrap failed"),
943 })
944 }
945}
946
947impl<A: Float + Default + Clone + Send + Sync + std::iter::Sum> DistributionComparator<A>
948 for JSDivergenceComparator<A>
949{
950 fn compare_distributions(
951 &self,
952 reference: &[A],
953 current: &[A],
954 ) -> Result<DistributionComparison<A>, String> {
955 let ref_mean =
957 reference.iter().cloned().sum::<A>() / A::from(reference.len()).expect("unwrap failed");
958 let cur_mean =
959 current.iter().cloned().sum::<A>() / A::from(current.len()).expect("unwrap failed");
960
961 let distance = (ref_mean - cur_mean).abs() * A::from(0.5).expect("unwrap failed"); let drift_detected = distance > self.threshold;
963
964 Ok(DistributionComparison {
965 distance,
966 threshold: self.threshold,
967 drift_detected,
968 confidence: if drift_detected {
969 A::from(0.75).expect("unwrap failed")
970 } else {
971 A::from(0.25).expect("unwrap failed")
972 },
973 })
974 }
975
976 fn get_threshold(&self) -> A {
977 self.threshold
978 }
979
980 fn update_threshold(&mut self, new_threshold: A) {
981 self.threshold = new_threshold;
982 }
983}
984
985struct LinearModelDetector<A: Float + Send + Sync> {
986 model_performance: A,
987 baseline_performance: A,
988}
989
990impl<A: Float + Default + Send + Sync + Send + Sync> LinearModelDetector<A> {
991 fn new() -> Result<Self, String> {
992 Ok(Self {
993 model_performance: A::zero(),
994 baseline_performance: A::zero(),
995 })
996 }
997}
998
999impl<A: Float + Default + Clone + Send + Sync + std::iter::Sum> ModelBasedDetector<A>
1000 for LinearModelDetector<A>
1001{
1002 fn update_model(&mut self, _data: &[StreamingDataPoint<A>]) -> Result<(), String> {
1003 Ok(())
1005 }
1006
1007 fn detect_drift(
1008 &mut self,
1009 _data: &[StreamingDataPoint<A>],
1010 ) -> Result<ModelDriftResult<A>, String> {
1011 let performance_degradation = self.baseline_performance - self.model_performance;
1013 let drift_detected = performance_degradation > A::from(0.1).expect("unwrap failed");
1014
1015 Ok(ModelDriftResult {
1016 drift_detected,
1017 performance_degradation,
1018 confidence: if drift_detected {
1019 A::from(0.7).expect("unwrap failed")
1020 } else {
1021 A::from(0.3).expect("unwrap failed")
1022 },
1023 feature_importance_changes: Vec::new(),
1024 })
1025 }
1026
1027 fn reset_model(&mut self) -> Result<(), String> {
1028 self.model_performance = A::zero();
1029 self.baseline_performance = A::zero();
1030 Ok(())
1031 }
1032}