Skip to main content

yscv_eval/
counting.rs

1use crate::EvalError;
2
3#[derive(Debug, Clone, Copy, PartialEq)]
4pub struct CountingMetrics {
5    pub num_frames: usize,
6    pub mae: f32,
7    pub rmse: f32,
8    pub max_abs_error: usize,
9}
10
11pub fn evaluate_counts(
12    ground_truth: &[usize],
13    predictions: &[usize],
14) -> Result<CountingMetrics, EvalError> {
15    if ground_truth.len() != predictions.len() {
16        return Err(EvalError::CountLengthMismatch {
17            ground_truth: ground_truth.len(),
18            predictions: predictions.len(),
19        });
20    }
21
22    if ground_truth.is_empty() {
23        return Ok(CountingMetrics {
24            num_frames: 0,
25            mae: 0.0,
26            rmse: 0.0,
27            max_abs_error: 0,
28        });
29    }
30
31    let mut abs_error_sum = 0.0f32;
32    let mut sq_error_sum = 0.0f32;
33    let mut max_abs_error = 0usize;
34    for (&gt, &prediction) in ground_truth.iter().zip(predictions.iter()) {
35        let error = prediction as i64 - gt as i64;
36        let abs_error = error.unsigned_abs() as usize;
37        abs_error_sum += abs_error as f32;
38        sq_error_sum += (error as f32).powi(2);
39        max_abs_error = max_abs_error.max(abs_error);
40    }
41
42    let denom = ground_truth.len() as f32;
43    Ok(CountingMetrics {
44        num_frames: ground_truth.len(),
45        mae: abs_error_sum / denom,
46        rmse: (sq_error_sum / denom).sqrt(),
47        max_abs_error,
48    })
49}