use crate::primitives::{Matrix, Vector};
use crate::traits::Estimator;
#[derive(Debug, Clone)]
pub struct CrossValidationResult {
pub scores: Vec<f32>,
}
impl CrossValidationResult {
#[must_use]
pub fn mean(&self) -> f32 {
if self.scores.is_empty() {
return 0.0;
}
self.scores.iter().sum::<f32>() / self.scores.len() as f32
}
#[must_use]
pub fn std(&self) -> f32 {
if self.scores.is_empty() {
return 0.0;
}
let mean = self.mean();
let variance = self
.scores
.iter()
.map(|&score| (score - mean).powi(2))
.sum::<f32>()
/ self.scores.len() as f32;
variance.sqrt()
}
pub fn min(&self) -> f32 {
self.scores.iter().copied().fold(f32::INFINITY, f32::min)
}
pub fn max(&self) -> f32 {
self.scores
.iter()
.copied()
.fold(f32::NEG_INFINITY, f32::max)
}
}
pub fn cross_validate<E>(
estimator: &E,
x: &Matrix<f32>,
y: &Vector<f32>,
cv: &KFold,
) -> Result<CrossValidationResult, String>
where
E: Estimator + Clone,
{
let n_samples = x.shape().0;
let splits = cv.split(n_samples);
let mut scores = Vec::with_capacity(splits.len());
for (train_idx, test_idx) in splits {
let (x_train, y_train) = extract_samples(x, y, &train_idx);
let (x_test, y_test) = extract_samples(x, y, &test_idx);
let mut fold_model = estimator.clone();
fold_model
.fit(&x_train, &y_train)
.map_err(|e| format!("Training failed: {e}"))?;
let score = fold_model.score(&x_test, &y_test);
scores.push(score);
}
Ok(CrossValidationResult { scores })
}
fn extract_samples(
x: &Matrix<f32>,
y: &Vector<f32>,
indices: &[usize],
) -> (Matrix<f32>, Vector<f32>) {
let n_features = x.shape().1;
let mut x_data = Vec::with_capacity(indices.len() * n_features);
let mut y_data = Vec::with_capacity(indices.len());
for &idx in indices {
for j in 0..n_features {
x_data.push(x.get(idx, j));
}
y_data.push(y.as_slice()[idx]);
}
let x_subset =
Matrix::from_vec(indices.len(), n_features, x_data).expect("Failed to create matrix");
let y_subset = Vector::from_vec(y_data);
(x_subset, y_subset)
}
#[derive(Debug, Clone)]
pub struct KFold {
n_splits: usize,
shuffle: bool,
random_state: Option<u64>,
}
impl KFold {
#[must_use]
pub fn new(n_splits: usize) -> Self {
Self {
n_splits,
shuffle: false,
random_state: None,
}
}
#[must_use]
pub fn with_shuffle(mut self, shuffle: bool) -> Self {
self.shuffle = shuffle;
self
}
#[must_use]
pub fn with_random_state(mut self, random_state: u64) -> Self {
self.random_state = Some(random_state);
self.shuffle = true; self
}
#[must_use]
pub fn split(&self, n_samples: usize) -> Vec<(Vec<usize>, Vec<usize>)> {
use rand::seq::SliceRandom;
use rand::SeedableRng;
let mut indices: Vec<usize> = (0..n_samples).collect();
if self.shuffle {
if let Some(seed) = self.random_state {
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
indices.shuffle(&mut rng);
} else {
let mut rng = rand::rng();
indices.shuffle(&mut rng);
}
}
let fold_size = n_samples / self.n_splits;
let remainder = n_samples % self.n_splits;
let mut result = Vec::with_capacity(self.n_splits);
let mut start = 0;
for i in 0..self.n_splits {
let current_fold_size = if i < remainder {
fold_size + 1
} else {
fold_size
};
let end = start + current_fold_size;
let test_indices: Vec<usize> = indices[start..end].to_vec();
let mut train_indices = Vec::with_capacity(n_samples - current_fold_size);
train_indices.extend_from_slice(&indices[..start]);
train_indices.extend_from_slice(&indices[end..]);
result.push((train_indices, test_indices));
start = end;
}
result
}
}
#[derive(Debug, Clone)]
pub struct StratifiedKFold {
n_splits: usize,
shuffle: bool,
random_state: Option<u64>,
}
impl StratifiedKFold {
#[must_use]
pub fn new(n_splits: usize) -> Self {
Self {
n_splits,
shuffle: false,
random_state: None,
}
}
#[must_use]
pub fn with_shuffle(mut self, shuffle: bool) -> Self {
self.shuffle = shuffle;
self
}
#[must_use]
pub fn with_random_state(mut self, random_state: u64) -> Self {
self.random_state = Some(random_state);
self.shuffle = true;
self
}
#[must_use]
pub fn split(&self, y: &Vector<f32>) -> Vec<(Vec<usize>, Vec<usize>)> {
use rand::seq::SliceRandom;
use rand::SeedableRng;
use std::collections::HashMap;
let n_samples = y.len();
let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
for (i, &label) in y.as_slice().iter().enumerate() {
class_indices.entry(label as i32).or_default().push(i);
}
if self.shuffle {
for indices in class_indices.values_mut() {
if let Some(seed) = self.random_state {
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
indices.shuffle(&mut rng);
} else {
let mut rng = rand::rng();
indices.shuffle(&mut rng);
}
}
}
let mut fold_indices: Vec<Vec<usize>> = vec![Vec::new(); self.n_splits];
for indices in class_indices.values() {
let class_size = indices.len();
let fold_size = class_size / self.n_splits;
let remainder = class_size % self.n_splits;
let mut start = 0;
for (i, fold) in fold_indices.iter_mut().enumerate() {
let current_size = if i < remainder {
fold_size + 1
} else {
fold_size
};
let end = start + current_size;
fold.extend_from_slice(&indices[start..end]);
start = end;
}
}
let mut result = Vec::with_capacity(self.n_splits);
for i in 0..self.n_splits {
let test_indices = fold_indices[i].clone();
let mut train_indices = Vec::with_capacity(n_samples - test_indices.len());
for (j, fold) in fold_indices.iter().enumerate() {
if i != j {
train_indices.extend_from_slice(fold);
}
}
result.push((train_indices, test_indices));
}
result
}
}
#[derive(Debug, Clone)]
pub struct GridSearchResult {
pub best_alpha: f32,
pub best_score: f32,
pub alphas: Vec<f32>,
pub scores: Vec<f32>,
}
include!("alpha.rs");
#[cfg(test)]
#[path = "tests_kfold_contract.rs"]
mod tests_kfold_contract;