use ndarray::{Array1, Array2};
use num_traits::Float;
use crate::error::FerroError;
use crate::traits::{Fit, Predict};
pub trait PipelineTransformer<F: Float + Send + Sync + 'static>: Send + Sync {
fn fit_pipeline(
&self,
x: &Array2<F>,
y: &Array1<F>,
) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError>;
}
pub trait FittedPipelineTransformer<F: Float + Send + Sync + 'static>: Send + Sync {
fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError>;
}
pub trait PipelineEstimator<F: Float + Send + Sync + 'static>: Send + Sync {
fn fit_pipeline(
&self,
x: &Array2<F>,
y: &Array1<F>,
) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError>;
}
pub trait FittedPipelineEstimator<F: Float + Send + Sync + 'static>: Send + Sync {
fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError>;
}
struct TransformStep<F: Float + Send + Sync + 'static> {
name: String,
step: Box<dyn PipelineTransformer<F>>,
}
pub struct Pipeline<F: Float + Send + Sync + 'static = f64> {
transforms: Vec<TransformStep<F>>,
estimator: Option<(String, Box<dyn PipelineEstimator<F>>)>,
}
impl<F: Float + Send + Sync + 'static> Pipeline<F> {
pub fn new() -> Self {
Self {
transforms: Vec::new(),
estimator: None,
}
}
#[must_use]
pub fn transform_step(mut self, name: &str, step: Box<dyn PipelineTransformer<F>>) -> Self {
self.transforms.push(TransformStep {
name: name.to_owned(),
step,
});
self
}
#[must_use]
pub fn estimator_step(mut self, name: &str, estimator: Box<dyn PipelineEstimator<F>>) -> Self {
self.estimator = Some((name.to_owned(), estimator));
self
}
#[must_use]
pub fn step(self, name: &str, step: Box<dyn PipelineStep<F>>) -> Self {
step.add_to_pipeline(self, name)
}
}
impl<F: Float + Send + Sync + 'static> Default for Pipeline<F> {
fn default() -> Self {
Self::new()
}
}
impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>> for Pipeline<F> {
type Fitted = FittedPipeline<F>;
type Error = FerroError;
fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedPipeline<F>, FerroError> {
if self.estimator.is_none() {
return Err(FerroError::InvalidParameter {
name: "estimator".into(),
reason: "pipeline must have a final estimator step".into(),
});
}
let mut current_x = x.clone();
let mut fitted_transforms = Vec::with_capacity(self.transforms.len());
for ts in &self.transforms {
let fitted = ts.step.fit_pipeline(¤t_x, y)?;
current_x = fitted.transform_pipeline(¤t_x)?;
fitted_transforms.push(FittedTransformStep {
name: ts.name.clone(),
step: fitted,
});
}
let (est_name, est) = self.estimator.as_ref().unwrap();
let fitted_est = est.fit_pipeline(¤t_x, y)?;
Ok(FittedPipeline {
transforms: fitted_transforms,
estimator: (est_name.clone(), fitted_est),
})
}
}
struct FittedTransformStep<F: Float + Send + Sync + 'static> {
name: String,
step: Box<dyn FittedPipelineTransformer<F>>,
}
pub struct FittedPipeline<F: Float + Send + Sync + 'static = f64> {
transforms: Vec<FittedTransformStep<F>>,
estimator: (String, Box<dyn FittedPipelineEstimator<F>>),
}
impl<F: Float + Send + Sync + 'static> FittedPipeline<F> {
pub fn step_names(&self) -> Vec<&str> {
let mut names: Vec<&str> = self.transforms.iter().map(|s| s.name.as_str()).collect();
names.push(&self.estimator.0);
names
}
}
impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedPipeline<F> {
type Output = Array1<F>;
type Error = FerroError;
fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
let mut current_x = x.clone();
for ts in &self.transforms {
current_x = ts.step.transform_pipeline(¤t_x)?;
}
self.estimator.1.predict_pipeline(¤t_x)
}
}
pub trait PipelineStep<F: Float + Send + Sync + 'static>: Send + Sync {
fn add_to_pipeline(self: Box<Self>, pipeline: Pipeline<F>, name: &str) -> Pipeline<F>;
}
pub struct TransformerStepWrapper<F: Float + Send + Sync + 'static>(
Box<dyn PipelineTransformer<F>>,
);
impl<F: Float + Send + Sync + 'static> PipelineStep<F> for TransformerStepWrapper<F> {
fn add_to_pipeline(self: Box<Self>, pipeline: Pipeline<F>, name: &str) -> Pipeline<F> {
pipeline.transform_step(name, self.0)
}
}
pub struct EstimatorStepWrapper<F: Float + Send + Sync + 'static>(Box<dyn PipelineEstimator<F>>);
impl<F: Float + Send + Sync + 'static> PipelineStep<F> for EstimatorStepWrapper<F> {
fn add_to_pipeline(self: Box<Self>, pipeline: Pipeline<F>, name: &str) -> Pipeline<F> {
pipeline.estimator_step(name, self.0)
}
}
pub fn as_transform_step<F: Float + Send + Sync + 'static>(
t: impl PipelineTransformer<F> + 'static,
) -> Box<dyn PipelineStep<F>> {
Box::new(TransformerStepWrapper(Box::new(t)))
}
pub fn as_estimator_step<F: Float + Send + Sync + 'static>(
e: impl PipelineEstimator<F> + 'static,
) -> Box<dyn PipelineStep<F>> {
Box::new(EstimatorStepWrapper(Box::new(e)))
}
#[cfg(test)]
mod tests {
use super::*;
struct DoublingTransformer;
impl PipelineTransformer<f64> for DoublingTransformer {
fn fit_pipeline(
&self,
_x: &Array2<f64>,
_y: &Array1<f64>,
) -> Result<Box<dyn FittedPipelineTransformer<f64>>, FerroError> {
Ok(Box::new(FittedDoublingTransformer))
}
}
struct FittedDoublingTransformer;
impl FittedPipelineTransformer<f64> for FittedDoublingTransformer {
fn transform_pipeline(&self, x: &Array2<f64>) -> Result<Array2<f64>, FerroError> {
Ok(x.mapv(|v| v * 2.0))
}
}
struct SumEstimator;
impl PipelineEstimator<f64> for SumEstimator {
fn fit_pipeline(
&self,
_x: &Array2<f64>,
_y: &Array1<f64>,
) -> Result<Box<dyn FittedPipelineEstimator<f64>>, FerroError> {
Ok(Box::new(FittedSumEstimator))
}
}
struct FittedSumEstimator;
impl FittedPipelineEstimator<f64> for FittedSumEstimator {
fn predict_pipeline(&self, x: &Array2<f64>) -> Result<Array1<f64>, FerroError> {
let sums: Vec<f64> = x.rows().into_iter().map(|row| row.sum()).collect();
Ok(Array1::from_vec(sums))
}
}
struct DoublingTransformerF32;
impl PipelineTransformer<f32> for DoublingTransformerF32 {
fn fit_pipeline(
&self,
_x: &Array2<f32>,
_y: &Array1<f32>,
) -> Result<Box<dyn FittedPipelineTransformer<f32>>, FerroError> {
Ok(Box::new(FittedDoublingTransformerF32))
}
}
struct FittedDoublingTransformerF32;
impl FittedPipelineTransformer<f32> for FittedDoublingTransformerF32 {
fn transform_pipeline(&self, x: &Array2<f32>) -> Result<Array2<f32>, FerroError> {
Ok(x.mapv(|v| v * 2.0))
}
}
struct SumEstimatorF32;
impl PipelineEstimator<f32> for SumEstimatorF32 {
fn fit_pipeline(
&self,
_x: &Array2<f32>,
_y: &Array1<f32>,
) -> Result<Box<dyn FittedPipelineEstimator<f32>>, FerroError> {
Ok(Box::new(FittedSumEstimatorF32))
}
}
struct FittedSumEstimatorF32;
impl FittedPipelineEstimator<f32> for FittedSumEstimatorF32 {
fn predict_pipeline(&self, x: &Array2<f32>) -> Result<Array1<f32>, FerroError> {
let sums: Vec<f32> = x.rows().into_iter().map(|row| row.sum()).collect();
Ok(Array1::from_vec(sums))
}
}
#[test]
fn test_pipeline_fit_predict() {
let pipeline = Pipeline::new()
.transform_step("doubler", Box::new(DoublingTransformer))
.estimator_step("sum", Box::new(SumEstimator));
let x = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let y = Array1::from_vec(vec![0.0, 1.0]);
let fitted = pipeline.fit(&x, &y).unwrap();
let preds = fitted.predict(&x).unwrap();
assert_eq!(preds.len(), 2);
assert!((preds[0] - 12.0).abs() < 1e-10);
assert!((preds[1] - 30.0).abs() < 1e-10);
}
#[test]
fn test_pipeline_f32_fit_predict() {
let pipeline = Pipeline::<f32>::new()
.transform_step("doubler", Box::new(DoublingTransformerF32))
.estimator_step("sum", Box::new(SumEstimatorF32));
let x = Array2::from_shape_vec((2, 3), vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let y = Array1::from_vec(vec![0.0f32, 1.0]);
let fitted = pipeline.fit(&x, &y).unwrap();
let preds = fitted.predict(&x).unwrap();
assert_eq!(preds.len(), 2);
assert!((preds[0] - 12.0).abs() < 1e-5);
assert!((preds[1] - 30.0).abs() < 1e-5);
}
#[test]
fn test_pipeline_step_builder() {
let pipeline = Pipeline::new()
.step("doubler", as_transform_step(DoublingTransformer))
.step("sum", as_estimator_step(SumEstimator));
let x = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let y = Array1::from_vec(vec![0.0, 1.0]);
let fitted = pipeline.fit(&x, &y).unwrap();
let preds = fitted.predict(&x).unwrap();
assert!((preds[0] - 12.0).abs() < 1e-10);
assert!((preds[1] - 30.0).abs() < 1e-10);
}
#[test]
fn test_pipeline_no_estimator_returns_error() {
let pipeline = Pipeline::new().transform_step("doubler", Box::new(DoublingTransformer));
let x = Array2::<f64>::zeros((2, 3));
let y = Array1::from_vec(vec![0.0, 1.0]);
let result = pipeline.fit(&x, &y);
assert!(result.is_err());
}
#[test]
fn test_pipeline_estimator_only() {
let pipeline = Pipeline::new().estimator_step("sum", Box::new(SumEstimator));
let x = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let y = Array1::from_vec(vec![0.0, 1.0]);
let fitted = pipeline.fit(&x, &y).unwrap();
let preds = fitted.predict(&x).unwrap();
assert!((preds[0] - 6.0).abs() < 1e-10);
assert!((preds[1] - 15.0).abs() < 1e-10);
}
#[test]
fn test_fitted_pipeline_step_names() {
let pipeline = Pipeline::new()
.transform_step("scaler", Box::new(DoublingTransformer))
.transform_step("normalizer", Box::new(DoublingTransformer))
.estimator_step("clf", Box::new(SumEstimator));
let x = Array2::<f64>::zeros((2, 3));
let y = Array1::from_vec(vec![0.0, 1.0]);
let fitted = pipeline.fit(&x, &y).unwrap();
let names = fitted.step_names();
assert_eq!(names, vec!["scaler", "normalizer", "clf"]);
}
#[test]
fn test_multiple_transform_steps() {
let pipeline = Pipeline::new()
.transform_step("double1", Box::new(DoublingTransformer))
.transform_step("double2", Box::new(DoublingTransformer))
.estimator_step("sum", Box::new(SumEstimator));
let x = Array2::from_shape_vec((1, 2), vec![1.0, 1.0]).unwrap();
let y = Array1::from_vec(vec![0.0]);
let fitted = pipeline.fit(&x, &y).unwrap();
let preds = fitted.predict(&x).unwrap();
assert!((preds[0] - 8.0).abs() < 1e-10);
}
#[test]
fn test_pipeline_default() {
let pipeline = Pipeline::<f64>::default();
let x = Array2::<f64>::zeros((2, 3));
let y = Array1::from_vec(vec![0.0, 1.0]);
assert!(pipeline.fit(&x, &y).is_err());
}
#[test]
fn test_pipeline_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<Pipeline<f64>>();
assert_send_sync::<Pipeline<f32>>();
assert_send_sync::<FittedPipeline<f64>>();
assert_send_sync::<FittedPipeline<f32>>();
}
}