use crate::dataset::Dataset;
use crate::error::{Result, ScryLearnError};
use crate::preprocess::Transformer;
trait BoxedTransformer {
fn fit(&mut self, data: &Dataset) -> Result<()>;
fn transform(&self, data: &mut Dataset) -> Result<()>;
}
impl<T: Transformer> BoxedTransformer for T {
fn fit(&mut self, data: &Dataset) -> Result<()> {
Transformer::fit(self, data)
}
fn transform(&self, data: &mut Dataset) -> Result<()> {
Transformer::transform(self, data)
}
}
struct TransformerStep {
columns: Vec<usize>,
transformer: Box<dyn BoxedTransformer>,
}
#[non_exhaustive]
pub struct ColumnTransformer {
steps: Vec<TransformerStep>,
fitted: bool,
}
impl ColumnTransformer {
pub fn new() -> Self {
Self {
steps: Vec::new(),
fitted: false,
}
}
pub fn add<T: Transformer + 'static>(mut self, columns: &[usize], transformer: T) -> Self {
self.steps.push(TransformerStep {
columns: columns.to_vec(),
transformer: Box::new(transformer),
});
self
}
}
impl Default for ColumnTransformer {
fn default() -> Self {
Self::new()
}
}
fn extract_columns(data: &Dataset, cols: &[usize]) -> Dataset {
let features: Vec<Vec<f64>> = cols.iter().map(|&c| data.features[c].clone()).collect();
let names: Vec<String> = cols
.iter()
.map(|&c| data.feature_names[c].clone())
.collect();
Dataset::new(features, data.target.clone(), names, &data.target_name)
}
impl Transformer for ColumnTransformer {
fn fit(&mut self, data: &Dataset) -> Result<()> {
if data.n_samples() == 0 {
return Err(ScryLearnError::EmptyDataset);
}
for step in &mut self.steps {
for &c in &step.columns {
if c >= data.n_features() {
return Err(ScryLearnError::InvalidColumn(format!(
"column index {c} out of range (dataset has {} features)",
data.n_features()
)));
}
}
let sub = extract_columns(data, &step.columns);
step.transformer.fit(&sub)?;
}
self.fitted = true;
Ok(())
}
fn transform(&self, data: &mut Dataset) -> Result<()> {
if !self.fitted {
return Err(ScryLearnError::NotFitted);
}
let mut result_cols: Vec<Vec<f64>> = Vec::new();
let mut result_names: Vec<String> = Vec::new();
for step in &self.steps {
let mut sub = extract_columns(data, &step.columns);
step.transformer.transform(&mut sub)?;
for (col, name) in sub.features.into_iter().zip(sub.feature_names) {
result_cols.push(col);
result_names.push(name);
}
}
data.features = result_cols;
data.feature_names = result_names;
data.sync_matrix();
Ok(())
}
fn inverse_transform(&self, _data: &mut Dataset) -> Result<()> {
Err(ScryLearnError::InvalidParameter(
"ColumnTransformer is not invertible".into(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::preprocess::{MinMaxScaler, StandardScaler};
#[test]
fn test_column_transformer_basic() {
let mut ds = Dataset::new(
vec![
vec![1.0, 2.0, 3.0, 4.0, 5.0], vec![10.0, 20.0, 30.0, 40.0, 50.0], vec![100.0, 200.0, 300.0, 400.0, 500.0], vec![5.0, 10.0, 15.0, 20.0, 25.0], ],
vec![0.0; 5],
vec!["a".into(), "b".into(), "c".into(), "d".into()],
"y",
);
let mut ct = ColumnTransformer::new()
.add(&[0, 1], StandardScaler::new())
.add(&[2, 3], MinMaxScaler::new());
ct.fit_transform(&mut ds).unwrap();
assert_eq!(ds.n_features(), 4);
let mean_a: f64 = ds.features[0].iter().sum::<f64>() / 5.0;
assert!(
mean_a.abs() < 1e-10,
"col 0 should be zero-mean, got {mean_a}"
);
let mean_b: f64 = ds.features[1].iter().sum::<f64>() / 5.0;
assert!(
mean_b.abs() < 1e-10,
"col 1 should be zero-mean, got {mean_b}"
);
assert!(ds.features[2][0].abs() < 1e-10, "col 2 min should be 0");
assert!(
(ds.features[2][4] - 1.0).abs() < 1e-10,
"col 2 max should be 1"
);
assert!(ds.features[3][0].abs() < 1e-10, "col 3 min should be 0");
assert!(
(ds.features[3][4] - 1.0).abs() < 1e-10,
"col 3 max should be 1"
);
}
#[test]
fn test_column_transformer_not_fitted() {
let ct = ColumnTransformer::new().add(&[0], StandardScaler::new());
let mut ds = Dataset::new(vec![vec![1.0, 2.0]], vec![0.0; 2], vec!["x".into()], "y");
assert!(Transformer::transform(&ct, &mut ds).is_err());
}
#[test]
fn test_column_transformer_invalid_column() {
let mut ct = ColumnTransformer::new().add(&[99], StandardScaler::new());
let ds = Dataset::new(vec![vec![1.0, 2.0]], vec![0.0; 2], vec!["x".into()], "y");
assert!(Transformer::fit(&mut ct, &ds).is_err());
}
}