use crate::error::{RusTorchError, RusTorchResult};
use std::collections::HashMap;
use std::fmt;
#[derive(Debug)]
pub struct AnomalyDetector {
config: AnomalyConfiguration,
stats: AnomalyStatistics,
}
#[derive(Debug, Clone)]
pub struct AnomalyConfiguration {
pub methods: Vec<StatisticalMethod>,
pub sensitivity: f64,
pub min_score_threshold: f64,
}
impl Default for AnomalyConfiguration {
fn default() -> Self {
Self {
methods: vec![StatisticalMethod::ZScore, StatisticalMethod::IQR],
sensitivity: 0.05, min_score_threshold: 0.8,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum StatisticalMethod {
ZScore,
IQR,
ModifiedZScore,
IsolationForest,
LOF,
}
#[derive(Debug, Clone)]
pub struct AnomalyResult {
pub anomalies_found: usize,
pub anomalies: Vec<AnomalyInfo>,
pub overall_score: f64,
pub methods_used: Vec<StatisticalMethod>,
}
#[derive(Debug, Clone)]
pub struct AnomalyInfo {
pub anomaly_type: AnomalyType,
pub confidence: f64,
pub score: f64,
pub location: Option<Vec<usize>>,
pub description: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AnomalyType {
StatisticalOutlier,
ExtremeValue,
PatternAnomaly,
ContextualAnomaly,
}
pub struct OutlierDetection;
impl OutlierDetection {
pub fn z_score_method(data: &[f64], threshold: f64) -> Vec<usize> {
if data.len() < 2 {
return Vec::new();
}
let mean = data.iter().sum::<f64>() / data.len() as f64;
let variance = data.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / data.len() as f64;
let std_dev = variance.sqrt();
if std_dev == 0.0 {
return Vec::new();
}
data.iter()
.enumerate()
.filter_map(|(i, &value)| {
let z_score = (value - mean).abs() / std_dev;
if z_score > threshold {
Some(i)
} else {
None
}
})
.collect()
}
pub fn iqr_method(data: &[f64]) -> Vec<usize> {
if data.len() < 4 {
return Vec::new();
}
let mut sorted_data = data.to_vec();
sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap());
let q1_idx = sorted_data.len() / 4;
let q3_idx = 3 * sorted_data.len() / 4;
let q1 = sorted_data[q1_idx];
let q3 = sorted_data[q3_idx];
let iqr = q3 - q1;
let lower_bound = q1 - 1.5 * iqr;
let upper_bound = q3 + 1.5 * iqr;
data.iter()
.enumerate()
.filter_map(|(i, &value)| {
if value < lower_bound || value > upper_bound {
Some(i)
} else {
None
}
})
.collect()
}
}
#[derive(Debug, Default)]
pub struct AnomalyStatistics {
pub total_detections: usize,
pub total_anomalies: usize,
pub anomalies_by_type: HashMap<AnomalyType, usize>,
}
impl AnomalyDetector {
pub fn new(config: AnomalyConfiguration) -> Self {
Self {
config,
stats: AnomalyStatistics::default(),
}
}
pub fn detect_anomalies<T>(
&mut self,
_tensor: &crate::tensor::Tensor<T>,
) -> RusTorchResult<AnomalyResult>
where
T: num_traits::Float + std::fmt::Debug + Clone + Send + Sync + 'static,
{
let dummy_data = vec![1.0, 2.0, 3.0, 100.0, 4.0, 5.0];
let mut all_anomalies = Vec::new();
let mut methods_used = Vec::new();
for method in &self.config.methods {
let outliers = match method {
StatisticalMethod::ZScore => {
methods_used.push(method.clone());
OutlierDetection::z_score_method(&dummy_data, 2.0)
}
StatisticalMethod::IQR => {
methods_used.push(method.clone());
OutlierDetection::iqr_method(&dummy_data)
}
_ => Vec::new(), };
for outlier_idx in outliers {
all_anomalies.push(AnomalyInfo {
anomaly_type: AnomalyType::StatisticalOutlier,
confidence: 0.9,
score: 0.95,
location: Some(vec![outlier_idx]),
description: format!("Statistical outlier detected using {:?}", method),
});
}
}
self.stats.total_detections += 1;
self.stats.total_anomalies += all_anomalies.len();
Ok(AnomalyResult {
anomalies_found: all_anomalies.len(),
anomalies: all_anomalies.clone(),
overall_score: if all_anomalies.is_empty() { 0.0 } else { 0.8 },
methods_used,
})
}
pub fn get_anomaly_count(&self) -> usize {
self.stats.total_anomalies
}
}