Skip to main content

ferrolearn_preprocess/
label_binarizer.rs

1//! One-vs-rest label binarizer.
2//!
3//! Transforms a vector of integer class labels into a binary indicator matrix.
4//! For *K* classes the output has *K* columns (one-hot rows), except in the
5//! binary case (*K* = 2) and single-class case (*K* = 1) where a single column
6//! is produced.
7//!
8//! Translation target: scikit-learn 1.5.2 `class LabelBinarizer`
9//! (`sklearn/preprocessing/_label.py:180`) + `label_binarize` (`:430`). Design:
10//! `.design/preprocess/label_binarizer.md`. Tracking: #1238.
11//!
12//! `## REQ status`
13//!
14//! | REQ | Status | Anchor |
15//! |---|---|---|
16//! | REQ-1 fit → sorted-unique classes_ (usize) | SHIPPED | `LabelBinarizer::fit`; sklearn `_label.py:306` |
17//! | REQ-2 transform multiclass (k≥3) one-hot values | SHIPPED | `FittedLabelBinarizer::transform` else-branch; sklearn `_label.py:552-577` |
18//! | REQ-3 transform binary (k=2) single col, pos_label on 2nd class | SHIPPED | `transform` k==2 branch; sklearn `_label.py:531`,`:592-596` |
19//! | REQ-4 transform unknown-label: ignore (all-neg_label row) | SHIPPED (#1239) | `transform` `if let Some(&idx) = class_to_idx.get`; sklearn `_label.py:556-559` |
20//! | REQ-5 transform single-class (k=1) → all-neg_label column | SHIPPED (#1240) | `transform` k==1 arm `Array2::from_elem`; sklearn `_label.py:532-538` |
21//! | REQ-6 inverse_transform binary STRICT threshold (`> (pos+neg)/2`) | SHIPPED (#1241) | `inverse_transform` k==2 branch; sklearn `_label.py:667` |
22//! | REQ-6b inverse_transform binary accepts 1-col AND 2-col indicator (dispatch on fitted type, not col count) | SHIPPED (#2340) | `inverse_transform` k==2 branch accepts `ncols ∈ {1,2}`, decodes `classes[last_col > threshold ? 1 : 0]` (2-col → `classes[y[:,1]]`, 1-col → `classes[y.ravel()]`); sklearn `_label.py:402-407`,`:647`,`:670-679` |
23//! | REQ-7 inverse_transform multiclass argmax | SHIPPED | `inverse_transform` else-branch; sklearn `_label.py:641` |
24//! | REQ-8 neg_label/pos_label ctor params + validation | SHIPPED (#1242) | `LabelBinarizer::with_neg_label`/`with_pos_label` + `Fit::fit` validation; `transform` neg/pos base+active; `inverse_transform` `(pos+neg)/2` threshold; consumer crate re-export `lib.rs`; sklearn `_label.py:263`,`:283-287`,`:579-583`,`:667` |
25//! | REQ-9 sparse_output CSR + constraint | NOT-STARTED (#1243) | sklearn `_label.py:563`,`:584-585`,`:289-294` |
26//! | REQ-10 `label_binarize` free function | SHIPPED (#1244) | `pub fn label_binarize` (this file): `neg<pos` validation (verbatim msg, sklearn `_label.py:499-504`); GIVEN-order columns (sklearn's "preserve label ordering" reorder `:587-590`, so `label_binarize([0,2,1],classes=[2,0,1])` → `[[0,1,0],[1,0,0],[0,0,1]]`); k==1 all-neg col (`:532-538`); single-col collapse gated on `type_of_target(y)=="binary"` AND `len(classes)==2` (NOT `len(classes)` alone — `:519`,`:531`,`:592-596`; "binary" = ≤2 distinct values for 1D int y, verified live, #2233), giving pos where `y==classes[1]` (the kept `Y[:,-1]` after reorder, `:596`); k==2 with multiclass y (3+ distinct) emits 2 cols, no collapse; k>2 one-hot in given order (`:552-577`); unseen label → all-neg row (`:556-559`). Consumer: crate re-export `lib.rs` (`pub use label_binarizer::label_binarize`). Live-oracle parity: `tests/divergence_label_binarizer.rs` (basic/neg-pos/binary/`[2,0,1]`-ordering/unseen/neg≥pos-err/==estimator). |
27//! | REQ-11 arbitrary label types + type_of_target/multilabel input | NOT-STARTED (#1245) | sklearn `_label.py:296`,`:543-550` (usize-only, R-DEV-3) |
28//! | REQ-12 PyO3 binding | NOT-STARTED (#1246) | `ferrolearn-python/src/` (absent) |
29//!
30//! # Examples
31//!
32//! ```
33//! use ferrolearn_preprocess::label_binarizer::LabelBinarizer;
34//! use ferrolearn_core::traits::{Fit, Transform};
35//! use ndarray::array;
36//!
37//! let lb = LabelBinarizer::new();
38//! let y = array![0_usize, 1, 2, 1];
39//! let fitted = lb.fit(&y, &()).unwrap();
40//! let mat = fitted.transform(&y).unwrap();
41//! // 3 classes → (4, 3) indicator matrix
42//! assert_eq!(mat.shape(), &[4, 3]);
43//! assert_eq!(mat[[0, 0]], 1.0);
44//! assert_eq!(mat[[0, 1]], 0.0);
45//! ```
46
47use ferrolearn_core::error::FerroError;
48use ferrolearn_core::traits::{Fit, Transform};
49use ndarray::{Array1, Array2};
50
51// ---------------------------------------------------------------------------
52// LabelBinarizer (unfitted)
53// ---------------------------------------------------------------------------
54
55/// An unfitted one-vs-rest label binarizer.
56///
57/// Calling [`Fit::fit`] on an `Array1<usize>` discovers the sorted set of
58/// unique class labels and returns a [`FittedLabelBinarizer`].
59///
60/// `neg_label` / `pos_label` are the integer values written into the output
61/// indicator matrix for absent / present classes, mirroring sklearn's
62/// `LabelBinarizer(neg_label=0, pos_label=1)` (`sklearn/preprocessing/_label.py:263`).
63/// The defaults `0` / `1` reproduce the canonical 0/1 indicator behavior.
64#[derive(Debug, Clone)]
65pub struct LabelBinarizer {
66    /// Value written for absent classes (sklearn `neg_label`, default `0`).
67    neg_label: i64,
68    /// Value written for the present class (sklearn `pos_label`, default `1`).
69    pos_label: i64,
70}
71
72impl Default for LabelBinarizer {
73    fn default() -> Self {
74        Self::new()
75    }
76}
77
78impl LabelBinarizer {
79    /// Create a new `LabelBinarizer` with the default `neg_label=0`,
80    /// `pos_label=1` (the canonical 0/1 indicator encoding).
81    #[must_use]
82    pub fn new() -> Self {
83        Self {
84            neg_label: 0,
85            pos_label: 1,
86        }
87    }
88
89    /// Set the `neg_label` (value used for absent classes).
90    ///
91    /// Mirrors sklearn's `LabelBinarizer(neg_label=...)`
92    /// (`sklearn/preprocessing/_label.py:263`). Must be strictly less than
93    /// `pos_label`; validated at [`Fit::fit`] time (`_label.py:283-287`).
94    #[must_use]
95    pub fn with_neg_label(mut self, neg_label: i64) -> Self {
96        self.neg_label = neg_label;
97        self
98    }
99
100    /// Set the `pos_label` (value used for the present class).
101    ///
102    /// Mirrors sklearn's `LabelBinarizer(pos_label=...)`
103    /// (`sklearn/preprocessing/_label.py:263`). Must be strictly greater than
104    /// `neg_label`; validated at [`Fit::fit`] time (`_label.py:283-287`).
105    #[must_use]
106    pub fn with_pos_label(mut self, pos_label: i64) -> Self {
107        self.pos_label = pos_label;
108        self
109    }
110
111    /// Return the configured `neg_label`.
112    #[must_use]
113    pub fn neg_label(&self) -> i64 {
114        self.neg_label
115    }
116
117    /// Return the configured `pos_label`.
118    #[must_use]
119    pub fn pos_label(&self) -> i64 {
120        self.pos_label
121    }
122}
123
124// ---------------------------------------------------------------------------
125// FittedLabelBinarizer
126// ---------------------------------------------------------------------------
127
128/// A fitted label binarizer holding the discovered class set.
129///
130/// Created by calling [`Fit::fit`] on a [`LabelBinarizer`].
131#[derive(Debug, Clone)]
132pub struct FittedLabelBinarizer {
133    /// Sorted unique class labels observed during fitting.
134    classes: Vec<usize>,
135    /// Value written for absent classes (sklearn `neg_label`, default `0`).
136    neg_label: i64,
137    /// Value written for the present class (sklearn `pos_label`, default `1`).
138    pos_label: i64,
139}
140
141impl FittedLabelBinarizer {
142    /// Return the sorted class labels discovered during fitting.
143    #[must_use]
144    pub fn classes(&self) -> &[usize] {
145        &self.classes
146    }
147
148    /// Return the number of unique classes.
149    #[must_use]
150    pub fn n_classes(&self) -> usize {
151        self.classes.len()
152    }
153
154    /// Return the configured `neg_label` (value used for absent classes).
155    #[must_use]
156    pub fn neg_label(&self) -> i64 {
157        self.neg_label
158    }
159
160    /// Return the configured `pos_label` (value used for the present class).
161    #[must_use]
162    pub fn pos_label(&self) -> i64 {
163        self.pos_label
164    }
165
166    /// Map a binary indicator matrix back to integer class labels.
167    ///
168    /// Dispatch follows sklearn's `LabelBinarizer.inverse_transform`, which
169    /// branches on the FITTED `y_type_` ("binary" vs "multiclass"), NOT on the
170    /// column count of `Y` (`sklearn/preprocessing/_label.py:402-407`). Here the
171    /// fitted type is "binary" iff exactly two classes were discovered
172    /// (`k == 2`):
173    ///
174    /// - **Multiclass** (`k != 2`): the class with the largest value (argmax)
175    ///   per row, mirroring `_inverse_binarize_multiclass`
176    ///   (`classes.take(Y.argmax(axis=1))`, `_label.py:641`). Requires exactly
177    ///   *K* columns.
178    /// - **Binary** (`k == 2`): `_inverse_binarize_thresholding` thresholds the
179    ///   indicator with a STRICT `y > threshold`
180    ///   (`threshold = (pos_label + neg_label) / 2`, `_label.py:399-400`,`:667`)
181    ///   then decodes (`_label.py:670-679`):
182    ///     - a **1-column** indicator → `classes[col0 > threshold ? 1 : 0]`
183    ///       (`classes[y.ravel()]`, `:679`);
184    ///     - a **2-column** indicator → `classes[col1 > threshold ? 1 : 0]`
185    ///       (`classes[y[:, 1]]`, `:673-674`): the SECOND column (after
186    ///       thresholding) selects the positive class; the first column is
187    ///       ignored. So `fit([10,20]).inverse_transform([[1,0],[0,1]])` →
188    ///       `[10, 20]` (verified vs the live sklearn 1.5.2 oracle).
189    ///
190    /// # Errors
191    ///
192    /// Returns [`FerroError::ShapeMismatch`] if the number of columns does not
193    /// match an accepted width: a binary (`k == 2`) fitted binarizer accepts
194    /// BOTH 1 and 2 columns (`_label.py:647`,`:670-679`); otherwise exactly *K*
195    /// columns are required.
196    pub fn inverse_transform(&self, y: &Array2<f64>) -> Result<Array1<usize>, FerroError> {
197        let k = self.classes.len();
198        let n = y.nrows();
199        let mut result = Array1::zeros(n);
200
201        if k == 2 {
202            // Binary fitted type: sklearn dispatches to
203            // `_inverse_binarize_thresholding` on the fitted `y_type_ == "binary"`,
204            // which accepts EITHER a 1-column or a 2-column indicator and rejects
205            // wider ones (`_label.py:647`: `y.shape[1] > 2` raises; `:670-679`:
206            // 2-col → `classes[y[:, 1]]`, else `classes[y.ravel()]`).
207            let ncols = y.ncols();
208            if ncols != 1 && ncols != 2 {
209                return Err(FerroError::ShapeMismatch {
210                    expected: vec![y.nrows(), 1],
211                    actual: vec![y.nrows(), ncols],
212                    context: "FittedLabelBinarizer::inverse_transform".into(),
213                });
214            }
215
216            // Strict threshold at `(pos_label + neg_label) / 2`, matching sklearn
217            // `_inverse_binarize_thresholding` (`_label.py:667`): `y = np.array(y >
218            // threshold)` with default `threshold = (pos_label + neg_label) / 2`
219            // (`:399-400`). STRICT, so an exact-threshold value maps to `classes[0]`.
220            // With the default `neg_label=0, pos_label=1` this reduces to `> 0.5`.
221            // Cast EACH to f64 BEFORE the add: `i64 + i64` would overflow (and
222            // panic in debug, R-CODE-2) for large-but-valid neg/pos like 2^62
223            // (#2232). sklearn computes this in arbitrary-precision then /2.0.
224            let threshold = (self.pos_label as f64 + self.neg_label as f64) / 2.0;
225            // The decisive column is the LAST one: `col1` for a 2-column indicator
226            // (`classes[y[:, 1]]`, `:673-674`, first column ignored) and `col0` for
227            // a 1-column indicator (`classes[y.ravel()]`, `:679`).
228            let decisive_col = ncols - 1;
229            for i in 0..n {
230                result[i] = if y[[i, decisive_col]] > threshold {
231                    self.classes[1]
232                } else {
233                    self.classes[0]
234                };
235            }
236        } else {
237            if y.ncols() != k {
238                return Err(FerroError::ShapeMismatch {
239                    expected: vec![y.nrows(), k],
240                    actual: vec![y.nrows(), y.ncols()],
241                    context: "FittedLabelBinarizer::inverse_transform".into(),
242                });
243            }
244            // Multiclass: argmax per row
245            for i in 0..n {
246                let row = y.row(i);
247                let mut best_j = 0;
248                let mut best_v = f64::NEG_INFINITY;
249                for (j, &v) in row.iter().enumerate() {
250                    if v > best_v {
251                        best_v = v;
252                        best_j = j;
253                    }
254                }
255                result[i] = self.classes[best_j];
256            }
257        }
258
259        Ok(result)
260    }
261}
262
263// ---------------------------------------------------------------------------
264// Trait implementations
265// ---------------------------------------------------------------------------
266
267impl Fit<Array1<usize>, ()> for LabelBinarizer {
268    type Fitted = FittedLabelBinarizer;
269    type Error = FerroError;
270
271    /// Fit the binarizer by discovering unique class labels.
272    ///
273    /// # Errors
274    ///
275    /// - Returns [`FerroError::InvalidParameter`] if `neg_label >= pos_label`,
276    ///   mirroring sklearn's `neg_label={0} must be strictly less than
277    ///   pos_label={1}.` raise (`sklearn/preprocessing/_label.py:283-287`).
278    /// - Returns [`FerroError::InsufficientSamples`] if the input is empty.
279    fn fit(&self, y: &Array1<usize>, _target: &()) -> Result<FittedLabelBinarizer, FerroError> {
280        // Validate neg_label < pos_label BEFORE class discovery, mirroring
281        // sklearn `fit` (`_label.py:283-287`): the message is verbatim
282        // `neg_label={neg} must be strictly less than pos_label={pos}.`.
283        if self.neg_label >= self.pos_label {
284            return Err(FerroError::InvalidParameter {
285                name: "neg_label".into(),
286                reason: format!(
287                    "neg_label={} must be strictly less than pos_label={}.",
288                    self.neg_label, self.pos_label
289                ),
290            });
291        }
292
293        if y.is_empty() {
294            return Err(FerroError::InsufficientSamples {
295                required: 1,
296                actual: 0,
297                context: "LabelBinarizer::fit".into(),
298            });
299        }
300
301        let mut classes: Vec<usize> = y.iter().copied().collect();
302        classes.sort_unstable();
303        classes.dedup();
304
305        Ok(FittedLabelBinarizer {
306            classes,
307            neg_label: self.neg_label,
308            pos_label: self.pos_label,
309        })
310    }
311}
312
313impl Transform<Array1<usize>> for FittedLabelBinarizer {
314    type Output = Array2<f64>;
315    type Error = FerroError;
316
317    /// Transform labels into a binary indicator matrix.
318    ///
319    /// - For *K* = 2 classes the output shape is `(n, 1)`.
320    /// - For *K* > 2 classes the output shape is `(n, K)`.
321    ///
322    /// Absent classes are written as `neg_label` and the present class as
323    /// `pos_label` (defaults `0` / `1`). Labels not seen during fitting are
324    /// silently ignored: their row is left at the `neg_label` base value, with
325    /// no error and no warning. This mirrors scikit-learn's `label_binarize`
326    /// (`sklearn/preprocessing/_label.py:556-559`), which selects only the
327    /// known labels (`y_in_classes = np.isin(y, classes)`) and leaves unseen
328    /// labels contributing nothing, then fills the dense base with `neg_label`
329    /// (`:579-583`).
330    fn transform(&self, y: &Array1<usize>) -> Result<Array2<f64>, FerroError> {
331        let k = self.classes.len();
332        let n = y.len();
333
334        // sklearn `LabelBinarizer.transform` delegates to `label_binarize`, which
335        // gates the binary single-column collapse on `type_of_target(y)=="binary"`
336        // (`_label.py:519`,`:531`) — computed on the TRANSFORM input `y`, NOT the
337        // fitted class count (#2234). For 1D integer `y`, "binary" means at most 2
338        // distinct values; a MULTICLASS transform input (3+ distinct) with 2
339        // fitted classes therefore emits the (n, 2) multi-column form, not the
340        // single column (e.g. fit([0,1]).transform([0,1,2]) -> [[1,0],[0,1],[0,0]]).
341        let y_is_binary = {
342            let mut distinct: Vec<usize> = y.iter().copied().collect();
343            distinct.sort_unstable();
344            distinct.dedup();
345            distinct.len() <= 2
346        };
347
348        // The base ("absent") value is `neg_label`; the active ("present")
349        // value is `pos_label`, mirroring sklearn `label_binarize`'s dense fill
350        // `Y[Y == 0] = neg_label` (`_label.py:579-583`) and the `pos_label`
351        // active positions (`:562`, `:599`).
352        let neg = self.neg_label as f64;
353        let pos = self.pos_label as f64;
354
355        // Build a lookup: class_value → column index
356        let class_to_idx: std::collections::HashMap<usize, usize> = self
357            .classes
358            .iter()
359            .enumerate()
360            .map(|(i, &c)| (c, i))
361            .collect();
362
363        if k == 1 {
364            // Single class (n_classes == 1): sklearn treats this as the binary
365            // degenerate case and returns an all-`neg_label` single column,
366            // never `pos_label` (`sklearn/preprocessing/_label.py:532-538`:
367            // `Y = np.zeros((len(y), 1)); Y += neg_label`).
368            Ok(Array2::from_elem((n, 1), neg))
369        } else if k == 2 && y_is_binary {
370            // Binary: single column, `pos_label` for the second class else
371            // `neg_label`. The base is filled with `neg_label` (NOT zeros).
372            // Only when the transform input is itself binary (#2234).
373            let mut out = Array2::from_elem((n, 1), neg);
374            for (i, &label) in y.iter().enumerate() {
375                // Unseen labels are silently ignored (row left at `neg_label`),
376                // mirroring sklearn `_label.py:556-559`.
377                if let Some(&idx) = class_to_idx.get(&label) {
378                    out[[i, 0]] = if idx == 1 { pos } else { neg };
379                }
380            }
381            Ok(out)
382        } else {
383            // Multiclass: one-hot rows — `pos_label` at the class column,
384            // `neg_label` everywhere else. The base is filled with `neg_label`.
385            let mut out = Array2::from_elem((n, k), neg);
386            for (i, &label) in y.iter().enumerate() {
387                // Unseen labels are silently ignored (row left all-`neg_label`),
388                // mirroring sklearn `_label.py:556-559`.
389                if let Some(&idx) = class_to_idx.get(&label) {
390                    out[[i, idx]] = pos;
391                }
392            }
393            Ok(out)
394        }
395    }
396}
397
398// ---------------------------------------------------------------------------
399// `label_binarize` free function (sklearn `label_binarize`, `_label.py:430`)
400// ---------------------------------------------------------------------------
401
402/// Binarize integer labels one-vs-all against an EXPLICIT class list — the
403/// standalone, estimator-less API mirroring scikit-learn's `label_binarize`
404/// free function (`sklearn/preprocessing/_label.py:430`).
405///
406/// Unlike [`LabelBinarizer`], which discovers its classes by fitting, this
407/// function takes the class set as an explicit `classes` argument and encodes
408/// `y` against it. The output is a binary indicator matrix written with
409/// `pos_label` at active positions and `neg_label` everywhere else (defaults
410/// `0` / `1`).
411///
412/// # Column ordering (the headline)
413///
414/// The output **columns follow the GIVEN `classes` order**, NOT a sorted order.
415/// sklearn builds the indicator in sorted-class order internally
416/// (`sorted_class = np.sort(classes)`, `_label.py:542`; columns via
417/// `np.searchsorted`, `:558`) but then **reorders the columns back to the given
418/// `classes` order** in the "preserve label ordering" step (`:587-590`:
419/// `indices = np.searchsorted(sorted_class, classes); Y = Y[:, indices]`).
420/// So `label_binarize([0,2,1], classes=[2,0,1])` yields
421/// `[[0,1,0],[1,0,0],[0,0,1]]` — column `j` corresponds to `classes[j]`, the
422/// *given* class, with `pos_label` where `y[i] == classes[j]`. (Verified live
423/// vs sklearn 1.5.2; see `tests/divergence_label_binarizer.rs`.)
424///
425/// # Shape / collapse rules
426///
427/// The single-column collapse is gated on `type_of_target(y) == "binary"`, NOT
428/// on `len(classes)` (`_label.py:519` `y_type = type_of_target(y)`; `:531`
429/// `if y_type == "binary":`; the collapse at `:592-596`). For 1D integer `y`,
430/// `type_of_target` is "binary" iff `y` has at most two distinct values, else
431/// "multiclass" (verified live vs sklearn 1.5.2). Writing
432/// `y_is_binary = (distinct count of y) <= 2`:
433/// - `k == 1`: a single all-`neg_label` column (`_label.py:532-538`).
434/// - `k == 2` AND `y_is_binary`: a single column — `pos_label` where
435///   `y == classes[last]` (the LAST given class), else `neg_label`. sklearn
436///   builds both columns then takes `Y[:, -1]` after the reorder
437///   (`_label.py:596`), so the kept column is the one for the last *given*
438///   class. (When `classes` is sorted — as the fitted estimator always is —
439///   `classes[last]` is the second-sorted class, so this coincides with
440///   [`FittedLabelBinarizer::transform`]'s `idx == 1` rule.)
441/// - `k == 2` but `y` is multiclass (3+ distinct values): NO collapse —
442///   `k == 2` columns in given order (`y_type` is not "binary", so the `:592`
443///   single-column step is skipped). E.g. `label_binarize([0,1,2], classes=[0,1])`
444///   → `(3, 2)` `[[1,0],[0,1],[0,0]]`, with the unseen `2` leaving an all-`neg`
445///   row.
446/// - `k > 2`: `k` columns in given order, `pos_label` at the value's column.
447///
448/// A value in `y` not present in `classes` leaves its row all-`neg_label`,
449/// mirroring sklearn's `y_in_classes = np.isin(y, classes)` silent ignore
450/// (`_label.py:556-559`).
451///
452/// `classes` is expected to be unique (sklearn: "Uniquely holds the label for
453/// each class", `_label.py:447`); duplicate entries are not part of the matched
454/// contract.
455///
456/// # Errors
457///
458/// Returns [`FerroError::InvalidParameter`] if `neg_label >= pos_label`, with
459/// the same verbatim message as [`LabelBinarizer`]'s `fit`
460/// (`_label.py:499-504`: `neg_label={neg} must be strictly less than
461/// pos_label={pos}.`). Returns [`FerroError::InsufficientSamples`] if `classes`
462/// is empty (sklearn cannot binarize against zero classes).
463#[must_use = "label_binarize returns a new indicator matrix"]
464pub fn label_binarize(
465    y: &Array1<usize>,
466    classes: &[usize],
467    neg_label: i64,
468    pos_label: i64,
469) -> Result<Array2<f64>, FerroError> {
470    // Validate neg_label < pos_label, mirroring sklearn `label_binarize`
471    // (`_label.py:499-504`) — the SAME verbatim message as the estimator's
472    // `fit` (`LabelBinarizer::fit`).
473    if neg_label >= pos_label {
474        return Err(FerroError::InvalidParameter {
475            name: "neg_label".into(),
476            reason: format!(
477                "neg_label={neg_label} must be strictly less than pos_label={pos_label}."
478            ),
479        });
480    }
481
482    let k = classes.len();
483    if k == 0 {
484        return Err(FerroError::InsufficientSamples {
485            required: 1,
486            actual: 0,
487            context: "label_binarize: classes".into(),
488        });
489    }
490
491    let n = y.len();
492    let neg = neg_label as f64;
493    let pos = pos_label as f64;
494
495    // Map each given class value to its GIVEN-order column index. The output
496    // columns follow the given `classes` order (sklearn's "preserve label
497    // ordering" reorder, `_label.py:587-590`), so column `j` belongs to
498    // `classes[j]`. For unique `classes` the last write wins identically; the
499    // contract assumes uniqueness (`_label.py:447`).
500    let class_to_col: std::collections::HashMap<usize, usize> =
501        classes.iter().enumerate().map(|(j, &c)| (c, j)).collect();
502
503    // The single-column collapse is gated on `type_of_target(y) == "binary"`,
504    // NOT on `len(classes)` (`_label.py:519` `y_type = type_of_target(y)`;
505    // `:531` `if y_type == "binary":`; the collapse itself at `:592-596`). For
506    // 1D integer `y`, `type_of_target` returns "binary" iff `y` has at most two
507    // distinct values, else "multiclass" (verified live vs sklearn 1.5.2:
508    // 1-distinct → "binary", 2-distinct → "binary", 3+ distinct → "multiclass").
509    // So `[5,5]` (1 distinct) and `[0,1,0]` (2 distinct) are binary, but
510    // `[0,1,2]` (3 distinct) is multiclass. When `k == 2` but `y` is multiclass,
511    // sklearn promotes to the `n_classes`-column form (`:539-540` only fires for
512    // `len(classes) >= 3`; here the non-binary `y_type` simply means the `:592`
513    // collapse is skipped), giving a `(n, 2)` indicator.
514    let mut distinct: Vec<usize> = y.iter().copied().collect();
515    distinct.sort_unstable();
516    distinct.dedup();
517    let y_is_binary = distinct.len() <= 2;
518
519    if k == 1 {
520        // n_classes == 1: all-`neg_label` single column (`_label.py:532-538`:
521        // `Y = np.zeros((len(y), 1)); Y += neg_label`). sklearn reaches this only
522        // when `y_type == "binary"` too; for plain integer `y` a single class
523        // implies `y` has ≤1 distinct value, which is always binary, so this
524        // single-column form is unconditional here.
525        Ok(Array2::from_elem((n, 1), neg))
526    } else if k == 2 && y_is_binary {
527        // Binary `y` with exactly two classes: the single column kept after the
528        // given-order reorder is `Y[:, -1]` (`_label.py:596`) — the column for
529        // the LAST given class. So `pos_label` where `y == classes[1]`, else
530        // `neg_label`. Unseen labels (not in `classes`) stay at `neg_label`
531        // (`:556-559`).
532        let last_class = classes[1];
533        let mut out = Array2::from_elem((n, 1), neg);
534        for (i, &label) in y.iter().enumerate() {
535            if label == last_class {
536                out[[i, 0]] = pos;
537            }
538        }
539        Ok(out)
540    } else {
541        // `k` columns in GIVEN order, `pos_label` at the value's column
542        // (`_label.py:552-577` + the `:587-590` reorder to the given order).
543        // Reached for genuine multiclass (`k > 2`) AND for `k == 2` with a
544        // multiclass `y` (3+ distinct values), where sklearn skips the `:592`
545        // single-column collapse and emits the full `(n, k)` indicator. Unseen
546        // labels leave the row all-`neg_label` (`:556-559`).
547        let mut out = Array2::from_elem((n, k), neg);
548        for (i, &label) in y.iter().enumerate() {
549            if let Some(&col) = class_to_col.get(&label) {
550                out[[i, col]] = pos;
551            }
552        }
553        Ok(out)
554    }
555}
556
557// ===========================================================================
558// Tests
559// ===========================================================================
560
561#[cfg(test)]
562mod tests {
563    use super::*;
564    use ndarray::array;
565
566    #[test]
567    fn test_fit_discovers_sorted_classes() {
568        let lb = LabelBinarizer::new();
569        let y = array![2_usize, 0, 1, 2, 0];
570        let fitted = lb.fit(&y, &()).unwrap();
571        assert_eq!(fitted.classes(), &[0, 1, 2]);
572    }
573
574    #[test]
575    fn test_fit_empty_input_error() {
576        let lb = LabelBinarizer::new();
577        let y: Array1<usize> = Array1::zeros(0);
578        assert!(lb.fit(&y, &()).is_err());
579    }
580
581    #[test]
582    fn test_binary_transform_single_column() {
583        let lb = LabelBinarizer::new();
584        let y = array![0_usize, 1, 0, 1];
585        let fitted = lb.fit(&y, &()).unwrap();
586        let mat = fitted.transform(&y).unwrap();
587        assert_eq!(mat.shape(), &[4, 1]);
588        assert_eq!(mat[[0, 0]], 0.0); // class 0 → 0
589        assert_eq!(mat[[1, 0]], 1.0); // class 1 → 1
590        assert_eq!(mat[[2, 0]], 0.0);
591        assert_eq!(mat[[3, 0]], 1.0);
592    }
593
594    #[test]
595    fn test_multiclass_transform_indicator_matrix() {
596        let lb = LabelBinarizer::new();
597        let y = array![0_usize, 1, 2, 1];
598        let fitted = lb.fit(&y, &()).unwrap();
599        let mat = fitted.transform(&y).unwrap();
600        assert_eq!(mat.shape(), &[4, 3]);
601        // Row 0: class 0 → [1, 0, 0]
602        assert_eq!(mat[[0, 0]], 1.0);
603        assert_eq!(mat[[0, 1]], 0.0);
604        assert_eq!(mat[[0, 2]], 0.0);
605        // Row 2: class 2 → [0, 0, 1]
606        assert_eq!(mat[[2, 0]], 0.0);
607        assert_eq!(mat[[2, 1]], 0.0);
608        assert_eq!(mat[[2, 2]], 1.0);
609    }
610
611    #[test]
612    fn test_inverse_transform_multiclass() {
613        let lb = LabelBinarizer::new();
614        let y = array![0_usize, 1, 2, 1];
615        let fitted = lb.fit(&y, &()).unwrap();
616        let mat = fitted.transform(&y).unwrap();
617        let recovered = fitted.inverse_transform(&mat).unwrap();
618        assert_eq!(recovered, y);
619    }
620
621    #[test]
622    fn test_inverse_transform_binary() {
623        let lb = LabelBinarizer::new();
624        let y = array![0_usize, 1, 0, 1];
625        let fitted = lb.fit(&y, &()).unwrap();
626        let mat = fitted.transform(&y).unwrap();
627        let recovered = fitted.inverse_transform(&mat).unwrap();
628        assert_eq!(recovered, y);
629    }
630
631    /// Unseen labels are silently ignored (row left all-neg_label), mirroring
632    /// sklearn `label_binarize` (`_label.py:556-559`).
633    ///
634    /// Live oracle (sklearn 1.5.2, from /tmp):
635    ///   `LabelBinarizer().fit([0,1,2]).transform([0,3]).tolist()`
636    ///     -> `[[1, 0, 0], [0, 0, 0]]`
637    #[test]
638    fn test_transform_unknown_label_ignored() {
639        let lb = LabelBinarizer::new();
640        let y = array![0_usize, 1, 2];
641        let y2 = array![0_usize, 3]; // 3 not in {0,1,2}
642        // Fit then transform, propagating any error into the Result we compare.
643        let got = lb.fit(&y, &()).and_then(|fitted| fitted.transform(&y2));
644        // sklearn-oracle value: [[1,0,0],[0,0,0]] (label 3 ignored, all-zero row).
645        // Compare via Option (FerroError is not PartialEq); Ok(_) is required.
646        let expected: Array2<f64> = array![[1.0, 0.0, 0.0], [0.0, 0.0, 0.0]];
647        assert_eq!(got.ok(), Some(expected));
648    }
649
650    #[test]
651    fn test_inverse_transform_shape_mismatch() {
652        let lb = LabelBinarizer::new();
653        let y = array![0_usize, 1, 2];
654        let fitted = lb.fit(&y, &()).unwrap();
655        // 3 classes expects 3 columns, but we give 2
656        let bad = Array2::<f64>::zeros((2, 2));
657        assert!(fitted.inverse_transform(&bad).is_err());
658    }
659
660    /// Single class (n_classes == 1) → an all-zero single column, mirroring
661    /// sklearn's binary-degenerate case (`_label.py:532-538`).
662    ///
663    /// Live oracle (sklearn 1.5.2):
664    ///   `LabelBinarizer().fit_transform([5,5,5]).tolist()` -> `[[0],[0],[0]]`
665    #[test]
666    fn test_single_class() {
667        let lb = LabelBinarizer::new();
668        let y = array![5_usize, 5, 5];
669        // Confirm exactly one class is discovered (degenerate single-class case).
670        let n_classes = lb.fit(&y, &()).map(|fitted| fitted.n_classes());
671        assert_eq!(n_classes.ok(), Some(1));
672        // Fit then transform, propagating any error into the Result we compare.
673        let got = lb.fit(&y, &()).and_then(|fitted| fitted.transform(&y));
674        // 1 class → 1 column, all zeros (never 1.0); sklearn-oracle [[0],[0],[0]].
675        let expected: Array2<f64> = array![[0.0], [0.0], [0.0]];
676        assert_eq!(got.ok(), Some(expected));
677    }
678
679    #[test]
680    fn test_non_contiguous_classes() {
681        let lb = LabelBinarizer::new();
682        let y = array![10_usize, 20, 30, 10];
683        let fitted = lb.fit(&y, &()).unwrap();
684        assert_eq!(fitted.classes(), &[10, 20, 30]);
685        let mat = fitted.transform(&y).unwrap();
686        assert_eq!(mat.shape(), &[4, 3]);
687        assert_eq!(mat[[0, 0]], 1.0); // 10 → col 0
688        assert_eq!(mat[[1, 1]], 1.0); // 20 → col 1
689        assert_eq!(mat[[2, 2]], 1.0); // 30 → col 2
690    }
691
692    #[test]
693    fn test_roundtrip_multiclass_non_contiguous() {
694        let lb = LabelBinarizer::new();
695        let y = array![10_usize, 20, 30, 20];
696        let fitted = lb.fit(&y, &()).unwrap();
697        let mat = fitted.transform(&y).unwrap();
698        let recovered = fitted.inverse_transform(&mat).unwrap();
699        assert_eq!(recovered, y);
700    }
701
702    // -- REQ-8: neg_label / pos_label ctor params + validation ----------------
703
704    /// REQ-8: builders + getters carry the configured neg/pos through fit.
705    /// Defaults preserve the canonical 0/1 encoding.
706    #[test]
707    fn test_neg_pos_label_builders_and_getters() {
708        let lb = LabelBinarizer::new();
709        assert_eq!(lb.neg_label(), 0);
710        assert_eq!(lb.pos_label(), 1);
711
712        let lb = LabelBinarizer::new().with_neg_label(-1).with_pos_label(2);
713        assert_eq!(lb.neg_label(), -1);
714        assert_eq!(lb.pos_label(), 2);
715
716        let fitted = lb.fit(&array![0_usize, 1, 2], &()).unwrap();
717        assert_eq!(fitted.neg_label(), -1);
718        assert_eq!(fitted.pos_label(), 2);
719    }
720
721    /// REQ-8: multiclass transform with neg_label=-1, pos_label=2.
722    /// Live oracle (sklearn 1.5.2, from /tmp):
723    ///   `LabelBinarizer(neg_label=-1,pos_label=2).fit([0,1,2]).transform([0,2]).tolist()`
724    ///     -> `[[2,-1,-1],[-1,-1,2]]` (present->2, absent->-1; base is -1, not 0)
725    #[test]
726    fn test_neg_pos_multiclass_transform() {
727        let lb = LabelBinarizer::new().with_neg_label(-1).with_pos_label(2);
728        let fitted = lb.fit(&array![0_usize, 1, 2], &()).unwrap();
729        let got = fitted.transform(&array![0_usize, 2]).unwrap();
730        let expected: Array2<f64> = array![[2.0, -1.0, -1.0], [-1.0, -1.0, 2.0]];
731        assert_eq!(got, expected);
732    }
733
734    /// REQ-8: binary (k==2) transform with neg_label=-1, pos_label=1.
735    /// Live oracle (sklearn 1.5.2, from /tmp):
736    ///   `LabelBinarizer(neg_label=-1,pos_label=1).fit([0,1]).transform([0,1,0]).tolist()`
737    ///     -> `[[-1],[1],[-1]]` (2nd class -> pos_label, else neg_label)
738    #[test]
739    fn test_neg_pos_binary_transform() {
740        let lb = LabelBinarizer::new().with_neg_label(-1).with_pos_label(1);
741        let fitted = lb.fit(&array![0_usize, 1], &()).unwrap();
742        let got = fitted.transform(&array![0_usize, 1, 0]).unwrap();
743        let expected: Array2<f64> = array![[-1.0], [1.0], [-1.0]];
744        assert_eq!(got, expected);
745    }
746
747    /// REQ-8: single-class (k==1) transform -> all neg_label.
748    /// Live oracle (sklearn 1.5.2, from /tmp):
749    ///   `LabelBinarizer(neg_label=-1,pos_label=2).fit_transform([5,5,5]).tolist()`
750    ///     -> `[[-1],[-1],[-1]]`
751    #[test]
752    fn test_neg_pos_single_class_all_neg() {
753        let lb = LabelBinarizer::new().with_neg_label(-1).with_pos_label(2);
754        let y = array![5_usize, 5, 5];
755        let fitted = lb.fit(&y, &()).unwrap();
756        let got = fitted.transform(&y).unwrap();
757        let expected: Array2<f64> = array![[-1.0], [-1.0], [-1.0]];
758        assert_eq!(got, expected);
759    }
760
761    /// REQ-8: unseen labels stay at neg_label (silent-ignore, now -1).
762    /// Live oracle (sklearn 1.5.2, from /tmp):
763    ///   `LabelBinarizer(neg_label=-1,pos_label=2).fit([0,1,2]).transform([0,3]).tolist()`
764    ///     -> `[[2,-1,-1],[-1,-1,-1]]` (label 3 ignored, row all neg_label)
765    #[test]
766    fn test_neg_pos_unseen_label_stays_neg() {
767        let lb = LabelBinarizer::new().with_neg_label(-1).with_pos_label(2);
768        let fitted = lb.fit(&array![0_usize, 1, 2], &()).unwrap();
769        let got = fitted.transform(&array![0_usize, 3]).unwrap();
770        let expected: Array2<f64> = array![[2.0, -1.0, -1.0], [-1.0, -1.0, -1.0]];
771        assert_eq!(got, expected);
772    }
773
774    /// REQ-8: neg_label >= pos_label is rejected at fit time, verbatim message.
775    /// Live oracle (sklearn 1.5.2, from /tmp):
776    ///   `LabelBinarizer(neg_label=2,pos_label=1).fit([0,1])`
777    ///     -> ValueError: "neg_label=2 must be strictly less than pos_label=1."
778    ///   `LabelBinarizer(neg_label=1,pos_label=1).fit([0,1])`
779    ///     -> ValueError: "neg_label=1 must be strictly less than pos_label=1."
780    #[test]
781    fn test_neg_ge_pos_rejected() {
782        // neg > pos
783        let err = LabelBinarizer::new()
784            .with_neg_label(2)
785            .with_pos_label(1)
786            .fit(&array![0_usize, 1], &())
787            .unwrap_err();
788        assert!(matches!(
789            &err,
790            FerroError::InvalidParameter { name, reason }
791                if name == "neg_label"
792                    && reason == "neg_label=2 must be strictly less than pos_label=1."
793        ));
794
795        // neg == pos
796        let err = LabelBinarizer::new()
797            .with_neg_label(1)
798            .with_pos_label(1)
799            .fit(&array![0_usize, 1], &())
800            .unwrap_err();
801        assert!(matches!(
802            &err,
803            FerroError::InvalidParameter { reason, .. }
804                if reason == "neg_label=1 must be strictly less than pos_label=1."
805        ));
806    }
807
808    /// REQ-8: inverse_transform binary threshold = (pos+neg)/2 (STRICT).
809    /// Live oracle (sklearn 1.5.2, from /tmp):
810    ///   neg=-1,pos=1 -> threshold 0.0:
811    ///     `inverse_transform([[0.0]])` -> [0]; `[[0.1]]` -> [1]; `[[-0.1]]` -> [0]
812    ///   neg=2,pos=4 -> threshold 3.0:
813    ///     `inverse_transform([[3.0]])` -> [0]; `[[3.1]]` -> [1]
814    #[test]
815    fn test_neg_pos_inverse_threshold() {
816        let fitted = LabelBinarizer::new()
817            .with_neg_label(-1)
818            .with_pos_label(1)
819            .fit(&array![0_usize, 1], &())
820            .unwrap();
821        // threshold = (1 + -1)/2 = 0.0; strict `> 0.0`
822        assert_eq!(
823            fitted.inverse_transform(&array![[0.0_f64]]).unwrap(),
824            array![0_usize]
825        );
826        assert_eq!(
827            fitted.inverse_transform(&array![[0.1_f64]]).unwrap(),
828            array![1_usize]
829        );
830        assert_eq!(
831            fitted.inverse_transform(&array![[-0.1_f64]]).unwrap(),
832            array![0_usize]
833        );
834
835        let fitted = LabelBinarizer::new()
836            .with_neg_label(2)
837            .with_pos_label(4)
838            .fit(&array![0_usize, 1], &())
839            .unwrap();
840        // threshold = (4 + 2)/2 = 3.0; strict `> 3.0`
841        assert_eq!(
842            fitted.inverse_transform(&array![[3.0_f64]]).unwrap(),
843            array![0_usize]
844        );
845        assert_eq!(
846            fitted.inverse_transform(&array![[3.1_f64]]).unwrap(),
847            array![1_usize]
848        );
849    }
850
851    /// REQ-8: inverse_transform multiclass round-trip with neg/pos (argmax
852    /// unchanged — pos_label is the largest so argmax still selects it).
853    /// Live oracle (sklearn 1.5.2, from /tmp):
854    ///   `LabelBinarizer(neg_label=-1,pos_label=2).fit([0,1,2]).inverse_transform(
855    ///       [[2,-1,-1],[-1,-1,2]])` -> [0, 2]
856    #[test]
857    fn test_neg_pos_inverse_multiclass_roundtrip() {
858        let fitted = LabelBinarizer::new()
859            .with_neg_label(-1)
860            .with_pos_label(2)
861            .fit(&array![0_usize, 1, 2], &())
862            .unwrap();
863        let mat: Array2<f64> = array![[2.0, -1.0, -1.0], [-1.0, -1.0, 2.0]];
864        let recovered = fitted.inverse_transform(&mat).unwrap();
865        assert_eq!(recovered, array![0_usize, 2]);
866    }
867
868    /// REQ-1/2/3 preserved: defaults (0/1) reproduce the canonical encoding.
869    #[test]
870    fn test_defaults_preserve_zero_one() {
871        let lb = LabelBinarizer::new();
872        // multiclass
873        let fitted = lb.fit(&array![0_usize, 1, 2, 1], &()).unwrap();
874        let expected: Array2<f64> = array![
875            [1.0, 0.0, 0.0],
876            [0.0, 1.0, 0.0],
877            [0.0, 0.0, 1.0],
878            [0.0, 1.0, 0.0]
879        ];
880        assert_eq!(
881            fitted.transform(&array![0_usize, 1, 2, 1]).unwrap(),
882            expected
883        );
884        // Default also via Default::default()
885        let lb2 = LabelBinarizer::default();
886        assert_eq!((lb2.neg_label(), lb2.pos_label()), (0, 1));
887    }
888}