use ferrolearn_core::error::FerroError;
use ferrolearn_core::traits::{Fit, Transform};
use ndarray::Array2;
#[derive(Debug, Clone, Default)]
pub struct MultiLabelBinarizer;
impl MultiLabelBinarizer {
#[must_use]
pub fn new() -> Self {
Self
}
}
#[derive(Debug, Clone)]
pub struct FittedMultiLabelBinarizer {
classes: Vec<usize>,
}
impl FittedMultiLabelBinarizer {
#[must_use]
pub fn classes(&self) -> &[usize] {
&self.classes
}
#[must_use]
pub fn n_classes(&self) -> usize {
self.classes.len()
}
pub fn inverse_transform(&self, y: &Array2<f64>) -> Result<Vec<Vec<usize>>, FerroError> {
let k = self.classes.len();
if y.ncols() != k {
return Err(FerroError::ShapeMismatch {
expected: vec![y.nrows(), k],
actual: vec![y.nrows(), y.ncols()],
context: "FittedMultiLabelBinarizer::inverse_transform".into(),
});
}
let n = y.nrows();
let mut result = Vec::with_capacity(n);
for i in 0..n {
let mut labels = Vec::new();
for (j, &cls) in self.classes.iter().enumerate() {
if y[[i, j]] >= 0.5 {
labels.push(cls);
}
}
result.push(labels);
}
Ok(result)
}
}
impl Fit<Vec<Vec<usize>>, ()> for MultiLabelBinarizer {
type Fitted = FittedMultiLabelBinarizer;
type Error = FerroError;
fn fit(
&self,
y: &Vec<Vec<usize>>,
_target: &(),
) -> Result<FittedMultiLabelBinarizer, FerroError> {
if y.is_empty() {
return Err(FerroError::InsufficientSamples {
required: 1,
actual: 0,
context: "MultiLabelBinarizer::fit".into(),
});
}
let mut classes: Vec<usize> = y.iter().flatten().copied().collect();
classes.sort_unstable();
classes.dedup();
Ok(FittedMultiLabelBinarizer { classes })
}
}
impl Transform<Vec<Vec<usize>>> for FittedMultiLabelBinarizer {
type Output = Array2<f64>;
type Error = FerroError;
fn transform(&self, y: &Vec<Vec<usize>>) -> Result<Array2<f64>, FerroError> {
let k = self.classes.len();
let n = y.len();
let class_to_idx: std::collections::HashMap<usize, usize> = self
.classes
.iter()
.enumerate()
.map(|(i, &c)| (c, i))
.collect();
let mut out = Array2::zeros((n, k));
for (i, labels) in y.iter().enumerate() {
for &label in labels {
let &idx = class_to_idx.get(&label).ok_or_else(|| {
FerroError::InvalidParameter {
name: "y".into(),
reason: format!("unknown label {label} not seen during fit"),
}
})?;
out[[i, idx]] = 1.0;
}
}
Ok(out)
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
fn test_fit_discovers_sorted_classes() {
let mlb = MultiLabelBinarizer::new();
let y = vec![vec![2, 0], vec![1]];
let fitted = mlb.fit(&y, &()).unwrap();
assert_eq!(fitted.classes(), &[0, 1, 2]);
}
#[test]
fn test_fit_empty_input_error() {
let mlb = MultiLabelBinarizer::new();
let y: Vec<Vec<usize>> = vec![];
assert!(mlb.fit(&y, &()).is_err());
}
#[test]
fn test_transform_multi_hot() {
let mlb = MultiLabelBinarizer::new();
let y = vec![vec![0, 2], vec![1], vec![0, 1, 2]];
let fitted = mlb.fit(&y, &()).unwrap();
let mat = fitted.transform(&y).unwrap();
assert_eq!(mat.shape(), &[3, 3]);
assert_eq!(mat[[0, 0]], 1.0);
assert_eq!(mat[[0, 1]], 0.0);
assert_eq!(mat[[0, 2]], 1.0);
assert_eq!(mat[[1, 0]], 0.0);
assert_eq!(mat[[1, 1]], 1.0);
assert_eq!(mat[[1, 2]], 0.0);
assert_eq!(mat[[2, 0]], 1.0);
assert_eq!(mat[[2, 1]], 1.0);
assert_eq!(mat[[2, 2]], 1.0);
}
#[test]
fn test_transform_unknown_label_error() {
let mlb = MultiLabelBinarizer::new();
let y = vec![vec![0, 1]];
let fitted = mlb.fit(&y, &()).unwrap();
let y2 = vec![vec![0, 5]]; assert!(fitted.transform(&y2).is_err());
}
#[test]
fn test_inverse_transform_roundtrip() {
let mlb = MultiLabelBinarizer::new();
let y = vec![vec![0, 2], vec![1], vec![0, 1, 2]];
let fitted = mlb.fit(&y, &()).unwrap();
let mat = fitted.transform(&y).unwrap();
let recovered = fitted.inverse_transform(&mat).unwrap();
assert_eq!(recovered, y);
}
#[test]
fn test_inverse_transform_shape_mismatch() {
let mlb = MultiLabelBinarizer::new();
let y = vec![vec![0, 1, 2]];
let fitted = mlb.fit(&y, &()).unwrap();
let bad = Array2::<f64>::zeros((2, 2));
assert!(fitted.inverse_transform(&bad).is_err());
}
#[test]
fn test_empty_label_set() {
let mlb = MultiLabelBinarizer::new();
let y = vec![vec![0, 1], vec![]]; let fitted = mlb.fit(&y, &()).unwrap();
let mat = fitted.transform(&y).unwrap();
assert_eq!(mat.shape(), &[2, 2]);
assert_eq!(mat[[1, 0]], 0.0);
assert_eq!(mat[[1, 1]], 0.0);
}
#[test]
fn test_inverse_transform_empty_row() {
let mlb = MultiLabelBinarizer::new();
let y = vec![vec![0, 1], vec![]];
let fitted = mlb.fit(&y, &()).unwrap();
let mat = fitted.transform(&y).unwrap();
let recovered = fitted.inverse_transform(&mat).unwrap();
assert_eq!(recovered, y);
}
#[test]
fn test_non_contiguous_classes() {
let mlb = MultiLabelBinarizer::new();
let y = vec![vec![10, 30], vec![20]];
let fitted = mlb.fit(&y, &()).unwrap();
assert_eq!(fitted.classes(), &[10, 20, 30]);
let mat = fitted.transform(&y).unwrap();
assert_eq!(mat.shape(), &[2, 3]);
assert_eq!(mat[[0, 0]], 1.0); assert_eq!(mat[[0, 1]], 0.0); assert_eq!(mat[[0, 2]], 1.0); }
#[test]
fn test_inverse_transform_non_contiguous_roundtrip() {
let mlb = MultiLabelBinarizer::new();
let y = vec![vec![10, 30], vec![20]];
let fitted = mlb.fit(&y, &()).unwrap();
let mat = fitted.transform(&y).unwrap();
let recovered = fitted.inverse_transform(&mat).unwrap();
assert_eq!(recovered, y);
}
#[test]
fn test_duplicate_labels_in_input() {
let mlb = MultiLabelBinarizer::new();
let y = vec![vec![0, 0, 1]]; let fitted = mlb.fit(&y, &()).unwrap();
let mat = fitted.transform(&y).unwrap();
assert_eq!(mat.shape(), &[1, 2]);
assert_eq!(mat[[0, 0]], 1.0);
assert_eq!(mat[[0, 1]], 1.0);
}
#[test]
fn test_inverse_threshold() {
let mlb = MultiLabelBinarizer::new();
let y = vec![vec![0, 1, 2]];
let fitted = mlb.fit(&y, &()).unwrap();
let mat = array![[0.4, 0.6, 0.5]];
let recovered = fitted.inverse_transform(&mat).unwrap();
assert_eq!(recovered, vec![vec![1, 2]]); }
}