ferrolearn_preprocess/multi_label_binarizer.rs
1//! Multi-label binarizer.
2//!
3//! Transforms a list of label sets into a multi-hot binary indicator matrix.
4//! Each sample can belong to zero or more classes simultaneously.
5//!
6//! Translation target: scikit-learn 1.5.2 `class MultiLabelBinarizer`
7//! (`sklearn/preprocessing/_label.py:688`). Design:
8//! `.design/preprocess/multi_label_binarizer.md`. Tracking: #1229.
9//!
10//! `## REQ status`
11//!
12//! | REQ | Status | Anchor |
13//! |---|---|---|
14//! | REQ-1 fit → sorted-unique classes_ (usize) | SHIPPED | `MultiLabelBinarizer::fit`; sklearn `_label.py:779` |
15//! | REQ-2 transform → dense multi-hot (known labels) | SHIPPED | `FittedMultiLabelBinarizer::transform`; sklearn `_label.py:869-907` |
16//! | REQ-3 transform unknown-label: ignore, no error | SHIPPED (#1230) | `transform` skips unknown via `class_to_idx.get`; sklearn `_label.py:889-902` |
17//! | REQ-4 inverse_transform 0/1 validation | SHIPPED (#1231) | `inverse_transform` rejects non-0/1, selects `== 1.0`; sklearn `_label.py:941-947` |
18//! | REQ-5 `classes` ctor param | NOT-STARTED (#1232) | sklearn `_label.py:756`,`:780-785` |
19//! | REQ-6 sparse_output CSR | NOT-STARTED (#1233) | sklearn `_label.py:858-859`,`:905-907` |
20//! | REQ-7 arbitrary orderable+hashable labels + object dtype | NOT-STARTED (#1234) | sklearn `_label.py:788` (usize-only, R-DEV-3) |
21//! | REQ-8 optimized single-pass fit_transform | NOT-STARTED (#1235) | sklearn `_label.py:814-835` |
22//! | REQ-9 PyO3 binding | NOT-STARTED (#1236) | `ferrolearn-python/src/` (absent) |
23//! | REQ-1 edge: empty-`y` fit yields empty classes_ (no error) | SHIPPED (#2339) | `MultiLabelBinarizer::fit` (no empty-`y` rejection): empty `y` → `classes = []`; `transform([[]])` → `(1, 0)` `Array2`; sklearn `_label.py:779` |
24//!
25//! # Examples
26//!
27//! ```
28//! use ferrolearn_preprocess::multi_label_binarizer::MultiLabelBinarizer;
29//! use ferrolearn_core::traits::{Fit, Transform};
30//!
31//! let mlb = MultiLabelBinarizer::new();
32//! let y = vec![vec![0, 1], vec![1, 2], vec![0]];
33//! let fitted = mlb.fit(&y, &()).unwrap();
34//! let mat = fitted.transform(&y).unwrap();
35//! // 3 classes → (3, 3) multi-hot matrix
36//! assert_eq!(mat.shape(), &[3, 3]);
37//! assert_eq!(mat[[0, 0]], 1.0); // sample 0 has label 0
38//! assert_eq!(mat[[0, 1]], 1.0); // sample 0 has label 1
39//! assert_eq!(mat[[0, 2]], 0.0); // sample 0 does NOT have label 2
40//! ```
41
42use ferrolearn_core::error::FerroError;
43use ferrolearn_core::traits::{Fit, Transform};
44use ndarray::Array2;
45
46// ---------------------------------------------------------------------------
47// MultiLabelBinarizer (unfitted)
48// ---------------------------------------------------------------------------
49
50/// An unfitted multi-label binarizer.
51///
52/// Calling [`Fit::fit`] on a `&[Vec<usize>]` discovers the sorted set of all
53/// unique labels across all samples and returns a [`FittedMultiLabelBinarizer`].
54#[derive(Debug, Clone, Default)]
55pub struct MultiLabelBinarizer;
56
57impl MultiLabelBinarizer {
58 /// Create a new `MultiLabelBinarizer`.
59 #[must_use]
60 pub fn new() -> Self {
61 Self
62 }
63}
64
65// ---------------------------------------------------------------------------
66// FittedMultiLabelBinarizer
67// ---------------------------------------------------------------------------
68
69/// A fitted multi-label binarizer holding the discovered class set.
70///
71/// Created by calling [`Fit::fit`] on a [`MultiLabelBinarizer`].
72#[derive(Debug, Clone)]
73pub struct FittedMultiLabelBinarizer {
74 /// Sorted unique class labels observed during fitting.
75 classes: Vec<usize>,
76}
77
78impl FittedMultiLabelBinarizer {
79 /// Return the sorted class labels discovered during fitting.
80 #[must_use]
81 pub fn classes(&self) -> &[usize] {
82 &self.classes
83 }
84
85 /// Return the number of unique classes.
86 #[must_use]
87 pub fn n_classes(&self) -> usize {
88 self.classes.len()
89 }
90
91 /// Map a multi-hot indicator matrix back to label sets.
92 ///
93 /// The indicator matrix must contain only exact `0.0` and `1.0` values; a
94 /// class is included for a sample iff its cell is exactly `1.0`. This
95 /// mirrors scikit-learn 1.5.2 `MultiLabelBinarizer.inverse_transform`
96 /// (`sklearn/preprocessing/_label.py:941-947`), which validates the matrix
97 /// with `np.setdiff1d(yt, [0, 1])` and raises `ValueError` on any value
98 /// outside `{0, 1}` before selecting classes where the cell `== 1`.
99 ///
100 /// # Errors
101 ///
102 /// Returns [`FerroError::ShapeMismatch`] if the number of columns does
103 /// not match the number of classes. Returns [`FerroError::InvalidParameter`]
104 /// if any cell value is not exactly `0.0` or `1.0`.
105 #[allow(
106 clippy::float_cmp,
107 reason = "indicator matrix must be exactly 0/1 per sklearn _label.py:941-947"
108 )]
109 pub fn inverse_transform(&self, y: &Array2<f64>) -> Result<Vec<Vec<usize>>, FerroError> {
110 let k = self.classes.len();
111 if y.ncols() != k {
112 return Err(FerroError::ShapeMismatch {
113 expected: vec![y.nrows(), k],
114 actual: vec![y.nrows(), y.ncols()],
115 context: "FittedMultiLabelBinarizer::inverse_transform".into(),
116 });
117 }
118
119 // Validate the indicator contains only 0s and 1s, matching sklearn's
120 // `np.setdiff1d(yt, [0, 1])` check (_label.py:941-947).
121 if let Some(&v) = y.iter().find(|&&v| v != 0.0 && v != 1.0) {
122 return Err(FerroError::InvalidParameter {
123 name: "y".into(),
124 reason: format!("Expected only 0s and 1s in label indicator, got {v}"),
125 });
126 }
127
128 let n = y.nrows();
129 let mut result = Vec::with_capacity(n);
130
131 for i in 0..n {
132 let mut labels = Vec::new();
133 for (j, &cls) in self.classes.iter().enumerate() {
134 if y[[i, j]] == 1.0 {
135 labels.push(cls);
136 }
137 }
138 result.push(labels);
139 }
140
141 Ok(result)
142 }
143}
144
145// ---------------------------------------------------------------------------
146// Trait implementations
147// ---------------------------------------------------------------------------
148
149impl Fit<Vec<Vec<usize>>, ()> for MultiLabelBinarizer {
150 type Fitted = FittedMultiLabelBinarizer;
151 type Error = FerroError;
152
153 /// Fit the binarizer by discovering all unique labels.
154 ///
155 /// An empty input (no samples) is accepted and yields an empty `classes_`,
156 /// mirroring sklearn `MultiLabelBinarizer.fit` where
157 /// `classes_ = sorted(set(itertools.chain.from_iterable(y)))` is the empty
158 /// set for `y == []` (`sklearn/preprocessing/_label.py:779`); sklearn raises
159 /// no error and a subsequent `transform([[]])` yields a `(1, 0)` matrix.
160 ///
161 /// # Errors
162 ///
163 /// This method does not return an error in the `usize` domain (kept as
164 /// `Result` for `Fit`-trait conformance and forward compatibility).
165 fn fit(
166 &self,
167 y: &Vec<Vec<usize>>,
168 _target: &(),
169 ) -> Result<FittedMultiLabelBinarizer, FerroError> {
170 let mut classes: Vec<usize> = y.iter().flatten().copied().collect();
171 classes.sort_unstable();
172 classes.dedup();
173
174 Ok(FittedMultiLabelBinarizer { classes })
175 }
176}
177
178impl Transform<Vec<Vec<usize>>> for FittedMultiLabelBinarizer {
179 type Output = Array2<f64>;
180 type Error = FerroError;
181
182 /// Transform label sets into a multi-hot indicator matrix.
183 ///
184 /// Each row has a `1.0` in every column corresponding to one of its labels
185 /// and `0.0` elsewhere.
186 ///
187 /// Labels not seen during fitting are silently ignored: the indicator is
188 /// built only from known labels (mirroring scikit-learn 1.5.2
189 /// `MultiLabelBinarizer._transform`, `sklearn/preprocessing/_label.py:889-902`).
190 /// scikit-learn additionally emits a `warnings.warn("unknown class(es) ...
191 /// will be ignored")`; that warning is intentionally not emitted here because
192 /// the crate has no logging facade and adding one would be out of scope.
193 ///
194 /// The [`Result`] return type is retained because the [`Transform`] trait
195 /// requires it; `transform` always returns [`Ok`].
196 fn transform(&self, y: &Vec<Vec<usize>>) -> Result<Array2<f64>, FerroError> {
197 let k = self.classes.len();
198 let n = y.len();
199
200 // Build lookup: class_value → column index
201 let class_to_idx: std::collections::HashMap<usize, usize> = self
202 .classes
203 .iter()
204 .enumerate()
205 .map(|(i, &c)| (c, i))
206 .collect();
207
208 let mut out = Array2::zeros((n, k));
209
210 for (i, labels) in y.iter().enumerate() {
211 for &label in labels {
212 // Unknown labels (not seen during fit) are silently ignored,
213 // matching scikit-learn's `_transform` (_label.py:889-902).
214 if let Some(&idx) = class_to_idx.get(&label) {
215 out[[i, idx]] = 1.0;
216 }
217 }
218 }
219
220 Ok(out)
221 }
222}
223
224// ===========================================================================
225// Tests
226// ===========================================================================
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231 use ndarray::array;
232
233 #[test]
234 fn test_fit_discovers_sorted_classes() {
235 let mlb = MultiLabelBinarizer::new();
236 let y = vec![vec![2, 0], vec![1]];
237 let fitted = mlb.fit(&y, &()).unwrap();
238 assert_eq!(fitted.classes(), &[0, 1, 2]);
239 }
240
241 #[test]
242 fn test_fit_empty_input_yields_empty_classes() {
243 // sklearn 1.5.2: `MultiLabelBinarizer().fit([])` SUCCEEDS with
244 // `classes_ == []` (`sorted(set(chain.from_iterable([])))` is empty,
245 // `_label.py:779`), and `transform([[]])` is a `(1, 0)` matrix.
246 // Live oracle (from /tmp):
247 // mlb = MultiLabelBinarizer().fit([]); mlb.classes_.tolist() -> []
248 // mlb.transform([[]]).shape -> (1, 0)
249 let mlb = MultiLabelBinarizer::new();
250 let empty: Vec<Vec<usize>> = vec![];
251 let one_empty_sample = vec![vec![]];
252 let got = mlb.fit(&empty, &()).and_then(|fitted| {
253 fitted
254 .transform(&one_empty_sample)
255 .map(|m| m.shape().to_vec())
256 });
257 assert_eq!(got.ok(), Some(vec![1, 0]));
258 }
259
260 #[test]
261 fn test_transform_multi_hot() {
262 let mlb = MultiLabelBinarizer::new();
263 let y = vec![vec![0, 2], vec![1], vec![0, 1, 2]];
264 let fitted = mlb.fit(&y, &()).unwrap();
265 let mat = fitted.transform(&y).unwrap();
266 assert_eq!(mat.shape(), &[3, 3]);
267 // Row 0: labels {0, 2} → [1, 0, 1]
268 assert_eq!(mat[[0, 0]], 1.0);
269 assert_eq!(mat[[0, 1]], 0.0);
270 assert_eq!(mat[[0, 2]], 1.0);
271 // Row 1: labels {1} → [0, 1, 0]
272 assert_eq!(mat[[1, 0]], 0.0);
273 assert_eq!(mat[[1, 1]], 1.0);
274 assert_eq!(mat[[1, 2]], 0.0);
275 // Row 2: labels {0, 1, 2} → [1, 1, 1]
276 assert_eq!(mat[[2, 0]], 1.0);
277 assert_eq!(mat[[2, 1]], 1.0);
278 assert_eq!(mat[[2, 2]], 1.0);
279 }
280
281 #[test]
282 fn test_transform_unknown_label_ignored() {
283 // Live oracle (sklearn 1.5.2):
284 // python3 -c "from sklearn.preprocessing import MultiLabelBinarizer; \
285 // import warnings; warnings.simplefilter('ignore'); \
286 // mlb=MultiLabelBinarizer().fit([[0,1]]); \
287 // print(mlb.transform([[0,5]]).tolist())"
288 // => [[1, 0]]
289 // Unknown labels are skipped, not errored (_label.py:889-902).
290 let mlb = MultiLabelBinarizer::new();
291 let y = vec![vec![0, 1]];
292 let fitted = mlb.fit(&y, &()).map_err(|e| format!("{e:?}"));
293 let y2 = vec![vec![0, 5]]; // 5 not in {0, 1} → ignored
294 // Transform must NOT error on the unknown label 5; it is skipped.
295 let got = fitted.and_then(|f| f.transform(&y2).map_err(|e| format!("{e:?}")));
296 assert_eq!(got, Ok(array![[1.0, 0.0]]));
297 }
298
299 #[test]
300 fn test_inverse_transform_roundtrip() {
301 let mlb = MultiLabelBinarizer::new();
302 let y = vec![vec![0, 2], vec![1], vec![0, 1, 2]];
303 let fitted = mlb.fit(&y, &()).unwrap();
304 let mat = fitted.transform(&y).unwrap();
305 let recovered = fitted.inverse_transform(&mat).unwrap();
306 assert_eq!(recovered, y);
307 }
308
309 #[test]
310 fn test_inverse_transform_shape_mismatch() {
311 let mlb = MultiLabelBinarizer::new();
312 let y = vec![vec![0, 1, 2]];
313 let fitted = mlb.fit(&y, &()).unwrap();
314 // 3 classes expects 3 columns
315 let bad = Array2::<f64>::zeros((2, 2));
316 assert!(fitted.inverse_transform(&bad).is_err());
317 }
318
319 #[test]
320 fn test_empty_label_set() {
321 let mlb = MultiLabelBinarizer::new();
322 let y = vec![vec![0, 1], vec![]]; // second sample has no labels
323 let fitted = mlb.fit(&y, &()).unwrap();
324 let mat = fitted.transform(&y).unwrap();
325 assert_eq!(mat.shape(), &[2, 2]);
326 // Row 1 should be all zeros
327 assert_eq!(mat[[1, 0]], 0.0);
328 assert_eq!(mat[[1, 1]], 0.0);
329 }
330
331 #[test]
332 fn test_inverse_transform_empty_row() {
333 let mlb = MultiLabelBinarizer::new();
334 let y = vec![vec![0, 1], vec![]];
335 let fitted = mlb.fit(&y, &()).unwrap();
336 let mat = fitted.transform(&y).unwrap();
337 let recovered = fitted.inverse_transform(&mat).unwrap();
338 assert_eq!(recovered, y);
339 }
340
341 #[test]
342 fn test_non_contiguous_classes() {
343 let mlb = MultiLabelBinarizer::new();
344 let y = vec![vec![10, 30], vec![20]];
345 let fitted = mlb.fit(&y, &()).unwrap();
346 assert_eq!(fitted.classes(), &[10, 20, 30]);
347 let mat = fitted.transform(&y).unwrap();
348 assert_eq!(mat.shape(), &[2, 3]);
349 assert_eq!(mat[[0, 0]], 1.0); // 10
350 assert_eq!(mat[[0, 1]], 0.0); // 20
351 assert_eq!(mat[[0, 2]], 1.0); // 30
352 }
353
354 #[test]
355 fn test_inverse_transform_non_contiguous_roundtrip() {
356 let mlb = MultiLabelBinarizer::new();
357 let y = vec![vec![10, 30], vec![20]];
358 let fitted = mlb.fit(&y, &()).unwrap();
359 let mat = fitted.transform(&y).unwrap();
360 let recovered = fitted.inverse_transform(&mat).unwrap();
361 assert_eq!(recovered, y);
362 }
363
364 #[test]
365 fn test_duplicate_labels_in_input() {
366 let mlb = MultiLabelBinarizer::new();
367 let y = vec![vec![0, 0, 1]]; // duplicate 0
368 let fitted = mlb.fit(&y, &()).unwrap();
369 let mat = fitted.transform(&y).unwrap();
370 // Still produces [1, 1] — duplicates don't cause double-counting
371 assert_eq!(mat.shape(), &[1, 2]);
372 assert_eq!(mat[[0, 0]], 1.0);
373 assert_eq!(mat[[0, 1]], 1.0);
374 }
375
376 #[test]
377 fn test_inverse_rejects_non_01() {
378 // sklearn 1.5.2 validates the indicator with `np.setdiff1d(yt, [0, 1])`
379 // and raises `ValueError` on any value outside {0, 1}
380 // (sklearn/preprocessing/_label.py:941-947). It does NOT threshold.
381 //
382 // Live oracle (sklearn 1.5.2), valid 0/1 round-trip (R-CHAR-3):
383 // python3 -c "import numpy as np; \
384 // from sklearn.preprocessing import MultiLabelBinarizer; \
385 // mlb=MultiLabelBinarizer().fit([[0,1,2]]); \
386 // print(mlb.inverse_transform(np.array([[1,0,1]])))"
387 // => [(0, 2)] == vec![vec![0, 2]]
388 let mlb = MultiLabelBinarizer::new();
389 let y = vec![vec![0, 1, 2]];
390 let fitted = mlb.fit(&y, &()).map_err(|e| format!("{e:?}"));
391
392 // Non-0/1 values are rejected (sklearn raises ValueError).
393 let bad = array![[0.4, 0.6, 0.5]];
394 let rejected = fitted
395 .as_ref()
396 .map_err(|e| e.clone())
397 .map(|f| f.inverse_transform(&bad).is_err());
398 assert_eq!(rejected, Ok(true));
399
400 // A valid 0/1 indicator round-trips to the live-oracle result.
401 let good = array![[1.0, 0.0, 1.0]];
402 let recovered =
403 fitted.and_then(|f| f.inverse_transform(&good).map_err(|e| format!("{e:?}")));
404 assert_eq!(recovered, Ok(vec![vec![0, 2]]));
405 }
406}