Skip to main content

entrenar/train/metrics/
classification.rs

1//! Classification metrics: Accuracy, Precision, Recall, F1
2//!
3//! Thresholding (continuous predictions → discrete labels) is entrenar's concern.
4//! Metric computation on discrete labels delegates to `aprender::metrics::classification`.
5
6use crate::Tensor;
7
8use super::Metric;
9
10/// Convert continuous predictions and targets to discrete binary labels.
11///
12/// Entrenar's training concern: model outputs are continuous (logits/probabilities),
13/// so thresholding is part of evaluation. After thresholding, the discrete labels
14/// can be passed to aprender for metric computation.
15fn threshold_to_labels(
16    predictions: &Tensor,
17    targets: &Tensor,
18    threshold: f32,
19) -> (Vec<usize>, Vec<usize>) {
20    let y_pred: Vec<usize> =
21        predictions.data().iter().map(|&p| usize::from(p >= threshold)).collect();
22    let y_true: Vec<usize> = targets.data().iter().map(|&t| usize::from(t >= 0.5)).collect();
23    (y_pred, y_true)
24}
25
26/// Accuracy metric for classification
27///
28/// For binary classification: fraction of correct predictions
29/// For multi-class: fraction where argmax(pred) == argmax(target)
30///
31/// # Example
32///
33/// ```
34/// use entrenar::train::{Accuracy, Metric};
35/// use entrenar::Tensor;
36///
37/// let metric = Accuracy::new(0.5);  // threshold for binary
38/// let pred = Tensor::from_vec(vec![0.9, 0.2, 0.8], false);
39/// let target = Tensor::from_vec(vec![1.0, 0.0, 1.0], false);
40///
41/// let acc = metric.compute(&pred, &target);
42/// assert_eq!(acc, 1.0);  // All correct
43/// ```
44#[derive(Debug, Clone)]
45pub struct Accuracy {
46    /// Threshold for binary classification
47    pub(crate) threshold: f32,
48}
49
50impl Accuracy {
51    /// Create new accuracy metric with given threshold for binary classification
52    pub fn new(threshold: f32) -> Self {
53        Self { threshold }
54    }
55
56    /// Create accuracy metric with default threshold of 0.5
57    pub fn default_threshold() -> Self {
58        Self::new(0.5)
59    }
60}
61
62impl Default for Accuracy {
63    fn default() -> Self {
64        Self::new(0.5)
65    }
66}
67
68impl Metric for Accuracy {
69    fn compute(&self, predictions: &Tensor, targets: &Tensor) -> f32 {
70        assert_eq!(
71            predictions.len(),
72            targets.len(),
73            "Predictions and targets must have same length"
74        );
75
76        if predictions.is_empty() {
77            return 0.0;
78        }
79
80        // Threshold to discrete labels (entrenar's concern), then delegate to aprender
81        let (y_pred, y_true) = threshold_to_labels(predictions, targets, self.threshold);
82        aprender::metrics::classification::accuracy(&y_pred, &y_true)
83    }
84
85    fn name(&self) -> &'static str {
86        "Accuracy"
87    }
88}
89
90/// Precision metric (true positives / predicted positives)
91///
92/// # Example
93///
94/// ```
95/// use entrenar::train::{Precision, Metric};
96/// use entrenar::Tensor;
97///
98/// let metric = Precision::new(0.5);
99/// let pred = Tensor::from_vec(vec![0.9, 0.8, 0.2], false);
100/// let target = Tensor::from_vec(vec![1.0, 0.0, 0.0], false);
101///
102/// let prec = metric.compute(&pred, &target);
103/// assert_eq!(prec, 0.5);  // 1 TP / 2 predicted positives
104/// ```
105#[derive(Debug, Clone)]
106pub struct Precision {
107    pub(crate) threshold: f32,
108}
109
110impl Precision {
111    pub fn new(threshold: f32) -> Self {
112        Self { threshold }
113    }
114}
115
116impl Default for Precision {
117    fn default() -> Self {
118        Self::new(0.5)
119    }
120}
121
122impl Metric for Precision {
123    fn compute(&self, predictions: &Tensor, targets: &Tensor) -> f32 {
124        assert_eq!(predictions.len(), targets.len());
125
126        if predictions.is_empty() {
127            return 0.0;
128        }
129
130        // Threshold to discrete labels (entrenar), count from labels
131        // Note: aprender's precision() uses macro/micro/weighted averaging which
132        // differs from binary positive-class precision. We threshold via entrenar,
133        // then compute TP/FP from discrete labels.
134        let (y_pred, y_true) = threshold_to_labels(predictions, targets, self.threshold);
135
136        let mut true_positives = 0usize;
137        let mut predicted_positives = 0usize;
138
139        for (&p, &t) in y_pred.iter().zip(y_true.iter()) {
140            if p == 1 {
141                predicted_positives += 1;
142                if t == 1 {
143                    true_positives += 1;
144                }
145            }
146        }
147
148        if predicted_positives == 0 {
149            return 0.0;
150        }
151
152        true_positives as f32 / predicted_positives as f32
153    }
154
155    fn name(&self) -> &'static str {
156        "Precision"
157    }
158}
159
160/// Recall metric (true positives / actual positives)
161///
162/// # Example
163///
164/// ```
165/// use entrenar::train::{Recall, Metric};
166/// use entrenar::Tensor;
167///
168/// let metric = Recall::new(0.5);
169/// let pred = Tensor::from_vec(vec![0.9, 0.2, 0.8], false);
170/// let target = Tensor::from_vec(vec![1.0, 1.0, 0.0], false);
171///
172/// let rec = metric.compute(&pred, &target);
173/// assert_eq!(rec, 0.5);  // 1 TP / 2 actual positives
174/// ```
175#[derive(Debug, Clone)]
176pub struct Recall {
177    pub(crate) threshold: f32,
178}
179
180impl Recall {
181    pub fn new(threshold: f32) -> Self {
182        Self { threshold }
183    }
184}
185
186impl Default for Recall {
187    fn default() -> Self {
188        Self::new(0.5)
189    }
190}
191
192impl Metric for Recall {
193    fn compute(&self, predictions: &Tensor, targets: &Tensor) -> f32 {
194        assert_eq!(predictions.len(), targets.len());
195
196        if predictions.is_empty() {
197            return 0.0;
198        }
199
200        // Threshold to discrete labels (entrenar), count from labels
201        let (y_pred, y_true) = threshold_to_labels(predictions, targets, self.threshold);
202
203        let mut true_positives = 0usize;
204        let mut actual_positives = 0usize;
205
206        for (&p, &t) in y_pred.iter().zip(y_true.iter()) {
207            if t == 1 {
208                actual_positives += 1;
209                if p == 1 {
210                    true_positives += 1;
211                }
212            }
213        }
214
215        if actual_positives == 0 {
216            return 0.0;
217        }
218
219        true_positives as f32 / actual_positives as f32
220    }
221
222    fn name(&self) -> &'static str {
223        "Recall"
224    }
225}
226
227/// F1 Score (harmonic mean of precision and recall)
228///
229/// F1 = 2 * (precision * recall) / (precision + recall)
230///
231/// # Example
232///
233/// ```
234/// use entrenar::train::{F1Score, Metric};
235/// use entrenar::Tensor;
236///
237/// let metric = F1Score::new(0.5);
238/// let pred = Tensor::from_vec(vec![0.9, 0.8, 0.2, 0.1], false);
239/// let target = Tensor::from_vec(vec![1.0, 0.0, 1.0, 0.0], false);
240///
241/// let f1 = metric.compute(&pred, &target);
242/// assert!(f1 > 0.0 && f1 <= 1.0);
243/// ```
244#[derive(Debug, Clone)]
245pub struct F1Score {
246    precision: Precision,
247    recall: Recall,
248}
249
250impl F1Score {
251    pub fn new(threshold: f32) -> Self {
252        Self { precision: Precision::new(threshold), recall: Recall::new(threshold) }
253    }
254}
255
256impl Default for F1Score {
257    fn default() -> Self {
258        Self::new(0.5)
259    }
260}
261
262impl Metric for F1Score {
263    fn compute(&self, predictions: &Tensor, targets: &Tensor) -> f32 {
264        let precision = self.precision.compute(predictions, targets);
265        let recall = self.recall.compute(predictions, targets);
266
267        if precision + recall == 0.0 {
268            return 0.0;
269        }
270
271        2.0 * (precision * recall) / (precision + recall)
272    }
273
274    fn name(&self) -> &'static str {
275        "F1"
276    }
277}