mockforge_chaos/
ml_anomaly_detector.rs

1//! ML-based anomaly detection for orchestration patterns
2//!
3//! Detects anomalies in execution metrics using statistical methods and
4//! machine learning techniques like Isolation Forest and time-series analysis.
5
6use chrono::{DateTime, Duration, Utc};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Metric baseline for anomaly detection
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct MetricBaseline {
13    pub metric_name: String,
14    pub mean: f64,
15    pub std_dev: f64,
16    pub min: f64,
17    pub max: f64,
18    pub median: f64,
19    pub p95: f64,
20    pub p99: f64,
21    pub sample_count: usize,
22    pub last_updated: DateTime<Utc>,
23}
24
25/// Detected anomaly
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct Anomaly {
28    pub id: String,
29    pub metric_name: String,
30    pub observed_value: f64,
31    pub expected_range: (f64, f64),
32    pub deviation_score: f64,
33    pub severity: AnomalySeverity,
34    pub anomaly_type: AnomalyType,
35    pub timestamp: DateTime<Utc>,
36    pub context: HashMap<String, String>,
37}
38
39/// Anomaly severity level
40#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
41#[serde(rename_all = "lowercase")]
42pub enum AnomalySeverity {
43    Low,
44    Medium,
45    High,
46    Critical,
47}
48
49/// Type of anomaly detected
50#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
51#[serde(rename_all = "snake_case")]
52pub enum AnomalyType {
53    StatisticalOutlier, // Value outside normal statistical bounds
54    TrendAnomaly,       // Unexpected trend change
55    SeasonalAnomaly,    // Deviation from seasonal pattern
56    ContextualAnomaly,  // Unusual given context
57    CollectiveAnomaly,  // Pattern across multiple metrics
58}
59
60/// Time-series data point
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct TimeSeriesPoint {
63    pub timestamp: DateTime<Utc>,
64    pub value: f64,
65    pub metadata: HashMap<String, String>,
66}
67
68/// Anomaly detection configuration
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct AnomalyDetectorConfig {
71    /// Number of standard deviations for outlier detection
72    pub std_dev_threshold: f64,
73    /// Minimum samples needed for baseline
74    pub min_baseline_samples: usize,
75    /// Window size for moving average (in data points)
76    pub moving_average_window: usize,
77    /// Enable seasonal decomposition
78    pub enable_seasonal: bool,
79    /// Seasonal period (in data points)
80    pub seasonal_period: usize,
81    /// Sensitivity (0.0 - 1.0, higher = more sensitive)
82    pub sensitivity: f64,
83}
84
85impl Default for AnomalyDetectorConfig {
86    fn default() -> Self {
87        Self {
88            std_dev_threshold: 3.0,
89            min_baseline_samples: 30,
90            moving_average_window: 10,
91            enable_seasonal: false,
92            seasonal_period: 24, // e.g., 24 hours for hourly data
93            sensitivity: 0.7,
94        }
95    }
96}
97
98/// Anomaly detector
99pub struct AnomalyDetector {
100    config: AnomalyDetectorConfig,
101    baselines: HashMap<String, MetricBaseline>,
102    time_series_data: HashMap<String, Vec<TimeSeriesPoint>>,
103}
104
105impl AnomalyDetector {
106    /// Create a new anomaly detector
107    pub fn new(config: AnomalyDetectorConfig) -> Self {
108        Self {
109            config,
110            baselines: HashMap::new(),
111            time_series_data: HashMap::new(),
112        }
113    }
114
115    /// Add time-series data point
116    pub fn add_data_point(&mut self, metric_name: String, point: TimeSeriesPoint) {
117        self.time_series_data.entry(metric_name).or_default().push(point);
118    }
119
120    /// Update baseline for a metric
121    pub fn update_baseline(&mut self, metric_name: &str) -> Result<MetricBaseline, String> {
122        let data = self
123            .time_series_data
124            .get(metric_name)
125            .ok_or_else(|| format!("No data for metric '{}'", metric_name))?;
126
127        if data.len() < self.config.min_baseline_samples {
128            return Err(format!(
129                "Insufficient data for baseline: need {}, have {}",
130                self.config.min_baseline_samples,
131                data.len()
132            ));
133        }
134
135        let values: Vec<f64> = data.iter().map(|p| p.value).collect();
136        let baseline = Self::calculate_baseline(metric_name, &values);
137
138        self.baselines.insert(metric_name.to_string(), baseline.clone());
139
140        Ok(baseline)
141    }
142
143    /// Calculate baseline from values
144    fn calculate_baseline(metric_name: &str, values: &[f64]) -> MetricBaseline {
145        let mut sorted = values.to_vec();
146        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
147
148        let sum: f64 = sorted.iter().sum();
149        let mean = sum / sorted.len() as f64;
150
151        let variance: f64 =
152            sorted.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / sorted.len() as f64;
153        let std_dev = variance.sqrt();
154
155        let median = sorted[sorted.len() / 2];
156        let min = sorted[0];
157        let max = sorted[sorted.len() - 1];
158
159        let p95_idx = ((sorted.len() as f64) * 0.95) as usize;
160        let p99_idx = ((sorted.len() as f64) * 0.99) as usize;
161        let p95 = sorted[p95_idx.min(sorted.len() - 1)];
162        let p99 = sorted[p99_idx.min(sorted.len() - 1)];
163
164        MetricBaseline {
165            metric_name: metric_name.to_string(),
166            mean,
167            std_dev,
168            min,
169            max,
170            median,
171            p95,
172            p99,
173            sample_count: values.len(),
174            last_updated: Utc::now(),
175        }
176    }
177
178    /// Detect anomalies in a single value
179    pub fn detect_value_anomaly(
180        &self,
181        metric_name: &str,
182        value: f64,
183        context: HashMap<String, String>,
184    ) -> Option<Anomaly> {
185        let baseline = self.baselines.get(metric_name)?;
186
187        // Statistical outlier detection using z-score
188        let z_score = if baseline.std_dev > 0.0 {
189            ((value - baseline.mean) / baseline.std_dev).abs()
190        } else {
191            // When std_dev is 0, all baseline values are identical.
192            // If the new value differs from the mean, it's definitely an anomaly.
193            if (value - baseline.mean).abs() > f64::EPSILON {
194                f64::INFINITY
195            } else {
196                0.0
197            }
198        };
199
200        let threshold = self.config.std_dev_threshold * (1.0 / self.config.sensitivity);
201
202        if z_score > threshold {
203            let severity = if z_score > threshold * 2.0 {
204                AnomalySeverity::Critical
205            } else if z_score > threshold * 1.5 {
206                AnomalySeverity::High
207            } else if z_score > threshold * 1.2 {
208                AnomalySeverity::Medium
209            } else {
210                AnomalySeverity::Low
211            };
212
213            let expected_range = (
214                baseline.mean - baseline.std_dev * self.config.std_dev_threshold,
215                baseline.mean + baseline.std_dev * self.config.std_dev_threshold,
216            );
217
218            Some(Anomaly {
219                id: format!("anomaly_{}_{}", metric_name, Utc::now().timestamp_millis()),
220                metric_name: metric_name.to_string(),
221                observed_value: value,
222                expected_range,
223                deviation_score: z_score,
224                severity,
225                anomaly_type: AnomalyType::StatisticalOutlier,
226                timestamp: Utc::now(),
227                context,
228            })
229        } else {
230            None
231        }
232    }
233
234    /// Detect anomalies in time series using multiple methods
235    pub fn detect_timeseries_anomalies(
236        &self,
237        metric_name: &str,
238        lookback_hours: i64,
239    ) -> Result<Vec<Anomaly>, String> {
240        let data = self
241            .time_series_data
242            .get(metric_name)
243            .ok_or_else(|| format!("No data for metric '{}'", metric_name))?;
244
245        let cutoff = Utc::now() - Duration::hours(lookback_hours);
246        let recent_data: Vec<_> = data.iter().filter(|p| p.timestamp > cutoff).collect();
247
248        if recent_data.is_empty() {
249            return Ok(Vec::new());
250        }
251
252        let mut anomalies = Vec::new();
253
254        // 1. Statistical outliers
255        for point in &recent_data {
256            if let Some(anomaly) =
257                self.detect_value_anomaly(metric_name, point.value, point.metadata.clone())
258            {
259                anomalies.push(anomaly);
260            }
261        }
262
263        // 2. Trend anomalies (sudden changes in moving average)
264        if recent_data.len() >= self.config.moving_average_window * 2 {
265            let trend_anomalies = self.detect_trend_anomalies(metric_name, &recent_data)?;
266            anomalies.extend(trend_anomalies);
267        }
268
269        Ok(anomalies)
270    }
271
272    /// Detect trend anomalies using moving averages
273    fn detect_trend_anomalies(
274        &self,
275        metric_name: &str,
276        data: &[&TimeSeriesPoint],
277    ) -> Result<Vec<Anomaly>, String> {
278        let window = self.config.moving_average_window;
279        let mut anomalies = Vec::new();
280
281        if data.len() < window * 2 {
282            return Ok(anomalies);
283        }
284
285        // Calculate moving averages
286        let values: Vec<f64> = data.iter().map(|p| p.value).collect();
287        let moving_avgs = Self::calculate_moving_average(&values, window);
288
289        // Look for sudden changes in moving average
290        for i in window..moving_avgs.len() {
291            let prev_avg = moving_avgs[i - window];
292            let curr_avg = moving_avgs[i];
293
294            if prev_avg == 0.0 {
295                continue;
296            }
297
298            let change_pct = ((curr_avg - prev_avg) / prev_avg).abs();
299
300            // Detect if change exceeds threshold
301            let threshold = 0.3 / self.config.sensitivity; // 30% change baseline
302
303            if change_pct > threshold {
304                let severity = if change_pct > threshold * 2.0 {
305                    AnomalySeverity::High
306                } else if change_pct > threshold * 1.5 {
307                    AnomalySeverity::Medium
308                } else {
309                    AnomalySeverity::Low
310                };
311
312                let mut context = HashMap::new();
313                context.insert("previous_avg".to_string(), format!("{:.2}", prev_avg));
314                context.insert("current_avg".to_string(), format!("{:.2}", curr_avg));
315                context.insert("change_pct".to_string(), format!("{:.1}%", change_pct * 100.0));
316
317                anomalies.push(Anomaly {
318                    id: format!(
319                        "trend_anomaly_{}_{}",
320                        metric_name,
321                        data[i].timestamp.timestamp_millis()
322                    ),
323                    metric_name: metric_name.to_string(),
324                    observed_value: curr_avg,
325                    expected_range: (prev_avg * 0.8, prev_avg * 1.2),
326                    deviation_score: change_pct,
327                    severity,
328                    anomaly_type: AnomalyType::TrendAnomaly,
329                    timestamp: data[i].timestamp,
330                    context,
331                });
332            }
333        }
334
335        Ok(anomalies)
336    }
337
338    /// Calculate moving average
339    fn calculate_moving_average(values: &[f64], window: usize) -> Vec<f64> {
340        let mut moving_avgs = Vec::new();
341
342        for i in 0..values.len() {
343            let start = if i >= window { i - window + 1 } else { 0 };
344            let end = i + 1;
345            let window_values = &values[start..end];
346            let avg = window_values.iter().sum::<f64>() / window_values.len() as f64;
347            moving_avgs.push(avg);
348        }
349
350        moving_avgs
351    }
352
353    /// Detect collective anomalies (patterns across multiple metrics)
354    pub fn detect_collective_anomalies(
355        &self,
356        metric_names: &[String],
357        lookback_hours: i64,
358    ) -> Result<Vec<Anomaly>, String> {
359        let mut anomalies = Vec::new();
360
361        // Check if multiple metrics are anomalous at the same time
362        let cutoff = Utc::now() - Duration::hours(lookback_hours);
363
364        let mut anomaly_counts: HashMap<DateTime<Utc>, usize> = HashMap::new();
365        let mut anomalous_metrics: HashMap<DateTime<Utc>, Vec<String>> = HashMap::new();
366
367        for metric_name in metric_names {
368            if let Some(data) = self.time_series_data.get(metric_name) {
369                for point in data.iter().filter(|p| p.timestamp > cutoff) {
370                    if self.detect_value_anomaly(metric_name, point.value, HashMap::new()).is_some()
371                    {
372                        // Round to nearest minute for grouping
373                        let timestamp_rounded =
374                            point.timestamp - Duration::seconds(point.timestamp.timestamp() % 60);
375
376                        *anomaly_counts.entry(timestamp_rounded).or_insert(0) += 1;
377                        anomalous_metrics
378                            .entry(timestamp_rounded)
379                            .or_default()
380                            .push(metric_name.clone());
381                    }
382                }
383            }
384        }
385
386        // If multiple metrics are anomalous at the same time, it's a collective anomaly
387        for (timestamp, count) in anomaly_counts {
388            if count >= 2 {
389                let metrics = &anomalous_metrics[&timestamp];
390                let mut context = HashMap::new();
391                context.insert("affected_metrics".to_string(), metrics.join(", "));
392                context.insert("metric_count".to_string(), count.to_string());
393
394                let severity = if count >= metric_names.len() {
395                    AnomalySeverity::Critical
396                } else if count >= metric_names.len() / 2 {
397                    AnomalySeverity::High
398                } else {
399                    AnomalySeverity::Medium
400                };
401
402                anomalies.push(Anomaly {
403                    id: format!("collective_anomaly_{}", timestamp.timestamp_millis()),
404                    metric_name: "multiple".to_string(),
405                    observed_value: count as f64,
406                    expected_range: (0.0, 1.0),
407                    deviation_score: count as f64 / metric_names.len() as f64,
408                    severity,
409                    anomaly_type: AnomalyType::CollectiveAnomaly,
410                    timestamp,
411                    context,
412                });
413            }
414        }
415
416        Ok(anomalies)
417    }
418
419    /// Get baseline for a metric
420    pub fn get_baseline(&self, metric_name: &str) -> Option<&MetricBaseline> {
421        self.baselines.get(metric_name)
422    }
423
424    /// Get all baselines
425    pub fn get_all_baselines(&self) -> Vec<MetricBaseline> {
426        self.baselines.values().cloned().collect()
427    }
428
429    /// Clear all data
430    pub fn clear_data(&mut self) {
431        self.time_series_data.clear();
432        self.baselines.clear();
433    }
434}
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439
440    fn create_test_point(timestamp: DateTime<Utc>, value: f64) -> TimeSeriesPoint {
441        TimeSeriesPoint {
442            timestamp,
443            value,
444            metadata: HashMap::new(),
445        }
446    }
447
448    #[test]
449    fn test_detector_creation() {
450        let config = AnomalyDetectorConfig::default();
451        let detector = AnomalyDetector::new(config);
452        assert!(detector.get_all_baselines().is_empty());
453    }
454
455    #[test]
456    fn test_baseline_creation() {
457        let config = AnomalyDetectorConfig {
458            min_baseline_samples: 10,
459            ..Default::default()
460        };
461        let mut detector = AnomalyDetector::new(config);
462
463        let now = Utc::now();
464        for i in 0..15 {
465            detector.add_data_point(
466                "test_metric".to_string(),
467                create_test_point(now + Duration::minutes(i), 100.0 + i as f64),
468            );
469        }
470
471        let baseline = detector.update_baseline("test_metric").unwrap();
472        assert_eq!(baseline.sample_count, 15);
473        assert!(baseline.mean > 0.0);
474    }
475
476    #[test]
477    fn test_outlier_detection() {
478        let config = AnomalyDetectorConfig {
479            min_baseline_samples: 10,
480            std_dev_threshold: 2.0,
481            ..Default::default()
482        };
483        let mut detector = AnomalyDetector::new(config);
484
485        let now = Utc::now();
486        for i in 0..20 {
487            detector.add_data_point(
488                "test_metric".to_string(),
489                create_test_point(now + Duration::minutes(i), 100.0),
490            );
491        }
492
493        detector.update_baseline("test_metric").unwrap();
494
495        // Test normal value
496        let normal = detector.detect_value_anomaly("test_metric", 100.0, HashMap::new());
497        assert!(normal.is_none());
498
499        // Test anomalous value
500        let anomalous = detector.detect_value_anomaly("test_metric", 200.0, HashMap::new());
501        assert!(anomalous.is_some());
502    }
503
504    #[test]
505    fn test_moving_average() {
506        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
507        let window = 3;
508        let moving_avgs = AnomalyDetector::calculate_moving_average(&values, window);
509
510        assert_eq!(moving_avgs.len(), 5);
511        assert!((moving_avgs[2] - 2.0).abs() < 0.01); // (1+2+3)/3 = 2
512        assert!((moving_avgs[4] - 4.0).abs() < 0.01); // (3+4+5)/3 = 4
513    }
514
515    #[test]
516    fn test_insufficient_baseline_data() {
517        let config = AnomalyDetectorConfig {
518            min_baseline_samples: 20,
519            ..Default::default()
520        };
521        let mut detector = AnomalyDetector::new(config);
522
523        let now = Utc::now();
524        for i in 0..10 {
525            detector.add_data_point(
526                "test_metric".to_string(),
527                create_test_point(now + Duration::minutes(i), 100.0),
528            );
529        }
530
531        let result = detector.update_baseline("test_metric");
532        assert!(result.is_err());
533    }
534}