entrenar/train/metrics/regression.rs
1//! Regression metrics: R2Score, MAE, RMSE
2//!
3//! These metrics delegate computation to `aprender::metrics` for the core math,
4//! while wrapping them in entrenar's `Metric` trait for integration with the
5//! training loop and evaluation framework.
6
7use crate::Tensor;
8use aprender::primitives::Vector;
9
10use super::Metric;
11
12/// Convert a Tensor's data to an aprender Vector for delegation.
13fn tensor_to_vector(t: &Tensor) -> Vector<f32> {
14 Vector::from_slice(t.data().as_slice().expect("contiguous tensor data"))
15}
16
17/// R² (coefficient of determination) for regression
18///
19/// Delegates to [`aprender::metrics::r_squared`] for computation.
20///
21/// R² = 1 - SS_res / SS_tot
22///
23/// Where:
24/// - SS_res = sum((y - y_pred)²)
25/// - SS_tot = sum((y - y_mean)²)
26///
27/// R² = 1.0 is perfect prediction, 0.0 means predicting the mean
28///
29/// # Example
30///
31/// ```
32/// use entrenar::train::{R2Score, Metric};
33/// use entrenar::Tensor;
34///
35/// let metric = R2Score;
36/// let pred = Tensor::from_vec(vec![1.0, 2.0, 3.0], false);
37/// let target = Tensor::from_vec(vec![1.0, 2.0, 3.0], false);
38///
39/// let r2 = metric.compute(&pred, &target);
40/// assert!((r2 - 1.0).abs() < 1e-5); // Perfect prediction
41/// ```
42#[derive(Debug, Clone, Copy, Default)]
43pub struct R2Score;
44
45impl Metric for R2Score {
46 fn compute(&self, predictions: &Tensor, targets: &Tensor) -> f32 {
47 assert_eq!(predictions.len(), targets.len());
48
49 if predictions.is_empty() {
50 return 0.0;
51 }
52
53 let y_pred = tensor_to_vector(predictions);
54 let y_true = tensor_to_vector(targets);
55 let r2 = aprender::metrics::r_squared(&y_pred, &y_true);
56
57 // aprender returns 0.0 for constant targets (ss_tot == 0);
58 // entrenar returns 1.0 when prediction is also perfect (ss_res == 0)
59 if r2 == 0.0 {
60 let ss_res: f32 = predictions
61 .data()
62 .iter()
63 .zip(targets.data().iter())
64 .map(|(&p, &t)| (t - p).powi(2))
65 .sum();
66 if ss_res == 0.0 {
67 return 1.0;
68 }
69 }
70
71 r2
72 }
73
74 fn name(&self) -> &'static str {
75 "R²"
76 }
77}
78
79/// Mean Absolute Error (MAE) metric
80///
81/// Delegates to [`aprender::metrics::mae`] for computation.
82///
83/// MAE = mean(|y - y_pred|)
84///
85/// # Example
86///
87/// ```
88/// use entrenar::train::{MAE, Metric};
89/// use entrenar::Tensor;
90///
91/// let metric = MAE;
92/// let pred = Tensor::from_vec(vec![1.0, 2.0, 3.0], false);
93/// let target = Tensor::from_vec(vec![1.5, 2.5, 3.5], false);
94///
95/// let mae = metric.compute(&pred, &target);
96/// assert!((mae - 0.5).abs() < 1e-5);
97/// ```
98#[derive(Debug, Clone, Copy, Default)]
99pub struct MAE;
100
101impl Metric for MAE {
102 fn compute(&self, predictions: &Tensor, targets: &Tensor) -> f32 {
103 assert_eq!(predictions.len(), targets.len());
104
105 if predictions.is_empty() {
106 return 0.0;
107 }
108
109 let y_pred = tensor_to_vector(predictions);
110 let y_true = tensor_to_vector(targets);
111 aprender::metrics::mae(&y_pred, &y_true)
112 }
113
114 fn name(&self) -> &'static str {
115 "MAE"
116 }
117
118 fn higher_is_better(&self) -> bool {
119 false
120 }
121}
122
123/// Root Mean Squared Error (RMSE) metric
124///
125/// Delegates to [`aprender::metrics::rmse`] for computation.
126///
127/// RMSE = sqrt(mean((y - y_pred)²))
128///
129/// # Example
130///
131/// ```
132/// use entrenar::train::{RMSE, Metric};
133/// use entrenar::Tensor;
134///
135/// let metric = RMSE;
136/// let pred = Tensor::from_vec(vec![1.0, 2.0, 3.0], false);
137/// let target = Tensor::from_vec(vec![1.0, 2.0, 3.0], false);
138///
139/// let rmse = metric.compute(&pred, &target);
140/// assert!(rmse < 1e-5); // Perfect prediction
141/// ```
142#[derive(Debug, Clone, Copy, Default)]
143pub struct RMSE;
144
145impl Metric for RMSE {
146 fn compute(&self, predictions: &Tensor, targets: &Tensor) -> f32 {
147 assert_eq!(predictions.len(), targets.len());
148
149 if predictions.is_empty() {
150 return 0.0;
151 }
152
153 let y_pred = tensor_to_vector(predictions);
154 let y_true = tensor_to_vector(targets);
155 aprender::metrics::rmse(&y_pred, &y_true)
156 }
157
158 fn name(&self) -> &'static str {
159 "RMSE"
160 }
161
162 fn higher_is_better(&self) -> bool {
163 false
164 }
165}