use ndarray::Array2;
use crate::error::{Result, RustMlError};
use crate::float::Float;
use crate::pipeline::{FitTransform, TransformStep};
use crate::traits::{FitUnsupervised, Transform};
#[derive(Debug, Clone)]
pub enum ColumnSelector {
Indices(Vec<usize>),
All,
}
pub struct ColumnTransformer<F: Float> {
branches: Vec<Branch<F>>,
remainder: Remainder,
}
struct Branch<F: Float> {
name: String,
selector: ColumnSelector,
transformer: Box<dyn FitTransform<F>>,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Remainder {
Drop,
Passthrough,
}
impl Default for Remainder {
fn default() -> Self {
Remainder::Drop
}
}
impl<F: Float> ColumnTransformer<F> {
pub fn new() -> Self {
Self {
branches: Vec::new(),
remainder: Remainder::Drop,
}
}
pub fn push(
mut self,
name: impl Into<String>,
selector: ColumnSelector,
transformer: impl FitTransform<F> + 'static,
) -> Self {
self.branches.push(Branch {
name: name.into(),
selector,
transformer: Box::new(transformer),
});
self
}
pub fn with_remainder(mut self, remainder: Remainder) -> Self {
self.remainder = remainder;
self
}
}
impl<F: Float> std::fmt::Debug for ColumnTransformer<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ColumnTransformer")
.field("n_branches", &self.branches.len())
.field("remainder", &self.remainder)
.finish()
}
}
fn resolve_columns(selector: &ColumnSelector, n_cols: usize) -> Vec<usize> {
match selector {
ColumnSelector::Indices(indices) => indices.clone(),
ColumnSelector::All => (0..n_cols).collect(),
}
}
fn select_columns<F: Float>(x: &Array2<F>, cols: &[usize]) -> Array2<F> {
let n_rows = x.nrows();
let n_cols = cols.len();
let mut data = Vec::with_capacity(n_rows * n_cols);
for i in 0..n_rows {
for &c in cols {
data.push(x[[i, c]]);
}
}
Array2::from_shape_vec((n_rows, n_cols), data).expect("shape matches data length")
}
pub struct FittedColumnTransformer<F: Float> {
fitted_branches: Vec<FittedBranch<F>>,
passthrough_cols: Vec<usize>,
n_features_in: usize,
}
struct FittedBranch<F: Float> {
name: String,
cols: Vec<usize>,
fitted: Box<dyn TransformStep<F>>,
}
unsafe impl<F: Float> Send for FittedColumnTransformer<F> {}
unsafe impl<F: Float> Sync for FittedColumnTransformer<F> {}
impl<F: Float> Transform<F> for FittedColumnTransformer<F> {
fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
if x.ncols() != self.n_features_in {
return Err(RustMlError::ShapeMismatch(format!(
"expected {} features, got {}",
self.n_features_in,
x.ncols()
)));
}
let n_rows = x.nrows();
let mut parts: Vec<Array2<F>> = Vec::with_capacity(self.fitted_branches.len() + 1);
for branch in &self.fitted_branches {
let sub_x = select_columns(x, &branch.cols);
let transformed = branch.fitted.transform(&sub_x)?;
parts.push(transformed);
}
if !self.passthrough_cols.is_empty() {
parts.push(select_columns(x, &self.passthrough_cols));
}
concat_horizontal(n_rows, &parts)
}
}
impl<F: Float + 'static> FitUnsupervised<F> for ColumnTransformer<F> {
type Fitted = FittedColumnTransformer<F>;
fn fit(&self, x: &Array2<F>) -> Result<Self::Fitted> {
let n_cols = x.ncols();
let mut used_cols = std::collections::HashSet::new();
let mut fitted_branches = Vec::with_capacity(self.branches.len());
for branch in &self.branches {
let cols = resolve_columns(&branch.selector, n_cols);
for &c in &cols {
if c >= n_cols {
return Err(RustMlError::InvalidParameter(format!(
"column index {} out of range for data with {} columns",
c, n_cols
)));
}
used_cols.insert(c);
}
let sub_x = select_columns(x, &cols);
let (fitted_step, _) = branch.transformer.fit_transform(&sub_x)?;
fitted_branches.push(FittedBranch {
name: branch.name.clone(),
cols,
fitted: fitted_step,
});
}
let passthrough_cols: Vec<usize> = if self.remainder == Remainder::Passthrough {
(0..n_cols).filter(|c| !used_cols.contains(c)).collect()
} else {
Vec::new()
};
Ok(FittedColumnTransformer {
fitted_branches,
passthrough_cols,
n_features_in: n_cols,
})
}
}
fn concat_horizontal<F: Float>(n_rows: usize, parts: &[Array2<F>]) -> Result<Array2<F>> {
if parts.is_empty() {
return Ok(Array2::zeros((n_rows, 0)));
}
let total_cols: usize = parts.iter().map(|p| p.ncols()).sum();
let mut result = Array2::zeros((n_rows, total_cols));
let mut col_offset = 0;
for part in parts {
for j in 0..part.ncols() {
for i in 0..n_rows {
result[[i, col_offset + j]] = part[[i, j]];
}
}
col_offset += part.ncols();
}
Ok(result)
}
impl<F: Float> FittedColumnTransformer<F> {
pub fn branch_names(&self) -> Vec<&str> {
self.fitted_branches
.iter()
.map(|b| b.name.as_str())
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::traits::{FitUnsupervised, Transform};
use ndarray::array;
#[derive(Debug, Clone)]
struct DoubleScaler;
struct FittedDoubleScaler;
impl<F: Float> FitUnsupervised<F> for DoubleScaler {
type Fitted = FittedDoubleScaler;
fn fit(&self, _x: &Array2<F>) -> Result<Self::Fitted> {
Ok(FittedDoubleScaler)
}
}
impl<F: Float> Transform<F> for FittedDoubleScaler {
fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
Ok(x.mapv(|v| v + v))
}
}
#[derive(Debug, Clone)]
struct IdentityTransformer;
struct FittedIdentity;
impl<F: Float> FitUnsupervised<F> for IdentityTransformer {
type Fitted = FittedIdentity;
fn fit(&self, _x: &Array2<F>) -> Result<Self::Fitted> {
Ok(FittedIdentity)
}
}
impl<F: Float> Transform<F> for FittedIdentity {
fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
Ok(x.to_owned())
}
}
#[test]
fn test_column_transformer_basic() {
let x = array![[1.0, 10.0, 100.0], [2.0, 20.0, 200.0]];
let ct = ColumnTransformer::<f64>::new()
.push(
"double_01",
ColumnSelector::Indices(vec![0, 1]),
DoubleScaler,
)
.push(
"identity_2",
ColumnSelector::Indices(vec![2]),
IdentityTransformer,
);
let fitted = FitUnsupervised::fit(&ct, &x).unwrap();
let transformed = fitted.transform(&x).unwrap();
assert_eq!(transformed, array![[2.0, 20.0, 100.0], [4.0, 40.0, 200.0]]);
}
#[test]
fn test_column_transformer_passthrough() {
let x = array![[1.0, 10.0, 100.0], [2.0, 20.0, 200.0]];
let ct = ColumnTransformer::<f64>::new()
.push("double_0", ColumnSelector::Indices(vec![0]), DoubleScaler)
.with_remainder(Remainder::Passthrough);
let fitted = FitUnsupervised::fit(&ct, &x).unwrap();
let transformed = fitted.transform(&x).unwrap();
assert_eq!(transformed.ncols(), 3);
assert_eq!(transformed[[0, 0]], 2.0);
assert_eq!(transformed[[0, 1]], 10.0);
assert_eq!(transformed[[0, 2]], 100.0);
}
#[test]
fn test_column_transformer_drop_remainder() {
let x = array![[1.0, 10.0, 100.0], [2.0, 20.0, 200.0]];
let ct = ColumnTransformer::<f64>::new().push(
"double_0",
ColumnSelector::Indices(vec![0]),
DoubleScaler,
);
let fitted = FitUnsupervised::fit(&ct, &x).unwrap();
let transformed = fitted.transform(&x).unwrap();
assert_eq!(transformed.ncols(), 1);
assert_eq!(transformed[[0, 0]], 2.0);
assert_eq!(transformed[[1, 0]], 4.0);
}
#[test]
fn test_column_transformer_all_selector() {
let x = array![[1.0, 2.0], [3.0, 4.0]];
let ct =
ColumnTransformer::<f64>::new().push("double_all", ColumnSelector::All, DoubleScaler);
let fitted = FitUnsupervised::fit(&ct, &x).unwrap();
let transformed = fitted.transform(&x).unwrap();
assert_eq!(transformed, array![[2.0, 4.0], [6.0, 8.0]]);
}
#[test]
fn test_column_transformer_shape_mismatch_predict() {
let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let ct = ColumnTransformer::<f64>::new().push(
"a",
ColumnSelector::Indices(vec![0]),
IdentityTransformer,
);
let fitted = FitUnsupervised::fit(&ct, &x).unwrap();
let x_bad = array![[1.0, 2.0]];
assert!(fitted.transform(&x_bad).is_err());
}
#[test]
fn test_column_transformer_invalid_column() {
let x = array![[1.0, 2.0]];
let ct = ColumnTransformer::<f64>::new().push(
"bad",
ColumnSelector::Indices(vec![5]),
IdentityTransformer,
);
assert!(FitUnsupervised::fit(&ct, &x).is_err());
}
}