reductionml_core/metrics/
mean_squared_error.rs

1use crate::{metrics::Metric, utils::AsInner, Features, ScalarPrediction, SimpleLabel};
2
3use super::MetricValue;
4
5pub struct MeanSquaredErrorMetric {
6    pub value: f32,
7    pub count: u64,
8}
9
10impl MeanSquaredErrorMetric {
11    pub fn new() -> MeanSquaredErrorMetric {
12        MeanSquaredErrorMetric {
13            value: 0.0,
14            count: 0,
15        }
16    }
17}
18
19impl Default for MeanSquaredErrorMetric {
20    fn default() -> Self {
21        Self::new()
22    }
23}
24
25impl Metric for MeanSquaredErrorMetric {
26    fn add_point(
27        &mut self,
28        _features: &Features,
29        label: &crate::types::Label,
30        prediction: &crate::types::Prediction,
31    ) {
32        let label: &SimpleLabel = label.as_inner().unwrap();
33        let pred: &ScalarPrediction = prediction.as_inner().unwrap();
34        self.value += (label.value() - pred.prediction) * (label.value() - pred.prediction);
35        self.count += 1;
36    }
37
38    fn get_value(&self) -> MetricValue {
39        MetricValue::Float(self.value / self.count as f32)
40    }
41
42    fn get_name(&self) -> String {
43        "MeanSquaredError".to_owned()
44    }
45}