Skip to main content

lean_ctx/core/
anomaly.rs

1//! Anomaly detection using Welford's online algorithm for running
2//! mean/variance and triggering alerts at >3x standard deviation.
3
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::sync::{Mutex, OnceLock};
7
8const DEFAULT_WINDOW: usize = 50;
9const DEFAULT_DEVIATION_THRESHOLD: f64 = 3.0;
10const MIN_SAMPLES: usize = 10;
11
12// ---------------------------------------------------------------------------
13// Welford online statistics
14// ---------------------------------------------------------------------------
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct WelfordState {
18    pub count: u64,
19    pub mean: f64,
20    pub m2: f64,
21    #[serde(default = "default_window")]
22    window_values: Vec<f64>,
23    #[serde(default = "default_window_size")]
24    window_size: usize,
25}
26
27fn default_window() -> Vec<f64> {
28    Vec::new()
29}
30
31fn default_window_size() -> usize {
32    DEFAULT_WINDOW
33}
34
35impl Default for WelfordState {
36    fn default() -> Self {
37        Self {
38            count: 0,
39            mean: 0.0,
40            m2: 0.0,
41            window_values: Vec::new(),
42            window_size: DEFAULT_WINDOW,
43        }
44    }
45}
46
47impl WelfordState {
48    pub fn with_window(size: usize) -> Self {
49        Self {
50            window_size: size,
51            ..Default::default()
52        }
53    }
54
55    pub fn update(&mut self, value: f64) {
56        self.count += 1;
57        let delta = value - self.mean;
58        self.mean += delta / self.count as f64;
59        let delta2 = value - self.mean;
60        self.m2 += delta * delta2;
61
62        self.window_values.push(value);
63        if self.window_values.len() > self.window_size {
64            self.window_values.remove(0);
65        }
66    }
67
68    pub fn variance(&self) -> f64 {
69        if self.count < 2 {
70            return 0.0;
71        }
72        self.m2 / (self.count - 1) as f64
73    }
74
75    pub fn std_dev(&self) -> f64 {
76        self.variance().sqrt()
77    }
78
79    pub fn windowed_mean(&self) -> f64 {
80        if self.window_values.is_empty() {
81            return self.mean;
82        }
83        let sum: f64 = self.window_values.iter().sum();
84        sum / self.window_values.len() as f64
85    }
86
87    pub fn windowed_std_dev(&self) -> f64 {
88        if self.window_values.len() < 2 {
89            return self.std_dev();
90        }
91        let mean = self.windowed_mean();
92        let variance: f64 = self
93            .window_values
94            .iter()
95            .map(|v| (v - mean).powi(2))
96            .sum::<f64>()
97            / (self.window_values.len() - 1) as f64;
98        variance.sqrt()
99    }
100
101    pub fn has_enough_data(&self) -> bool {
102        self.count as usize >= MIN_SAMPLES
103    }
104}
105
106// ---------------------------------------------------------------------------
107// Anomaly detector
108// ---------------------------------------------------------------------------
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct AnomalyDetector {
112    pub metrics: HashMap<String, WelfordState>,
113    #[serde(default = "default_threshold")]
114    pub deviation_threshold: f64,
115}
116
117fn default_threshold() -> f64 {
118    DEFAULT_DEVIATION_THRESHOLD
119}
120
121impl Default for AnomalyDetector {
122    fn default() -> Self {
123        Self {
124            metrics: HashMap::new(),
125            deviation_threshold: DEFAULT_DEVIATION_THRESHOLD,
126        }
127    }
128}
129
130#[derive(Debug, Clone, Serialize)]
131pub struct AnomalyAlert {
132    pub metric: String,
133    pub expected: f64,
134    pub actual: f64,
135    pub std_dev: f64,
136    pub deviation_factor: f64,
137}
138
139impl AnomalyDetector {
140    pub fn record(&mut self, metric: &str, value: f64) -> Option<AnomalyAlert> {
141        let state = self
142            .metrics
143            .entry(metric.to_string())
144            .or_insert_with(|| WelfordState::with_window(DEFAULT_WINDOW));
145
146        let alert = if state.has_enough_data() {
147            let expected = state.windowed_mean();
148            let sd = state.windowed_std_dev();
149
150            if sd > 0.0 {
151                let deviation = (value - expected).abs() / sd;
152                if deviation > self.deviation_threshold {
153                    Some(AnomalyAlert {
154                        metric: metric.to_string(),
155                        expected,
156                        actual: value,
157                        std_dev: sd,
158                        deviation_factor: deviation,
159                    })
160                } else {
161                    None
162                }
163            } else {
164                None
165            }
166        } else {
167            None
168        };
169
170        state.update(value);
171        alert
172    }
173
174    pub fn summary(&self) -> Vec<MetricSummary> {
175        let mut out: Vec<MetricSummary> = self
176            .metrics
177            .iter()
178            .map(|(name, state)| MetricSummary {
179                metric: name.clone(),
180                count: state.count,
181                mean: state.windowed_mean(),
182                std_dev: state.windowed_std_dev(),
183                last_value: state.window_values.last().copied().unwrap_or(0.0),
184            })
185            .collect();
186        out.sort_by_key(|s| s.metric.clone());
187        out
188    }
189
190    pub fn save(&self) {
191        if let Ok(dir) = crate::core::data_dir::lean_ctx_data_dir() {
192            let path = dir.join("anomaly_detector.json");
193            if let Ok(json) = serde_json::to_string(self) {
194                let _ = std::fs::write(path, json);
195            }
196        }
197    }
198
199    pub fn load() -> Self {
200        crate::core::data_dir::lean_ctx_data_dir()
201            .ok()
202            .map(|d| d.join("anomaly_detector.json"))
203            .and_then(|p| std::fs::read_to_string(p).ok())
204            .and_then(|s| serde_json::from_str(&s).ok())
205            .unwrap_or_default()
206    }
207}
208
209#[derive(Debug, Clone, Serialize)]
210pub struct MetricSummary {
211    pub metric: String,
212    pub count: u64,
213    pub mean: f64,
214    pub std_dev: f64,
215    pub last_value: f64,
216}
217
218// ---------------------------------------------------------------------------
219// Global singleton
220// ---------------------------------------------------------------------------
221
222static DETECTOR: OnceLock<Mutex<AnomalyDetector>> = OnceLock::new();
223
224fn global_detector() -> &'static Mutex<AnomalyDetector> {
225    DETECTOR.get_or_init(|| Mutex::new(AnomalyDetector::load()))
226}
227
228pub fn record_metric(metric: &str, value: f64) -> Option<AnomalyAlert> {
229    let mut det = global_detector()
230        .lock()
231        .unwrap_or_else(std::sync::PoisonError::into_inner);
232    let alert = det.record(metric, value);
233
234    if let Some(ref a) = alert {
235        crate::core::events::emit_anomaly(&a.metric, a.expected, a.actual, a.deviation_factor);
236    }
237
238    alert
239}
240
241pub fn summary() -> Vec<MetricSummary> {
242    global_detector()
243        .lock()
244        .map(|d| d.summary())
245        .unwrap_or_default()
246}
247
248pub fn save() {
249    if let Ok(d) = global_detector().lock() {
250        d.save();
251    }
252}
253
254/// Debounced save: skips if less than 3s since last save.
255/// Use in hot paths (per-tool-call) to avoid excessive I/O.
256pub fn save_debounced() {
257    use std::sync::atomic::{AtomicU64, Ordering};
258    use std::time::{SystemTime, UNIX_EPOCH};
259
260    static LAST_SAVE_MS: AtomicU64 = AtomicU64::new(0);
261    let now_ms = SystemTime::now()
262        .duration_since(UNIX_EPOCH)
263        .map_or(0, |d| d.as_millis() as u64);
264    let prev = LAST_SAVE_MS.load(Ordering::Relaxed);
265    if prev != 0 && now_ms.saturating_sub(prev) < 3000 {
266        return;
267    }
268    LAST_SAVE_MS.store(now_ms, Ordering::Relaxed);
269    save();
270}
271
272// ---------------------------------------------------------------------------
273// Tests
274// ---------------------------------------------------------------------------
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    #[test]
281    fn welford_basic_stats() {
282        let mut w = WelfordState::default();
283        for v in [2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0] {
284            w.update(v);
285        }
286        assert!((w.mean - 5.0).abs() < 0.01);
287        // Sample variance (n-1): 32/7 ≈ 4.571
288        assert!((w.variance() - 4.571).abs() < 0.01);
289        assert!((w.std_dev() - 2.138).abs() < 0.01);
290    }
291
292    #[test]
293    fn welford_window_limits() {
294        let mut w = WelfordState::with_window(5);
295        for i in 0..20 {
296            w.update(i as f64);
297        }
298        assert_eq!(w.window_values.len(), 5);
299        assert_eq!(w.window_values[0], 15.0);
300    }
301
302    #[test]
303    fn no_alert_with_few_samples() {
304        let mut det = AnomalyDetector::default();
305        for i in 0..5 {
306            assert!(det.record("test", i as f64).is_none());
307        }
308    }
309
310    #[test]
311    fn alert_on_extreme_value() {
312        let mut det = AnomalyDetector::default();
313        for i in 0..20 {
314            let v = 100.0 + (i % 5) as f64;
315            det.record("tokens", v);
316        }
317        let alert = det.record("tokens", 1000.0);
318        assert!(alert.is_some());
319        let a = alert.unwrap();
320        assert_eq!(a.metric, "tokens");
321        assert!(a.deviation_factor > 3.0);
322    }
323
324    #[test]
325    fn no_alert_on_normal_value() {
326        let mut det = AnomalyDetector::default();
327        for i in 0..20 {
328            let v = 100.0 + (i % 3) as f64;
329            assert!(det.record("tokens", v).is_none());
330        }
331    }
332
333    #[test]
334    fn summary_returns_all_metrics() {
335        let mut det = AnomalyDetector::default();
336        det.record("tokens", 100.0);
337        det.record("cost", 0.5);
338        det.record("tokens", 120.0);
339        let s = det.summary();
340        assert_eq!(s.len(), 2);
341    }
342
343    #[test]
344    fn global_record_works() {
345        let _ = record_metric("test_global", 42.0);
346    }
347}