quantrs2_ml/anomaly_detection/
streaming.rs

1//! Real-time streaming anomaly detection
2
3use crate::error::{MLError, Result};
4use scirs2_core::ndarray::{Array1, Array2, Axis};
5use std::collections::VecDeque;
6
7use super::config::*;
8use super::core::QuantumAnomalyDetector;
9
10/// Streaming anomaly detector for real-time processing
11pub struct StreamingAnomalyDetector {
12    detector: QuantumAnomalyDetector,
13    buffer: VecDeque<Array1<f64>>,
14    config: RealtimeConfig,
15    drift_detector: Option<DriftDetector>,
16}
17
18/// Drift detection for concept drift in streaming data
19#[derive(Debug)]
20pub struct DriftDetector {
21    threshold: f64,
22    warning_threshold: f64,
23    drift_detected: bool,
24    warning_detected: bool,
25    error_rate_history: VecDeque<f64>,
26}
27
28impl StreamingAnomalyDetector {
29    /// Create new streaming anomaly detector
30    pub fn new(detector: QuantumAnomalyDetector, config: RealtimeConfig) -> Self {
31        let drift_detector = if config.drift_detection {
32            Some(DriftDetector::new(0.05, 0.02))
33        } else {
34            None
35        };
36
37        StreamingAnomalyDetector {
38            detector,
39            buffer: VecDeque::with_capacity(config.buffer_size),
40            config,
41            drift_detector,
42        }
43    }
44
45    /// Process a single sample
46    pub fn process_sample(&mut self, sample: Array1<f64>) -> Result<f64> {
47        // Add sample to buffer
48        self.buffer.push_back(sample.clone());
49
50        // Remove old samples if buffer is full
51        while self.buffer.len() > self.config.buffer_size {
52            self.buffer.pop_front();
53        }
54
55        // Detect anomaly if buffer has enough samples
56        if self.buffer.len() >= self.config.buffer_size / 2 {
57            let data = self.buffer_to_array()?;
58            let result = self.detector.detect(&data)?;
59            let anomaly_score = result.anomaly_scores[result.anomaly_scores.len() - 1];
60
61            // Check for drift if enabled
62            if let Some(ref mut drift_detector) = self.drift_detector {
63                let is_anomaly = result.anomaly_labels[result.anomaly_labels.len() - 1] == 1;
64                drift_detector.update(is_anomaly);
65
66                if drift_detector.is_drift_detected() {
67                    // Handle drift (could retrain model here)
68                    drift_detector.reset();
69                }
70            }
71
72            return Ok(anomaly_score);
73        }
74
75        // Not enough data for detection
76        Ok(0.0)
77    }
78
79    /// Process a batch of samples
80    pub fn process_batch(&mut self, batch: &Array2<f64>) -> Result<Array1<f64>> {
81        let mut scores = Array1::zeros(batch.nrows());
82
83        for (i, sample) in batch.outer_iter().enumerate() {
84            scores[i] = self.process_sample(sample.to_owned())?;
85        }
86
87        Ok(scores)
88    }
89
90    /// Update the underlying detector with new labeled data
91    pub fn update_detector(
92        &mut self,
93        data: &Array2<f64>,
94        labels: Option<&Array1<i32>>,
95    ) -> Result<()> {
96        if self.config.online_learning {
97            self.detector.update(data, labels)?;
98        }
99        Ok(())
100    }
101
102    /// Check if drift is detected
103    pub fn is_drift_detected(&self) -> bool {
104        self.drift_detector
105            .as_ref()
106            .map(|d| d.is_drift_detected())
107            .unwrap_or(false)
108    }
109
110    /// Get buffer size
111    pub fn buffer_size(&self) -> usize {
112        self.buffer.len()
113    }
114
115    /// Clear the buffer
116    pub fn clear_buffer(&mut self) {
117        self.buffer.clear();
118    }
119
120    // Helper methods
121
122    fn buffer_to_array(&self) -> Result<Array2<f64>> {
123        if self.buffer.is_empty() {
124            return Err(MLError::DataError("Buffer is empty".to_string()));
125        }
126
127        let n_samples = self.buffer.len();
128        let n_features = self.buffer[0].len();
129
130        let data = Array2::from_shape_vec(
131            (n_samples, n_features),
132            self.buffer.iter().flat_map(|s| s.iter().cloned()).collect(),
133        )
134        .map_err(|e| MLError::DataError(e.to_string()))?;
135
136        Ok(data)
137    }
138}
139
140impl DriftDetector {
141    /// Create new drift detector
142    pub fn new(drift_threshold: f64, warning_threshold: f64) -> Self {
143        DriftDetector {
144            threshold: drift_threshold,
145            warning_threshold,
146            drift_detected: false,
147            warning_detected: false,
148            error_rate_history: VecDeque::with_capacity(1000),
149        }
150    }
151
152    /// Update with a new prediction result
153    pub fn update(&mut self, is_error: bool) {
154        let error_rate = if is_error { 1.0 } else { 0.0 };
155        self.error_rate_history.push_back(error_rate);
156
157        // Keep only recent history
158        while self.error_rate_history.len() > 100 {
159            self.error_rate_history.pop_front();
160        }
161
162        // Calculate moving average error rate
163        let avg_error_rate =
164            self.error_rate_history.iter().sum::<f64>() / self.error_rate_history.len() as f64;
165
166        // Check for warning level
167        if avg_error_rate > self.warning_threshold {
168            self.warning_detected = true;
169        }
170
171        // Check for drift
172        if avg_error_rate > self.threshold {
173            self.drift_detected = true;
174        }
175    }
176
177    /// Check if drift is detected
178    pub fn is_drift_detected(&self) -> bool {
179        self.drift_detected
180    }
181
182    /// Check if warning is detected
183    pub fn is_warning_detected(&self) -> bool {
184        self.warning_detected
185    }
186
187    /// Reset the drift detector
188    pub fn reset(&mut self) {
189        self.drift_detected = false;
190        self.warning_detected = false;
191        self.error_rate_history.clear();
192    }
193
194    /// Get current error rate
195    pub fn current_error_rate(&self) -> f64 {
196        if self.error_rate_history.is_empty() {
197            0.0
198        } else {
199            self.error_rate_history.iter().sum::<f64>() / self.error_rate_history.len() as f64
200        }
201    }
202}