use scirs2_core::ndarray::{Array2, ArrayBase, Data, Ix2};
use scirs2_core::numeric::{Float, NumCast};
use std::any::Any;
use crate::error::{Result, TransformError};
pub trait Transformer: Send + Sync {
fn fit(&mut self, x: &Array2<f64>) -> Result<()>;
fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>>;
fn fit_transform(&mut self, x: &Array2<f64>) -> Result<Array2<f64>> {
self.fit(x)?;
self.transform(x)
}
fn clone_box(&self) -> Box<dyn Transformer>;
fn as_any(&self) -> &dyn Any;
fn as_any_mut(&mut self) -> &mut dyn Any;
}
pub struct Pipeline {
steps: Vec<(String, Box<dyn Transformer>)>,
fitted: bool,
}
impl Pipeline {
pub fn new() -> Self {
Pipeline {
steps: Vec::new(),
fitted: false,
}
}
pub fn add_step(mut self, name: impl Into<String>, transformer: Box<dyn Transformer>) -> Self {
self.steps.push((name.into(), transformer));
self
}
pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
where
S: Data,
S::Elem: Float + NumCast,
{
let mut x_transformed = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
for (name, transformer) in &mut self.steps {
transformer.fit(&x_transformed).map_err(|e| {
TransformError::TransformationError(format!("Failed to fit step '{name}': {e}"))
})?;
x_transformed = transformer.transform(&x_transformed).map_err(|e| {
TransformError::TransformationError(format!(
"Failed to transform in step '{name}': {e}"
))
})?;
}
self.fitted = true;
Ok(())
}
pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
where
S: Data,
S::Elem: Float + NumCast,
{
if !self.fitted {
return Err(TransformError::TransformationError(
"Pipeline has not been fitted".to_string(),
));
}
let mut x_transformed = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
for (name, transformer) in &self.steps {
x_transformed = transformer.transform(&x_transformed).map_err(|e| {
TransformError::TransformationError(format!(
"Failed to transform in step '{name}': {e}"
))
})?;
}
Ok(x_transformed)
}
pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
where
S: Data,
S::Elem: Float + NumCast,
{
self.fit(x)?;
self.transform(x)
}
pub fn len(&self) -> usize {
self.steps.len()
}
pub fn is_empty(&self) -> bool {
self.steps.is_empty()
}
pub fn get_step(&self, name: &str) -> Option<&dyn Transformer> {
self.steps
.iter()
.find(|(n, _)| n == name)
.map(|(_, t)| t.as_ref())
}
pub fn get_step_mut(&mut self, name: &str) -> Option<&mut Box<dyn Transformer>> {
self.steps
.iter_mut()
.find(|(n, _)| n == name)
.map(|(_, t)| t)
}
}
impl Default for Pipeline {
fn default() -> Self {
Self::new()
}
}
pub struct ColumnTransformer {
transformers: Vec<(String, Box<dyn Transformer>, Vec<usize>)>,
remainder: RemainderOption,
fitted: bool,
}
#[derive(Debug, Clone, Copy)]
pub enum RemainderOption {
Drop,
Passthrough,
}
impl ColumnTransformer {
pub fn new(remainder: RemainderOption) -> Self {
ColumnTransformer {
transformers: Vec::new(),
remainder,
fitted: false,
}
}
pub fn add_transformer(
mut self,
name: impl Into<String>,
transformer: Box<dyn Transformer>,
columns: Vec<usize>,
) -> Self {
self.transformers.push((name.into(), transformer, columns));
self
}
pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
where
S: Data,
S::Elem: Float + NumCast,
{
let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
let n_features = x_f64.shape()[1];
for (name_, transformer, columns) in &self.transformers {
for &col in columns {
if col >= n_features {
return Err(TransformError::InvalidInput(format!(
"Column index {col} in transformer '{name_}' exceeds number of features {n_features}"
)));
}
}
}
for (name, transformer, columns) in &mut self.transformers {
let subset = extract_columns(&x_f64, columns);
transformer.fit(&subset).map_err(|e| {
TransformError::TransformationError(format!(
"Failed to fit transformer '{name}': {e}"
))
})?;
}
self.fitted = true;
Ok(())
}
pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
where
S: Data,
S::Elem: Float + NumCast,
{
if !self.fitted {
return Err(TransformError::TransformationError(
"ColumnTransformer has not been fitted".to_string(),
));
}
let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
let n_samples = x_f64.shape()[0];
let n_features = x_f64.shape()[1];
let mut used_columns = vec![false; n_features];
let mut transformed_parts = Vec::new();
for (name, transformer, columns) in &self.transformers {
for &col in columns {
used_columns[col] = true;
}
let subset = extract_columns(&x_f64, columns);
let transformed = transformer.transform(&subset).map_err(|e| {
TransformError::TransformationError(format!(
"Failed to transform with '{name}': {e}"
))
})?;
transformed_parts.push(transformed);
}
match self.remainder {
RemainderOption::Passthrough => {
let unused_columns: Vec<usize> =
(0..n_features).filter(|&i| !used_columns[i]).collect();
if !unused_columns.is_empty() {
let remainder = extract_columns(&x_f64, &unused_columns);
transformed_parts.push(remainder);
}
}
RemainderOption::Drop => {
}
}
if transformed_parts.is_empty() {
return Ok(Array2::zeros((n_samples, 0)));
}
concatenate_horizontal(&transformed_parts)
}
pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
where
S: Data,
S::Elem: Float + NumCast,
{
self.fit(x)?;
self.transform(x)
}
}
#[allow(dead_code)]
fn extract_columns(data: &Array2<f64>, columns: &[usize]) -> Array2<f64> {
let n_samples = data.shape()[0];
let n_cols = columns.len();
let mut result = Array2::zeros((n_samples, n_cols));
for (j, &col_idx) in columns.iter().enumerate() {
for i in 0..n_samples {
result[[i, j]] = data[[i, col_idx]];
}
}
result
}
#[allow(dead_code)]
fn concatenate_horizontal(arrays: &[Array2<f64>]) -> Result<Array2<f64>> {
if arrays.is_empty() {
return Err(TransformError::InvalidInput(
"Cannot concatenate empty array list".to_string(),
));
}
let n_samples = arrays[0].shape()[0];
let total_features: usize = arrays.iter().map(|a| a.shape()[1]).sum();
for arr in arrays {
if arr.shape()[0] != n_samples {
return Err(TransformError::InvalidInput(
"All _arrays must have the same number of samples".to_string(),
));
}
}
let mut result = Array2::zeros((n_samples, total_features));
let mut col_offset = 0;
for arr in arrays {
let n_cols = arr.shape()[1];
for i in 0..n_samples {
for j in 0..n_cols {
result[[i, col_offset + j]] = arr[[i, j]];
}
}
col_offset += n_cols;
}
Ok(result)
}
#[allow(dead_code)]
pub fn make_pipeline(steps: Vec<(&str, Box<dyn Transformer>)>) -> Pipeline {
let mut pipeline = Pipeline::new();
for (name, transformer) in steps {
pipeline = pipeline.add_step(name, transformer);
}
pipeline
}
#[allow(dead_code)]
pub fn make_column_transformer(
transformers: Vec<(&str, Box<dyn Transformer>, Vec<usize>)>,
remainder: RemainderOption,
) -> ColumnTransformer {
let mut ct = ColumnTransformer::new(remainder);
for (name, transformer, columns) in transformers {
ct = ct.add_transformer(name, transformer, columns);
}
ct
}