entrenar/train/metrics/classification.rs
1//! Classification metrics: Accuracy, Precision, Recall, F1
2//!
3//! Thresholding (continuous predictions → discrete labels) is entrenar's concern.
4//! Metric computation on discrete labels delegates to `aprender::metrics::classification`.
5
6use crate::Tensor;
7
8use super::Metric;
9
10/// Convert continuous predictions and targets to discrete binary labels.
11///
12/// Entrenar's training concern: model outputs are continuous (logits/probabilities),
13/// so thresholding is part of evaluation. After thresholding, the discrete labels
14/// can be passed to aprender for metric computation.
15fn threshold_to_labels(
16 predictions: &Tensor,
17 targets: &Tensor,
18 threshold: f32,
19) -> (Vec<usize>, Vec<usize>) {
20 let y_pred: Vec<usize> =
21 predictions.data().iter().map(|&p| usize::from(p >= threshold)).collect();
22 let y_true: Vec<usize> = targets.data().iter().map(|&t| usize::from(t >= 0.5)).collect();
23 (y_pred, y_true)
24}
25
26/// Accuracy metric for classification
27///
28/// For binary classification: fraction of correct predictions
29/// For multi-class: fraction where argmax(pred) == argmax(target)
30///
31/// # Example
32///
33/// ```
34/// use entrenar::train::{Accuracy, Metric};
35/// use entrenar::Tensor;
36///
37/// let metric = Accuracy::new(0.5); // threshold for binary
38/// let pred = Tensor::from_vec(vec![0.9, 0.2, 0.8], false);
39/// let target = Tensor::from_vec(vec![1.0, 0.0, 1.0], false);
40///
41/// let acc = metric.compute(&pred, &target);
42/// assert_eq!(acc, 1.0); // All correct
43/// ```
44#[derive(Debug, Clone)]
45pub struct Accuracy {
46 /// Threshold for binary classification
47 pub(crate) threshold: f32,
48}
49
50impl Accuracy {
51 /// Create new accuracy metric with given threshold for binary classification
52 pub fn new(threshold: f32) -> Self {
53 Self { threshold }
54 }
55
56 /// Create accuracy metric with default threshold of 0.5
57 pub fn default_threshold() -> Self {
58 Self::new(0.5)
59 }
60}
61
62impl Default for Accuracy {
63 fn default() -> Self {
64 Self::new(0.5)
65 }
66}
67
68impl Metric for Accuracy {
69 fn compute(&self, predictions: &Tensor, targets: &Tensor) -> f32 {
70 assert_eq!(
71 predictions.len(),
72 targets.len(),
73 "Predictions and targets must have same length"
74 );
75
76 if predictions.is_empty() {
77 return 0.0;
78 }
79
80 // Threshold to discrete labels (entrenar's concern), then delegate to aprender
81 let (y_pred, y_true) = threshold_to_labels(predictions, targets, self.threshold);
82 aprender::metrics::classification::accuracy(&y_pred, &y_true)
83 }
84
85 fn name(&self) -> &'static str {
86 "Accuracy"
87 }
88}
89
90/// Precision metric (true positives / predicted positives)
91///
92/// # Example
93///
94/// ```
95/// use entrenar::train::{Precision, Metric};
96/// use entrenar::Tensor;
97///
98/// let metric = Precision::new(0.5);
99/// let pred = Tensor::from_vec(vec![0.9, 0.8, 0.2], false);
100/// let target = Tensor::from_vec(vec![1.0, 0.0, 0.0], false);
101///
102/// let prec = metric.compute(&pred, &target);
103/// assert_eq!(prec, 0.5); // 1 TP / 2 predicted positives
104/// ```
105#[derive(Debug, Clone)]
106pub struct Precision {
107 pub(crate) threshold: f32,
108}
109
110impl Precision {
111 pub fn new(threshold: f32) -> Self {
112 Self { threshold }
113 }
114}
115
116impl Default for Precision {
117 fn default() -> Self {
118 Self::new(0.5)
119 }
120}
121
122impl Metric for Precision {
123 fn compute(&self, predictions: &Tensor, targets: &Tensor) -> f32 {
124 assert_eq!(predictions.len(), targets.len());
125
126 if predictions.is_empty() {
127 return 0.0;
128 }
129
130 // Threshold to discrete labels (entrenar), count from labels
131 // Note: aprender's precision() uses macro/micro/weighted averaging which
132 // differs from binary positive-class precision. We threshold via entrenar,
133 // then compute TP/FP from discrete labels.
134 let (y_pred, y_true) = threshold_to_labels(predictions, targets, self.threshold);
135
136 let mut true_positives = 0usize;
137 let mut predicted_positives = 0usize;
138
139 for (&p, &t) in y_pred.iter().zip(y_true.iter()) {
140 if p == 1 {
141 predicted_positives += 1;
142 if t == 1 {
143 true_positives += 1;
144 }
145 }
146 }
147
148 if predicted_positives == 0 {
149 return 0.0;
150 }
151
152 true_positives as f32 / predicted_positives as f32
153 }
154
155 fn name(&self) -> &'static str {
156 "Precision"
157 }
158}
159
160/// Recall metric (true positives / actual positives)
161///
162/// # Example
163///
164/// ```
165/// use entrenar::train::{Recall, Metric};
166/// use entrenar::Tensor;
167///
168/// let metric = Recall::new(0.5);
169/// let pred = Tensor::from_vec(vec![0.9, 0.2, 0.8], false);
170/// let target = Tensor::from_vec(vec![1.0, 1.0, 0.0], false);
171///
172/// let rec = metric.compute(&pred, &target);
173/// assert_eq!(rec, 0.5); // 1 TP / 2 actual positives
174/// ```
175#[derive(Debug, Clone)]
176pub struct Recall {
177 pub(crate) threshold: f32,
178}
179
180impl Recall {
181 pub fn new(threshold: f32) -> Self {
182 Self { threshold }
183 }
184}
185
186impl Default for Recall {
187 fn default() -> Self {
188 Self::new(0.5)
189 }
190}
191
192impl Metric for Recall {
193 fn compute(&self, predictions: &Tensor, targets: &Tensor) -> f32 {
194 assert_eq!(predictions.len(), targets.len());
195
196 if predictions.is_empty() {
197 return 0.0;
198 }
199
200 // Threshold to discrete labels (entrenar), count from labels
201 let (y_pred, y_true) = threshold_to_labels(predictions, targets, self.threshold);
202
203 let mut true_positives = 0usize;
204 let mut actual_positives = 0usize;
205
206 for (&p, &t) in y_pred.iter().zip(y_true.iter()) {
207 if t == 1 {
208 actual_positives += 1;
209 if p == 1 {
210 true_positives += 1;
211 }
212 }
213 }
214
215 if actual_positives == 0 {
216 return 0.0;
217 }
218
219 true_positives as f32 / actual_positives as f32
220 }
221
222 fn name(&self) -> &'static str {
223 "Recall"
224 }
225}
226
227/// F1 Score (harmonic mean of precision and recall)
228///
229/// F1 = 2 * (precision * recall) / (precision + recall)
230///
231/// # Example
232///
233/// ```
234/// use entrenar::train::{F1Score, Metric};
235/// use entrenar::Tensor;
236///
237/// let metric = F1Score::new(0.5);
238/// let pred = Tensor::from_vec(vec![0.9, 0.8, 0.2, 0.1], false);
239/// let target = Tensor::from_vec(vec![1.0, 0.0, 1.0, 0.0], false);
240///
241/// let f1 = metric.compute(&pred, &target);
242/// assert!(f1 > 0.0 && f1 <= 1.0);
243/// ```
244#[derive(Debug, Clone)]
245pub struct F1Score {
246 precision: Precision,
247 recall: Recall,
248}
249
250impl F1Score {
251 pub fn new(threshold: f32) -> Self {
252 Self { precision: Precision::new(threshold), recall: Recall::new(threshold) }
253 }
254}
255
256impl Default for F1Score {
257 fn default() -> Self {
258 Self::new(0.5)
259 }
260}
261
262impl Metric for F1Score {
263 fn compute(&self, predictions: &Tensor, targets: &Tensor) -> f32 {
264 let precision = self.precision.compute(predictions, targets);
265 let recall = self.recall.compute(predictions, targets);
266
267 if precision + recall == 0.0 {
268 return 0.0;
269 }
270
271 2.0 * (precision * recall) / (precision + recall)
272 }
273
274 fn name(&self) -> &'static str {
275 "F1"
276 }
277}