use crate::pipeline::StreamingPreprocessor;
#[derive(Clone, Debug)]
pub struct OneHotEncoder {
categorical_indices: Vec<usize>,
categories: Vec<Vec<f64>>,
max_categories: usize,
n_input_features: Option<usize>,
}
impl OneHotEncoder {
pub fn new(categorical_indices: Vec<usize>) -> Self {
let n_cat = categorical_indices.len();
Self {
categorical_indices,
categories: vec![Vec::new(); n_cat],
max_categories: 64,
n_input_features: None,
}
}
pub fn with_max_categories(categorical_indices: Vec<usize>, max_categories: usize) -> Self {
assert!(max_categories > 0, "max_categories must be > 0");
let n_cat = categorical_indices.len();
Self {
categorical_indices,
categories: vec![Vec::new(); n_cat],
max_categories,
n_input_features: None,
}
}
pub fn categorical_indices(&self) -> &[usize] {
&self.categorical_indices
}
pub fn categories(&self) -> &[Vec<f64>] {
&self.categories
}
pub fn max_categories(&self) -> usize {
self.max_categories
}
pub fn n_discovered_categories(&self, cat_idx: usize) -> usize {
self.categories[cat_idx].len()
}
fn cat_ordinal(&self, feature_idx: usize) -> Option<usize> {
self.categorical_indices
.iter()
.position(|&ci| ci == feature_idx)
}
fn register_category(&mut self, j: usize, value: f64) {
let cats = &mut self.categories[j];
if cats.len() >= self.max_categories {
return;
}
match cats.binary_search_by(|probe| probe.partial_cmp(&value).unwrap()) {
Ok(_) => {} Err(pos) => {
cats.insert(pos, value);
}
}
}
}
impl StreamingPreprocessor for OneHotEncoder {
fn update_and_transform(&mut self, features: &[f64]) -> Vec<f64> {
self.n_input_features = Some(features.len());
let n_cat = self.categorical_indices.len();
for j in 0..n_cat {
let ci = self.categorical_indices[j];
if ci < features.len() {
let value = features[ci];
self.register_category(j, value);
}
}
self.transform(features)
}
fn transform(&self, features: &[f64]) -> Vec<f64> {
let mut out = Vec::new();
for (i, &fval) in features.iter().enumerate() {
if let Some(j) = self.cat_ordinal(i) {
let cats = &self.categories[j];
match cats.binary_search_by(|probe| probe.partial_cmp(&fval).unwrap()) {
Ok(pos) => {
for k in 0..cats.len() {
out.push(if k == pos { 1.0 } else { 0.0 });
}
}
Err(_) => {
out.extend(std::iter::repeat(0.0).take(cats.len()));
}
}
} else {
out.push(fval);
}
}
out
}
fn output_dim(&self) -> Option<usize> {
let n_input = self.n_input_features?;
if self.categories.iter().all(|c| c.is_empty()) {
return None;
}
let n_categorical = self.categorical_indices.len();
let n_one_hot: usize = self.categories.iter().map(|c| c.len()).sum();
Some(n_input - n_categorical + n_one_hot)
}
fn reset(&mut self) {
for cats in &mut self.categories {
cats.clear();
}
self.n_input_features = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn passthrough_non_categorical() {
let mut enc = OneHotEncoder::new(vec![]);
let input = [1.0, 2.5, 3.0];
let out = enc.update_and_transform(&input);
assert_eq!(out, vec![1.0, 2.5, 3.0]);
}
#[test]
fn discovers_categories_online() {
let mut enc = OneHotEncoder::new(vec![0]);
let out1 = enc.update_and_transform(&[0.0, 7.0]);
assert_eq!(enc.n_discovered_categories(0), 1);
assert_eq!(out1, vec![1.0, 7.0]);
let out2 = enc.update_and_transform(&[1.0, 8.0]);
assert_eq!(enc.n_discovered_categories(0), 2);
assert_eq!(out2, vec![0.0, 1.0, 8.0]);
let out3 = enc.update_and_transform(&[2.0, 9.0]);
assert_eq!(enc.n_discovered_categories(0), 3);
assert_eq!(out3, vec![0.0, 0.0, 1.0, 9.0]);
}
#[test]
fn max_categories_cap() {
let mut enc = OneHotEncoder::with_max_categories(vec![0], 2);
enc.update_and_transform(&[0.0]);
enc.update_and_transform(&[1.0]);
enc.update_and_transform(&[2.0]);
assert_eq!(enc.n_discovered_categories(0), 2);
let out = enc.transform(&[2.0]);
assert_eq!(out, vec![0.0, 0.0]);
}
#[test]
fn one_hot_encoding_correct() {
let mut enc = OneHotEncoder::new(vec![0]);
enc.update_and_transform(&[0.0]);
enc.update_and_transform(&[1.0]);
enc.update_and_transform(&[2.0]);
assert_eq!(enc.transform(&[0.0]), vec![1.0, 0.0, 0.0]);
assert_eq!(enc.transform(&[1.0]), vec![0.0, 1.0, 0.0]);
assert_eq!(enc.transform(&[2.0]), vec![0.0, 0.0, 1.0]);
}
#[test]
fn unknown_category_in_transform() {
let mut enc = OneHotEncoder::new(vec![0]);
enc.update_and_transform(&[0.0, 1.0]);
enc.update_and_transform(&[1.0, 2.0]);
let out = enc.transform(&[5.0, 3.0]);
assert_eq!(out, vec![0.0, 0.0, 3.0]);
}
#[test]
fn mixed_categorical_and_continuous() {
let mut enc = OneHotEncoder::new(vec![0, 2]);
enc.update_and_transform(&[0.0, 5.0, 10.0]);
enc.update_and_transform(&[1.0, 6.0, 20.0]);
enc.update_and_transform(&[0.0, 7.0, 30.0]);
assert_eq!(enc.output_dim(), Some(6));
let out = enc.transform(&[1.0, 9.0, 20.0]);
assert_eq!(out, vec![0.0, 1.0, 9.0, 0.0, 1.0, 0.0]);
}
#[test]
fn reset_clears_categories() {
let mut enc = OneHotEncoder::new(vec![0]);
enc.update_and_transform(&[0.0, 1.0]);
enc.update_and_transform(&[1.0, 2.0]);
assert!(enc.output_dim().is_some());
enc.reset();
assert_eq!(enc.n_discovered_categories(0), 0);
assert_eq!(enc.output_dim(), None);
}
#[test]
fn deterministic_ordering() {
let mut enc = OneHotEncoder::new(vec![0]);
enc.update_and_transform(&[2.0]);
enc.update_and_transform(&[0.0]);
enc.update_and_transform(&[1.0]);
assert_eq!(enc.transform(&[0.0]), vec![1.0, 0.0, 0.0]);
assert_eq!(enc.transform(&[1.0]), vec![0.0, 1.0, 0.0]);
assert_eq!(enc.transform(&[2.0]), vec![0.0, 0.0, 1.0]);
}
}