use crate::core::error::{Error, Result};
use crate::dataframe::DataFrame;
use std::collections::HashMap;
pub trait PipelineTransformer {
fn transform(&self, df: &DataFrame) -> Result<DataFrame>;
fn fit(&mut self, df: &DataFrame) -> Result<()>;
fn fit_transform(&mut self, df: &DataFrame) -> Result<DataFrame> {
self.fit(df)?;
self.transform(df)
}
}
#[derive(Debug)]
pub enum PipelineStage {
StandardScaler {
columns: Option<Vec<String>>,
_means: Option<HashMap<String, f64>>,
_stds: Option<HashMap<String, f64>>,
},
MinMaxScaler {
columns: Option<Vec<String>>,
feature_range: (f64, f64),
_min_values: Option<HashMap<String, f64>>,
_max_values: Option<HashMap<String, f64>>,
},
OneHotEncoder {
columns: Option<Vec<String>>,
drop_first: bool,
prefix: Option<String>,
_categories: Option<HashMap<String, Vec<String>>>,
},
Imputer {
columns: Option<Vec<String>>,
strategy: String,
fill_value: Option<f64>,
_fill_values: Option<HashMap<String, f64>>,
},
FeatureSelector {
columns: Vec<String>,
},
}
impl PipelineTransformer for PipelineStage {
fn transform(&self, df: &DataFrame) -> Result<DataFrame> {
match self {
PipelineStage::StandardScaler { .. } => {
Ok(df.clone())
}
PipelineStage::MinMaxScaler { .. } => {
Ok(df.clone())
}
PipelineStage::OneHotEncoder { .. } => {
Ok(df.clone())
}
PipelineStage::Imputer { .. } => {
Ok(df.clone())
}
PipelineStage::FeatureSelector { columns } => {
let mut result = DataFrame::new();
for col_name in columns {
if !df.contains_column(col_name) {
return Err(Error::InvalidValue(format!(
"Column '{}' not found",
col_name
)));
}
let col: &crate::series::Series<String> = df.get_column(col_name)?;
result.add_column(col_name.clone(), col.clone())?;
}
Ok(result)
} }
}
fn fit(&mut self, df: &DataFrame) -> Result<()> {
match self {
PipelineStage::StandardScaler {
columns,
_means,
_stds,
} => {
let cols: Vec<String> = match columns {
Some(cols) => cols.clone(),
None => df.column_names(),
};
let mut means = HashMap::new();
let mut stds = HashMap::new();
for col_name in &cols {
if !df.contains_column(col_name) {
return Err(Error::InvalidInput(format!(
"Column '{}' not found",
col_name
)));
}
let col: &crate::series::Series<String> = df.get_column(col_name)?;
means.insert(col_name.clone(), 0.0);
stds.insert(col_name.clone(), 1.0);
}
*_means = Some(means);
*_stds = Some(stds);
Ok(())
}
PipelineStage::MinMaxScaler { .. } => Ok(()),
PipelineStage::OneHotEncoder { .. } => Ok(()),
PipelineStage::Imputer { .. } => Ok(()),
PipelineStage::FeatureSelector { .. } => Ok(()),
}
}
}
#[derive(Debug)]
pub struct Pipeline {
pub stages: Vec<PipelineStage>,
}
impl Pipeline {
pub fn new() -> Self {
Pipeline { stages: Vec::new() }
}
pub fn add_stage(&mut self, stage: PipelineStage) -> &mut Self {
self.stages.push(stage);
self
}
pub fn fit(&mut self, df: &DataFrame) -> Result<()> {
let mut current_df = df.clone();
for stage in &mut self.stages {
stage.fit(¤t_df)?;
current_df = stage.transform(¤t_df)?;
}
Ok(())
}
pub fn transform(&self, df: &DataFrame) -> Result<DataFrame> {
let mut current_df = df.clone();
for stage in &self.stages {
current_df = stage.transform(¤t_df)?;
}
Ok(current_df)
}
pub fn fit_transform(&mut self, df: &DataFrame) -> Result<DataFrame> {
let mut current_df = df.clone();
for stage in &mut self.stages {
stage.fit(¤t_df)?;
current_df = stage.transform(¤t_df)?;
}
Ok(current_df)
}
}