use crate::error::{MLError, Result};
use scirs2_core::ndarray::{Array1, Array2, Axis};
use std::collections::VecDeque;
use super::config::*;
use super::core::QuantumAnomalyDetector;
pub struct StreamingAnomalyDetector {
detector: QuantumAnomalyDetector,
buffer: VecDeque<Array1<f64>>,
config: RealtimeConfig,
drift_detector: Option<DriftDetector>,
}
#[derive(Debug)]
pub struct DriftDetector {
threshold: f64,
warning_threshold: f64,
drift_detected: bool,
warning_detected: bool,
error_rate_history: VecDeque<f64>,
}
impl StreamingAnomalyDetector {
pub fn new(detector: QuantumAnomalyDetector, config: RealtimeConfig) -> Self {
let drift_detector = if config.drift_detection {
Some(DriftDetector::new(0.05, 0.02))
} else {
None
};
StreamingAnomalyDetector {
detector,
buffer: VecDeque::with_capacity(config.buffer_size),
config,
drift_detector,
}
}
pub fn process_sample(&mut self, sample: Array1<f64>) -> Result<f64> {
self.buffer.push_back(sample.clone());
while self.buffer.len() > self.config.buffer_size {
self.buffer.pop_front();
}
if self.buffer.len() >= self.config.buffer_size / 2 {
let data = self.buffer_to_array()?;
let result = self.detector.detect(&data)?;
let anomaly_score = result.anomaly_scores[result.anomaly_scores.len() - 1];
if let Some(ref mut drift_detector) = self.drift_detector {
let is_anomaly = result.anomaly_labels[result.anomaly_labels.len() - 1] == 1;
drift_detector.update(is_anomaly);
if drift_detector.is_drift_detected() {
drift_detector.reset();
}
}
return Ok(anomaly_score);
}
Ok(0.0)
}
pub fn process_batch(&mut self, batch: &Array2<f64>) -> Result<Array1<f64>> {
let mut scores = Array1::zeros(batch.nrows());
for (i, sample) in batch.outer_iter().enumerate() {
scores[i] = self.process_sample(sample.to_owned())?;
}
Ok(scores)
}
pub fn update_detector(
&mut self,
data: &Array2<f64>,
labels: Option<&Array1<i32>>,
) -> Result<()> {
if self.config.online_learning {
self.detector.update(data, labels)?;
}
Ok(())
}
pub fn is_drift_detected(&self) -> bool {
self.drift_detector
.as_ref()
.map(|d| d.is_drift_detected())
.unwrap_or(false)
}
pub fn buffer_size(&self) -> usize {
self.buffer.len()
}
pub fn clear_buffer(&mut self) {
self.buffer.clear();
}
fn buffer_to_array(&self) -> Result<Array2<f64>> {
if self.buffer.is_empty() {
return Err(MLError::DataError("Buffer is empty".to_string()));
}
let n_samples = self.buffer.len();
let n_features = self.buffer[0].len();
let data = Array2::from_shape_vec(
(n_samples, n_features),
self.buffer.iter().flat_map(|s| s.iter().cloned()).collect(),
)
.map_err(|e| MLError::DataError(e.to_string()))?;
Ok(data)
}
}
impl DriftDetector {
pub fn new(drift_threshold: f64, warning_threshold: f64) -> Self {
DriftDetector {
threshold: drift_threshold,
warning_threshold,
drift_detected: false,
warning_detected: false,
error_rate_history: VecDeque::with_capacity(1000),
}
}
pub fn update(&mut self, is_error: bool) {
let error_rate = if is_error { 1.0 } else { 0.0 };
self.error_rate_history.push_back(error_rate);
while self.error_rate_history.len() > 100 {
self.error_rate_history.pop_front();
}
let avg_error_rate =
self.error_rate_history.iter().sum::<f64>() / self.error_rate_history.len() as f64;
if avg_error_rate > self.warning_threshold {
self.warning_detected = true;
}
if avg_error_rate > self.threshold {
self.drift_detected = true;
}
}
pub fn is_drift_detected(&self) -> bool {
self.drift_detected
}
pub fn is_warning_detected(&self) -> bool {
self.warning_detected
}
pub fn reset(&mut self) {
self.drift_detected = false;
self.warning_detected = false;
self.error_rate_history.clear();
}
pub fn current_error_rate(&self) -> f64 {
if self.error_rate_history.is_empty() {
0.0
} else {
self.error_rate_history.iter().sum::<f64>() / self.error_rate_history.len() as f64
}
}
}