use std::collections::HashMap;
use super::types::{HyperparameterSpace, ParameterDomain, ParameterValue};
#[derive(Debug, Clone)]
pub struct GridSearch {
space: HyperparameterSpace,
pub(crate) n_points: usize,
}
fn domain_grid_values(domain: &ParameterDomain, n_points: usize) -> Vec<ParameterValue> {
match domain {
ParameterDomain::Continuous { low, high, log_scale } => {
let divisor = (n_points - 1) as f64;
if *log_scale {
let log_low = low.max(f64::MIN_POSITIVE).ln();
let log_high = high.max(f64::MIN_POSITIVE).ln();
(0..n_points)
.map(|i| {
let t = i as f64 / divisor;
ParameterValue::Float((log_low + t * (log_high - log_low)).exp())
})
.collect()
} else {
(0..n_points)
.map(|i| {
let t = i as f64 / divisor;
ParameterValue::Float(low + t * (high - low))
})
.collect()
}
}
ParameterDomain::Discrete { low, high } => {
(*low..=*high).map(ParameterValue::Int).collect()
}
ParameterDomain::Categorical { choices } => {
choices.iter().map(|c| ParameterValue::Categorical(c.clone())).collect()
}
}
}
impl GridSearch {
pub fn new(space: HyperparameterSpace, n_points: usize) -> Self {
Self { space, n_points: n_points.max(2) }
}
pub fn configurations(&self) -> Vec<HashMap<String, ParameterValue>> {
let param_values: Vec<(String, Vec<ParameterValue>)> = self
.space
.iter()
.map(|(name, domain)| (name.clone(), domain_grid_values(domain, self.n_points)))
.collect();
Self::cartesian_product(¶m_values)
}
fn cartesian_product(
param_values: &[(String, Vec<ParameterValue>)],
) -> Vec<HashMap<String, ParameterValue>> {
if param_values.is_empty() {
return vec![HashMap::new()];
}
let (name, values) = ¶m_values[0];
let rest = param_values.get(1..).unwrap_or_default();
let rest_configs = Self::cartesian_product(rest);
values
.iter()
.flat_map(|v| {
rest_configs.iter().map(move |config| {
let mut new_config = config.clone();
new_config.insert(name.clone(), v.clone());
new_config
})
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_grid_search_new() {
let space = HyperparameterSpace::new();
let grid = GridSearch::new(space, 5);
assert_eq!(grid.n_points, 5);
}
#[test]
fn test_grid_search_empty_space() {
let space = HyperparameterSpace::new();
let grid = GridSearch::new(space, 5);
let configs = grid.configurations();
assert_eq!(configs.len(), 1); }
#[test]
fn test_grid_search_single_param() {
let mut space = HyperparameterSpace::new();
space.add("lr", ParameterDomain::Continuous { low: 0.0, high: 1.0, log_scale: false });
let grid = GridSearch::new(space, 5);
let configs = grid.configurations();
assert_eq!(configs.len(), 5);
let values: Vec<f64> = configs
.iter()
.map(|c| c.get("lr").expect("key should exist").as_float().expect("key should exist"))
.collect();
assert!((values[0] - 0.0).abs() < 1e-10);
assert!((values[4] - 1.0).abs() < 1e-10);
}
#[test]
fn test_grid_search_multiple_params() {
let mut space = HyperparameterSpace::new();
space.add("lr", ParameterDomain::Continuous { low: 0.0, high: 1.0, log_scale: false });
space.add(
"act",
ParameterDomain::Categorical { choices: vec!["relu".to_string(), "gelu".to_string()] },
);
let grid = GridSearch::new(space, 3);
let configs = grid.configurations();
assert_eq!(configs.len(), 6);
}
#[test]
fn test_grid_search_discrete() {
let mut space = HyperparameterSpace::new();
space.add("batch_size", ParameterDomain::Discrete { low: 8, high: 10 });
let grid = GridSearch::new(space, 5);
let configs = grid.configurations();
assert_eq!(configs.len(), 3);
}
#[test]
fn test_grid_search_log_scale() {
let mut space = HyperparameterSpace::new();
space.add("lr", ParameterDomain::Continuous { low: 1e-4, high: 1e-1, log_scale: true });
let grid = GridSearch::new(space, 4);
let configs = grid.configurations();
let values: Vec<f64> = configs
.iter()
.map(|c| c.get("lr").expect("key should exist").as_float().expect("key should exist"))
.collect();
assert!(values[0] < 1e-3);
assert!(values[3] > 1e-2);
}
#[test]
fn test_grid_search_min_n_points() {
let space = HyperparameterSpace::new();
let grid = GridSearch::new(space, 1); assert_eq!(grid.n_points, 2);
}
}
#[cfg(test)]
mod property_tests {
use super::*;
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(200))]
#[test]
fn prop_grid_search_size(n_points in 2usize..10) {
let mut space = HyperparameterSpace::new();
space.add("x", ParameterDomain::Continuous {
low: 0.0,
high: 1.0,
log_scale: false,
});
let grid = GridSearch::new(space, n_points);
let configs = grid.configurations();
prop_assert_eq!(configs.len(), n_points);
}
}
}