use crate::dataset::Dataset;
use crate::error::{Result, ScryLearnError};
use crate::preprocess::Transformer;
#[derive(Clone, Debug, Default, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum DropStrategy {
#[default]
None,
First,
IfBinary,
}
#[derive(Clone, Debug, Default, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum UnknownStrategy {
#[default]
Error,
Ignore,
}
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct OneHotEncoder {
feature_indices: Vec<usize>,
drop_strategy: DropStrategy,
unknown_strategy: UnknownStrategy,
categories: Vec<Vec<f64>>,
orig_feature_names: Vec<String>,
fitted: bool,
}
impl OneHotEncoder {
pub fn new(feature_indices: Vec<usize>) -> Self {
Self {
feature_indices,
drop_strategy: DropStrategy::None,
unknown_strategy: UnknownStrategy::Error,
categories: Vec::new(),
orig_feature_names: Vec::new(),
fitted: false,
}
}
pub fn drop(mut self, strategy: DropStrategy) -> Self {
self.drop_strategy = strategy;
self
}
pub fn handle_unknown(mut self, strategy: UnknownStrategy) -> Self {
self.unknown_strategy = strategy;
self
}
pub fn categories(&self) -> &[Vec<f64>] {
&self.categories
}
pub fn get_feature_names(&self) -> Vec<String> {
if !self.fitted || self.orig_feature_names.is_empty() {
return Vec::new();
}
let encoded_set: std::collections::HashSet<usize> =
self.feature_indices.iter().copied().collect();
let mut names = Vec::new();
for (j, orig_name) in self.orig_feature_names.iter().enumerate() {
if encoded_set.contains(&j) {
let cat_idx = self
.feature_indices
.iter()
.position(|&fi| fi == j)
.expect("encoded_set built from feature_indices");
let cats = &self.categories[cat_idx];
let skip = self.n_drop(cat_idx);
for (ci, &cat_val) in cats.iter().enumerate() {
if ci < skip {
continue;
}
names.push(format!("{}_{}", orig_name, cat_val as i64));
}
} else {
names.push(orig_name.clone());
}
}
names
}
}
impl OneHotEncoder {
fn n_drop(&self, cat_idx: usize) -> usize {
match self.drop_strategy {
DropStrategy::None => 0,
DropStrategy::First => 1,
DropStrategy::IfBinary => usize::from(self.categories[cat_idx].len() == 2),
}
}
}
impl Transformer for OneHotEncoder {
fn fit(&mut self, data: &Dataset) -> Result<()> {
if data.n_samples() == 0 {
return Err(ScryLearnError::EmptyDataset);
}
for &idx in &self.feature_indices {
if idx >= data.n_features() {
return Err(ScryLearnError::InvalidParameter(format!(
"feature index {idx} out of range (dataset has {} features)",
data.n_features()
)));
}
}
self.categories.clear();
self.orig_feature_names.clone_from(&data.feature_names);
for &idx in &self.feature_indices {
let mut unique: Vec<f64> = data.features[idx].clone();
unique.sort_by(|a, b| a.total_cmp(b));
unique.dedup();
self.categories.push(unique);
}
self.fitted = true;
Ok(())
}
fn transform(&self, data: &mut Dataset) -> Result<()> {
if !self.fitted {
return Err(ScryLearnError::NotFitted);
}
let n = data.n_samples();
let encoded_set: std::collections::HashSet<usize> =
self.feature_indices.iter().copied().collect();
let mut new_features: Vec<Vec<f64>> = Vec::new();
let mut new_names: Vec<String> = Vec::new();
for j in 0..data.n_features() {
if encoded_set.contains(&j) {
let cat_idx = self
.feature_indices
.iter()
.position(|&fi| fi == j)
.ok_or(ScryLearnError::InvalidFeatureIndex(j))?;
let cats = &self.categories[cat_idx];
let skip = self.n_drop(cat_idx);
let orig_name = &data.feature_names[j];
for (ci, &cat_val) in cats.iter().enumerate() {
if ci < skip {
continue;
}
let mut col = Vec::with_capacity(n);
for s in 0..n {
let val = data.features[j][s];
if (val - cat_val).abs() < 1e-10 {
col.push(1.0);
} else if cats.iter().any(|&c| (val - c).abs() < 1e-10) {
col.push(0.0);
} else {
match self.unknown_strategy {
UnknownStrategy::Error => {
return Err(ScryLearnError::InvalidParameter(format!(
"unknown category {val} in feature '{orig_name}'"
)));
}
UnknownStrategy::Ignore => {
col.push(0.0);
}
}
}
}
new_features.push(col);
new_names.push(format!("{}_{}", orig_name, cat_val as i64));
}
} else {
new_features.push(data.features[j].clone());
new_names.push(data.feature_names[j].clone());
}
}
data.features = new_features;
data.feature_names = new_names;
data.sync_matrix();
Ok(())
}
fn inverse_transform(&self, data: &mut Dataset) -> Result<()> {
if !self.fitted {
return Err(ScryLearnError::NotFitted);
}
let n = data.n_samples();
let mut new_features: Vec<Vec<f64>> = Vec::new();
let mut new_names: Vec<String> = Vec::new();
let mut j = 0;
let mut cat_idx = 0;
while j < data.n_features() {
if cat_idx < self.feature_indices.len() {
let cats = &self.categories[cat_idx];
let skip = self.n_drop(cat_idx);
let n_cols = cats.len() - skip;
if j + n_cols <= data.n_features() {
let first_name = &data.feature_names[j];
let prefix = first_name
.rfind('_')
.map_or(first_name.as_str(), |pos| &first_name[..pos]);
let mut col = Vec::with_capacity(n);
for s in 0..n {
let mut found = false;
for (ci, &cat_val) in cats.iter().enumerate().skip(skip) {
let col_idx = j + ci - skip;
if data.features[col_idx][s] > 0.5 {
col.push(cat_val);
found = true;
break;
}
}
if !found {
if skip > 0 {
col.push(cats[0]);
} else {
col.push(f64::NAN);
}
}
}
new_features.push(col);
new_names.push(prefix.to_string());
j += n_cols;
cat_idx += 1;
continue;
}
}
new_features.push(data.features[j].clone());
new_names.push(data.feature_names[j].clone());
j += 1;
}
data.features = new_features;
data.feature_names = new_names;
data.sync_matrix();
Ok(())
}
}
#[cfg(test)]
#[allow(clippy::float_cmp)]
mod tests {
use super::*;
fn color_dataset() -> Dataset {
Dataset::new(
vec![
vec![0.0, 1.0, 2.0, 0.0, 1.0, 2.0],
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
],
vec![0.0, 0.0, 1.0, 1.0, 0.0, 1.0],
vec!["color".into(), "value".into()],
"target",
)
}
#[test]
fn onehot_basic_encoding() {
let mut ds = color_dataset();
let mut enc = OneHotEncoder::new(vec![0]);
enc.fit_transform(&mut ds).unwrap();
assert_eq!(ds.n_features(), 4);
assert_eq!(ds.feature_names[0], "color_0");
assert_eq!(ds.feature_names[1], "color_1");
assert_eq!(ds.feature_names[2], "color_2");
assert_eq!(ds.feature_names[3], "value");
assert_eq!(ds.features[0][0], 1.0);
assert_eq!(ds.features[1][0], 0.0);
assert_eq!(ds.features[2][0], 0.0);
assert_eq!(ds.features[0][2], 0.0);
assert_eq!(ds.features[1][2], 0.0);
assert_eq!(ds.features[2][2], 1.0);
}
#[test]
fn onehot_drop_first() {
let mut ds = color_dataset();
let mut enc = OneHotEncoder::new(vec![0]).drop(DropStrategy::First);
enc.fit_transform(&mut ds).unwrap();
assert_eq!(ds.n_features(), 3);
assert_eq!(ds.feature_names[0], "color_1");
assert_eq!(ds.feature_names[1], "color_2");
}
#[test]
fn onehot_drop_if_binary() {
let mut ds = Dataset::new(
vec![vec![0.0, 1.0, 0.0, 1.0], vec![10.0, 20.0, 30.0, 40.0]],
vec![0.0; 4],
vec!["binary".into(), "num".into()],
"y",
);
let mut enc = OneHotEncoder::new(vec![0]).drop(DropStrategy::IfBinary);
enc.fit_transform(&mut ds).unwrap();
assert_eq!(ds.n_features(), 2);
assert_eq!(ds.feature_names[0], "binary_1");
let mut ds3 = color_dataset();
let mut enc3 = OneHotEncoder::new(vec![0]).drop(DropStrategy::IfBinary);
enc3.fit_transform(&mut ds3).unwrap();
assert_eq!(ds3.n_features(), 4); }
#[test]
fn onehot_unknown_error() {
let mut ds = color_dataset();
let mut enc = OneHotEncoder::new(vec![0]);
enc.fit(&ds).unwrap();
ds.features[0][0] = 99.0;
assert!(enc.transform(&mut ds).is_err());
}
#[test]
fn onehot_unknown_ignore() {
let mut ds = color_dataset();
let mut enc = OneHotEncoder::new(vec![0]).handle_unknown(UnknownStrategy::Ignore);
enc.fit(&ds).unwrap();
ds.features[0][0] = 99.0;
enc.transform(&mut ds).unwrap();
assert_eq!(ds.features[0][0], 0.0); assert_eq!(ds.features[1][0], 0.0); assert_eq!(ds.features[2][0], 0.0); }
#[test]
fn onehot_roundtrip_inverse() {
let original = color_dataset();
let mut ds = original.clone();
let mut enc = OneHotEncoder::new(vec![0]);
enc.fit_transform(&mut ds).unwrap();
enc.inverse_transform(&mut ds).unwrap();
assert_eq!(ds.n_features(), 2);
for i in 0..original.n_samples() {
assert!(
(ds.features[0][i] - original.features[0][i]).abs() < 1e-10,
"roundtrip mismatch at sample {i}"
);
}
}
#[test]
fn onehot_feature_names() {
let mut ds = color_dataset();
let mut enc = OneHotEncoder::new(vec![0]);
enc.fit_transform(&mut ds).unwrap();
let names = enc.get_feature_names();
assert_eq!(names, &["color_0", "color_1", "color_2", "value"]);
}
#[test]
fn onehot_not_fitted_error() {
let enc = OneHotEncoder::new(vec![0]);
let mut ds = color_dataset();
assert!(enc.transform(&mut ds).is_err());
}
#[test]
fn onehot_multiple_features() {
let mut ds = Dataset::new(
vec![
vec![0.0, 1.0, 0.0, 1.0], vec![0.0, 1.0, 2.0, 0.0], vec![5.0, 6.0, 7.0, 8.0], ],
vec![0.0; 4],
vec!["a".into(), "b".into(), "num".into()],
"y",
);
let mut enc = OneHotEncoder::new(vec![0, 1]);
enc.fit_transform(&mut ds).unwrap();
assert_eq!(ds.n_features(), 6);
assert_eq!(ds.feature_names[0], "a_0");
assert_eq!(ds.feature_names[1], "a_1");
assert_eq!(ds.feature_names[2], "b_0");
assert_eq!(ds.feature_names[3], "b_1");
assert_eq!(ds.feature_names[4], "b_2");
assert_eq!(ds.feature_names[5], "num");
}
}