Skip to main content

entrenar/optim/hpo/
grid.rs

1//! Grid search for hyperparameter optimization
2
3use std::collections::HashMap;
4
5use super::types::{HyperparameterSpace, ParameterDomain, ParameterValue};
6
7/// Grid search generator
8#[derive(Debug, Clone)]
9pub struct GridSearch {
10    space: HyperparameterSpace,
11    /// Grid points per continuous parameter
12    pub(crate) n_points: usize,
13}
14
15/// Generate grid values for a single parameter domain.
16fn domain_grid_values(domain: &ParameterDomain, n_points: usize) -> Vec<ParameterValue> {
17    match domain {
18        ParameterDomain::Continuous { low, high, log_scale } => {
19            let divisor = (n_points - 1) as f64;
20            if *log_scale {
21                let log_low = low.max(f64::MIN_POSITIVE).ln();
22                let log_high = high.max(f64::MIN_POSITIVE).ln();
23                (0..n_points)
24                    .map(|i| {
25                        let t = i as f64 / divisor;
26                        ParameterValue::Float((log_low + t * (log_high - log_low)).exp())
27                    })
28                    .collect()
29            } else {
30                (0..n_points)
31                    .map(|i| {
32                        let t = i as f64 / divisor;
33                        ParameterValue::Float(low + t * (high - low))
34                    })
35                    .collect()
36            }
37        }
38        ParameterDomain::Discrete { low, high } => {
39            (*low..=*high).map(ParameterValue::Int).collect()
40        }
41        ParameterDomain::Categorical { choices } => {
42            choices.iter().map(|c| ParameterValue::Categorical(c.clone())).collect()
43        }
44    }
45}
46
47impl GridSearch {
48    /// Create new grid search
49    pub fn new(space: HyperparameterSpace, n_points: usize) -> Self {
50        Self { space, n_points: n_points.max(2) }
51    }
52
53    /// Generate all grid configurations
54    pub fn configurations(&self) -> Vec<HashMap<String, ParameterValue>> {
55        let param_values: Vec<(String, Vec<ParameterValue>)> = self
56            .space
57            .iter()
58            .map(|(name, domain)| (name.clone(), domain_grid_values(domain, self.n_points)))
59            .collect();
60
61        // Generate cartesian product
62        Self::cartesian_product(&param_values)
63    }
64
65    fn cartesian_product(
66        param_values: &[(String, Vec<ParameterValue>)],
67    ) -> Vec<HashMap<String, ParameterValue>> {
68        if param_values.is_empty() {
69            return vec![HashMap::new()];
70        }
71
72        let (name, values) = &param_values[0];
73        let rest = param_values.get(1..).unwrap_or_default();
74        let rest_configs = Self::cartesian_product(rest);
75
76        values
77            .iter()
78            .flat_map(|v| {
79                rest_configs.iter().map(move |config| {
80                    let mut new_config = config.clone();
81                    new_config.insert(name.clone(), v.clone());
82                    new_config
83                })
84            })
85            .collect()
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92
93    #[test]
94    fn test_grid_search_new() {
95        let space = HyperparameterSpace::new();
96        let grid = GridSearch::new(space, 5);
97        assert_eq!(grid.n_points, 5);
98    }
99
100    #[test]
101    fn test_grid_search_empty_space() {
102        let space = HyperparameterSpace::new();
103        let grid = GridSearch::new(space, 5);
104        let configs = grid.configurations();
105        assert_eq!(configs.len(), 1); // One empty config
106    }
107
108    #[test]
109    fn test_grid_search_single_param() {
110        let mut space = HyperparameterSpace::new();
111        space.add("lr", ParameterDomain::Continuous { low: 0.0, high: 1.0, log_scale: false });
112
113        let grid = GridSearch::new(space, 5);
114        let configs = grid.configurations();
115        assert_eq!(configs.len(), 5);
116
117        // Check values are evenly spaced
118        let values: Vec<f64> = configs
119            .iter()
120            .map(|c| c.get("lr").expect("key should exist").as_float().expect("key should exist"))
121            .collect();
122        assert!((values[0] - 0.0).abs() < 1e-10);
123        assert!((values[4] - 1.0).abs() < 1e-10);
124    }
125
126    #[test]
127    fn test_grid_search_multiple_params() {
128        let mut space = HyperparameterSpace::new();
129        space.add("lr", ParameterDomain::Continuous { low: 0.0, high: 1.0, log_scale: false });
130        space.add(
131            "act",
132            ParameterDomain::Categorical { choices: vec!["relu".to_string(), "gelu".to_string()] },
133        );
134
135        let grid = GridSearch::new(space, 3);
136        let configs = grid.configurations();
137        // 3 lr values * 2 activation functions = 6
138        assert_eq!(configs.len(), 6);
139    }
140
141    #[test]
142    fn test_grid_search_discrete() {
143        let mut space = HyperparameterSpace::new();
144        space.add("batch_size", ParameterDomain::Discrete { low: 8, high: 10 });
145
146        let grid = GridSearch::new(space, 5);
147        let configs = grid.configurations();
148        // Discrete [8,9,10] = 3 values
149        assert_eq!(configs.len(), 3);
150    }
151
152    #[test]
153    fn test_grid_search_log_scale() {
154        let mut space = HyperparameterSpace::new();
155        space.add("lr", ParameterDomain::Continuous { low: 1e-4, high: 1e-1, log_scale: true });
156
157        let grid = GridSearch::new(space, 4);
158        let configs = grid.configurations();
159
160        let values: Vec<f64> = configs
161            .iter()
162            .map(|c| c.get("lr").expect("key should exist").as_float().expect("key should exist"))
163            .collect();
164
165        // Log scale should give approximately: 1e-4, 1e-3, 1e-2, 1e-1
166        assert!(values[0] < 1e-3);
167        assert!(values[3] > 1e-2);
168    }
169
170    #[test]
171    fn test_grid_search_min_n_points() {
172        let space = HyperparameterSpace::new();
173        let grid = GridSearch::new(space, 1); // Should be clamped to 2
174        assert_eq!(grid.n_points, 2);
175    }
176}
177
178#[cfg(test)]
179mod property_tests {
180    use super::*;
181    use proptest::prelude::*;
182
183    proptest! {
184        #![proptest_config(ProptestConfig::with_cases(200))]
185
186        #[test]
187        fn prop_grid_search_size(n_points in 2usize..10) {
188            let mut space = HyperparameterSpace::new();
189            space.add("x", ParameterDomain::Continuous {
190                low: 0.0,
191                high: 1.0,
192                log_scale: false,
193            });
194
195            let grid = GridSearch::new(space, n_points);
196            let configs = grid.configurations();
197            prop_assert_eq!(configs.len(), n_points);
198        }
199    }
200}