Skip to main content

ferrolearn_preprocess/
label_binarizer.rs

1//! One-vs-rest label binarizer.
2//!
3//! Transforms a vector of integer class labels into a binary indicator matrix.
4//! For *K* classes the output has *K* columns (one-hot rows), except in the
5//! binary case (*K* = 2) where a single column is produced.
6//!
7//! # Examples
8//!
9//! ```
10//! use ferrolearn_preprocess::label_binarizer::LabelBinarizer;
11//! use ferrolearn_core::traits::{Fit, Transform};
12//! use ndarray::array;
13//!
14//! let lb = LabelBinarizer::new();
15//! let y = array![0_usize, 1, 2, 1];
16//! let fitted = lb.fit(&y, &()).unwrap();
17//! let mat = fitted.transform(&y).unwrap();
18//! // 3 classes → (4, 3) indicator matrix
19//! assert_eq!(mat.shape(), &[4, 3]);
20//! assert_eq!(mat[[0, 0]], 1.0);
21//! assert_eq!(mat[[0, 1]], 0.0);
22//! ```
23
24use ferrolearn_core::error::FerroError;
25use ferrolearn_core::traits::{Fit, Transform};
26use ndarray::{Array1, Array2};
27
28// ---------------------------------------------------------------------------
29// LabelBinarizer (unfitted)
30// ---------------------------------------------------------------------------
31
32/// An unfitted one-vs-rest label binarizer.
33///
34/// Calling [`Fit::fit`] on an `Array1<usize>` discovers the sorted set of
35/// unique class labels and returns a [`FittedLabelBinarizer`].
36#[derive(Debug, Clone, Default)]
37pub struct LabelBinarizer;
38
39impl LabelBinarizer {
40    /// Create a new `LabelBinarizer`.
41    #[must_use]
42    pub fn new() -> Self {
43        Self
44    }
45}
46
47// ---------------------------------------------------------------------------
48// FittedLabelBinarizer
49// ---------------------------------------------------------------------------
50
51/// A fitted label binarizer holding the discovered class set.
52///
53/// Created by calling [`Fit::fit`] on a [`LabelBinarizer`].
54#[derive(Debug, Clone)]
55pub struct FittedLabelBinarizer {
56    /// Sorted unique class labels observed during fitting.
57    classes: Vec<usize>,
58}
59
60impl FittedLabelBinarizer {
61    /// Return the sorted class labels discovered during fitting.
62    #[must_use]
63    pub fn classes(&self) -> &[usize] {
64        &self.classes
65    }
66
67    /// Return the number of unique classes.
68    #[must_use]
69    pub fn n_classes(&self) -> usize {
70        self.classes.len()
71    }
72
73    /// Map a binary indicator matrix back to integer class labels.
74    ///
75    /// For each row the class with the largest value (argmax) is chosen.
76    ///
77    /// # Errors
78    ///
79    /// Returns [`FerroError::ShapeMismatch`] if the number of columns does
80    /// not match the expected output width (1 for binary, *K* for multiclass).
81    pub fn inverse_transform(&self, y: &Array2<f64>) -> Result<Array1<usize>, FerroError> {
82        let k = self.classes.len();
83        let expected_cols = if k == 2 { 1 } else { k };
84
85        if y.ncols() != expected_cols {
86            return Err(FerroError::ShapeMismatch {
87                expected: vec![y.nrows(), expected_cols],
88                actual: vec![y.nrows(), y.ncols()],
89                context: "FittedLabelBinarizer::inverse_transform".into(),
90            });
91        }
92
93        let n = y.nrows();
94        let mut result = Array1::zeros(n);
95
96        if k == 2 {
97            // Single column: threshold at 0.5
98            for i in 0..n {
99                result[i] = if y[[i, 0]] >= 0.5 {
100                    self.classes[1]
101                } else {
102                    self.classes[0]
103                };
104            }
105        } else {
106            // Multiclass: argmax per row
107            for i in 0..n {
108                let row = y.row(i);
109                let mut best_j = 0;
110                let mut best_v = f64::NEG_INFINITY;
111                for (j, &v) in row.iter().enumerate() {
112                    if v > best_v {
113                        best_v = v;
114                        best_j = j;
115                    }
116                }
117                result[i] = self.classes[best_j];
118            }
119        }
120
121        Ok(result)
122    }
123}
124
125// ---------------------------------------------------------------------------
126// Trait implementations
127// ---------------------------------------------------------------------------
128
129impl Fit<Array1<usize>, ()> for LabelBinarizer {
130    type Fitted = FittedLabelBinarizer;
131    type Error = FerroError;
132
133    /// Fit the binarizer by discovering unique class labels.
134    ///
135    /// # Errors
136    ///
137    /// Returns [`FerroError::InsufficientSamples`] if the input is empty.
138    fn fit(&self, y: &Array1<usize>, _target: &()) -> Result<FittedLabelBinarizer, FerroError> {
139        if y.is_empty() {
140            return Err(FerroError::InsufficientSamples {
141                required: 1,
142                actual: 0,
143                context: "LabelBinarizer::fit".into(),
144            });
145        }
146
147        let mut classes: Vec<usize> = y.iter().copied().collect();
148        classes.sort_unstable();
149        classes.dedup();
150
151        Ok(FittedLabelBinarizer { classes })
152    }
153}
154
155impl Transform<Array1<usize>> for FittedLabelBinarizer {
156    type Output = Array2<f64>;
157    type Error = FerroError;
158
159    /// Transform labels into a binary indicator matrix.
160    ///
161    /// - For *K* = 2 classes the output shape is `(n, 1)`.
162    /// - For *K* > 2 classes the output shape is `(n, K)`.
163    ///
164    /// # Errors
165    ///
166    /// Returns [`FerroError::InvalidParameter`] if any label in `y` was not
167    /// seen during fitting.
168    fn transform(&self, y: &Array1<usize>) -> Result<Array2<f64>, FerroError> {
169        let k = self.classes.len();
170        let n = y.len();
171
172        // Build a lookup: class_value → column index
173        let class_to_idx: std::collections::HashMap<usize, usize> = self
174            .classes
175            .iter()
176            .enumerate()
177            .map(|(i, &c)| (c, i))
178            .collect();
179
180        if k == 2 {
181            // Binary: single column, 1.0 for the second class
182            let mut out = Array2::zeros((n, 1));
183            for (i, &label) in y.iter().enumerate() {
184                let idx = class_to_idx.get(&label).ok_or_else(|| {
185                    FerroError::InvalidParameter {
186                        name: "y".into(),
187                        reason: format!("unknown label {label} not seen during fit"),
188                    }
189                })?;
190                out[[i, 0]] = if *idx == 1 { 1.0 } else { 0.0 };
191            }
192            Ok(out)
193        } else {
194            // Multiclass: one-hot rows
195            let mut out = Array2::zeros((n, k));
196            for (i, &label) in y.iter().enumerate() {
197                let &idx = class_to_idx.get(&label).ok_or_else(|| {
198                    FerroError::InvalidParameter {
199                        name: "y".into(),
200                        reason: format!("unknown label {label} not seen during fit"),
201                    }
202                })?;
203                out[[i, idx]] = 1.0;
204            }
205            Ok(out)
206        }
207    }
208}
209
210// ===========================================================================
211// Tests
212// ===========================================================================
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217    use ndarray::array;
218
219    #[test]
220    fn test_fit_discovers_sorted_classes() {
221        let lb = LabelBinarizer::new();
222        let y = array![2_usize, 0, 1, 2, 0];
223        let fitted = lb.fit(&y, &()).unwrap();
224        assert_eq!(fitted.classes(), &[0, 1, 2]);
225    }
226
227    #[test]
228    fn test_fit_empty_input_error() {
229        let lb = LabelBinarizer::new();
230        let y: Array1<usize> = Array1::zeros(0);
231        assert!(lb.fit(&y, &()).is_err());
232    }
233
234    #[test]
235    fn test_binary_transform_single_column() {
236        let lb = LabelBinarizer::new();
237        let y = array![0_usize, 1, 0, 1];
238        let fitted = lb.fit(&y, &()).unwrap();
239        let mat = fitted.transform(&y).unwrap();
240        assert_eq!(mat.shape(), &[4, 1]);
241        assert_eq!(mat[[0, 0]], 0.0); // class 0 → 0
242        assert_eq!(mat[[1, 0]], 1.0); // class 1 → 1
243        assert_eq!(mat[[2, 0]], 0.0);
244        assert_eq!(mat[[3, 0]], 1.0);
245    }
246
247    #[test]
248    fn test_multiclass_transform_indicator_matrix() {
249        let lb = LabelBinarizer::new();
250        let y = array![0_usize, 1, 2, 1];
251        let fitted = lb.fit(&y, &()).unwrap();
252        let mat = fitted.transform(&y).unwrap();
253        assert_eq!(mat.shape(), &[4, 3]);
254        // Row 0: class 0 → [1, 0, 0]
255        assert_eq!(mat[[0, 0]], 1.0);
256        assert_eq!(mat[[0, 1]], 0.0);
257        assert_eq!(mat[[0, 2]], 0.0);
258        // Row 2: class 2 → [0, 0, 1]
259        assert_eq!(mat[[2, 0]], 0.0);
260        assert_eq!(mat[[2, 1]], 0.0);
261        assert_eq!(mat[[2, 2]], 1.0);
262    }
263
264    #[test]
265    fn test_inverse_transform_multiclass() {
266        let lb = LabelBinarizer::new();
267        let y = array![0_usize, 1, 2, 1];
268        let fitted = lb.fit(&y, &()).unwrap();
269        let mat = fitted.transform(&y).unwrap();
270        let recovered = fitted.inverse_transform(&mat).unwrap();
271        assert_eq!(recovered, y);
272    }
273
274    #[test]
275    fn test_inverse_transform_binary() {
276        let lb = LabelBinarizer::new();
277        let y = array![0_usize, 1, 0, 1];
278        let fitted = lb.fit(&y, &()).unwrap();
279        let mat = fitted.transform(&y).unwrap();
280        let recovered = fitted.inverse_transform(&mat).unwrap();
281        assert_eq!(recovered, y);
282    }
283
284    #[test]
285    fn test_transform_unknown_label_error() {
286        let lb = LabelBinarizer::new();
287        let y = array![0_usize, 1, 2];
288        let fitted = lb.fit(&y, &()).unwrap();
289        let y2 = array![0_usize, 3]; // 3 not in {0,1,2}
290        assert!(fitted.transform(&y2).is_err());
291    }
292
293    #[test]
294    fn test_inverse_transform_shape_mismatch() {
295        let lb = LabelBinarizer::new();
296        let y = array![0_usize, 1, 2];
297        let fitted = lb.fit(&y, &()).unwrap();
298        // 3 classes expects 3 columns, but we give 2
299        let bad = Array2::<f64>::zeros((2, 2));
300        assert!(fitted.inverse_transform(&bad).is_err());
301    }
302
303    #[test]
304    fn test_single_class() {
305        let lb = LabelBinarizer::new();
306        let y = array![5_usize, 5, 5];
307        let fitted = lb.fit(&y, &()).unwrap();
308        assert_eq!(fitted.n_classes(), 1);
309        // 1 class → 1 column (all zeros since it's the only class)
310        let mat = fitted.transform(&y).unwrap();
311        assert_eq!(mat.shape(), &[3, 1]);
312    }
313
314    #[test]
315    fn test_non_contiguous_classes() {
316        let lb = LabelBinarizer::new();
317        let y = array![10_usize, 20, 30, 10];
318        let fitted = lb.fit(&y, &()).unwrap();
319        assert_eq!(fitted.classes(), &[10, 20, 30]);
320        let mat = fitted.transform(&y).unwrap();
321        assert_eq!(mat.shape(), &[4, 3]);
322        assert_eq!(mat[[0, 0]], 1.0); // 10 → col 0
323        assert_eq!(mat[[1, 1]], 1.0); // 20 → col 1
324        assert_eq!(mat[[2, 2]], 1.0); // 30 → col 2
325    }
326
327    #[test]
328    fn test_roundtrip_multiclass_non_contiguous() {
329        let lb = LabelBinarizer::new();
330        let y = array![10_usize, 20, 30, 20];
331        let fitted = lb.fit(&y, &()).unwrap();
332        let mat = fitted.transform(&y).unwrap();
333        let recovered = fitted.inverse_transform(&mat).unwrap();
334        assert_eq!(recovered, y);
335    }
336}