use ferrolearn_core::error::FerroError;
use ferrolearn_core::traits::{Fit, FitTransform, Transform};
use ndarray::Array1;
use std::collections::HashMap;
#[derive(Debug, Clone, Default)]
pub struct LabelEncoder;
impl LabelEncoder {
#[must_use]
pub fn new() -> Self {
Self
}
}
#[derive(Debug, Clone)]
pub struct FittedLabelEncoder {
pub(crate) classes: Vec<String>,
pub(crate) label_to_index: HashMap<String, usize>,
}
impl FittedLabelEncoder {
#[must_use]
pub fn classes(&self) -> &[String] {
&self.classes
}
#[must_use]
pub fn n_classes(&self) -> usize {
self.classes.len()
}
pub fn inverse_transform(&self, y: &Array1<usize>) -> Result<Array1<String>, FerroError> {
let n_classes = self.classes.len();
let mut out = Vec::with_capacity(y.len());
for (i, &idx) in y.iter().enumerate() {
if idx >= n_classes {
return Err(FerroError::InvalidParameter {
name: format!("y[{i}]"),
reason: format!("index {idx} is out of range (n_classes = {n_classes})"),
});
}
out.push(self.classes[idx].clone());
}
Ok(Array1::from_vec(out))
}
}
impl Fit<Array1<String>, ()> for LabelEncoder {
type Fitted = FittedLabelEncoder;
type Error = FerroError;
fn fit(&self, x: &Array1<String>, _y: &()) -> Result<FittedLabelEncoder, FerroError> {
if x.is_empty() {
return Err(FerroError::InsufficientSamples {
required: 1,
actual: 0,
context: "LabelEncoder::fit".into(),
});
}
let mut unique: Vec<String> = x
.iter()
.cloned()
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
unique.sort();
let label_to_index: HashMap<String, usize> = unique
.iter()
.enumerate()
.map(|(i, label)| (label.clone(), i))
.collect();
Ok(FittedLabelEncoder {
classes: unique,
label_to_index,
})
}
}
impl Transform<Array1<String>> for FittedLabelEncoder {
type Output = Array1<usize>;
type Error = FerroError;
fn transform(&self, x: &Array1<String>) -> Result<Array1<usize>, FerroError> {
let mut out = Vec::with_capacity(x.len());
for (i, label) in x.iter().enumerate() {
match self.label_to_index.get(label) {
Some(&idx) => out.push(idx),
None => {
return Err(FerroError::InvalidParameter {
name: format!("x[{i}]"),
reason: format!("unknown label \"{label}\""),
});
}
}
}
Ok(Array1::from_vec(out))
}
}
impl Transform<Array1<String>> for LabelEncoder {
type Output = Array1<usize>;
type Error = FerroError;
fn transform(&self, _x: &Array1<String>) -> Result<Array1<usize>, FerroError> {
Err(FerroError::InvalidParameter {
name: "LabelEncoder".into(),
reason: "encoder must be fitted before calling transform; use fit() first".into(),
})
}
}
impl FitTransform<Array1<String>> for LabelEncoder {
type FitError = FerroError;
fn fit_transform(&self, x: &Array1<String>) -> Result<Array1<usize>, FerroError> {
let fitted = self.fit(x, &())?;
fitted.transform(x)
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
fn str_arr(v: &[&str]) -> Array1<String> {
Array1::from_vec(v.iter().map(std::string::ToString::to_string).collect())
}
#[test]
fn test_label_encoder_basic() {
let enc = LabelEncoder::new();
let labels = str_arr(&["cat", "dog", "cat", "bird"]);
let fitted = enc.fit(&labels, &()).unwrap();
assert_eq!(fitted.classes(), &["bird", "cat", "dog"]);
assert_eq!(fitted.n_classes(), 3);
let encoded = fitted.transform(&labels).unwrap();
assert_eq!(encoded[0], 1); assert_eq!(encoded[1], 2); assert_eq!(encoded[2], 1); assert_eq!(encoded[3], 0); }
#[test]
fn test_inverse_transform_roundtrip() {
let enc = LabelEncoder::new();
let labels = str_arr(&["a", "b", "c", "a", "b"]);
let fitted = enc.fit(&labels, &()).unwrap();
let encoded = fitted.transform(&labels).unwrap();
let recovered = fitted.inverse_transform(&encoded).unwrap();
for (orig, rec) in labels.iter().zip(recovered.iter()) {
assert_eq!(orig, rec);
}
}
#[test]
fn test_unknown_label_error() {
let enc = LabelEncoder::new();
let labels = str_arr(&["a", "b"]);
let fitted = enc.fit(&labels, &()).unwrap();
let unknown = str_arr(&["c"]);
assert!(fitted.transform(&unknown).is_err());
}
#[test]
fn test_inverse_transform_out_of_range() {
let enc = LabelEncoder::new();
let labels = str_arr(&["x", "y"]);
let fitted = enc.fit(&labels, &()).unwrap();
let bad_indices = array![5usize];
assert!(fitted.inverse_transform(&bad_indices).is_err());
}
#[test]
fn test_fit_transform_equivalence() {
let enc = LabelEncoder::new();
let labels = str_arr(&["foo", "bar", "foo", "baz"]);
let via_fit_transform = enc.fit_transform(&labels).unwrap();
let fitted = enc.fit(&labels, &()).unwrap();
let via_separate = fitted.transform(&labels).unwrap();
assert_eq!(via_fit_transform, via_separate);
}
#[test]
fn test_empty_input_error() {
let enc = LabelEncoder::new();
let empty: Array1<String> = Array1::from_vec(vec![]);
assert!(enc.fit(&empty, &()).is_err());
}
#[test]
fn test_single_class() {
let enc = LabelEncoder::new();
let labels = str_arr(&["only", "only", "only"]);
let fitted = enc.fit(&labels, &()).unwrap();
assert_eq!(fitted.n_classes(), 1);
let encoded = fitted.transform(&labels).unwrap();
assert!(encoded.iter().all(|&v| v == 0));
}
}