kaneru/predictive/
anomaly.rs1use crate::Observation;
7use std::collections::VecDeque;
8
9pub struct AnomalyDetector {
15 history: VecDeque<Vec<f64>>,
16 mean: Vec<f64>,
17 variance: Vec<f64>,
18 threshold: f64,
19 window_size: usize,
20}
21
22impl AnomalyDetector {
23 pub fn new(threshold: f64, window_size: usize) -> Self {
32 Self {
33 history: VecDeque::with_capacity(window_size),
34 mean: Vec::new(),
35 variance: Vec::new(),
36 threshold,
37 window_size,
38 }
39 }
40
41 pub fn update(&mut self, obs: &Observation) {
43 let features = self.extract_features(obs);
44
45 if self.history.len() >= self.window_size {
47 self.history.pop_front();
48 }
49 self.history.push_back(features.clone());
50
51 if self.history.len() >= 10 {
53 self.update_statistics();
54 }
55 }
56
57 pub fn is_anomaly(&self, obs: &Observation) -> bool {
59 if self.mean.is_empty() || self.history.len() < 10 {
60 return false;
62 }
63
64 let score = self.anomaly_score(obs);
65 score > self.threshold
66 }
67
68 pub fn anomaly_score(&self, obs: &Observation) -> f64 {
74 if self.mean.is_empty() {
75 return 0.0;
76 }
77
78 let features = self.extract_features(obs);
79 self.compute_zscore(&features)
80 }
81
82 fn compute_zscore(&self, features: &[f64]) -> f64 {
84 if self.mean.is_empty() || features.len() != self.mean.len() {
85 return 0.0;
86 }
87
88 let mut total_zscore = 0.0;
89 let mut count = 0;
90
91 for (i, &feature) in features.iter().enumerate().take(self.mean.len()) {
92 if self.variance[i] > 1e-10 {
93 let std_dev = self.variance[i].sqrt();
95 let zscore = ((feature - self.mean[i]) / std_dev).abs();
96 total_zscore += zscore;
97 count += 1;
98 }
99 }
100
101 if count > 0 {
102 total_zscore / count as f64
103 } else {
104 0.0
105 }
106 }
107
108 fn extract_features(&self, obs: &Observation) -> Vec<f64> {
110 let mut features = Vec::new();
111
112 if let Some(f) = obs.value.as_f64() {
114 features.push(f);
115 } else if let Some(i) = obs.value.as_i64() {
116 features.push(i as f64);
117 } else if let Some(b) = obs.value.as_bool() {
118 features.push(if b { 1.0 } else { 0.0 });
119 } else {
120 features.push(0.0);
122 }
123
124 features.push(obs.confidence.value() as f64);
126
127 features
128 }
129
130 fn update_statistics(&mut self) {
132 if self.history.is_empty() {
133 return;
134 }
135
136 let n = self.history.len();
137 let feature_dim = self.history[0].len();
138
139 self.mean = vec![0.0; feature_dim];
141 self.variance = vec![0.0; feature_dim];
142
143 for features in &self.history {
145 for (i, &value) in features.iter().enumerate() {
146 if i < feature_dim {
147 self.mean[i] += value;
148 }
149 }
150 }
151
152 for mean_val in &mut self.mean {
153 *mean_val /= n as f64;
154 }
155
156 for features in &self.history {
158 for (i, &value) in features.iter().enumerate() {
159 if i < feature_dim {
160 let diff = value - self.mean[i];
161 self.variance[i] += diff * diff;
162 }
163 }
164 }
165
166 for var_val in &mut self.variance {
167 *var_val /= n as f64;
168 *var_val = var_val.max(1e-10);
170 }
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177
178 #[test]
179 fn test_anomaly_detector_creation() {
180 let detector = AnomalyDetector::new(2.0, 100);
181 assert_eq!(detector.threshold, 2.0);
182 assert_eq!(detector.window_size, 100);
183 }
184
185 #[test]
186 fn test_anomaly_detection_insufficient_data() {
187 let mut detector = AnomalyDetector::new(2.0, 100);
188
189 for i in 0..5 {
191 detector.update(&Observation::sensor("temp", 20.0 + i as f64));
192 }
193
194 let obs = Observation::sensor("temp", 100.0);
196 assert!(!detector.is_anomaly(&obs));
197 }
198
199 #[test]
200 fn test_anomaly_detection() {
201 let mut detector = AnomalyDetector::new(2.0, 100);
202
203 for i in 0..100 {
205 let value = 20.0 + (i % 5) as f64;
206 detector.update(&Observation::sensor("temp", value));
207 }
208
209 let normal = Observation::sensor("temp", 22.0);
211 assert!(!detector.is_anomaly(&normal));
212
213 let anomaly = Observation::sensor("temp", 100.0);
215 assert!(detector.is_anomaly(&anomaly));
216 }
217
218 #[test]
219 fn test_anomaly_score() {
220 let mut detector = AnomalyDetector::new(2.0, 100);
221
222 for i in 0..50 {
224 detector.update(&Observation::sensor("temp", 20.0 + (i % 3) as f64));
225 }
226
227 let normal = Observation::sensor("temp", 20.0);
228 let slightly_off = Observation::sensor("temp", 25.0);
229 let very_off = Observation::sensor("temp", 100.0);
230
231 let score_normal = detector.anomaly_score(&normal);
232 let score_slightly = detector.anomaly_score(&slightly_off);
233 let score_very = detector.anomaly_score(&very_off);
234
235 assert!(score_normal < score_slightly);
237 assert!(score_slightly < score_very);
238 }
239
240 #[test]
241 fn test_window_size_limit() {
242 let mut detector = AnomalyDetector::new(2.0, 5);
243
244 for i in 0..10 {
246 detector.update(&Observation::sensor("temp", i as f64));
247 }
248
249 assert_eq!(detector.history.len(), 5);
250 }
251
252 #[test]
253 fn test_statistics_update() {
254 let mut detector = AnomalyDetector::new(2.0, 100);
255
256 for _ in 0..20 {
258 detector.update(&Observation::sensor("temp", 10.0));
259 }
260
261 assert!(!detector.mean.is_empty());
263 assert!((detector.mean[0] - 10.0).abs() < 0.1);
264
265 assert!(detector.variance[0] < 0.1);
267 }
268
269 #[test]
270 fn test_different_value_types() {
271 let mut detector = AnomalyDetector::new(2.0, 100);
272
273 for i in 0..20 {
275 detector.update(&Observation::sensor("count", i));
276 }
277
278 detector.update(&Observation::sensor("flag", true));
280 detector.update(&Observation::sensor("flag", false));
281
282 assert!(detector.history.len() > 0);
283 }
284
285 #[test]
286 fn test_zscore_calculation() {
287 let mut detector = AnomalyDetector::new(2.0, 100);
288
289 for i in 0..100 {
291 let value = 50.0 + ((i % 20) as f64 - 10.0);
292 detector.update(&Observation::sensor("value", value));
293 }
294
295 let at_mean = Observation::sensor("value", 50.0);
297 let score_mean = detector.anomaly_score(&at_mean);
298 assert!(score_mean < 1.0);
299
300 let far_away = Observation::sensor("value", 80.0);
302 let score_far = detector.anomaly_score(&far_away);
303 assert!(score_far > 2.0);
304 }
305}