use anofox_ml_core::{Result, RustMlError};
use std::collections::HashMap;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct OrdinalEncoder;
impl OrdinalEncoder {
pub fn new() -> Self {
Self
}
pub fn fit(&self, columns: &[Vec<String>]) -> Result<FittedOrdinalEncoder> {
if columns.is_empty() {
return Err(RustMlError::EmptyInput("columns slice is empty".into()));
}
let nrows = columns[0].len();
if nrows == 0 {
return Err(RustMlError::EmptyInput("columns contain no rows".into()));
}
for (j, col) in columns.iter().enumerate() {
if col.len() != nrows {
return Err(RustMlError::ShapeMismatch(format!(
"column {} has {} rows, expected {}",
j,
col.len(),
nrows
)));
}
}
let mut vocabularies = Vec::with_capacity(columns.len());
let mut mappings = Vec::with_capacity(columns.len());
for col in columns {
let mut vocab: Vec<String> = col.iter().cloned().collect();
vocab.sort();
vocab.dedup();
let mapping: HashMap<String, usize> = vocab
.iter()
.enumerate()
.map(|(i, s)| (s.clone(), i))
.collect();
vocabularies.push(vocab);
mappings.push(mapping);
}
Ok(FittedOrdinalEncoder {
vocabularies,
mappings,
})
}
}
impl Default for OrdinalEncoder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct FittedOrdinalEncoder {
vocabularies: Vec<Vec<String>>,
mappings: Vec<HashMap<String, usize>>,
}
impl FittedOrdinalEncoder {
pub fn transform(&self, columns: &[Vec<String>]) -> Result<Vec<Vec<usize>>> {
if columns.len() != self.vocabularies.len() {
return Err(RustMlError::ShapeMismatch(format!(
"expected {} columns, got {}",
self.vocabularies.len(),
columns.len()
)));
}
let mut result = Vec::with_capacity(columns.len());
for (j, col) in columns.iter().enumerate() {
let mapping = &self.mappings[j];
let mut encoded = Vec::with_capacity(col.len());
for val in col {
match mapping.get(val) {
Some(&idx) => encoded.push(idx),
None => {
return Err(RustMlError::InvalidParameter(format!(
"unknown category '{}' in column {}",
val, j
)));
}
}
}
result.push(encoded);
}
Ok(result)
}
pub fn inverse_transform(&self, columns: &[Vec<usize>]) -> Result<Vec<Vec<String>>> {
if columns.len() != self.vocabularies.len() {
return Err(RustMlError::ShapeMismatch(format!(
"expected {} columns, got {}",
self.vocabularies.len(),
columns.len()
)));
}
let mut result = Vec::with_capacity(columns.len());
for (j, col) in columns.iter().enumerate() {
let vocab = &self.vocabularies[j];
let mut decoded = Vec::with_capacity(col.len());
for &idx in col {
if idx >= vocab.len() {
return Err(RustMlError::InvalidParameter(format!(
"encoded index {} is out of range for column {} (vocabulary size {})",
idx,
j,
vocab.len()
)));
}
decoded.push(vocab[idx].clone());
}
result.push(decoded);
}
Ok(result)
}
pub fn vocabulary(&self, column: usize) -> Option<&[String]> {
self.vocabularies.get(column).map(|v| v.as_slice())
}
pub fn n_columns(&self) -> usize {
self.vocabularies.len()
}
pub fn n_categories(&self) -> Vec<usize> {
self.vocabularies.iter().map(|v| v.len()).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn s(val: &str) -> String {
val.to_string()
}
#[test]
fn test_fit_transform_single_column() {
let columns = vec![vec![s("cat"), s("dog"), s("cat"), s("bird")]];
let encoder = OrdinalEncoder::new();
let fitted = encoder.fit(&columns).unwrap();
let encoded = fitted.transform(&columns).unwrap();
assert_eq!(encoded, vec![vec![1, 2, 1, 0]]);
}
#[test]
fn test_fit_transform_multiple_columns() {
let columns = vec![
vec![s("red"), s("blue"), s("green")],
vec![s("small"), s("large"), s("small")],
];
let encoder = OrdinalEncoder::new();
let fitted = encoder.fit(&columns).unwrap();
let encoded = fitted.transform(&columns).unwrap();
assert_eq!(encoded[0], vec![2, 0, 1]);
assert_eq!(encoded[1], vec![1, 0, 1]);
}
#[test]
fn test_inverse_transform_roundtrip() {
let columns = vec![
vec![s("apple"), s("banana"), s("cherry")],
vec![s("x"), s("y"), s("z")],
];
let encoder = OrdinalEncoder::new();
let fitted = encoder.fit(&columns).unwrap();
let encoded = fitted.transform(&columns).unwrap();
let recovered = fitted.inverse_transform(&encoded).unwrap();
assert_eq!(recovered, columns);
}
#[test]
fn test_unknown_category() {
let columns = vec![vec![s("cat"), s("dog")]];
let encoder = OrdinalEncoder::new();
let fitted = encoder.fit(&columns).unwrap();
let unknown = vec![vec![s("fish")]];
assert!(fitted.transform(&unknown).is_err());
}
#[test]
fn test_out_of_range_index() {
let columns = vec![vec![s("a"), s("b")]];
let encoder = OrdinalEncoder::new();
let fitted = encoder.fit(&columns).unwrap();
let bad = vec![vec![99]];
assert!(fitted.inverse_transform(&bad).is_err());
}
#[test]
fn test_empty_columns() {
let columns: Vec<Vec<String>> = vec![];
let encoder = OrdinalEncoder::new();
assert!(encoder.fit(&columns).is_err());
}
#[test]
fn test_empty_rows() {
let columns = vec![vec![]];
let encoder = OrdinalEncoder::new();
assert!(encoder.fit(&columns).is_err());
}
#[test]
fn test_column_length_mismatch() {
let columns = vec![vec![s("a"), s("b")], vec![s("x")]];
let encoder = OrdinalEncoder::new();
assert!(encoder.fit(&columns).is_err());
}
#[test]
fn test_shape_mismatch_transform() {
let columns = vec![vec![s("a"), s("b")]];
let encoder = OrdinalEncoder::new();
let fitted = encoder.fit(&columns).unwrap();
let wrong = vec![vec![s("a")], vec![s("b")]];
assert!(fitted.transform(&wrong).is_err());
}
#[test]
fn test_shape_mismatch_inverse() {
let columns = vec![vec![s("a"), s("b")]];
let encoder = OrdinalEncoder::new();
let fitted = encoder.fit(&columns).unwrap();
let wrong = vec![vec![0], vec![1]];
assert!(fitted.inverse_transform(&wrong).is_err());
}
#[test]
fn test_vocabulary_accessor() {
let columns = vec![
vec![s("z"), s("a"), s("m")],
vec![s("big"), s("small"), s("big")],
];
let encoder = OrdinalEncoder::new();
let fitted = encoder.fit(&columns).unwrap();
assert_eq!(fitted.vocabulary(0).unwrap(), &[s("a"), s("m"), s("z")]);
assert_eq!(fitted.vocabulary(1).unwrap(), &[s("big"), s("small")]);
assert!(fitted.vocabulary(5).is_none());
}
#[test]
fn test_n_categories() {
let columns = vec![vec![s("a"), s("b"), s("c")], vec![s("x"), s("y"), s("x")]];
let encoder = OrdinalEncoder::new();
let fitted = encoder.fit(&columns).unwrap();
assert_eq!(fitted.n_columns(), 2);
assert_eq!(fitted.n_categories(), vec![3, 2]);
}
#[test]
fn test_default() {
let encoder = OrdinalEncoder::default();
let columns = vec![vec![s("a")]];
let fitted = encoder.fit(&columns).unwrap();
assert_eq!(fitted.n_columns(), 1);
}
#[test]
fn test_sorted_vocabulary() {
let columns = vec![vec![s("zebra"), s("apple"), s("mango")]];
let encoder = OrdinalEncoder::new();
let fitted = encoder.fit(&columns).unwrap();
assert_eq!(
fitted.vocabulary(0).unwrap(),
&[s("apple"), s("mango"), s("zebra")]
);
}
#[test]
fn test_duplicate_values() {
let columns = vec![vec![s("a"), s("a"), s("b"), s("b"), s("a")]];
let encoder = OrdinalEncoder::new();
let fitted = encoder.fit(&columns).unwrap();
let encoded = fitted.transform(&columns).unwrap();
assert_eq!(encoded[0], vec![0, 0, 1, 1, 0]);
}
}