Skip to main content

ferrolearn_preprocess/
multi_label_binarizer.rs

1//! Multi-label binarizer.
2//!
3//! Transforms a list of label sets into a multi-hot binary indicator matrix.
4//! Each sample can belong to zero or more classes simultaneously.
5//!
6//! # Examples
7//!
8//! ```
9//! use ferrolearn_preprocess::multi_label_binarizer::MultiLabelBinarizer;
10//! use ferrolearn_core::traits::{Fit, Transform};
11//!
12//! let mlb = MultiLabelBinarizer::new();
13//! let y = vec![vec![0, 1], vec![1, 2], vec![0]];
14//! let fitted = mlb.fit(&y, &()).unwrap();
15//! let mat = fitted.transform(&y).unwrap();
16//! // 3 classes → (3, 3) multi-hot matrix
17//! assert_eq!(mat.shape(), &[3, 3]);
18//! assert_eq!(mat[[0, 0]], 1.0); // sample 0 has label 0
19//! assert_eq!(mat[[0, 1]], 1.0); // sample 0 has label 1
20//! assert_eq!(mat[[0, 2]], 0.0); // sample 0 does NOT have label 2
21//! ```
22
23use ferrolearn_core::error::FerroError;
24use ferrolearn_core::traits::{Fit, Transform};
25use ndarray::Array2;
26
27// ---------------------------------------------------------------------------
28// MultiLabelBinarizer (unfitted)
29// ---------------------------------------------------------------------------
30
31/// An unfitted multi-label binarizer.
32///
33/// Calling [`Fit::fit`] on a `&[Vec<usize>]` discovers the sorted set of all
34/// unique labels across all samples and returns a [`FittedMultiLabelBinarizer`].
35#[derive(Debug, Clone, Default)]
36pub struct MultiLabelBinarizer;
37
38impl MultiLabelBinarizer {
39    /// Create a new `MultiLabelBinarizer`.
40    #[must_use]
41    pub fn new() -> Self {
42        Self
43    }
44}
45
46// ---------------------------------------------------------------------------
47// FittedMultiLabelBinarizer
48// ---------------------------------------------------------------------------
49
50/// A fitted multi-label binarizer holding the discovered class set.
51///
52/// Created by calling [`Fit::fit`] on a [`MultiLabelBinarizer`].
53#[derive(Debug, Clone)]
54pub struct FittedMultiLabelBinarizer {
55    /// Sorted unique class labels observed during fitting.
56    classes: Vec<usize>,
57}
58
59impl FittedMultiLabelBinarizer {
60    /// Return the sorted class labels discovered during fitting.
61    #[must_use]
62    pub fn classes(&self) -> &[usize] {
63        &self.classes
64    }
65
66    /// Return the number of unique classes.
67    #[must_use]
68    pub fn n_classes(&self) -> usize {
69        self.classes.len()
70    }
71
72    /// Map a multi-hot indicator matrix back to label sets.
73    ///
74    /// Each column value is thresholded at 0.5: values >= 0.5 are included.
75    ///
76    /// # Errors
77    ///
78    /// Returns [`FerroError::ShapeMismatch`] if the number of columns does
79    /// not match the number of classes.
80    pub fn inverse_transform(&self, y: &Array2<f64>) -> Result<Vec<Vec<usize>>, FerroError> {
81        let k = self.classes.len();
82        if y.ncols() != k {
83            return Err(FerroError::ShapeMismatch {
84                expected: vec![y.nrows(), k],
85                actual: vec![y.nrows(), y.ncols()],
86                context: "FittedMultiLabelBinarizer::inverse_transform".into(),
87            });
88        }
89
90        let n = y.nrows();
91        let mut result = Vec::with_capacity(n);
92
93        for i in 0..n {
94            let mut labels = Vec::new();
95            for (j, &cls) in self.classes.iter().enumerate() {
96                if y[[i, j]] >= 0.5 {
97                    labels.push(cls);
98                }
99            }
100            result.push(labels);
101        }
102
103        Ok(result)
104    }
105}
106
107// ---------------------------------------------------------------------------
108// Trait implementations
109// ---------------------------------------------------------------------------
110
111impl Fit<Vec<Vec<usize>>, ()> for MultiLabelBinarizer {
112    type Fitted = FittedMultiLabelBinarizer;
113    type Error = FerroError;
114
115    /// Fit the binarizer by discovering all unique labels.
116    ///
117    /// # Errors
118    ///
119    /// Returns [`FerroError::InsufficientSamples`] if the input is empty.
120    fn fit(
121        &self,
122        y: &Vec<Vec<usize>>,
123        _target: &(),
124    ) -> Result<FittedMultiLabelBinarizer, FerroError> {
125        if y.is_empty() {
126            return Err(FerroError::InsufficientSamples {
127                required: 1,
128                actual: 0,
129                context: "MultiLabelBinarizer::fit".into(),
130            });
131        }
132
133        let mut classes: Vec<usize> = y.iter().flatten().copied().collect();
134        classes.sort_unstable();
135        classes.dedup();
136
137        Ok(FittedMultiLabelBinarizer { classes })
138    }
139}
140
141impl Transform<Vec<Vec<usize>>> for FittedMultiLabelBinarizer {
142    type Output = Array2<f64>;
143    type Error = FerroError;
144
145    /// Transform label sets into a multi-hot indicator matrix.
146    ///
147    /// Each row has a `1.0` in every column corresponding to one of its labels
148    /// and `0.0` elsewhere.
149    ///
150    /// # Errors
151    ///
152    /// Returns [`FerroError::InvalidParameter`] if any label was not seen
153    /// during fitting.
154    fn transform(&self, y: &Vec<Vec<usize>>) -> Result<Array2<f64>, FerroError> {
155        let k = self.classes.len();
156        let n = y.len();
157
158        // Build lookup: class_value → column index
159        let class_to_idx: std::collections::HashMap<usize, usize> = self
160            .classes
161            .iter()
162            .enumerate()
163            .map(|(i, &c)| (c, i))
164            .collect();
165
166        let mut out = Array2::zeros((n, k));
167
168        for (i, labels) in y.iter().enumerate() {
169            for &label in labels {
170                let &idx = class_to_idx.get(&label).ok_or_else(|| {
171                    FerroError::InvalidParameter {
172                        name: "y".into(),
173                        reason: format!("unknown label {label} not seen during fit"),
174                    }
175                })?;
176                out[[i, idx]] = 1.0;
177            }
178        }
179
180        Ok(out)
181    }
182}
183
184// ===========================================================================
185// Tests
186// ===========================================================================
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191    use ndarray::array;
192
193    #[test]
194    fn test_fit_discovers_sorted_classes() {
195        let mlb = MultiLabelBinarizer::new();
196        let y = vec![vec![2, 0], vec![1]];
197        let fitted = mlb.fit(&y, &()).unwrap();
198        assert_eq!(fitted.classes(), &[0, 1, 2]);
199    }
200
201    #[test]
202    fn test_fit_empty_input_error() {
203        let mlb = MultiLabelBinarizer::new();
204        let y: Vec<Vec<usize>> = vec![];
205        assert!(mlb.fit(&y, &()).is_err());
206    }
207
208    #[test]
209    fn test_transform_multi_hot() {
210        let mlb = MultiLabelBinarizer::new();
211        let y = vec![vec![0, 2], vec![1], vec![0, 1, 2]];
212        let fitted = mlb.fit(&y, &()).unwrap();
213        let mat = fitted.transform(&y).unwrap();
214        assert_eq!(mat.shape(), &[3, 3]);
215        // Row 0: labels {0, 2} → [1, 0, 1]
216        assert_eq!(mat[[0, 0]], 1.0);
217        assert_eq!(mat[[0, 1]], 0.0);
218        assert_eq!(mat[[0, 2]], 1.0);
219        // Row 1: labels {1} → [0, 1, 0]
220        assert_eq!(mat[[1, 0]], 0.0);
221        assert_eq!(mat[[1, 1]], 1.0);
222        assert_eq!(mat[[1, 2]], 0.0);
223        // Row 2: labels {0, 1, 2} → [1, 1, 1]
224        assert_eq!(mat[[2, 0]], 1.0);
225        assert_eq!(mat[[2, 1]], 1.0);
226        assert_eq!(mat[[2, 2]], 1.0);
227    }
228
229    #[test]
230    fn test_transform_unknown_label_error() {
231        let mlb = MultiLabelBinarizer::new();
232        let y = vec![vec![0, 1]];
233        let fitted = mlb.fit(&y, &()).unwrap();
234        let y2 = vec![vec![0, 5]]; // 5 not in {0, 1}
235        assert!(fitted.transform(&y2).is_err());
236    }
237
238    #[test]
239    fn test_inverse_transform_roundtrip() {
240        let mlb = MultiLabelBinarizer::new();
241        let y = vec![vec![0, 2], vec![1], vec![0, 1, 2]];
242        let fitted = mlb.fit(&y, &()).unwrap();
243        let mat = fitted.transform(&y).unwrap();
244        let recovered = fitted.inverse_transform(&mat).unwrap();
245        assert_eq!(recovered, y);
246    }
247
248    #[test]
249    fn test_inverse_transform_shape_mismatch() {
250        let mlb = MultiLabelBinarizer::new();
251        let y = vec![vec![0, 1, 2]];
252        let fitted = mlb.fit(&y, &()).unwrap();
253        // 3 classes expects 3 columns
254        let bad = Array2::<f64>::zeros((2, 2));
255        assert!(fitted.inverse_transform(&bad).is_err());
256    }
257
258    #[test]
259    fn test_empty_label_set() {
260        let mlb = MultiLabelBinarizer::new();
261        let y = vec![vec![0, 1], vec![]]; // second sample has no labels
262        let fitted = mlb.fit(&y, &()).unwrap();
263        let mat = fitted.transform(&y).unwrap();
264        assert_eq!(mat.shape(), &[2, 2]);
265        // Row 1 should be all zeros
266        assert_eq!(mat[[1, 0]], 0.0);
267        assert_eq!(mat[[1, 1]], 0.0);
268    }
269
270    #[test]
271    fn test_inverse_transform_empty_row() {
272        let mlb = MultiLabelBinarizer::new();
273        let y = vec![vec![0, 1], vec![]];
274        let fitted = mlb.fit(&y, &()).unwrap();
275        let mat = fitted.transform(&y).unwrap();
276        let recovered = fitted.inverse_transform(&mat).unwrap();
277        assert_eq!(recovered, y);
278    }
279
280    #[test]
281    fn test_non_contiguous_classes() {
282        let mlb = MultiLabelBinarizer::new();
283        let y = vec![vec![10, 30], vec![20]];
284        let fitted = mlb.fit(&y, &()).unwrap();
285        assert_eq!(fitted.classes(), &[10, 20, 30]);
286        let mat = fitted.transform(&y).unwrap();
287        assert_eq!(mat.shape(), &[2, 3]);
288        assert_eq!(mat[[0, 0]], 1.0); // 10
289        assert_eq!(mat[[0, 1]], 0.0); // 20
290        assert_eq!(mat[[0, 2]], 1.0); // 30
291    }
292
293    #[test]
294    fn test_inverse_transform_non_contiguous_roundtrip() {
295        let mlb = MultiLabelBinarizer::new();
296        let y = vec![vec![10, 30], vec![20]];
297        let fitted = mlb.fit(&y, &()).unwrap();
298        let mat = fitted.transform(&y).unwrap();
299        let recovered = fitted.inverse_transform(&mat).unwrap();
300        assert_eq!(recovered, y);
301    }
302
303    #[test]
304    fn test_duplicate_labels_in_input() {
305        let mlb = MultiLabelBinarizer::new();
306        let y = vec![vec![0, 0, 1]]; // duplicate 0
307        let fitted = mlb.fit(&y, &()).unwrap();
308        let mat = fitted.transform(&y).unwrap();
309        // Still produces [1, 1] — duplicates don't cause double-counting
310        assert_eq!(mat.shape(), &[1, 2]);
311        assert_eq!(mat[[0, 0]], 1.0);
312        assert_eq!(mat[[0, 1]], 1.0);
313    }
314
315    #[test]
316    fn test_inverse_threshold() {
317        let mlb = MultiLabelBinarizer::new();
318        let y = vec![vec![0, 1, 2]];
319        let fitted = mlb.fit(&y, &()).unwrap();
320        // Values below 0.5 → not included
321        let mat = array![[0.4, 0.6, 0.5]];
322        let recovered = fitted.inverse_transform(&mat).unwrap();
323        assert_eq!(recovered, vec![vec![1, 2]]); // 0.4 < 0.5 so label 0 excluded
324    }
325}