cross_validate

Function cross_validate 

Source
pub fn cross_validate<E>(
    estimator: &E,
    x: &Matrix<f32>,
    y: &Vector<f32>,
    cv: &KFold,
) -> Result<CrossValidationResult, String>
where E: Estimator + Clone,
Expand description

Run cross-validation on an estimator.

Automatically trains and evaluates the model on each fold, returning scores.

§Arguments

  • estimator - The model to cross-validate (must be cloneable)
  • x - Feature matrix
  • y - Target vector
  • cv - Cross-validation splitter (e.g., KFold)

§Example

use aprender::prelude::*;
use aprender::model_selection::{cross_validate, KFold};

let x = Matrix::from_vec(50, 1, (0..50).map(|i| i as f32).collect()).unwrap();
let y = Vector::from_slice(&vec![0.0; 50]);

let model = LinearRegression::new();
let kfold = KFold::new(5);

let results = cross_validate(&model, &x, &y, &kfold).unwrap();
println!("Mean R²: {:.3} ± {:.3}", results.mean(), results.std());