1use crate::error::{MetricsError, Result};
54use scirs2_core::numeric::Float;
55use std::collections::VecDeque;
56
57pub mod advanced;
59
60pub use advanced::{
62 AdaptiveStreamingMetrics, AdwinDetector, AlertSeverity, AnomalyDetectionAlgorithm,
63 AnomalySummary, ConceptDriftDetector, DdmDetector, DriftDetectionMethod, DriftStatus,
64 PageHinkleyDetector, StreamingConfig, UpdateResult, WindowAdaptationStrategy,
65};
66
67#[derive(Debug, Clone)]
69pub struct StreamingClassificationMetrics {
70 total_samples: usize,
71 correct_predictions: usize,
72 true_positives: usize,
73 false_positives: usize,
74 true_negatives: usize,
75 false_negatives: usize,
76}
77
78impl Default for StreamingClassificationMetrics {
79 fn default() -> Self {
80 Self::new()
81 }
82}
83
84impl StreamingClassificationMetrics {
85 pub fn new() -> Self {
87 Self {
88 total_samples: 0,
89 correct_predictions: 0,
90 true_positives: 0,
91 false_positives: 0,
92 true_negatives: 0,
93 false_negatives: 0,
94 }
95 }
96
97 pub fn update(&mut self, true_label: i32, predlabel: i32) {
99 self.total_samples += 1;
100
101 if true_label == predlabel {
102 self.correct_predictions += 1;
103 }
104
105 match (true_label, predlabel) {
107 (1, 1) => self.true_positives += 1,
108 (0, 1) => self.false_positives += 1,
109 (0, 0) => self.true_negatives += 1,
110 (1, 0) => self.false_negatives += 1,
111 _ => {} }
113 }
114
115 pub fn update_batch(&mut self, true_labels: &[i32], predlabels: &[i32]) -> Result<()> {
117 if true_labels.len() != predlabels.len() {
118 return Err(MetricsError::InvalidInput(
119 "True and predicted _labels must have the same length".to_string(),
120 ));
121 }
122
123 for (&true_label, &predlabel) in true_labels.iter().zip(predlabels.iter()) {
124 self.update(true_label, predlabel);
125 }
126
127 Ok(())
128 }
129
130 pub fn accuracy(&self) -> f64 {
132 if self.total_samples == 0 {
133 0.0
134 } else {
135 self.correct_predictions as f64 / self.total_samples as f64
136 }
137 }
138
139 pub fn precision(&self) -> f64 {
141 let total_positive_predictions = self.true_positives + self.false_positives;
142 if total_positive_predictions == 0 {
143 0.0
144 } else {
145 self.true_positives as f64 / total_positive_predictions as f64
146 }
147 }
148
149 pub fn recall(&self) -> f64 {
151 let total_actual_positives = self.true_positives + self.false_negatives;
152 if total_actual_positives == 0 {
153 0.0
154 } else {
155 self.true_positives as f64 / total_actual_positives as f64
156 }
157 }
158
159 pub fn f1_score(&self) -> f64 {
161 let precision = self.precision();
162 let recall = self.recall();
163
164 if precision + recall == 0.0 {
165 0.0
166 } else {
167 2.0 * precision * recall / (precision + recall)
168 }
169 }
170
171 pub fn specificity(&self) -> f64 {
173 let total_actual_negatives = self.true_negatives + self.false_positives;
174 if total_actual_negatives == 0 {
175 0.0
176 } else {
177 self.true_negatives as f64 / total_actual_negatives as f64
178 }
179 }
180
181 pub fn sample_count(&self) -> usize {
183 self.total_samples
184 }
185
186 pub fn confusion_matrix(&self) -> (usize, usize, usize, usize) {
188 (
189 self.true_positives,
190 self.false_positives,
191 self.true_negatives,
192 self.false_negatives,
193 )
194 }
195
196 pub fn reset(&mut self) {
198 self.total_samples = 0;
199 self.correct_predictions = 0;
200 self.true_positives = 0;
201 self.false_positives = 0;
202 self.true_negatives = 0;
203 self.false_negatives = 0;
204 }
205}
206
207#[derive(Debug, Clone)]
209pub struct StreamingRegressionMetrics<F: Float> {
210 total_samples: usize,
211 sum_squared_errors: F,
212 sum_absolute_errors: F,
213 sum_true_values: F,
214 sum_true_squared: F,
215 sum_pred_values: F,
216 min_error: Option<F>,
217 max_error: Option<F>,
218}
219
220impl<F: Float> Default for StreamingRegressionMetrics<F> {
221 fn default() -> Self {
222 Self::new()
223 }
224}
225
226impl<F: Float> StreamingRegressionMetrics<F> {
227 pub fn new() -> Self {
229 Self {
230 total_samples: 0,
231 sum_squared_errors: F::zero(),
232 sum_absolute_errors: F::zero(),
233 sum_true_values: F::zero(),
234 sum_true_squared: F::zero(),
235 sum_pred_values: F::zero(),
236 min_error: None,
237 max_error: None,
238 }
239 }
240
241 pub fn update(&mut self, true_value: F, predvalue: F) {
243 self.total_samples += 1;
244
245 let error = true_value - predvalue;
246 let abs_error = error.abs();
247 let squared_error = error * error;
248
249 self.sum_squared_errors = self.sum_squared_errors + squared_error;
250 self.sum_absolute_errors = self.sum_absolute_errors + abs_error;
251 self.sum_true_values = self.sum_true_values + true_value;
252 self.sum_true_squared = self.sum_true_squared + (true_value * true_value);
253 self.sum_pred_values = self.sum_pred_values + predvalue;
254
255 match self.min_error {
257 None => self.min_error = Some(abs_error),
258 Some(current_min) => {
259 if abs_error < current_min {
260 self.min_error = Some(abs_error);
261 }
262 }
263 }
264
265 match self.max_error {
266 None => self.max_error = Some(abs_error),
267 Some(current_max) => {
268 if abs_error > current_max {
269 self.max_error = Some(abs_error);
270 }
271 }
272 }
273 }
274
275 pub fn update_batch(&mut self, true_values: &[F], predvalues: &[F]) -> Result<()> {
277 if true_values.len() != predvalues.len() {
278 return Err(MetricsError::InvalidInput(
279 "True and predicted _values must have the same length".to_string(),
280 ));
281 }
282
283 for (&true_value, &predvalue) in true_values.iter().zip(predvalues.iter()) {
284 self.update(true_value, predvalue);
285 }
286
287 Ok(())
288 }
289
290 pub fn mse(&self) -> F {
292 if self.total_samples == 0 {
293 F::zero()
294 } else {
295 self.sum_squared_errors
296 / F::from(self.total_samples).expect("Failed to convert to float")
297 }
298 }
299
300 pub fn rmse(&self) -> F {
302 self.mse().sqrt()
303 }
304
305 pub fn mae(&self) -> F {
307 if self.total_samples == 0 {
308 F::zero()
309 } else {
310 self.sum_absolute_errors
311 / F::from(self.total_samples).expect("Failed to convert to float")
312 }
313 }
314
315 pub fn r2_score(&self) -> F {
317 if self.total_samples == 0 {
318 F::zero()
319 } else {
320 let n = F::from(self.total_samples).expect("Failed to convert to float");
321 let mean_true = self.sum_true_values / n;
322
323 let ss_tot = self.sum_true_squared - n * mean_true * mean_true;
325
326 let ss_res = self.sum_squared_errors;
328
329 if ss_tot == F::zero() {
330 F::zero()
331 } else {
332 F::one() - (ss_res / ss_tot)
333 }
334 }
335 }
336
337 pub fn min_error(&self) -> Option<F> {
339 self.min_error
340 }
341
342 pub fn max_error(&self) -> Option<F> {
344 self.max_error
345 }
346
347 pub fn sample_count(&self) -> usize {
349 self.total_samples
350 }
351
352 pub fn reset(&mut self) {
354 self.total_samples = 0;
355 self.sum_squared_errors = F::zero();
356 self.sum_absolute_errors = F::zero();
357 self.sum_true_values = F::zero();
358 self.sum_true_squared = F::zero();
359 self.sum_pred_values = F::zero();
360 self.min_error = None;
361 self.max_error = None;
362 }
363}
364
365#[derive(Debug, Clone)]
367pub struct WindowedClassificationMetrics {
368 _windowsize: usize,
369 predictions: VecDeque<(i32, i32)>, metrics: StreamingClassificationMetrics,
371}
372
373impl WindowedClassificationMetrics {
374 pub fn new(_windowsize: usize) -> Self {
376 Self {
377 _windowsize,
378 predictions: VecDeque::with_capacity(_windowsize),
379 metrics: StreamingClassificationMetrics::new(),
380 }
381 }
382
383 pub fn update(&mut self, true_label: i32, predlabel: i32) {
385 if self.predictions.len() >= self._windowsize {
387 if let Some((old_true, old_pred)) = self.predictions.pop_front() {
388 self.subtract_prediction(old_true, old_pred);
390 }
391 }
392
393 self.predictions.push_back((true_label, predlabel));
395 self.metrics.update(true_label, predlabel);
396 }
397
398 fn subtract_prediction(&mut self, true_label: i32, predlabel: i32) {
400 if self.metrics.total_samples > 0 {
401 self.metrics.total_samples -= 1;
402 }
403
404 if true_label == predlabel && self.metrics.correct_predictions > 0 {
405 self.metrics.correct_predictions -= 1;
406 }
407
408 match (true_label, predlabel) {
409 (1, 1) => {
410 if self.metrics.true_positives > 0 {
411 self.metrics.true_positives -= 1;
412 }
413 }
414 (0, 1) => {
415 if self.metrics.false_positives > 0 {
416 self.metrics.false_positives -= 1;
417 }
418 }
419 (0, 0) => {
420 if self.metrics.true_negatives > 0 {
421 self.metrics.true_negatives -= 1;
422 }
423 }
424 (1, 0) => {
425 if self.metrics.false_negatives > 0 {
426 self.metrics.false_negatives -= 1;
427 }
428 }
429 _ => {}
430 }
431 }
432
433 pub fn current_window_size(&self) -> usize {
435 self.predictions.len()
436 }
437
438 pub fn max_window_size(&self) -> usize {
440 self._windowsize
441 }
442
443 pub fn accuracy(&self) -> f64 {
445 self.metrics.accuracy()
446 }
447
448 pub fn precision(&self) -> f64 {
449 self.metrics.precision()
450 }
451
452 pub fn recall(&self) -> f64 {
453 self.metrics.recall()
454 }
455
456 pub fn f1_score(&self) -> f64 {
457 self.metrics.f1_score()
458 }
459
460 pub fn sample_count(&self) -> usize {
461 self.metrics.sample_count()
462 }
463
464 pub fn reset(&mut self) {
466 self.predictions.clear();
467 self.metrics.reset();
468 }
469}
470
471#[derive(Debug, Clone)]
473pub struct WindowedRegressionMetrics<F: Float> {
474 _windowsize: usize,
475 predictions: VecDeque<(F, F)>, }
477
478impl<F: Float> WindowedRegressionMetrics<F> {
479 pub fn new(_windowsize: usize) -> Self {
481 Self {
482 _windowsize,
483 predictions: VecDeque::with_capacity(_windowsize),
484 }
485 }
486
487 pub fn update(&mut self, true_value: F, predvalue: F) {
489 if self.predictions.len() >= self._windowsize {
491 self.predictions.pop_front();
492 }
493
494 self.predictions.push_back((true_value, predvalue));
496 }
497
498 pub fn mse(&self) -> F {
500 if self.predictions.is_empty() {
501 return F::zero();
502 }
503
504 let sum_squared_errors = self
505 .predictions
506 .iter()
507 .map(|(true_val, pred_val)| {
508 let error = *true_val - *pred_val;
509 error * error
510 })
511 .fold(F::zero(), |acc, x| acc + x);
512
513 sum_squared_errors / F::from(self.predictions.len()).expect("Operation failed")
514 }
515
516 pub fn rmse(&self) -> F {
518 self.mse().sqrt()
519 }
520
521 pub fn mae(&self) -> F {
523 if self.predictions.is_empty() {
524 return F::zero();
525 }
526
527 let sum_absolute_errors = self
528 .predictions
529 .iter()
530 .map(|(true_val, pred_val)| (*true_val - *pred_val).abs())
531 .fold(F::zero(), |acc, x| acc + x);
532
533 sum_absolute_errors / F::from(self.predictions.len()).expect("Operation failed")
534 }
535
536 pub fn current_window_size(&self) -> usize {
538 self.predictions.len()
539 }
540
541 pub fn max_window_size(&self) -> usize {
543 self._windowsize
544 }
545
546 pub fn reset(&mut self) {
548 self.predictions.clear();
549 }
550}
551
552#[cfg(test)]
553mod tests {
554 use super::*;
555
556 #[test]
557 fn test_streaming_classification_metrics() {
558 let mut metrics = StreamingClassificationMetrics::new();
559
560 assert_eq!(metrics.accuracy(), 0.0);
562 assert_eq!(metrics.sample_count(), 0);
563
564 metrics.update(1, 1);
566 metrics.update(0, 0);
567 metrics.update(1, 1);
568
569 assert_eq!(metrics.accuracy(), 1.0);
570 assert_eq!(metrics.sample_count(), 3);
571 assert_eq!(metrics.precision(), 1.0);
572 assert_eq!(metrics.recall(), 1.0);
573 assert_eq!(metrics.f1_score(), 1.0);
574
575 metrics.update(1, 0); metrics.update(0, 1); assert_eq!(metrics.accuracy(), 0.6); assert_eq!(metrics.sample_count(), 5);
581
582 let (tp, fp, tn, fn_) = metrics.confusion_matrix();
583 assert_eq!(tp, 2);
584 assert_eq!(fp, 1);
585 assert_eq!(tn, 1);
586 assert_eq!(fn_, 1);
587 }
588
589 #[test]
590 fn test_streaming_regression_metrics() {
591 let mut metrics = StreamingRegressionMetrics::<f64>::new();
592
593 assert_eq!(metrics.mse(), 0.0);
595 assert_eq!(metrics.mae(), 0.0);
596 assert_eq!(metrics.sample_count(), 0);
597
598 metrics.update(1.0, 1.0);
600 metrics.update(2.0, 2.0);
601 metrics.update(3.0, 3.0);
602
603 assert_eq!(metrics.mse(), 0.0);
604 assert_eq!(metrics.mae(), 0.0);
605 assert_eq!(metrics.rmse(), 0.0);
606 assert_eq!(metrics.sample_count(), 3);
607
608 metrics.update(4.0, 5.0); metrics.update(6.0, 4.0); assert_eq!(metrics.mse(), 1.0);
614 assert_eq!(metrics.mae(), 0.6);
616 assert_eq!(metrics.rmse(), 1.0);
617 assert_eq!(metrics.min_error(), Some(0.0));
618 assert_eq!(metrics.max_error(), Some(2.0));
619 }
620
621 #[test]
622 fn test_windowed_classification_metrics() {
623 let mut metrics = WindowedClassificationMetrics::new(3);
624
625 assert_eq!(metrics.current_window_size(), 0);
626 assert_eq!(metrics.max_window_size(), 3);
627
628 metrics.update(1, 1); metrics.update(0, 0); metrics.update(1, 0); assert_eq!(metrics.current_window_size(), 3);
634 assert_eq!(metrics.accuracy(), 2.0 / 3.0);
635
636 metrics.update(0, 1); assert_eq!(metrics.current_window_size(), 3);
640 assert_eq!(metrics.accuracy(), 1.0 / 3.0); }
642
643 #[test]
644 fn test_windowed_regression_metrics() {
645 let mut metrics = WindowedRegressionMetrics::<f64>::new(2);
646
647 assert_eq!(metrics.current_window_size(), 0);
648 assert_eq!(metrics.max_window_size(), 2);
649
650 metrics.update(1.0, 1.0); metrics.update(2.0, 3.0); assert_eq!(metrics.current_window_size(), 2);
655 assert_eq!(metrics.mse(), 0.5); assert_eq!(metrics.mae(), 0.5); metrics.update(4.0, 2.0); assert_eq!(metrics.current_window_size(), 2);
662 assert_eq!(metrics.mse(), 2.5); assert_eq!(metrics.mae(), 1.5); }
665
666 #[test]
667 fn test_batch_updates() {
668 let mut metrics = StreamingClassificationMetrics::new();
669
670 let true_labels = vec![1, 0, 1, 0];
671 let predlabels = vec![1, 0, 0, 1];
672
673 metrics
674 .update_batch(&true_labels, &predlabels)
675 .expect("Operation failed");
676
677 assert_eq!(metrics.sample_count(), 4);
678 assert_eq!(metrics.accuracy(), 0.5); }
680
681 #[test]
682 fn test_reset_functionality() {
683 let mut metrics = StreamingClassificationMetrics::new();
684
685 metrics.update(1, 1);
686 metrics.update(0, 0);
687
688 assert_eq!(metrics.sample_count(), 2);
689 assert_eq!(metrics.accuracy(), 1.0);
690
691 metrics.reset();
692
693 assert_eq!(metrics.sample_count(), 0);
694 assert_eq!(metrics.accuracy(), 0.0);
695 }
696}