quantrs2_ml/anomaly_detection/
streaming.rs1use crate::error::{MLError, Result};
4use scirs2_core::ndarray::{Array1, Array2, Axis};
5use std::collections::VecDeque;
6
7use super::config::*;
8use super::core::QuantumAnomalyDetector;
9
10pub struct StreamingAnomalyDetector {
12 detector: QuantumAnomalyDetector,
13 buffer: VecDeque<Array1<f64>>,
14 config: RealtimeConfig,
15 drift_detector: Option<DriftDetector>,
16}
17
18#[derive(Debug)]
20pub struct DriftDetector {
21 threshold: f64,
22 warning_threshold: f64,
23 drift_detected: bool,
24 warning_detected: bool,
25 error_rate_history: VecDeque<f64>,
26}
27
28impl StreamingAnomalyDetector {
29 pub fn new(detector: QuantumAnomalyDetector, config: RealtimeConfig) -> Self {
31 let drift_detector = if config.drift_detection {
32 Some(DriftDetector::new(0.05, 0.02))
33 } else {
34 None
35 };
36
37 StreamingAnomalyDetector {
38 detector,
39 buffer: VecDeque::with_capacity(config.buffer_size),
40 config,
41 drift_detector,
42 }
43 }
44
45 pub fn process_sample(&mut self, sample: Array1<f64>) -> Result<f64> {
47 self.buffer.push_back(sample.clone());
49
50 while self.buffer.len() > self.config.buffer_size {
52 self.buffer.pop_front();
53 }
54
55 if self.buffer.len() >= self.config.buffer_size / 2 {
57 let data = self.buffer_to_array()?;
58 let result = self.detector.detect(&data)?;
59 let anomaly_score = result.anomaly_scores[result.anomaly_scores.len() - 1];
60
61 if let Some(ref mut drift_detector) = self.drift_detector {
63 let is_anomaly = result.anomaly_labels[result.anomaly_labels.len() - 1] == 1;
64 drift_detector.update(is_anomaly);
65
66 if drift_detector.is_drift_detected() {
67 drift_detector.reset();
69 }
70 }
71
72 return Ok(anomaly_score);
73 }
74
75 Ok(0.0)
77 }
78
79 pub fn process_batch(&mut self, batch: &Array2<f64>) -> Result<Array1<f64>> {
81 let mut scores = Array1::zeros(batch.nrows());
82
83 for (i, sample) in batch.outer_iter().enumerate() {
84 scores[i] = self.process_sample(sample.to_owned())?;
85 }
86
87 Ok(scores)
88 }
89
90 pub fn update_detector(
92 &mut self,
93 data: &Array2<f64>,
94 labels: Option<&Array1<i32>>,
95 ) -> Result<()> {
96 if self.config.online_learning {
97 self.detector.update(data, labels)?;
98 }
99 Ok(())
100 }
101
102 pub fn is_drift_detected(&self) -> bool {
104 self.drift_detector
105 .as_ref()
106 .map(|d| d.is_drift_detected())
107 .unwrap_or(false)
108 }
109
110 pub fn buffer_size(&self) -> usize {
112 self.buffer.len()
113 }
114
115 pub fn clear_buffer(&mut self) {
117 self.buffer.clear();
118 }
119
120 fn buffer_to_array(&self) -> Result<Array2<f64>> {
123 if self.buffer.is_empty() {
124 return Err(MLError::DataError("Buffer is empty".to_string()));
125 }
126
127 let n_samples = self.buffer.len();
128 let n_features = self.buffer[0].len();
129
130 let data = Array2::from_shape_vec(
131 (n_samples, n_features),
132 self.buffer.iter().flat_map(|s| s.iter().cloned()).collect(),
133 )
134 .map_err(|e| MLError::DataError(e.to_string()))?;
135
136 Ok(data)
137 }
138}
139
140impl DriftDetector {
141 pub fn new(drift_threshold: f64, warning_threshold: f64) -> Self {
143 DriftDetector {
144 threshold: drift_threshold,
145 warning_threshold,
146 drift_detected: false,
147 warning_detected: false,
148 error_rate_history: VecDeque::with_capacity(1000),
149 }
150 }
151
152 pub fn update(&mut self, is_error: bool) {
154 let error_rate = if is_error { 1.0 } else { 0.0 };
155 self.error_rate_history.push_back(error_rate);
156
157 while self.error_rate_history.len() > 100 {
159 self.error_rate_history.pop_front();
160 }
161
162 let avg_error_rate =
164 self.error_rate_history.iter().sum::<f64>() / self.error_rate_history.len() as f64;
165
166 if avg_error_rate > self.warning_threshold {
168 self.warning_detected = true;
169 }
170
171 if avg_error_rate > self.threshold {
173 self.drift_detected = true;
174 }
175 }
176
177 pub fn is_drift_detected(&self) -> bool {
179 self.drift_detected
180 }
181
182 pub fn is_warning_detected(&self) -> bool {
184 self.warning_detected
185 }
186
187 pub fn reset(&mut self) {
189 self.drift_detected = false;
190 self.warning_detected = false;
191 self.error_rate_history.clear();
192 }
193
194 pub fn current_error_rate(&self) -> f64 {
196 if self.error_rate_history.is_empty() {
197 0.0
198 } else {
199 self.error_rate_history.iter().sum::<f64>() / self.error_rate_history.len() as f64
200 }
201 }
202}