use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use std::collections::HashMap;
#[derive(Debug, Clone)]
struct LabelEncoderState {
classes: Vec<String>,
class_to_index: HashMap<String, usize>,
}
#[pyclass(name = "LabelEncoder")]
pub struct PyLabelEncoder {
state: Option<LabelEncoderState>,
}
#[pymethods]
impl PyLabelEncoder {
#[new]
fn new() -> Self {
Self { state: None }
}
fn fit(&mut self, y: Vec<String>) -> PyResult<()> {
if y.is_empty() {
return Err(PyValueError::new_err("y cannot be empty"));
}
let mut classes: Vec<String> = y
.into_iter()
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
classes.sort();
let class_to_index: HashMap<String, usize> = classes
.iter()
.enumerate()
.map(|(i, c)| (c.clone(), i))
.collect();
self.state = Some(LabelEncoderState {
classes,
class_to_index,
});
Ok(())
}
fn fit_transform(&mut self, y: Vec<String>) -> PyResult<Vec<i64>> {
self.fit(y.clone())?;
self.transform(y)
}
fn transform(&self, y: Vec<String>) -> PyResult<Vec<i64>> {
let state = self
.state
.as_ref()
.ok_or_else(|| PyValueError::new_err("LabelEncoder not fitted. Call fit() first."))?;
let mut encoded = Vec::with_capacity(y.len());
for label in y.iter() {
match state.class_to_index.get(label) {
Some(&index) => encoded.push(index as i64),
None => {
return Err(PyValueError::new_err(format!(
"Unknown label '{}'. Label encoder has only seen: {:?}",
label, state.classes
)));
}
}
}
Ok(encoded)
}
fn inverse_transform(&self, y: Vec<i64>) -> PyResult<Vec<String>> {
let state = self
.state
.as_ref()
.ok_or_else(|| PyValueError::new_err("LabelEncoder not fitted. Call fit() first."))?;
let mut decoded = Vec::with_capacity(y.len());
for &index in y.iter() {
if index < 0 || index >= state.classes.len() as i64 {
return Err(PyValueError::new_err(format!(
"Index {} is out of bounds for {} classes",
index,
state.classes.len()
)));
}
decoded.push(state.classes[index as usize].clone());
}
Ok(decoded)
}
#[getter]
fn classes_(&self) -> PyResult<Vec<String>> {
let state = self
.state
.as_ref()
.ok_or_else(|| PyValueError::new_err("LabelEncoder not fitted. Call fit() first."))?;
Ok(state.classes.clone())
}
fn __repr__(&self) -> String {
"LabelEncoder()".to_string()
}
}