ferrolearn_preprocess/label_binarizer.rs
1//! One-vs-rest label binarizer.
2//!
3//! Transforms a vector of integer class labels into a binary indicator matrix.
4//! For *K* classes the output has *K* columns (one-hot rows), except in the
5//! binary case (*K* = 2) and single-class case (*K* = 1) where a single column
6//! is produced.
7//!
8//! Translation target: scikit-learn 1.5.2 `class LabelBinarizer`
9//! (`sklearn/preprocessing/_label.py:180`) + `label_binarize` (`:430`). Design:
10//! `.design/preprocess/label_binarizer.md`. Tracking: #1238.
11//!
12//! `## REQ status`
13//!
14//! | REQ | Status | Anchor |
15//! |---|---|---|
16//! | REQ-1 fit → sorted-unique classes_ (usize) | SHIPPED | `LabelBinarizer::fit`; sklearn `_label.py:306` |
17//! | REQ-2 transform multiclass (k≥3) one-hot values | SHIPPED | `FittedLabelBinarizer::transform` else-branch; sklearn `_label.py:552-577` |
18//! | REQ-3 transform binary (k=2) single col, pos_label on 2nd class | SHIPPED | `transform` k==2 branch; sklearn `_label.py:531`,`:592-596` |
19//! | REQ-4 transform unknown-label: ignore (all-neg_label row) | SHIPPED (#1239) | `transform` `if let Some(&idx) = class_to_idx.get`; sklearn `_label.py:556-559` |
20//! | REQ-5 transform single-class (k=1) → all-neg_label column | SHIPPED (#1240) | `transform` k==1 arm `Array2::from_elem`; sklearn `_label.py:532-538` |
21//! | REQ-6 inverse_transform binary STRICT threshold (`> (pos+neg)/2`) | SHIPPED (#1241) | `inverse_transform` k==2 branch; sklearn `_label.py:667` |
22//! | REQ-6b inverse_transform binary accepts 1-col AND 2-col indicator (dispatch on fitted type, not col count) | SHIPPED (#2340) | `inverse_transform` k==2 branch accepts `ncols ∈ {1,2}`, decodes `classes[last_col > threshold ? 1 : 0]` (2-col → `classes[y[:,1]]`, 1-col → `classes[y.ravel()]`); sklearn `_label.py:402-407`,`:647`,`:670-679` |
23//! | REQ-7 inverse_transform multiclass argmax | SHIPPED | `inverse_transform` else-branch; sklearn `_label.py:641` |
24//! | REQ-8 neg_label/pos_label ctor params + validation | SHIPPED (#1242) | `LabelBinarizer::with_neg_label`/`with_pos_label` + `Fit::fit` validation; `transform` neg/pos base+active; `inverse_transform` `(pos+neg)/2` threshold; consumer crate re-export `lib.rs`; sklearn `_label.py:263`,`:283-287`,`:579-583`,`:667` |
25//! | REQ-9 sparse_output CSR + constraint | NOT-STARTED (#1243) | sklearn `_label.py:563`,`:584-585`,`:289-294` |
26//! | REQ-10 `label_binarize` free function | SHIPPED (#1244) | `pub fn label_binarize` (this file): `neg<pos` validation (verbatim msg, sklearn `_label.py:499-504`); GIVEN-order columns (sklearn's "preserve label ordering" reorder `:587-590`, so `label_binarize([0,2,1],classes=[2,0,1])` → `[[0,1,0],[1,0,0],[0,0,1]]`); k==1 all-neg col (`:532-538`); single-col collapse gated on `type_of_target(y)=="binary"` AND `len(classes)==2` (NOT `len(classes)` alone — `:519`,`:531`,`:592-596`; "binary" = ≤2 distinct values for 1D int y, verified live, #2233), giving pos where `y==classes[1]` (the kept `Y[:,-1]` after reorder, `:596`); k==2 with multiclass y (3+ distinct) emits 2 cols, no collapse; k>2 one-hot in given order (`:552-577`); unseen label → all-neg row (`:556-559`). Consumer: crate re-export `lib.rs` (`pub use label_binarizer::label_binarize`). Live-oracle parity: `tests/divergence_label_binarizer.rs` (basic/neg-pos/binary/`[2,0,1]`-ordering/unseen/neg≥pos-err/==estimator). |
27//! | REQ-11 arbitrary label types + type_of_target/multilabel input | NOT-STARTED (#1245) | sklearn `_label.py:296`,`:543-550` (usize-only, R-DEV-3) |
28//! | REQ-12 PyO3 binding | NOT-STARTED (#1246) | `ferrolearn-python/src/` (absent) |
29//!
30//! # Examples
31//!
32//! ```
33//! use ferrolearn_preprocess::label_binarizer::LabelBinarizer;
34//! use ferrolearn_core::traits::{Fit, Transform};
35//! use ndarray::array;
36//!
37//! let lb = LabelBinarizer::new();
38//! let y = array![0_usize, 1, 2, 1];
39//! let fitted = lb.fit(&y, &()).unwrap();
40//! let mat = fitted.transform(&y).unwrap();
41//! // 3 classes → (4, 3) indicator matrix
42//! assert_eq!(mat.shape(), &[4, 3]);
43//! assert_eq!(mat[[0, 0]], 1.0);
44//! assert_eq!(mat[[0, 1]], 0.0);
45//! ```
46
47use ferrolearn_core::error::FerroError;
48use ferrolearn_core::traits::{Fit, Transform};
49use ndarray::{Array1, Array2};
50
51// ---------------------------------------------------------------------------
52// LabelBinarizer (unfitted)
53// ---------------------------------------------------------------------------
54
55/// An unfitted one-vs-rest label binarizer.
56///
57/// Calling [`Fit::fit`] on an `Array1<usize>` discovers the sorted set of
58/// unique class labels and returns a [`FittedLabelBinarizer`].
59///
60/// `neg_label` / `pos_label` are the integer values written into the output
61/// indicator matrix for absent / present classes, mirroring sklearn's
62/// `LabelBinarizer(neg_label=0, pos_label=1)` (`sklearn/preprocessing/_label.py:263`).
63/// The defaults `0` / `1` reproduce the canonical 0/1 indicator behavior.
64#[derive(Debug, Clone)]
65pub struct LabelBinarizer {
66 /// Value written for absent classes (sklearn `neg_label`, default `0`).
67 neg_label: i64,
68 /// Value written for the present class (sklearn `pos_label`, default `1`).
69 pos_label: i64,
70}
71
72impl Default for LabelBinarizer {
73 fn default() -> Self {
74 Self::new()
75 }
76}
77
78impl LabelBinarizer {
79 /// Create a new `LabelBinarizer` with the default `neg_label=0`,
80 /// `pos_label=1` (the canonical 0/1 indicator encoding).
81 #[must_use]
82 pub fn new() -> Self {
83 Self {
84 neg_label: 0,
85 pos_label: 1,
86 }
87 }
88
89 /// Set the `neg_label` (value used for absent classes).
90 ///
91 /// Mirrors sklearn's `LabelBinarizer(neg_label=...)`
92 /// (`sklearn/preprocessing/_label.py:263`). Must be strictly less than
93 /// `pos_label`; validated at [`Fit::fit`] time (`_label.py:283-287`).
94 #[must_use]
95 pub fn with_neg_label(mut self, neg_label: i64) -> Self {
96 self.neg_label = neg_label;
97 self
98 }
99
100 /// Set the `pos_label` (value used for the present class).
101 ///
102 /// Mirrors sklearn's `LabelBinarizer(pos_label=...)`
103 /// (`sklearn/preprocessing/_label.py:263`). Must be strictly greater than
104 /// `neg_label`; validated at [`Fit::fit`] time (`_label.py:283-287`).
105 #[must_use]
106 pub fn with_pos_label(mut self, pos_label: i64) -> Self {
107 self.pos_label = pos_label;
108 self
109 }
110
111 /// Return the configured `neg_label`.
112 #[must_use]
113 pub fn neg_label(&self) -> i64 {
114 self.neg_label
115 }
116
117 /// Return the configured `pos_label`.
118 #[must_use]
119 pub fn pos_label(&self) -> i64 {
120 self.pos_label
121 }
122}
123
124// ---------------------------------------------------------------------------
125// FittedLabelBinarizer
126// ---------------------------------------------------------------------------
127
128/// A fitted label binarizer holding the discovered class set.
129///
130/// Created by calling [`Fit::fit`] on a [`LabelBinarizer`].
131#[derive(Debug, Clone)]
132pub struct FittedLabelBinarizer {
133 /// Sorted unique class labels observed during fitting.
134 classes: Vec<usize>,
135 /// Value written for absent classes (sklearn `neg_label`, default `0`).
136 neg_label: i64,
137 /// Value written for the present class (sklearn `pos_label`, default `1`).
138 pos_label: i64,
139}
140
141impl FittedLabelBinarizer {
142 /// Return the sorted class labels discovered during fitting.
143 #[must_use]
144 pub fn classes(&self) -> &[usize] {
145 &self.classes
146 }
147
148 /// Return the number of unique classes.
149 #[must_use]
150 pub fn n_classes(&self) -> usize {
151 self.classes.len()
152 }
153
154 /// Return the configured `neg_label` (value used for absent classes).
155 #[must_use]
156 pub fn neg_label(&self) -> i64 {
157 self.neg_label
158 }
159
160 /// Return the configured `pos_label` (value used for the present class).
161 #[must_use]
162 pub fn pos_label(&self) -> i64 {
163 self.pos_label
164 }
165
166 /// Map a binary indicator matrix back to integer class labels.
167 ///
168 /// Dispatch follows sklearn's `LabelBinarizer.inverse_transform`, which
169 /// branches on the FITTED `y_type_` ("binary" vs "multiclass"), NOT on the
170 /// column count of `Y` (`sklearn/preprocessing/_label.py:402-407`). Here the
171 /// fitted type is "binary" iff exactly two classes were discovered
172 /// (`k == 2`):
173 ///
174 /// - **Multiclass** (`k != 2`): the class with the largest value (argmax)
175 /// per row, mirroring `_inverse_binarize_multiclass`
176 /// (`classes.take(Y.argmax(axis=1))`, `_label.py:641`). Requires exactly
177 /// *K* columns.
178 /// - **Binary** (`k == 2`): `_inverse_binarize_thresholding` thresholds the
179 /// indicator with a STRICT `y > threshold`
180 /// (`threshold = (pos_label + neg_label) / 2`, `_label.py:399-400`,`:667`)
181 /// then decodes (`_label.py:670-679`):
182 /// - a **1-column** indicator → `classes[col0 > threshold ? 1 : 0]`
183 /// (`classes[y.ravel()]`, `:679`);
184 /// - a **2-column** indicator → `classes[col1 > threshold ? 1 : 0]`
185 /// (`classes[y[:, 1]]`, `:673-674`): the SECOND column (after
186 /// thresholding) selects the positive class; the first column is
187 /// ignored. So `fit([10,20]).inverse_transform([[1,0],[0,1]])` →
188 /// `[10, 20]` (verified vs the live sklearn 1.5.2 oracle).
189 ///
190 /// # Errors
191 ///
192 /// Returns [`FerroError::ShapeMismatch`] if the number of columns does not
193 /// match an accepted width: a binary (`k == 2`) fitted binarizer accepts
194 /// BOTH 1 and 2 columns (`_label.py:647`,`:670-679`); otherwise exactly *K*
195 /// columns are required.
196 pub fn inverse_transform(&self, y: &Array2<f64>) -> Result<Array1<usize>, FerroError> {
197 let k = self.classes.len();
198 let n = y.nrows();
199 let mut result = Array1::zeros(n);
200
201 if k == 2 {
202 // Binary fitted type: sklearn dispatches to
203 // `_inverse_binarize_thresholding` on the fitted `y_type_ == "binary"`,
204 // which accepts EITHER a 1-column or a 2-column indicator and rejects
205 // wider ones (`_label.py:647`: `y.shape[1] > 2` raises; `:670-679`:
206 // 2-col → `classes[y[:, 1]]`, else `classes[y.ravel()]`).
207 let ncols = y.ncols();
208 if ncols != 1 && ncols != 2 {
209 return Err(FerroError::ShapeMismatch {
210 expected: vec![y.nrows(), 1],
211 actual: vec![y.nrows(), ncols],
212 context: "FittedLabelBinarizer::inverse_transform".into(),
213 });
214 }
215
216 // Strict threshold at `(pos_label + neg_label) / 2`, matching sklearn
217 // `_inverse_binarize_thresholding` (`_label.py:667`): `y = np.array(y >
218 // threshold)` with default `threshold = (pos_label + neg_label) / 2`
219 // (`:399-400`). STRICT, so an exact-threshold value maps to `classes[0]`.
220 // With the default `neg_label=0, pos_label=1` this reduces to `> 0.5`.
221 // Cast EACH to f64 BEFORE the add: `i64 + i64` would overflow (and
222 // panic in debug, R-CODE-2) for large-but-valid neg/pos like 2^62
223 // (#2232). sklearn computes this in arbitrary-precision then /2.0.
224 let threshold = (self.pos_label as f64 + self.neg_label as f64) / 2.0;
225 // The decisive column is the LAST one: `col1` for a 2-column indicator
226 // (`classes[y[:, 1]]`, `:673-674`, first column ignored) and `col0` for
227 // a 1-column indicator (`classes[y.ravel()]`, `:679`).
228 let decisive_col = ncols - 1;
229 for i in 0..n {
230 result[i] = if y[[i, decisive_col]] > threshold {
231 self.classes[1]
232 } else {
233 self.classes[0]
234 };
235 }
236 } else {
237 if y.ncols() != k {
238 return Err(FerroError::ShapeMismatch {
239 expected: vec![y.nrows(), k],
240 actual: vec![y.nrows(), y.ncols()],
241 context: "FittedLabelBinarizer::inverse_transform".into(),
242 });
243 }
244 // Multiclass: argmax per row
245 for i in 0..n {
246 let row = y.row(i);
247 let mut best_j = 0;
248 let mut best_v = f64::NEG_INFINITY;
249 for (j, &v) in row.iter().enumerate() {
250 if v > best_v {
251 best_v = v;
252 best_j = j;
253 }
254 }
255 result[i] = self.classes[best_j];
256 }
257 }
258
259 Ok(result)
260 }
261}
262
263// ---------------------------------------------------------------------------
264// Trait implementations
265// ---------------------------------------------------------------------------
266
267impl Fit<Array1<usize>, ()> for LabelBinarizer {
268 type Fitted = FittedLabelBinarizer;
269 type Error = FerroError;
270
271 /// Fit the binarizer by discovering unique class labels.
272 ///
273 /// # Errors
274 ///
275 /// - Returns [`FerroError::InvalidParameter`] if `neg_label >= pos_label`,
276 /// mirroring sklearn's `neg_label={0} must be strictly less than
277 /// pos_label={1}.` raise (`sklearn/preprocessing/_label.py:283-287`).
278 /// - Returns [`FerroError::InsufficientSamples`] if the input is empty.
279 fn fit(&self, y: &Array1<usize>, _target: &()) -> Result<FittedLabelBinarizer, FerroError> {
280 // Validate neg_label < pos_label BEFORE class discovery, mirroring
281 // sklearn `fit` (`_label.py:283-287`): the message is verbatim
282 // `neg_label={neg} must be strictly less than pos_label={pos}.`.
283 if self.neg_label >= self.pos_label {
284 return Err(FerroError::InvalidParameter {
285 name: "neg_label".into(),
286 reason: format!(
287 "neg_label={} must be strictly less than pos_label={}.",
288 self.neg_label, self.pos_label
289 ),
290 });
291 }
292
293 if y.is_empty() {
294 return Err(FerroError::InsufficientSamples {
295 required: 1,
296 actual: 0,
297 context: "LabelBinarizer::fit".into(),
298 });
299 }
300
301 let mut classes: Vec<usize> = y.iter().copied().collect();
302 classes.sort_unstable();
303 classes.dedup();
304
305 Ok(FittedLabelBinarizer {
306 classes,
307 neg_label: self.neg_label,
308 pos_label: self.pos_label,
309 })
310 }
311}
312
313impl Transform<Array1<usize>> for FittedLabelBinarizer {
314 type Output = Array2<f64>;
315 type Error = FerroError;
316
317 /// Transform labels into a binary indicator matrix.
318 ///
319 /// - For *K* = 2 classes the output shape is `(n, 1)`.
320 /// - For *K* > 2 classes the output shape is `(n, K)`.
321 ///
322 /// Absent classes are written as `neg_label` and the present class as
323 /// `pos_label` (defaults `0` / `1`). Labels not seen during fitting are
324 /// silently ignored: their row is left at the `neg_label` base value, with
325 /// no error and no warning. This mirrors scikit-learn's `label_binarize`
326 /// (`sklearn/preprocessing/_label.py:556-559`), which selects only the
327 /// known labels (`y_in_classes = np.isin(y, classes)`) and leaves unseen
328 /// labels contributing nothing, then fills the dense base with `neg_label`
329 /// (`:579-583`).
330 fn transform(&self, y: &Array1<usize>) -> Result<Array2<f64>, FerroError> {
331 let k = self.classes.len();
332 let n = y.len();
333
334 // sklearn `LabelBinarizer.transform` delegates to `label_binarize`, which
335 // gates the binary single-column collapse on `type_of_target(y)=="binary"`
336 // (`_label.py:519`,`:531`) — computed on the TRANSFORM input `y`, NOT the
337 // fitted class count (#2234). For 1D integer `y`, "binary" means at most 2
338 // distinct values; a MULTICLASS transform input (3+ distinct) with 2
339 // fitted classes therefore emits the (n, 2) multi-column form, not the
340 // single column (e.g. fit([0,1]).transform([0,1,2]) -> [[1,0],[0,1],[0,0]]).
341 let y_is_binary = {
342 let mut distinct: Vec<usize> = y.iter().copied().collect();
343 distinct.sort_unstable();
344 distinct.dedup();
345 distinct.len() <= 2
346 };
347
348 // The base ("absent") value is `neg_label`; the active ("present")
349 // value is `pos_label`, mirroring sklearn `label_binarize`'s dense fill
350 // `Y[Y == 0] = neg_label` (`_label.py:579-583`) and the `pos_label`
351 // active positions (`:562`, `:599`).
352 let neg = self.neg_label as f64;
353 let pos = self.pos_label as f64;
354
355 // Build a lookup: class_value → column index
356 let class_to_idx: std::collections::HashMap<usize, usize> = self
357 .classes
358 .iter()
359 .enumerate()
360 .map(|(i, &c)| (c, i))
361 .collect();
362
363 if k == 1 {
364 // Single class (n_classes == 1): sklearn treats this as the binary
365 // degenerate case and returns an all-`neg_label` single column,
366 // never `pos_label` (`sklearn/preprocessing/_label.py:532-538`:
367 // `Y = np.zeros((len(y), 1)); Y += neg_label`).
368 Ok(Array2::from_elem((n, 1), neg))
369 } else if k == 2 && y_is_binary {
370 // Binary: single column, `pos_label` for the second class else
371 // `neg_label`. The base is filled with `neg_label` (NOT zeros).
372 // Only when the transform input is itself binary (#2234).
373 let mut out = Array2::from_elem((n, 1), neg);
374 for (i, &label) in y.iter().enumerate() {
375 // Unseen labels are silently ignored (row left at `neg_label`),
376 // mirroring sklearn `_label.py:556-559`.
377 if let Some(&idx) = class_to_idx.get(&label) {
378 out[[i, 0]] = if idx == 1 { pos } else { neg };
379 }
380 }
381 Ok(out)
382 } else {
383 // Multiclass: one-hot rows — `pos_label` at the class column,
384 // `neg_label` everywhere else. The base is filled with `neg_label`.
385 let mut out = Array2::from_elem((n, k), neg);
386 for (i, &label) in y.iter().enumerate() {
387 // Unseen labels are silently ignored (row left all-`neg_label`),
388 // mirroring sklearn `_label.py:556-559`.
389 if let Some(&idx) = class_to_idx.get(&label) {
390 out[[i, idx]] = pos;
391 }
392 }
393 Ok(out)
394 }
395 }
396}
397
398// ---------------------------------------------------------------------------
399// `label_binarize` free function (sklearn `label_binarize`, `_label.py:430`)
400// ---------------------------------------------------------------------------
401
402/// Binarize integer labels one-vs-all against an EXPLICIT class list — the
403/// standalone, estimator-less API mirroring scikit-learn's `label_binarize`
404/// free function (`sklearn/preprocessing/_label.py:430`).
405///
406/// Unlike [`LabelBinarizer`], which discovers its classes by fitting, this
407/// function takes the class set as an explicit `classes` argument and encodes
408/// `y` against it. The output is a binary indicator matrix written with
409/// `pos_label` at active positions and `neg_label` everywhere else (defaults
410/// `0` / `1`).
411///
412/// # Column ordering (the headline)
413///
414/// The output **columns follow the GIVEN `classes` order**, NOT a sorted order.
415/// sklearn builds the indicator in sorted-class order internally
416/// (`sorted_class = np.sort(classes)`, `_label.py:542`; columns via
417/// `np.searchsorted`, `:558`) but then **reorders the columns back to the given
418/// `classes` order** in the "preserve label ordering" step (`:587-590`:
419/// `indices = np.searchsorted(sorted_class, classes); Y = Y[:, indices]`).
420/// So `label_binarize([0,2,1], classes=[2,0,1])` yields
421/// `[[0,1,0],[1,0,0],[0,0,1]]` — column `j` corresponds to `classes[j]`, the
422/// *given* class, with `pos_label` where `y[i] == classes[j]`. (Verified live
423/// vs sklearn 1.5.2; see `tests/divergence_label_binarizer.rs`.)
424///
425/// # Shape / collapse rules
426///
427/// The single-column collapse is gated on `type_of_target(y) == "binary"`, NOT
428/// on `len(classes)` (`_label.py:519` `y_type = type_of_target(y)`; `:531`
429/// `if y_type == "binary":`; the collapse at `:592-596`). For 1D integer `y`,
430/// `type_of_target` is "binary" iff `y` has at most two distinct values, else
431/// "multiclass" (verified live vs sklearn 1.5.2). Writing
432/// `y_is_binary = (distinct count of y) <= 2`:
433/// - `k == 1`: a single all-`neg_label` column (`_label.py:532-538`).
434/// - `k == 2` AND `y_is_binary`: a single column — `pos_label` where
435/// `y == classes[last]` (the LAST given class), else `neg_label`. sklearn
436/// builds both columns then takes `Y[:, -1]` after the reorder
437/// (`_label.py:596`), so the kept column is the one for the last *given*
438/// class. (When `classes` is sorted — as the fitted estimator always is —
439/// `classes[last]` is the second-sorted class, so this coincides with
440/// [`FittedLabelBinarizer::transform`]'s `idx == 1` rule.)
441/// - `k == 2` but `y` is multiclass (3+ distinct values): NO collapse —
442/// `k == 2` columns in given order (`y_type` is not "binary", so the `:592`
443/// single-column step is skipped). E.g. `label_binarize([0,1,2], classes=[0,1])`
444/// → `(3, 2)` `[[1,0],[0,1],[0,0]]`, with the unseen `2` leaving an all-`neg`
445/// row.
446/// - `k > 2`: `k` columns in given order, `pos_label` at the value's column.
447///
448/// A value in `y` not present in `classes` leaves its row all-`neg_label`,
449/// mirroring sklearn's `y_in_classes = np.isin(y, classes)` silent ignore
450/// (`_label.py:556-559`).
451///
452/// `classes` is expected to be unique (sklearn: "Uniquely holds the label for
453/// each class", `_label.py:447`); duplicate entries are not part of the matched
454/// contract.
455///
456/// # Errors
457///
458/// Returns [`FerroError::InvalidParameter`] if `neg_label >= pos_label`, with
459/// the same verbatim message as [`LabelBinarizer`]'s `fit`
460/// (`_label.py:499-504`: `neg_label={neg} must be strictly less than
461/// pos_label={pos}.`). Returns [`FerroError::InsufficientSamples`] if `classes`
462/// is empty (sklearn cannot binarize against zero classes).
463#[must_use = "label_binarize returns a new indicator matrix"]
464pub fn label_binarize(
465 y: &Array1<usize>,
466 classes: &[usize],
467 neg_label: i64,
468 pos_label: i64,
469) -> Result<Array2<f64>, FerroError> {
470 // Validate neg_label < pos_label, mirroring sklearn `label_binarize`
471 // (`_label.py:499-504`) — the SAME verbatim message as the estimator's
472 // `fit` (`LabelBinarizer::fit`).
473 if neg_label >= pos_label {
474 return Err(FerroError::InvalidParameter {
475 name: "neg_label".into(),
476 reason: format!(
477 "neg_label={neg_label} must be strictly less than pos_label={pos_label}."
478 ),
479 });
480 }
481
482 let k = classes.len();
483 if k == 0 {
484 return Err(FerroError::InsufficientSamples {
485 required: 1,
486 actual: 0,
487 context: "label_binarize: classes".into(),
488 });
489 }
490
491 let n = y.len();
492 let neg = neg_label as f64;
493 let pos = pos_label as f64;
494
495 // Map each given class value to its GIVEN-order column index. The output
496 // columns follow the given `classes` order (sklearn's "preserve label
497 // ordering" reorder, `_label.py:587-590`), so column `j` belongs to
498 // `classes[j]`. For unique `classes` the last write wins identically; the
499 // contract assumes uniqueness (`_label.py:447`).
500 let class_to_col: std::collections::HashMap<usize, usize> =
501 classes.iter().enumerate().map(|(j, &c)| (c, j)).collect();
502
503 // The single-column collapse is gated on `type_of_target(y) == "binary"`,
504 // NOT on `len(classes)` (`_label.py:519` `y_type = type_of_target(y)`;
505 // `:531` `if y_type == "binary":`; the collapse itself at `:592-596`). For
506 // 1D integer `y`, `type_of_target` returns "binary" iff `y` has at most two
507 // distinct values, else "multiclass" (verified live vs sklearn 1.5.2:
508 // 1-distinct → "binary", 2-distinct → "binary", 3+ distinct → "multiclass").
509 // So `[5,5]` (1 distinct) and `[0,1,0]` (2 distinct) are binary, but
510 // `[0,1,2]` (3 distinct) is multiclass. When `k == 2` but `y` is multiclass,
511 // sklearn promotes to the `n_classes`-column form (`:539-540` only fires for
512 // `len(classes) >= 3`; here the non-binary `y_type` simply means the `:592`
513 // collapse is skipped), giving a `(n, 2)` indicator.
514 let mut distinct: Vec<usize> = y.iter().copied().collect();
515 distinct.sort_unstable();
516 distinct.dedup();
517 let y_is_binary = distinct.len() <= 2;
518
519 if k == 1 {
520 // n_classes == 1: all-`neg_label` single column (`_label.py:532-538`:
521 // `Y = np.zeros((len(y), 1)); Y += neg_label`). sklearn reaches this only
522 // when `y_type == "binary"` too; for plain integer `y` a single class
523 // implies `y` has ≤1 distinct value, which is always binary, so this
524 // single-column form is unconditional here.
525 Ok(Array2::from_elem((n, 1), neg))
526 } else if k == 2 && y_is_binary {
527 // Binary `y` with exactly two classes: the single column kept after the
528 // given-order reorder is `Y[:, -1]` (`_label.py:596`) — the column for
529 // the LAST given class. So `pos_label` where `y == classes[1]`, else
530 // `neg_label`. Unseen labels (not in `classes`) stay at `neg_label`
531 // (`:556-559`).
532 let last_class = classes[1];
533 let mut out = Array2::from_elem((n, 1), neg);
534 for (i, &label) in y.iter().enumerate() {
535 if label == last_class {
536 out[[i, 0]] = pos;
537 }
538 }
539 Ok(out)
540 } else {
541 // `k` columns in GIVEN order, `pos_label` at the value's column
542 // (`_label.py:552-577` + the `:587-590` reorder to the given order).
543 // Reached for genuine multiclass (`k > 2`) AND for `k == 2` with a
544 // multiclass `y` (3+ distinct values), where sklearn skips the `:592`
545 // single-column collapse and emits the full `(n, k)` indicator. Unseen
546 // labels leave the row all-`neg_label` (`:556-559`).
547 let mut out = Array2::from_elem((n, k), neg);
548 for (i, &label) in y.iter().enumerate() {
549 if let Some(&col) = class_to_col.get(&label) {
550 out[[i, col]] = pos;
551 }
552 }
553 Ok(out)
554 }
555}
556
557// ===========================================================================
558// Tests
559// ===========================================================================
560
561#[cfg(test)]
562mod tests {
563 use super::*;
564 use ndarray::array;
565
566 #[test]
567 fn test_fit_discovers_sorted_classes() {
568 let lb = LabelBinarizer::new();
569 let y = array![2_usize, 0, 1, 2, 0];
570 let fitted = lb.fit(&y, &()).unwrap();
571 assert_eq!(fitted.classes(), &[0, 1, 2]);
572 }
573
574 #[test]
575 fn test_fit_empty_input_error() {
576 let lb = LabelBinarizer::new();
577 let y: Array1<usize> = Array1::zeros(0);
578 assert!(lb.fit(&y, &()).is_err());
579 }
580
581 #[test]
582 fn test_binary_transform_single_column() {
583 let lb = LabelBinarizer::new();
584 let y = array![0_usize, 1, 0, 1];
585 let fitted = lb.fit(&y, &()).unwrap();
586 let mat = fitted.transform(&y).unwrap();
587 assert_eq!(mat.shape(), &[4, 1]);
588 assert_eq!(mat[[0, 0]], 0.0); // class 0 → 0
589 assert_eq!(mat[[1, 0]], 1.0); // class 1 → 1
590 assert_eq!(mat[[2, 0]], 0.0);
591 assert_eq!(mat[[3, 0]], 1.0);
592 }
593
594 #[test]
595 fn test_multiclass_transform_indicator_matrix() {
596 let lb = LabelBinarizer::new();
597 let y = array![0_usize, 1, 2, 1];
598 let fitted = lb.fit(&y, &()).unwrap();
599 let mat = fitted.transform(&y).unwrap();
600 assert_eq!(mat.shape(), &[4, 3]);
601 // Row 0: class 0 → [1, 0, 0]
602 assert_eq!(mat[[0, 0]], 1.0);
603 assert_eq!(mat[[0, 1]], 0.0);
604 assert_eq!(mat[[0, 2]], 0.0);
605 // Row 2: class 2 → [0, 0, 1]
606 assert_eq!(mat[[2, 0]], 0.0);
607 assert_eq!(mat[[2, 1]], 0.0);
608 assert_eq!(mat[[2, 2]], 1.0);
609 }
610
611 #[test]
612 fn test_inverse_transform_multiclass() {
613 let lb = LabelBinarizer::new();
614 let y = array![0_usize, 1, 2, 1];
615 let fitted = lb.fit(&y, &()).unwrap();
616 let mat = fitted.transform(&y).unwrap();
617 let recovered = fitted.inverse_transform(&mat).unwrap();
618 assert_eq!(recovered, y);
619 }
620
621 #[test]
622 fn test_inverse_transform_binary() {
623 let lb = LabelBinarizer::new();
624 let y = array![0_usize, 1, 0, 1];
625 let fitted = lb.fit(&y, &()).unwrap();
626 let mat = fitted.transform(&y).unwrap();
627 let recovered = fitted.inverse_transform(&mat).unwrap();
628 assert_eq!(recovered, y);
629 }
630
631 /// Unseen labels are silently ignored (row left all-neg_label), mirroring
632 /// sklearn `label_binarize` (`_label.py:556-559`).
633 ///
634 /// Live oracle (sklearn 1.5.2, from /tmp):
635 /// `LabelBinarizer().fit([0,1,2]).transform([0,3]).tolist()`
636 /// -> `[[1, 0, 0], [0, 0, 0]]`
637 #[test]
638 fn test_transform_unknown_label_ignored() {
639 let lb = LabelBinarizer::new();
640 let y = array![0_usize, 1, 2];
641 let y2 = array![0_usize, 3]; // 3 not in {0,1,2}
642 // Fit then transform, propagating any error into the Result we compare.
643 let got = lb.fit(&y, &()).and_then(|fitted| fitted.transform(&y2));
644 // sklearn-oracle value: [[1,0,0],[0,0,0]] (label 3 ignored, all-zero row).
645 // Compare via Option (FerroError is not PartialEq); Ok(_) is required.
646 let expected: Array2<f64> = array![[1.0, 0.0, 0.0], [0.0, 0.0, 0.0]];
647 assert_eq!(got.ok(), Some(expected));
648 }
649
650 #[test]
651 fn test_inverse_transform_shape_mismatch() {
652 let lb = LabelBinarizer::new();
653 let y = array![0_usize, 1, 2];
654 let fitted = lb.fit(&y, &()).unwrap();
655 // 3 classes expects 3 columns, but we give 2
656 let bad = Array2::<f64>::zeros((2, 2));
657 assert!(fitted.inverse_transform(&bad).is_err());
658 }
659
660 /// Single class (n_classes == 1) → an all-zero single column, mirroring
661 /// sklearn's binary-degenerate case (`_label.py:532-538`).
662 ///
663 /// Live oracle (sklearn 1.5.2):
664 /// `LabelBinarizer().fit_transform([5,5,5]).tolist()` -> `[[0],[0],[0]]`
665 #[test]
666 fn test_single_class() {
667 let lb = LabelBinarizer::new();
668 let y = array![5_usize, 5, 5];
669 // Confirm exactly one class is discovered (degenerate single-class case).
670 let n_classes = lb.fit(&y, &()).map(|fitted| fitted.n_classes());
671 assert_eq!(n_classes.ok(), Some(1));
672 // Fit then transform, propagating any error into the Result we compare.
673 let got = lb.fit(&y, &()).and_then(|fitted| fitted.transform(&y));
674 // 1 class → 1 column, all zeros (never 1.0); sklearn-oracle [[0],[0],[0]].
675 let expected: Array2<f64> = array![[0.0], [0.0], [0.0]];
676 assert_eq!(got.ok(), Some(expected));
677 }
678
679 #[test]
680 fn test_non_contiguous_classes() {
681 let lb = LabelBinarizer::new();
682 let y = array![10_usize, 20, 30, 10];
683 let fitted = lb.fit(&y, &()).unwrap();
684 assert_eq!(fitted.classes(), &[10, 20, 30]);
685 let mat = fitted.transform(&y).unwrap();
686 assert_eq!(mat.shape(), &[4, 3]);
687 assert_eq!(mat[[0, 0]], 1.0); // 10 → col 0
688 assert_eq!(mat[[1, 1]], 1.0); // 20 → col 1
689 assert_eq!(mat[[2, 2]], 1.0); // 30 → col 2
690 }
691
692 #[test]
693 fn test_roundtrip_multiclass_non_contiguous() {
694 let lb = LabelBinarizer::new();
695 let y = array![10_usize, 20, 30, 20];
696 let fitted = lb.fit(&y, &()).unwrap();
697 let mat = fitted.transform(&y).unwrap();
698 let recovered = fitted.inverse_transform(&mat).unwrap();
699 assert_eq!(recovered, y);
700 }
701
702 // -- REQ-8: neg_label / pos_label ctor params + validation ----------------
703
704 /// REQ-8: builders + getters carry the configured neg/pos through fit.
705 /// Defaults preserve the canonical 0/1 encoding.
706 #[test]
707 fn test_neg_pos_label_builders_and_getters() {
708 let lb = LabelBinarizer::new();
709 assert_eq!(lb.neg_label(), 0);
710 assert_eq!(lb.pos_label(), 1);
711
712 let lb = LabelBinarizer::new().with_neg_label(-1).with_pos_label(2);
713 assert_eq!(lb.neg_label(), -1);
714 assert_eq!(lb.pos_label(), 2);
715
716 let fitted = lb.fit(&array![0_usize, 1, 2], &()).unwrap();
717 assert_eq!(fitted.neg_label(), -1);
718 assert_eq!(fitted.pos_label(), 2);
719 }
720
721 /// REQ-8: multiclass transform with neg_label=-1, pos_label=2.
722 /// Live oracle (sklearn 1.5.2, from /tmp):
723 /// `LabelBinarizer(neg_label=-1,pos_label=2).fit([0,1,2]).transform([0,2]).tolist()`
724 /// -> `[[2,-1,-1],[-1,-1,2]]` (present->2, absent->-1; base is -1, not 0)
725 #[test]
726 fn test_neg_pos_multiclass_transform() {
727 let lb = LabelBinarizer::new().with_neg_label(-1).with_pos_label(2);
728 let fitted = lb.fit(&array![0_usize, 1, 2], &()).unwrap();
729 let got = fitted.transform(&array![0_usize, 2]).unwrap();
730 let expected: Array2<f64> = array![[2.0, -1.0, -1.0], [-1.0, -1.0, 2.0]];
731 assert_eq!(got, expected);
732 }
733
734 /// REQ-8: binary (k==2) transform with neg_label=-1, pos_label=1.
735 /// Live oracle (sklearn 1.5.2, from /tmp):
736 /// `LabelBinarizer(neg_label=-1,pos_label=1).fit([0,1]).transform([0,1,0]).tolist()`
737 /// -> `[[-1],[1],[-1]]` (2nd class -> pos_label, else neg_label)
738 #[test]
739 fn test_neg_pos_binary_transform() {
740 let lb = LabelBinarizer::new().with_neg_label(-1).with_pos_label(1);
741 let fitted = lb.fit(&array![0_usize, 1], &()).unwrap();
742 let got = fitted.transform(&array![0_usize, 1, 0]).unwrap();
743 let expected: Array2<f64> = array![[-1.0], [1.0], [-1.0]];
744 assert_eq!(got, expected);
745 }
746
747 /// REQ-8: single-class (k==1) transform -> all neg_label.
748 /// Live oracle (sklearn 1.5.2, from /tmp):
749 /// `LabelBinarizer(neg_label=-1,pos_label=2).fit_transform([5,5,5]).tolist()`
750 /// -> `[[-1],[-1],[-1]]`
751 #[test]
752 fn test_neg_pos_single_class_all_neg() {
753 let lb = LabelBinarizer::new().with_neg_label(-1).with_pos_label(2);
754 let y = array![5_usize, 5, 5];
755 let fitted = lb.fit(&y, &()).unwrap();
756 let got = fitted.transform(&y).unwrap();
757 let expected: Array2<f64> = array![[-1.0], [-1.0], [-1.0]];
758 assert_eq!(got, expected);
759 }
760
761 /// REQ-8: unseen labels stay at neg_label (silent-ignore, now -1).
762 /// Live oracle (sklearn 1.5.2, from /tmp):
763 /// `LabelBinarizer(neg_label=-1,pos_label=2).fit([0,1,2]).transform([0,3]).tolist()`
764 /// -> `[[2,-1,-1],[-1,-1,-1]]` (label 3 ignored, row all neg_label)
765 #[test]
766 fn test_neg_pos_unseen_label_stays_neg() {
767 let lb = LabelBinarizer::new().with_neg_label(-1).with_pos_label(2);
768 let fitted = lb.fit(&array![0_usize, 1, 2], &()).unwrap();
769 let got = fitted.transform(&array![0_usize, 3]).unwrap();
770 let expected: Array2<f64> = array![[2.0, -1.0, -1.0], [-1.0, -1.0, -1.0]];
771 assert_eq!(got, expected);
772 }
773
774 /// REQ-8: neg_label >= pos_label is rejected at fit time, verbatim message.
775 /// Live oracle (sklearn 1.5.2, from /tmp):
776 /// `LabelBinarizer(neg_label=2,pos_label=1).fit([0,1])`
777 /// -> ValueError: "neg_label=2 must be strictly less than pos_label=1."
778 /// `LabelBinarizer(neg_label=1,pos_label=1).fit([0,1])`
779 /// -> ValueError: "neg_label=1 must be strictly less than pos_label=1."
780 #[test]
781 fn test_neg_ge_pos_rejected() {
782 // neg > pos
783 let err = LabelBinarizer::new()
784 .with_neg_label(2)
785 .with_pos_label(1)
786 .fit(&array![0_usize, 1], &())
787 .unwrap_err();
788 assert!(matches!(
789 &err,
790 FerroError::InvalidParameter { name, reason }
791 if name == "neg_label"
792 && reason == "neg_label=2 must be strictly less than pos_label=1."
793 ));
794
795 // neg == pos
796 let err = LabelBinarizer::new()
797 .with_neg_label(1)
798 .with_pos_label(1)
799 .fit(&array![0_usize, 1], &())
800 .unwrap_err();
801 assert!(matches!(
802 &err,
803 FerroError::InvalidParameter { reason, .. }
804 if reason == "neg_label=1 must be strictly less than pos_label=1."
805 ));
806 }
807
808 /// REQ-8: inverse_transform binary threshold = (pos+neg)/2 (STRICT).
809 /// Live oracle (sklearn 1.5.2, from /tmp):
810 /// neg=-1,pos=1 -> threshold 0.0:
811 /// `inverse_transform([[0.0]])` -> [0]; `[[0.1]]` -> [1]; `[[-0.1]]` -> [0]
812 /// neg=2,pos=4 -> threshold 3.0:
813 /// `inverse_transform([[3.0]])` -> [0]; `[[3.1]]` -> [1]
814 #[test]
815 fn test_neg_pos_inverse_threshold() {
816 let fitted = LabelBinarizer::new()
817 .with_neg_label(-1)
818 .with_pos_label(1)
819 .fit(&array![0_usize, 1], &())
820 .unwrap();
821 // threshold = (1 + -1)/2 = 0.0; strict `> 0.0`
822 assert_eq!(
823 fitted.inverse_transform(&array![[0.0_f64]]).unwrap(),
824 array![0_usize]
825 );
826 assert_eq!(
827 fitted.inverse_transform(&array![[0.1_f64]]).unwrap(),
828 array![1_usize]
829 );
830 assert_eq!(
831 fitted.inverse_transform(&array![[-0.1_f64]]).unwrap(),
832 array![0_usize]
833 );
834
835 let fitted = LabelBinarizer::new()
836 .with_neg_label(2)
837 .with_pos_label(4)
838 .fit(&array![0_usize, 1], &())
839 .unwrap();
840 // threshold = (4 + 2)/2 = 3.0; strict `> 3.0`
841 assert_eq!(
842 fitted.inverse_transform(&array![[3.0_f64]]).unwrap(),
843 array![0_usize]
844 );
845 assert_eq!(
846 fitted.inverse_transform(&array![[3.1_f64]]).unwrap(),
847 array![1_usize]
848 );
849 }
850
851 /// REQ-8: inverse_transform multiclass round-trip with neg/pos (argmax
852 /// unchanged — pos_label is the largest so argmax still selects it).
853 /// Live oracle (sklearn 1.5.2, from /tmp):
854 /// `LabelBinarizer(neg_label=-1,pos_label=2).fit([0,1,2]).inverse_transform(
855 /// [[2,-1,-1],[-1,-1,2]])` -> [0, 2]
856 #[test]
857 fn test_neg_pos_inverse_multiclass_roundtrip() {
858 let fitted = LabelBinarizer::new()
859 .with_neg_label(-1)
860 .with_pos_label(2)
861 .fit(&array![0_usize, 1, 2], &())
862 .unwrap();
863 let mat: Array2<f64> = array![[2.0, -1.0, -1.0], [-1.0, -1.0, 2.0]];
864 let recovered = fitted.inverse_transform(&mat).unwrap();
865 assert_eq!(recovered, array![0_usize, 2]);
866 }
867
868 /// REQ-1/2/3 preserved: defaults (0/1) reproduce the canonical encoding.
869 #[test]
870 fn test_defaults_preserve_zero_one() {
871 let lb = LabelBinarizer::new();
872 // multiclass
873 let fitted = lb.fit(&array![0_usize, 1, 2, 1], &()).unwrap();
874 let expected: Array2<f64> = array![
875 [1.0, 0.0, 0.0],
876 [0.0, 1.0, 0.0],
877 [0.0, 0.0, 1.0],
878 [0.0, 1.0, 0.0]
879 ];
880 assert_eq!(
881 fitted.transform(&array![0_usize, 1, 2, 1]).unwrap(),
882 expected
883 );
884 // Default also via Default::default()
885 let lb2 = LabelBinarizer::default();
886 assert_eq!((lb2.neg_label(), lb2.pos_label()), (0, 1));
887 }
888}