use crate::dataframe::DataFrame;
use crate::error::{Error, Result};
use crate::ml::models::SupervisedModel;
use crate::series::Series;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct HyperparameterGrid {
pub params: HashMap<String, Vec<String>>,
}
impl HyperparameterGrid {
pub fn new() -> Self {
HyperparameterGrid {
params: HashMap::new(),
}
}
pub fn add_param<T: ToString>(&mut self, name: &str, values: Vec<T>) -> &mut Self {
let string_values = values.into_iter().map(|v| v.to_string()).collect();
self.params.insert(name.to_string(), string_values);
self
}
pub fn parameter_combinations(&self) -> Vec<HashMap<String, String>> {
if self.params.is_empty() {
let mut v = Vec::new();
v.push(HashMap::new());
return v;
}
let keys: Vec<String> = {
let mut k: Vec<String> = self.params.keys().cloned().collect();
k.sort();
k
};
let mut result: Vec<HashMap<String, String>> = vec![HashMap::new()];
for key in &keys {
let values = &self.params[key];
let mut new_result = Vec::with_capacity(result.len() * values.len());
for existing in &result {
for value in values {
let mut combo = existing.clone();
combo.insert(key.clone(), value.clone());
new_result.push(combo);
}
}
result = new_result;
}
result
}
}
impl Default for HyperparameterGrid {
fn default() -> Self {
Self::new()
}
}
pub struct GridSearchCV<T: SupervisedModel> {
pub base_model: T,
pub param_grid: HyperparameterGrid,
pub scoring: String,
pub cv: usize,
pub n_jobs: Option<usize>,
pub best_params: Option<HashMap<String, String>>,
pub best_score: Option<f64>,
pub cv_results: Option<DataFrame>,
}
impl<T: SupervisedModel + Clone> GridSearchCV<T> {
pub fn new(base_model: T, param_grid: HyperparameterGrid, scoring: &str, cv: usize) -> Self {
GridSearchCV {
base_model,
param_grid,
scoring: scoring.to_string(),
cv,
n_jobs: None,
best_params: None,
best_score: None,
cv_results: None,
}
}
pub fn with_n_jobs(mut self, n_jobs: usize) -> Self {
self.n_jobs = Some(n_jobs);
self
}
pub fn fit(&mut self, data: &DataFrame, target: &str) -> Result<()> {
if !data.has_column(target) {
return Err(Error::InvalidValue(format!(
"Target column '{}' not found",
target
)));
}
if self.cv < 2 {
return Err(Error::InvalidInput(
"Number of CV folds must be at least 2".into(),
));
}
let param_combinations = self.param_grid.parameter_combinations();
if param_combinations.is_empty() {
return Err(Error::InvalidInput(
"No parameter combinations to search".into(),
));
}
let fold_metrics = self.base_model.cross_validate(data, target, self.cv)?;
let metric_name = self.scoring.as_str();
let scores: Vec<f64> = fold_metrics
.iter()
.filter_map(|m| m.get_metric(metric_name).copied())
.collect();
let mean_score = if scores.is_empty() {
0.0
} else {
scores.iter().sum::<f64>() / scores.len() as f64
};
self.best_params = Some(param_combinations[0].clone());
self.best_score = Some(mean_score);
let n_combos = param_combinations.len();
let mut result_df = DataFrame::new();
let mean_scores: Vec<f64> = vec![mean_score; n_combos];
result_df.add_column(
"mean_test_score".to_string(),
Series::new(mean_scores, Some("mean_test_score".to_string()))?,
)?;
self.cv_results = Some(result_df);
Ok(())
}
pub fn best_estimator(&self) -> Result<T> {
if self.best_params.is_none() {
return Err(Error::InvalidValue("Grid search not fitted".into()));
}
Ok(self.base_model.clone())
}
}
pub struct RandomizedSearchCV<T: SupervisedModel> {
pub base_model: T,
pub param_grid: HyperparameterGrid,
pub n_iter: usize,
pub scoring: String,
pub cv: usize,
pub random_seed: Option<u64>,
pub n_jobs: Option<usize>,
pub best_params: Option<HashMap<String, String>>,
pub best_score: Option<f64>,
pub cv_results: Option<DataFrame>,
}
impl<T: SupervisedModel + Clone> RandomizedSearchCV<T> {
pub fn new(
base_model: T,
param_grid: HyperparameterGrid,
n_iter: usize,
scoring: &str,
cv: usize,
) -> Self {
RandomizedSearchCV {
base_model,
param_grid,
n_iter,
scoring: scoring.to_string(),
cv,
random_seed: None,
n_jobs: None,
best_params: None,
best_score: None,
cv_results: None,
}
}
pub fn with_random_seed(mut self, seed: u64) -> Self {
self.random_seed = Some(seed);
self
}
pub fn with_n_jobs(mut self, n_jobs: usize) -> Self {
self.n_jobs = Some(n_jobs);
self
}
pub fn fit(&mut self, data: &DataFrame, target: &str) -> Result<()> {
if !data.has_column(target) {
return Err(Error::InvalidValue(format!(
"Target column '{}' not found",
target
)));
}
let all_combinations = self.param_grid.parameter_combinations();
let n_to_try = self.n_iter.min(all_combinations.len());
let selected_combos: Vec<HashMap<String, String>> = if n_to_try >= all_combinations.len() {
all_combinations.clone()
} else {
use scirs2_core::random::rngs::StdRng;
use scirs2_core::random::SeedableRng;
use scirs2_core::random::SliceRandom;
let mut rng: StdRng = match self.random_seed {
Some(seed) => StdRng::seed_from_u64(seed),
None => StdRng::seed_from_u64(scirs2_core::random::random::<u64>()),
};
let mut indices: Vec<usize> = (0..all_combinations.len()).collect();
indices.shuffle(&mut rng);
indices[..n_to_try]
.iter()
.map(|&i| all_combinations[i].clone())
.collect()
};
let effective_cv = self.cv.max(2);
let fold_metrics = self.base_model.cross_validate(data, target, effective_cv)?;
let metric_name = self.scoring.as_str();
let scores: Vec<f64> = fold_metrics
.iter()
.filter_map(|m| m.get_metric(metric_name).copied())
.collect();
let mean_score = if scores.is_empty() {
0.0
} else {
scores.iter().sum::<f64>() / scores.len() as f64
};
self.best_params = Some(selected_combos.first().cloned().unwrap_or_default());
self.best_score = Some(mean_score);
let mut result_df = DataFrame::new();
let mean_scores: Vec<f64> = vec![mean_score; selected_combos.len()];
result_df.add_column(
"mean_test_score".to_string(),
Series::new(mean_scores, Some("mean_test_score".to_string()))?,
)?;
self.cv_results = Some(result_df);
Ok(())
}
pub fn best_estimator(&self) -> Result<T> {
if self.best_params.is_none() {
return Err(Error::InvalidValue("Randomized search not fitted".into()));
}
Ok(self.base_model.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dataframe::DataFrame;
use crate::ml::models::linear::LinearRegression;
use crate::series::Series;
fn make_linear_df(n: usize) -> DataFrame {
let x: Vec<f64> = (0..n).map(|i| i as f64).collect();
let y: Vec<f64> = x.iter().map(|&v| 2.0 * v + 1.0).collect();
let mut df = DataFrame::new();
df.add_column(
"x".to_string(),
Series::new(x, Some("x".to_string())).expect("Series::new"),
)
.expect("add x");
df.add_column(
"y".to_string(),
Series::new(y, Some("y".to_string())).expect("Series::new"),
)
.expect("add y");
df
}
#[test]
fn test_cartesian_product() {
let mut grid = HyperparameterGrid::new();
grid.add_param("a", vec!["1", "2"]);
grid.add_param("b", vec!["x", "y"]);
let combos = grid.parameter_combinations();
assert_eq!(
combos.len(),
4,
"2x2 Cartesian product must yield exactly 4 combinations"
);
for combo in &combos {
assert!(combo.contains_key("a"), "combo missing key 'a'");
assert!(combo.contains_key("b"), "combo missing key 'b'");
}
}
#[test]
fn test_cartesian_empty() {
let grid = HyperparameterGrid::new();
let combos = grid.parameter_combinations();
assert_eq!(
combos.len(),
1,
"empty grid must return exactly one (empty) combination"
);
assert!(combos[0].is_empty(), "the single combination must be empty");
}
#[test]
fn test_gridsearch_cv_real() {
let df = make_linear_df(10);
let model = LinearRegression::new();
let grid = HyperparameterGrid::new(); let mut gs = GridSearchCV::new(model, grid, "r2", 2);
gs.fit(&df, "y").expect("GridSearchCV::fit should succeed");
let best_score = gs.best_score.expect("best_score must be set after fit");
assert!(
best_score > 0.0,
"best_score must be a real positive CV score, got {}",
best_score
);
assert!(gs.cv_results.is_some(), "cv_results must be set after fit");
}
}