Skip to main content

ferrolearn_preprocess/
multi_label_binarizer.rs

1//! Multi-label binarizer.
2//!
3//! Transforms a list of label sets into a multi-hot binary indicator matrix.
4//! Each sample can belong to zero or more classes simultaneously.
5//!
6//! Translation target: scikit-learn 1.5.2 `class MultiLabelBinarizer`
7//! (`sklearn/preprocessing/_label.py:688`). Design:
8//! `.design/preprocess/multi_label_binarizer.md`. Tracking: #1229.
9//!
10//! `## REQ status`
11//!
12//! | REQ | Status | Anchor |
13//! |---|---|---|
14//! | REQ-1 fit → sorted-unique classes_ (usize) | SHIPPED | `MultiLabelBinarizer::fit`; sklearn `_label.py:779` |
15//! | REQ-2 transform → dense multi-hot (known labels) | SHIPPED | `FittedMultiLabelBinarizer::transform`; sklearn `_label.py:869-907` |
16//! | REQ-3 transform unknown-label: ignore, no error | SHIPPED (#1230) | `transform` skips unknown via `class_to_idx.get`; sklearn `_label.py:889-902` |
17//! | REQ-4 inverse_transform 0/1 validation | SHIPPED (#1231) | `inverse_transform` rejects non-0/1, selects `== 1.0`; sklearn `_label.py:941-947` |
18//! | REQ-5 `classes` ctor param | NOT-STARTED (#1232) | sklearn `_label.py:756`,`:780-785` |
19//! | REQ-6 sparse_output CSR | NOT-STARTED (#1233) | sklearn `_label.py:858-859`,`:905-907` |
20//! | REQ-7 arbitrary orderable+hashable labels + object dtype | NOT-STARTED (#1234) | sklearn `_label.py:788` (usize-only, R-DEV-3) |
21//! | REQ-8 optimized single-pass fit_transform | NOT-STARTED (#1235) | sklearn `_label.py:814-835` |
22//! | REQ-9 PyO3 binding | NOT-STARTED (#1236) | `ferrolearn-python/src/` (absent) |
23//! | REQ-1 edge: empty-`y` fit yields empty classes_ (no error) | SHIPPED (#2339) | `MultiLabelBinarizer::fit` (no empty-`y` rejection): empty `y` → `classes = []`; `transform([[]])` → `(1, 0)` `Array2`; sklearn `_label.py:779` |
24//!
25//! # Examples
26//!
27//! ```
28//! use ferrolearn_preprocess::multi_label_binarizer::MultiLabelBinarizer;
29//! use ferrolearn_core::traits::{Fit, Transform};
30//!
31//! let mlb = MultiLabelBinarizer::new();
32//! let y = vec![vec![0, 1], vec![1, 2], vec![0]];
33//! let fitted = mlb.fit(&y, &()).unwrap();
34//! let mat = fitted.transform(&y).unwrap();
35//! // 3 classes → (3, 3) multi-hot matrix
36//! assert_eq!(mat.shape(), &[3, 3]);
37//! assert_eq!(mat[[0, 0]], 1.0); // sample 0 has label 0
38//! assert_eq!(mat[[0, 1]], 1.0); // sample 0 has label 1
39//! assert_eq!(mat[[0, 2]], 0.0); // sample 0 does NOT have label 2
40//! ```
41
42use ferrolearn_core::error::FerroError;
43use ferrolearn_core::traits::{Fit, Transform};
44use ndarray::Array2;
45
46// ---------------------------------------------------------------------------
47// MultiLabelBinarizer (unfitted)
48// ---------------------------------------------------------------------------
49
50/// An unfitted multi-label binarizer.
51///
52/// Calling [`Fit::fit`] on a `&[Vec<usize>]` discovers the sorted set of all
53/// unique labels across all samples and returns a [`FittedMultiLabelBinarizer`].
54#[derive(Debug, Clone, Default)]
55pub struct MultiLabelBinarizer;
56
57impl MultiLabelBinarizer {
58    /// Create a new `MultiLabelBinarizer`.
59    #[must_use]
60    pub fn new() -> Self {
61        Self
62    }
63}
64
65// ---------------------------------------------------------------------------
66// FittedMultiLabelBinarizer
67// ---------------------------------------------------------------------------
68
69/// A fitted multi-label binarizer holding the discovered class set.
70///
71/// Created by calling [`Fit::fit`] on a [`MultiLabelBinarizer`].
72#[derive(Debug, Clone)]
73pub struct FittedMultiLabelBinarizer {
74    /// Sorted unique class labels observed during fitting.
75    classes: Vec<usize>,
76}
77
78impl FittedMultiLabelBinarizer {
79    /// Return the sorted class labels discovered during fitting.
80    #[must_use]
81    pub fn classes(&self) -> &[usize] {
82        &self.classes
83    }
84
85    /// Return the number of unique classes.
86    #[must_use]
87    pub fn n_classes(&self) -> usize {
88        self.classes.len()
89    }
90
91    /// Map a multi-hot indicator matrix back to label sets.
92    ///
93    /// The indicator matrix must contain only exact `0.0` and `1.0` values; a
94    /// class is included for a sample iff its cell is exactly `1.0`. This
95    /// mirrors scikit-learn 1.5.2 `MultiLabelBinarizer.inverse_transform`
96    /// (`sklearn/preprocessing/_label.py:941-947`), which validates the matrix
97    /// with `np.setdiff1d(yt, [0, 1])` and raises `ValueError` on any value
98    /// outside `{0, 1}` before selecting classes where the cell `== 1`.
99    ///
100    /// # Errors
101    ///
102    /// Returns [`FerroError::ShapeMismatch`] if the number of columns does
103    /// not match the number of classes. Returns [`FerroError::InvalidParameter`]
104    /// if any cell value is not exactly `0.0` or `1.0`.
105    #[allow(
106        clippy::float_cmp,
107        reason = "indicator matrix must be exactly 0/1 per sklearn _label.py:941-947"
108    )]
109    pub fn inverse_transform(&self, y: &Array2<f64>) -> Result<Vec<Vec<usize>>, FerroError> {
110        let k = self.classes.len();
111        if y.ncols() != k {
112            return Err(FerroError::ShapeMismatch {
113                expected: vec![y.nrows(), k],
114                actual: vec![y.nrows(), y.ncols()],
115                context: "FittedMultiLabelBinarizer::inverse_transform".into(),
116            });
117        }
118
119        // Validate the indicator contains only 0s and 1s, matching sklearn's
120        // `np.setdiff1d(yt, [0, 1])` check (_label.py:941-947).
121        if let Some(&v) = y.iter().find(|&&v| v != 0.0 && v != 1.0) {
122            return Err(FerroError::InvalidParameter {
123                name: "y".into(),
124                reason: format!("Expected only 0s and 1s in label indicator, got {v}"),
125            });
126        }
127
128        let n = y.nrows();
129        let mut result = Vec::with_capacity(n);
130
131        for i in 0..n {
132            let mut labels = Vec::new();
133            for (j, &cls) in self.classes.iter().enumerate() {
134                if y[[i, j]] == 1.0 {
135                    labels.push(cls);
136                }
137            }
138            result.push(labels);
139        }
140
141        Ok(result)
142    }
143}
144
145// ---------------------------------------------------------------------------
146// Trait implementations
147// ---------------------------------------------------------------------------
148
149impl Fit<Vec<Vec<usize>>, ()> for MultiLabelBinarizer {
150    type Fitted = FittedMultiLabelBinarizer;
151    type Error = FerroError;
152
153    /// Fit the binarizer by discovering all unique labels.
154    ///
155    /// An empty input (no samples) is accepted and yields an empty `classes_`,
156    /// mirroring sklearn `MultiLabelBinarizer.fit` where
157    /// `classes_ = sorted(set(itertools.chain.from_iterable(y)))` is the empty
158    /// set for `y == []` (`sklearn/preprocessing/_label.py:779`); sklearn raises
159    /// no error and a subsequent `transform([[]])` yields a `(1, 0)` matrix.
160    ///
161    /// # Errors
162    ///
163    /// This method does not return an error in the `usize` domain (kept as
164    /// `Result` for `Fit`-trait conformance and forward compatibility).
165    fn fit(
166        &self,
167        y: &Vec<Vec<usize>>,
168        _target: &(),
169    ) -> Result<FittedMultiLabelBinarizer, FerroError> {
170        let mut classes: Vec<usize> = y.iter().flatten().copied().collect();
171        classes.sort_unstable();
172        classes.dedup();
173
174        Ok(FittedMultiLabelBinarizer { classes })
175    }
176}
177
178impl Transform<Vec<Vec<usize>>> for FittedMultiLabelBinarizer {
179    type Output = Array2<f64>;
180    type Error = FerroError;
181
182    /// Transform label sets into a multi-hot indicator matrix.
183    ///
184    /// Each row has a `1.0` in every column corresponding to one of its labels
185    /// and `0.0` elsewhere.
186    ///
187    /// Labels not seen during fitting are silently ignored: the indicator is
188    /// built only from known labels (mirroring scikit-learn 1.5.2
189    /// `MultiLabelBinarizer._transform`, `sklearn/preprocessing/_label.py:889-902`).
190    /// scikit-learn additionally emits a `warnings.warn("unknown class(es) ...
191    /// will be ignored")`; that warning is intentionally not emitted here because
192    /// the crate has no logging facade and adding one would be out of scope.
193    ///
194    /// The [`Result`] return type is retained because the [`Transform`] trait
195    /// requires it; `transform` always returns [`Ok`].
196    fn transform(&self, y: &Vec<Vec<usize>>) -> Result<Array2<f64>, FerroError> {
197        let k = self.classes.len();
198        let n = y.len();
199
200        // Build lookup: class_value → column index
201        let class_to_idx: std::collections::HashMap<usize, usize> = self
202            .classes
203            .iter()
204            .enumerate()
205            .map(|(i, &c)| (c, i))
206            .collect();
207
208        let mut out = Array2::zeros((n, k));
209
210        for (i, labels) in y.iter().enumerate() {
211            for &label in labels {
212                // Unknown labels (not seen during fit) are silently ignored,
213                // matching scikit-learn's `_transform` (_label.py:889-902).
214                if let Some(&idx) = class_to_idx.get(&label) {
215                    out[[i, idx]] = 1.0;
216                }
217            }
218        }
219
220        Ok(out)
221    }
222}
223
224// ===========================================================================
225// Tests
226// ===========================================================================
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use ndarray::array;
232
233    #[test]
234    fn test_fit_discovers_sorted_classes() {
235        let mlb = MultiLabelBinarizer::new();
236        let y = vec![vec![2, 0], vec![1]];
237        let fitted = mlb.fit(&y, &()).unwrap();
238        assert_eq!(fitted.classes(), &[0, 1, 2]);
239    }
240
241    #[test]
242    fn test_fit_empty_input_yields_empty_classes() {
243        // sklearn 1.5.2: `MultiLabelBinarizer().fit([])` SUCCEEDS with
244        // `classes_ == []` (`sorted(set(chain.from_iterable([])))` is empty,
245        // `_label.py:779`), and `transform([[]])` is a `(1, 0)` matrix.
246        // Live oracle (from /tmp):
247        //   mlb = MultiLabelBinarizer().fit([]); mlb.classes_.tolist() -> []
248        //   mlb.transform([[]]).shape -> (1, 0)
249        let mlb = MultiLabelBinarizer::new();
250        let empty: Vec<Vec<usize>> = vec![];
251        let one_empty_sample = vec![vec![]];
252        let got = mlb.fit(&empty, &()).and_then(|fitted| {
253            fitted
254                .transform(&one_empty_sample)
255                .map(|m| m.shape().to_vec())
256        });
257        assert_eq!(got.ok(), Some(vec![1, 0]));
258    }
259
260    #[test]
261    fn test_transform_multi_hot() {
262        let mlb = MultiLabelBinarizer::new();
263        let y = vec![vec![0, 2], vec![1], vec![0, 1, 2]];
264        let fitted = mlb.fit(&y, &()).unwrap();
265        let mat = fitted.transform(&y).unwrap();
266        assert_eq!(mat.shape(), &[3, 3]);
267        // Row 0: labels {0, 2} → [1, 0, 1]
268        assert_eq!(mat[[0, 0]], 1.0);
269        assert_eq!(mat[[0, 1]], 0.0);
270        assert_eq!(mat[[0, 2]], 1.0);
271        // Row 1: labels {1} → [0, 1, 0]
272        assert_eq!(mat[[1, 0]], 0.0);
273        assert_eq!(mat[[1, 1]], 1.0);
274        assert_eq!(mat[[1, 2]], 0.0);
275        // Row 2: labels {0, 1, 2} → [1, 1, 1]
276        assert_eq!(mat[[2, 0]], 1.0);
277        assert_eq!(mat[[2, 1]], 1.0);
278        assert_eq!(mat[[2, 2]], 1.0);
279    }
280
281    #[test]
282    fn test_transform_unknown_label_ignored() {
283        // Live oracle (sklearn 1.5.2):
284        //   python3 -c "from sklearn.preprocessing import MultiLabelBinarizer; \
285        //     import warnings; warnings.simplefilter('ignore'); \
286        //     mlb=MultiLabelBinarizer().fit([[0,1]]); \
287        //     print(mlb.transform([[0,5]]).tolist())"
288        //   => [[1, 0]]
289        // Unknown labels are skipped, not errored (_label.py:889-902).
290        let mlb = MultiLabelBinarizer::new();
291        let y = vec![vec![0, 1]];
292        let fitted = mlb.fit(&y, &()).map_err(|e| format!("{e:?}"));
293        let y2 = vec![vec![0, 5]]; // 5 not in {0, 1} → ignored
294        // Transform must NOT error on the unknown label 5; it is skipped.
295        let got = fitted.and_then(|f| f.transform(&y2).map_err(|e| format!("{e:?}")));
296        assert_eq!(got, Ok(array![[1.0, 0.0]]));
297    }
298
299    #[test]
300    fn test_inverse_transform_roundtrip() {
301        let mlb = MultiLabelBinarizer::new();
302        let y = vec![vec![0, 2], vec![1], vec![0, 1, 2]];
303        let fitted = mlb.fit(&y, &()).unwrap();
304        let mat = fitted.transform(&y).unwrap();
305        let recovered = fitted.inverse_transform(&mat).unwrap();
306        assert_eq!(recovered, y);
307    }
308
309    #[test]
310    fn test_inverse_transform_shape_mismatch() {
311        let mlb = MultiLabelBinarizer::new();
312        let y = vec![vec![0, 1, 2]];
313        let fitted = mlb.fit(&y, &()).unwrap();
314        // 3 classes expects 3 columns
315        let bad = Array2::<f64>::zeros((2, 2));
316        assert!(fitted.inverse_transform(&bad).is_err());
317    }
318
319    #[test]
320    fn test_empty_label_set() {
321        let mlb = MultiLabelBinarizer::new();
322        let y = vec![vec![0, 1], vec![]]; // second sample has no labels
323        let fitted = mlb.fit(&y, &()).unwrap();
324        let mat = fitted.transform(&y).unwrap();
325        assert_eq!(mat.shape(), &[2, 2]);
326        // Row 1 should be all zeros
327        assert_eq!(mat[[1, 0]], 0.0);
328        assert_eq!(mat[[1, 1]], 0.0);
329    }
330
331    #[test]
332    fn test_inverse_transform_empty_row() {
333        let mlb = MultiLabelBinarizer::new();
334        let y = vec![vec![0, 1], vec![]];
335        let fitted = mlb.fit(&y, &()).unwrap();
336        let mat = fitted.transform(&y).unwrap();
337        let recovered = fitted.inverse_transform(&mat).unwrap();
338        assert_eq!(recovered, y);
339    }
340
341    #[test]
342    fn test_non_contiguous_classes() {
343        let mlb = MultiLabelBinarizer::new();
344        let y = vec![vec![10, 30], vec![20]];
345        let fitted = mlb.fit(&y, &()).unwrap();
346        assert_eq!(fitted.classes(), &[10, 20, 30]);
347        let mat = fitted.transform(&y).unwrap();
348        assert_eq!(mat.shape(), &[2, 3]);
349        assert_eq!(mat[[0, 0]], 1.0); // 10
350        assert_eq!(mat[[0, 1]], 0.0); // 20
351        assert_eq!(mat[[0, 2]], 1.0); // 30
352    }
353
354    #[test]
355    fn test_inverse_transform_non_contiguous_roundtrip() {
356        let mlb = MultiLabelBinarizer::new();
357        let y = vec![vec![10, 30], vec![20]];
358        let fitted = mlb.fit(&y, &()).unwrap();
359        let mat = fitted.transform(&y).unwrap();
360        let recovered = fitted.inverse_transform(&mat).unwrap();
361        assert_eq!(recovered, y);
362    }
363
364    #[test]
365    fn test_duplicate_labels_in_input() {
366        let mlb = MultiLabelBinarizer::new();
367        let y = vec![vec![0, 0, 1]]; // duplicate 0
368        let fitted = mlb.fit(&y, &()).unwrap();
369        let mat = fitted.transform(&y).unwrap();
370        // Still produces [1, 1] — duplicates don't cause double-counting
371        assert_eq!(mat.shape(), &[1, 2]);
372        assert_eq!(mat[[0, 0]], 1.0);
373        assert_eq!(mat[[0, 1]], 1.0);
374    }
375
376    #[test]
377    fn test_inverse_rejects_non_01() {
378        // sklearn 1.5.2 validates the indicator with `np.setdiff1d(yt, [0, 1])`
379        // and raises `ValueError` on any value outside {0, 1}
380        // (sklearn/preprocessing/_label.py:941-947). It does NOT threshold.
381        //
382        // Live oracle (sklearn 1.5.2), valid 0/1 round-trip (R-CHAR-3):
383        //   python3 -c "import numpy as np; \
384        //     from sklearn.preprocessing import MultiLabelBinarizer; \
385        //     mlb=MultiLabelBinarizer().fit([[0,1,2]]); \
386        //     print(mlb.inverse_transform(np.array([[1,0,1]])))"
387        //   => [(0, 2)]  ==  vec![vec![0, 2]]
388        let mlb = MultiLabelBinarizer::new();
389        let y = vec![vec![0, 1, 2]];
390        let fitted = mlb.fit(&y, &()).map_err(|e| format!("{e:?}"));
391
392        // Non-0/1 values are rejected (sklearn raises ValueError).
393        let bad = array![[0.4, 0.6, 0.5]];
394        let rejected = fitted
395            .as_ref()
396            .map_err(|e| e.clone())
397            .map(|f| f.inverse_transform(&bad).is_err());
398        assert_eq!(rejected, Ok(true));
399
400        // A valid 0/1 indicator round-trips to the live-oracle result.
401        let good = array![[1.0, 0.0, 1.0]];
402        let recovered =
403            fitted.and_then(|f| f.inverse_transform(&good).map_err(|e| format!("{e:?}")));
404        assert_eq!(recovered, Ok(vec![vec![0, 2]]));
405    }
406}