use crate::dataset::Dataset;
use crate::error::{Result, ScryLearnError};
use crate::preprocess::Transformer;
#[non_exhaustive]
pub struct Pipeline {
transformers: Vec<Box<dyn TransformerBox>>,
model: Option<Box<dyn PipelineModel>>,
}
trait TransformerBox {
fn fit(&mut self, data: &Dataset) -> Result<()>;
fn transform(&self, data: &mut Dataset) -> Result<()>;
}
impl<T: Transformer> TransformerBox for T {
fn fit(&mut self, data: &Dataset) -> Result<()> {
Transformer::fit(self, data)
}
fn transform(&self, data: &mut Dataset) -> Result<()> {
Transformer::transform(self, data)
}
}
pub trait PipelineModel {
fn fit(&mut self, data: &Dataset) -> Result<()>;
fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>>;
}
macro_rules! impl_pipeline_model {
($($ty:ty),* $(,)?) => {
$(
impl PipelineModel for $ty {
fn fit(&mut self, data: &Dataset) -> Result<()> { self.fit(data) }
fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> { self.predict(features) }
}
)*
};
}
impl_pipeline_model! {
crate::tree::DecisionTreeClassifier,
crate::tree::RandomForestClassifier,
crate::linear::LinearRegression,
crate::linear::LogisticRegression,
crate::neighbors::KnnClassifier,
crate::naive_bayes::GaussianNb,
crate::tree::DecisionTreeRegressor,
crate::tree::RandomForestRegressor,
crate::tree::GradientBoostingClassifier,
crate::tree::GradientBoostingRegressor,
crate::linear::LassoRegression,
crate::linear::ElasticNet,
crate::svm::LinearSVC,
crate::svm::LinearSVR,
crate::naive_bayes::BernoulliNB,
crate::naive_bayes::MultinomialNB,
crate::tree::HistGradientBoostingClassifier,
crate::tree::HistGradientBoostingRegressor,
crate::neural::MLPClassifier,
crate::neural::MLPRegressor,
}
#[cfg(feature = "experimental")]
impl_pipeline_model! {
crate::svm::KernelSVC,
crate::svm::KernelSVR,
}
impl Pipeline {
pub fn new() -> Self {
Self {
transformers: Vec::new(),
model: None,
}
}
pub fn add_transformer<T: Transformer + 'static>(mut self, t: T) -> Self {
self.transformers.push(Box::new(t));
self
}
pub fn set_model<M: PipelineModel + 'static>(mut self, m: M) -> Self {
self.model = Some(Box::new(m));
self
}
pub fn fit(&mut self, data: &Dataset) -> Result<()> {
data.validate_finite()?;
let mut transformed = data.clone();
for t in &mut self.transformers {
t.fit(&transformed)?;
t.transform(&mut transformed)?;
}
if let Some(model) = &mut self.model {
model.fit(&transformed)?;
}
Ok(())
}
pub fn predict(&self, data: &Dataset) -> Result<Vec<f64>> {
let mut transformed = data.clone();
for t in &self.transformers {
t.transform(&mut transformed)?;
}
let model = self.model.as_ref().ok_or(ScryLearnError::NotFitted)?;
let features = transformed.feature_matrix();
model.predict(&features)
}
}
impl Default for Pipeline {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::preprocess::StandardScaler;
use crate::tree::DecisionTreeClassifier;
#[test]
fn test_pipeline_fit_predict() {
let features = vec![
vec![0.0, 0.5, 1.0, 5.0, 5.5, 6.0],
vec![0.0, 0.5, 1.0, 5.0, 5.5, 6.0],
];
let target = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
let data = Dataset::new(features, target, vec!["x".into(), "y".into()], "class");
let mut pipeline = Pipeline::new()
.add_transformer(StandardScaler::new())
.set_model(DecisionTreeClassifier::new());
pipeline.fit(&data).unwrap();
let preds = pipeline.predict(&data).unwrap();
assert_eq!(preds.len(), 6);
}
#[test]
fn test_pipeline_pca_then_model() {
use crate::preprocess::Pca;
let features = vec![
vec![0.0, 0.5, 1.0, 5.0, 5.5, 6.0],
vec![0.0, 0.5, 1.0, 5.0, 5.5, 6.0],
];
let target = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
let data = Dataset::new(features, target, vec!["x".into(), "y".into()], "class");
let mut pipeline = Pipeline::new()
.add_transformer(Pca::with_n_components(2))
.add_transformer(StandardScaler::new())
.set_model(DecisionTreeClassifier::new());
pipeline.fit(&data).unwrap();
let preds = pipeline.predict(&data).unwrap();
assert_eq!(preds.len(), 6);
}
}