Skip to main content

entrenar/train/metrics/
regression.rs

1//! Regression metrics: R2Score, MAE, RMSE
2//!
3//! These metrics delegate computation to `aprender::metrics` for the core math,
4//! while wrapping them in entrenar's `Metric` trait for integration with the
5//! training loop and evaluation framework.
6
7use crate::Tensor;
8use aprender::primitives::Vector;
9
10use super::Metric;
11
12/// Convert a Tensor's data to an aprender Vector for delegation.
13fn tensor_to_vector(t: &Tensor) -> Vector<f32> {
14    Vector::from_slice(t.data().as_slice().expect("contiguous tensor data"))
15}
16
17/// R² (coefficient of determination) for regression
18///
19/// Delegates to [`aprender::metrics::r_squared`] for computation.
20///
21/// R² = 1 - SS_res / SS_tot
22///
23/// Where:
24/// - SS_res = sum((y - y_pred)²)
25/// - SS_tot = sum((y - y_mean)²)
26///
27/// R² = 1.0 is perfect prediction, 0.0 means predicting the mean
28///
29/// # Example
30///
31/// ```
32/// use entrenar::train::{R2Score, Metric};
33/// use entrenar::Tensor;
34///
35/// let metric = R2Score;
36/// let pred = Tensor::from_vec(vec![1.0, 2.0, 3.0], false);
37/// let target = Tensor::from_vec(vec![1.0, 2.0, 3.0], false);
38///
39/// let r2 = metric.compute(&pred, &target);
40/// assert!((r2 - 1.0).abs() < 1e-5);  // Perfect prediction
41/// ```
42#[derive(Debug, Clone, Copy, Default)]
43pub struct R2Score;
44
45impl Metric for R2Score {
46    fn compute(&self, predictions: &Tensor, targets: &Tensor) -> f32 {
47        assert_eq!(predictions.len(), targets.len());
48
49        if predictions.is_empty() {
50            return 0.0;
51        }
52
53        let y_pred = tensor_to_vector(predictions);
54        let y_true = tensor_to_vector(targets);
55        let r2 = aprender::metrics::r_squared(&y_pred, &y_true);
56
57        // aprender returns 0.0 for constant targets (ss_tot == 0);
58        // entrenar returns 1.0 when prediction is also perfect (ss_res == 0)
59        if r2 == 0.0 {
60            let ss_res: f32 = predictions
61                .data()
62                .iter()
63                .zip(targets.data().iter())
64                .map(|(&p, &t)| (t - p).powi(2))
65                .sum();
66            if ss_res == 0.0 {
67                return 1.0;
68            }
69        }
70
71        r2
72    }
73
74    fn name(&self) -> &'static str {
75        "R²"
76    }
77}
78
79/// Mean Absolute Error (MAE) metric
80///
81/// Delegates to [`aprender::metrics::mae`] for computation.
82///
83/// MAE = mean(|y - y_pred|)
84///
85/// # Example
86///
87/// ```
88/// use entrenar::train::{MAE, Metric};
89/// use entrenar::Tensor;
90///
91/// let metric = MAE;
92/// let pred = Tensor::from_vec(vec![1.0, 2.0, 3.0], false);
93/// let target = Tensor::from_vec(vec![1.5, 2.5, 3.5], false);
94///
95/// let mae = metric.compute(&pred, &target);
96/// assert!((mae - 0.5).abs() < 1e-5);
97/// ```
98#[derive(Debug, Clone, Copy, Default)]
99pub struct MAE;
100
101impl Metric for MAE {
102    fn compute(&self, predictions: &Tensor, targets: &Tensor) -> f32 {
103        assert_eq!(predictions.len(), targets.len());
104
105        if predictions.is_empty() {
106            return 0.0;
107        }
108
109        let y_pred = tensor_to_vector(predictions);
110        let y_true = tensor_to_vector(targets);
111        aprender::metrics::mae(&y_pred, &y_true)
112    }
113
114    fn name(&self) -> &'static str {
115        "MAE"
116    }
117
118    fn higher_is_better(&self) -> bool {
119        false
120    }
121}
122
123/// Root Mean Squared Error (RMSE) metric
124///
125/// Delegates to [`aprender::metrics::rmse`] for computation.
126///
127/// RMSE = sqrt(mean((y - y_pred)²))
128///
129/// # Example
130///
131/// ```
132/// use entrenar::train::{RMSE, Metric};
133/// use entrenar::Tensor;
134///
135/// let metric = RMSE;
136/// let pred = Tensor::from_vec(vec![1.0, 2.0, 3.0], false);
137/// let target = Tensor::from_vec(vec![1.0, 2.0, 3.0], false);
138///
139/// let rmse = metric.compute(&pred, &target);
140/// assert!(rmse < 1e-5);  // Perfect prediction
141/// ```
142#[derive(Debug, Clone, Copy, Default)]
143pub struct RMSE;
144
145impl Metric for RMSE {
146    fn compute(&self, predictions: &Tensor, targets: &Tensor) -> f32 {
147        assert_eq!(predictions.len(), targets.len());
148
149        if predictions.is_empty() {
150            return 0.0;
151        }
152
153        let y_pred = tensor_to_vector(predictions);
154        let y_true = tensor_to_vector(targets);
155        aprender::metrics::rmse(&y_pred, &y_true)
156    }
157
158    fn name(&self) -> &'static str {
159        "RMSE"
160    }
161
162    fn higher_is_better(&self) -> bool {
163        false
164    }
165}