reductionml_core/metrics/
mean_squared_error.rs1use crate::{metrics::Metric, utils::AsInner, Features, ScalarPrediction, SimpleLabel};
2
3use super::MetricValue;
4
5pub struct MeanSquaredErrorMetric {
6 pub value: f32,
7 pub count: u64,
8}
9
10impl MeanSquaredErrorMetric {
11 pub fn new() -> MeanSquaredErrorMetric {
12 MeanSquaredErrorMetric {
13 value: 0.0,
14 count: 0,
15 }
16 }
17}
18
19impl Default for MeanSquaredErrorMetric {
20 fn default() -> Self {
21 Self::new()
22 }
23}
24
25impl Metric for MeanSquaredErrorMetric {
26 fn add_point(
27 &mut self,
28 _features: &Features,
29 label: &crate::types::Label,
30 prediction: &crate::types::Prediction,
31 ) {
32 let label: &SimpleLabel = label.as_inner().unwrap();
33 let pred: &ScalarPrediction = prediction.as_inner().unwrap();
34 self.value += (label.value() - pred.prediction) * (label.value() - pred.prediction);
35 self.count += 1;
36 }
37
38 fn get_value(&self) -> MetricValue {
39 MetricValue::Float(self.value / self.count as f32)
40 }
41
42 fn get_name(&self) -> String {
43 "MeanSquaredError".to_owned()
44 }
45}