use crate::error::{Result, ScryLearnError};
#[derive(Clone, Debug)]
#[non_exhaustive]
pub struct LabelEncoder {
classes: Vec<String>,
fitted: bool,
}
impl LabelEncoder {
pub fn new() -> Self {
Self {
classes: Vec::new(),
fitted: false,
}
}
pub fn fit(&mut self, labels: &[&str]) {
let mut unique: Vec<String> = labels
.iter()
.map(std::string::ToString::to_string)
.collect();
unique.sort();
unique.dedup();
self.classes = unique;
self.fitted = true;
}
pub fn transform(&self, labels: &[&str]) -> Result<Vec<f64>> {
if !self.fitted {
return Err(ScryLearnError::NotFitted);
}
labels
.iter()
.map(|&label| {
self.classes
.iter()
.position(|c| c == label)
.map(|i| i as f64)
.ok_or_else(|| {
ScryLearnError::InvalidParameter(format!("unknown label: {label}"))
})
})
.collect()
}
pub fn inverse_transform(&self, indices: &[f64]) -> Result<Vec<String>> {
if !self.fitted {
return Err(ScryLearnError::NotFitted);
}
indices
.iter()
.map(|&idx| {
let i = idx as usize;
self.classes.get(i).cloned().ok_or_else(|| {
ScryLearnError::InvalidParameter(format!("index out of range: {i}"))
})
})
.collect()
}
pub fn classes(&self) -> &[String] {
&self.classes
}
pub fn n_classes(&self) -> usize {
self.classes.len()
}
}
impl Default for LabelEncoder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_label_encoder_roundtrip() {
let mut enc = LabelEncoder::new();
enc.fit(&["cat", "dog", "bird", "cat"]);
assert_eq!(enc.n_classes(), 3);
let encoded = enc.transform(&["dog", "cat", "bird"]).unwrap();
assert_eq!(encoded, vec![2.0, 1.0, 0.0]);
let decoded = enc.inverse_transform(&encoded).unwrap();
assert_eq!(decoded, vec!["dog", "cat", "bird"]);
}
#[test]
fn test_label_encoder_unknown() {
let mut enc = LabelEncoder::new();
enc.fit(&["a", "b"]);
assert!(enc.transform(&["c"]).is_err());
}
}