#[derive(Debug, Clone)]
pub struct LabelEncoder<T> {
classes: Vec<T>,
}
impl<T> Default for LabelEncoder<T> {
fn default() -> Self {
Self {
classes: Vec::new(),
}
}
}
impl<T: Ord + Clone> LabelEncoder<T> {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn fit(&mut self, y: &[T]) -> &mut Self {
let mut classes = y.to_vec();
classes.sort();
classes.dedup();
self.classes = classes;
self
}
#[must_use]
pub fn transform(&self, y: &[T]) -> Vec<usize> {
y.iter()
.map(|v| self.classes.binary_search(v).unwrap_or(self.classes.len()))
.collect()
}
pub fn fit_transform(&mut self, y: &[T]) -> Vec<usize> {
self.fit(y);
self.transform(y)
}
#[must_use]
pub fn inverse_transform(&self, codes: &[usize]) -> Vec<T> {
codes
.iter()
.filter_map(|&c| self.classes.get(c).cloned())
.collect()
}
#[must_use]
pub fn classes(&self) -> &[T] {
&self.classes
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn label_encoder_matches_sklearn() {
let mut le = LabelEncoder::new();
let codes = le.fit_transform(&["A", "B", "C", "A"]);
assert_eq!(codes, vec![0, 1, 2, 0]);
assert_eq!(le.classes(), &["A", "B", "C"]);
assert_eq!(le.inverse_transform(&[2, 0, 1]), vec!["C", "A", "B"]);
let mut li = LabelEncoder::new();
assert_eq!(li.fit_transform(&[10i64, 5, 10, 20]), vec![1, 0, 1, 2]);
assert_eq!(li.classes(), &[5, 10, 20]);
}
}