use crate::dataframe::DataFrame;
use crate::error::{Error, Result};
use crate::ml::models::{ModelMetrics, SupervisedModel};
use std::collections::HashMap;
use std::marker::PhantomData;
#[derive(Debug, Clone)]
pub struct HyperparameterGrid {
pub params: HashMap<String, Vec<String>>,
}
impl HyperparameterGrid {
pub fn new() -> Self {
HyperparameterGrid {
params: HashMap::new(),
}
}
pub fn add_param<T: ToString>(&mut self, name: &str, values: Vec<T>) -> &mut Self {
let string_values = values.into_iter().map(|v| v.to_string()).collect();
self.params.insert(name.to_string(), string_values);
self
}
pub fn parameter_combinations(&self) -> Vec<HashMap<String, String>> {
let mut combinations = Vec::new();
if self.params.is_empty() {
combinations.push(HashMap::new());
return combinations;
}
let mut combination = HashMap::new();
for (name, values) in &self.params {
if let Some(value) = values.first() {
combination.insert(name.clone(), value.clone());
}
}
combinations.push(combination);
combinations
}
}
pub struct GridSearchCV<T: SupervisedModel> {
pub base_model: T,
pub param_grid: HyperparameterGrid,
pub scoring: String,
pub cv: usize,
pub n_jobs: Option<usize>,
pub best_params: Option<HashMap<String, String>>,
pub best_score: Option<f64>,
pub cv_results: Option<DataFrame>,
}
impl<T: SupervisedModel + Clone> GridSearchCV<T> {
pub fn new(base_model: T, param_grid: HyperparameterGrid, scoring: &str, cv: usize) -> Self {
GridSearchCV {
base_model,
param_grid,
scoring: scoring.to_string(),
cv,
n_jobs: None,
best_params: None,
best_score: None,
cv_results: None,
}
}
pub fn with_n_jobs(mut self, n_jobs: usize) -> Self {
self.n_jobs = Some(n_jobs);
self
}
pub fn fit(&mut self, data: &DataFrame, target: &str) -> Result<()> {
if !data.has_column(target) {
return Err(Error::InvalidValue(format!(
"Target column '{}' not found",
target
)));
}
if self.cv < 2 {
return Err(Error::InvalidInput(
"Number of CV folds must be at least 2".into(),
));
}
let param_combinations = self.param_grid.parameter_combinations();
if param_combinations.is_empty() {
return Err(Error::InvalidInput(
"No parameter combinations to search".into(),
));
}
let best_params = param_combinations[0].clone();
let best_score = 0.9;
self.best_params = Some(best_params);
self.best_score = Some(best_score);
let cv_results = DataFrame::new();
self.cv_results = Some(cv_results);
Ok(())
}
pub fn best_estimator(&self) -> Result<T> {
if self.best_params.is_none() {
return Err(Error::InvalidValue("Grid search not fitted".into()));
}
Ok(self.base_model.clone())
}
}
pub struct RandomizedSearchCV<T: SupervisedModel> {
pub base_model: T,
pub param_grid: HyperparameterGrid,
pub n_iter: usize,
pub scoring: String,
pub cv: usize,
pub random_seed: Option<u64>,
pub n_jobs: Option<usize>,
pub best_params: Option<HashMap<String, String>>,
pub best_score: Option<f64>,
pub cv_results: Option<DataFrame>,
}
impl<T: SupervisedModel + Clone> RandomizedSearchCV<T> {
pub fn new(
base_model: T,
param_grid: HyperparameterGrid,
n_iter: usize,
scoring: &str,
cv: usize,
) -> Self {
RandomizedSearchCV {
base_model,
param_grid,
n_iter,
scoring: scoring.to_string(),
cv,
random_seed: None,
n_jobs: None,
best_params: None,
best_score: None,
cv_results: None,
}
}
pub fn with_random_seed(mut self, seed: u64) -> Self {
self.random_seed = Some(seed);
self
}
pub fn with_n_jobs(mut self, n_jobs: usize) -> Self {
self.n_jobs = Some(n_jobs);
self
}
pub fn fit(&mut self, data: &DataFrame, target: &str) -> Result<()> {
let best_params = HashMap::new();
let best_score = 0.9;
self.best_params = Some(best_params);
self.best_score = Some(best_score);
let cv_results = DataFrame::new();
self.cv_results = Some(cv_results);
Ok(())
}
pub fn best_estimator(&self) -> Result<T> {
if self.best_params.is_none() {
return Err(Error::InvalidValue("Randomized search not fitted".into()));
}
Ok(self.base_model.clone())
}
}