use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use super::builder::{Pipeline, PipelineError};
use super::evaluator::r2_score;
use super::predictor::{LinearRegressor, KnnRegressor};
use super::scalers::StandardScaler;
use super::traits::FeatureTransformer;
pub trait RegressionPredictor: Send + Sync {
fn fit_reg(
&mut self,
x: ArrayView2<f64>,
y: ArrayView1<f64>,
) -> Result<(), PipelineError>;
fn predict_reg(&self, x: ArrayView2<f64>) -> Result<Array1<f64>, PipelineError>;
fn model_name(&self) -> &str;
}
impl RegressionPredictor for LinearRegressor {
fn fit_reg(
&mut self,
x: ArrayView2<f64>,
y: ArrayView1<f64>,
) -> Result<(), PipelineError> {
self.fit(x, y)
}
fn predict_reg(&self, x: ArrayView2<f64>) -> Result<Array1<f64>, PipelineError> {
self.predict(x)
}
fn model_name(&self) -> &str {
"LinearRegressor"
}
}
impl RegressionPredictor for KnnRegressor {
fn fit_reg(
&mut self,
x: ArrayView2<f64>,
y: ArrayView1<f64>,
) -> Result<(), PipelineError> {
self.fit(x, y)
}
fn predict_reg(&self, x: ArrayView2<f64>) -> Result<Array1<f64>, PipelineError> {
self.predict(x)
}
fn model_name(&self) -> &str {
"KnnRegressor"
}
}
pub struct RegressionPipeline {
preprocessing: Pipeline<f64>,
predictor: Box<dyn RegressionPredictor>,
}
impl RegressionPipeline {
pub fn new(
preprocessing: Pipeline<f64>,
predictor: impl RegressionPredictor + 'static,
) -> Self {
Self {
preprocessing,
predictor: Box::new(predictor),
}
}
pub fn fit(
&mut self,
x: ArrayView2<f64>,
y: ArrayView1<f64>,
) -> Result<(), PipelineError> {
if x.nrows() == 0 {
return Err(PipelineError::EmptyInput("RegressionPipeline.fit".to_string()));
}
let x_transformed = self.preprocessing.fit_transform(&x.to_owned())?;
self.predictor.fit_reg(x_transformed.view(), y)
}
pub fn predict(&self, x: ArrayView2<f64>) -> Result<Array1<f64>, PipelineError> {
if x.nrows() == 0 {
return Err(PipelineError::EmptyInput("RegressionPipeline.predict".to_string()));
}
let x_transformed = self.preprocessing.transform(&x.to_owned())?;
self.predictor.predict_reg(x_transformed.view())
}
pub fn fit_predict(
&mut self,
x: ArrayView2<f64>,
y: ArrayView1<f64>,
) -> Result<Array1<f64>, PipelineError> {
self.fit(x, y)?;
self.predict(x)
}
pub fn name(&self) -> String {
let steps = self.preprocessing.step_names().join(" → ");
let pred = self.predictor.model_name();
if steps.is_empty() {
pred.to_string()
} else {
format!("{steps} → {pred}")
}
}
}
pub fn cross_validate_regression(
pipeline_factory: impl Fn() -> RegressionPipeline,
x: ArrayView2<f64>,
y: ArrayView1<f64>,
n_folds: usize,
) -> Result<Vec<f64>, PipelineError> {
let n = x.nrows();
if n_folds < 2 {
return Err(PipelineError::StepError {
step: "cross_validate_regression".to_string(),
message: format!("n_folds must be >= 2, got {n_folds}"),
});
}
if n_folds > n {
return Err(PipelineError::StepError {
step: "cross_validate_regression".to_string(),
message: format!("n_folds ({n_folds}) must be <= n_samples ({n})"),
});
}
let fold_size = n / n_folds;
let mut scores = Vec::with_capacity(n_folds);
for fold in 0..n_folds {
let test_start = fold * fold_size;
let test_end = if fold == n_folds - 1 {
n } else {
(fold + 1) * fold_size
};
let train_indices: Vec<usize> = (0..n)
.filter(|&i| i < test_start || i >= test_end)
.collect();
let test_indices: Vec<usize> = (test_start..test_end).collect();
if train_indices.is_empty() || test_indices.is_empty() {
continue;
}
let n_train = train_indices.len();
let n_test = test_indices.len();
let ncols = x.ncols();
let mut x_train_data = Vec::with_capacity(n_train * ncols);
let mut y_train_data = Vec::with_capacity(n_train);
for &i in &train_indices {
for j in 0..ncols {
x_train_data.push(x[[i, j]]);
}
y_train_data.push(y[i]);
}
let x_train = Array2::from_shape_vec((n_train, ncols), x_train_data).map_err(|e| {
PipelineError::StepError {
step: "cross_validate_regression".to_string(),
message: format!("train array shape error: {e}"),
}
})?;
let y_train = Array1::from_vec(y_train_data);
let mut x_test_data = Vec::with_capacity(n_test * ncols);
let mut y_test_data = Vec::with_capacity(n_test);
for &i in &test_indices {
for j in 0..ncols {
x_test_data.push(x[[i, j]]);
}
y_test_data.push(y[i]);
}
let x_test = Array2::from_shape_vec((n_test, ncols), x_test_data).map_err(|e| {
PipelineError::StepError {
step: "cross_validate_regression".to_string(),
message: format!("test array shape error: {e}"),
}
})?;
let y_test = Array1::from_vec(y_test_data);
let mut pipeline = pipeline_factory();
pipeline.fit(x_train.view(), y_train.view())?;
let preds = pipeline.predict(x_test.view())?;
let score = r2_score(&y_test, &preds);
scores.push(score);
}
Ok(scores)
}
pub fn linear_regression_pipeline() -> RegressionPipeline {
RegressionPipeline::new(
Pipeline::<f64>::new().add_transformer(Box::new(StandardScaler::new())),
LinearRegressor::new(),
)
}
pub fn knn_regression_pipeline(k: usize) -> RegressionPipeline {
RegressionPipeline::new(
Pipeline::<f64>::new().add_transformer(Box::new(StandardScaler::new())),
KnnRegressor::new(k),
)
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{Array1, Array2};
fn make_linear_data(n: usize) -> (Array2<f64>, Array1<f64>) {
let x: Array2<f64> = Array2::from_shape_vec(
(n, 1),
(1..=n).map(|i| i as f64).collect::<Vec<_>>(),
)
.expect("shape ok");
let y: Array1<f64> = Array1::from_vec(
(1..=n).map(|i| 3.0 * i as f64 + 2.0).collect::<Vec<_>>(),
);
(x, y)
}
#[test]
fn test_regression_pipeline_fit_predict() {
let (x, y) = make_linear_data(6);
let mut pipeline = linear_regression_pipeline();
pipeline.fit(x.view(), y.view()).expect("fit ok");
let preds = pipeline.predict(x.view()).expect("predict ok");
assert_eq!(preds.len(), 6);
for (i, (&yt, &yp)) in y.iter().zip(preds.iter()).enumerate() {
assert!((yt - yp).abs() < 1.0, "row {i}: true={yt}, pred={yp}");
}
}
#[test]
fn test_regression_pipeline_knn() {
let (x, y) = make_linear_data(8);
let mut pipeline = knn_regression_pipeline(2);
pipeline.fit(x.view(), y.view()).expect("fit ok");
let preds = pipeline.predict(x.view()).expect("predict ok");
assert_eq!(preds.len(), 8);
}
#[test]
fn test_regression_pipeline_predict_before_fit_fails() {
let pipeline = RegressionPipeline::new(Pipeline::new(), LinearRegressor::new());
let x = Array2::from_shape_vec((2, 1), vec![1.0f64, 2.0]).expect("ok");
let result = pipeline.predict(x.view());
assert!(result.is_err(), "predict before fit should fail");
}
#[test]
fn test_regression_pipeline_fit_predict_method() {
let (x, y) = make_linear_data(4);
let mut pipeline = linear_regression_pipeline();
let preds = pipeline.fit_predict(x.view(), y.view()).expect("fit_predict ok");
assert_eq!(preds.len(), 4);
}
#[test]
fn test_regression_pipeline_empty_input() {
let mut pipeline = linear_regression_pipeline();
let x: Array2<f64> = Array2::zeros((0, 1));
let y: Array1<f64> = Array1::zeros(0);
let result = pipeline.fit(x.view(), y.view());
assert!(matches!(result, Err(PipelineError::EmptyInput(_))));
}
#[test]
fn test_regression_pipeline_name() {
let pipeline = linear_regression_pipeline();
let name = pipeline.name();
assert!(name.contains("StandardScaler"), "name: {name}");
assert!(name.contains("LinearRegressor"), "name: {name}");
}
#[test]
fn test_cross_validate_regression_linear_data() {
let (x, y) = make_linear_data(20);
let factory = || linear_regression_pipeline();
let scores = cross_validate_regression(factory, x.view(), y.view(), 5).expect("cv ok");
assert_eq!(scores.len(), 5);
for &s in &scores {
assert!(s > 0.5, "fold R² should be high for linear data: {s}");
}
}
#[test]
fn test_cross_validate_regression_n_folds_too_small() {
let (x, y) = make_linear_data(10);
let factory = || linear_regression_pipeline();
let result = cross_validate_regression(factory, x.view(), y.view(), 1);
assert!(matches!(result, Err(PipelineError::StepError { .. })));
}
#[test]
fn test_cross_validate_regression_n_folds_too_large() {
let (x, y) = make_linear_data(3);
let factory = || linear_regression_pipeline();
let result = cross_validate_regression(factory, x.view(), y.view(), 5);
assert!(matches!(result, Err(PipelineError::StepError { .. })));
}
#[test]
fn test_cross_validate_regression_2_folds() {
let (x, y) = make_linear_data(10);
let factory = || linear_regression_pipeline();
let scores = cross_validate_regression(factory, x.view(), y.view(), 2).expect("cv ok");
assert_eq!(scores.len(), 2);
}
#[test]
fn test_cross_validate_regression_knn() {
let (x, y) = make_linear_data(12);
let factory = || knn_regression_pipeline(2);
let scores = cross_validate_regression(factory, x.view(), y.view(), 3).expect("cv ok");
assert_eq!(scores.len(), 3);
}
}