entrenar/optim/hpo/
grid.rs1use std::collections::HashMap;
4
5use super::types::{HyperparameterSpace, ParameterDomain, ParameterValue};
6
7#[derive(Debug, Clone)]
9pub struct GridSearch {
10 space: HyperparameterSpace,
11 pub(crate) n_points: usize,
13}
14
15fn domain_grid_values(domain: &ParameterDomain, n_points: usize) -> Vec<ParameterValue> {
17 match domain {
18 ParameterDomain::Continuous { low, high, log_scale } => {
19 let divisor = (n_points - 1) as f64;
20 if *log_scale {
21 let log_low = low.max(f64::MIN_POSITIVE).ln();
22 let log_high = high.max(f64::MIN_POSITIVE).ln();
23 (0..n_points)
24 .map(|i| {
25 let t = i as f64 / divisor;
26 ParameterValue::Float((log_low + t * (log_high - log_low)).exp())
27 })
28 .collect()
29 } else {
30 (0..n_points)
31 .map(|i| {
32 let t = i as f64 / divisor;
33 ParameterValue::Float(low + t * (high - low))
34 })
35 .collect()
36 }
37 }
38 ParameterDomain::Discrete { low, high } => {
39 (*low..=*high).map(ParameterValue::Int).collect()
40 }
41 ParameterDomain::Categorical { choices } => {
42 choices.iter().map(|c| ParameterValue::Categorical(c.clone())).collect()
43 }
44 }
45}
46
47impl GridSearch {
48 pub fn new(space: HyperparameterSpace, n_points: usize) -> Self {
50 Self { space, n_points: n_points.max(2) }
51 }
52
53 pub fn configurations(&self) -> Vec<HashMap<String, ParameterValue>> {
55 let param_values: Vec<(String, Vec<ParameterValue>)> = self
56 .space
57 .iter()
58 .map(|(name, domain)| (name.clone(), domain_grid_values(domain, self.n_points)))
59 .collect();
60
61 Self::cartesian_product(¶m_values)
63 }
64
65 fn cartesian_product(
66 param_values: &[(String, Vec<ParameterValue>)],
67 ) -> Vec<HashMap<String, ParameterValue>> {
68 if param_values.is_empty() {
69 return vec![HashMap::new()];
70 }
71
72 let (name, values) = ¶m_values[0];
73 let rest = param_values.get(1..).unwrap_or_default();
74 let rest_configs = Self::cartesian_product(rest);
75
76 values
77 .iter()
78 .flat_map(|v| {
79 rest_configs.iter().map(move |config| {
80 let mut new_config = config.clone();
81 new_config.insert(name.clone(), v.clone());
82 new_config
83 })
84 })
85 .collect()
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92
93 #[test]
94 fn test_grid_search_new() {
95 let space = HyperparameterSpace::new();
96 let grid = GridSearch::new(space, 5);
97 assert_eq!(grid.n_points, 5);
98 }
99
100 #[test]
101 fn test_grid_search_empty_space() {
102 let space = HyperparameterSpace::new();
103 let grid = GridSearch::new(space, 5);
104 let configs = grid.configurations();
105 assert_eq!(configs.len(), 1); }
107
108 #[test]
109 fn test_grid_search_single_param() {
110 let mut space = HyperparameterSpace::new();
111 space.add("lr", ParameterDomain::Continuous { low: 0.0, high: 1.0, log_scale: false });
112
113 let grid = GridSearch::new(space, 5);
114 let configs = grid.configurations();
115 assert_eq!(configs.len(), 5);
116
117 let values: Vec<f64> = configs
119 .iter()
120 .map(|c| c.get("lr").expect("key should exist").as_float().expect("key should exist"))
121 .collect();
122 assert!((values[0] - 0.0).abs() < 1e-10);
123 assert!((values[4] - 1.0).abs() < 1e-10);
124 }
125
126 #[test]
127 fn test_grid_search_multiple_params() {
128 let mut space = HyperparameterSpace::new();
129 space.add("lr", ParameterDomain::Continuous { low: 0.0, high: 1.0, log_scale: false });
130 space.add(
131 "act",
132 ParameterDomain::Categorical { choices: vec!["relu".to_string(), "gelu".to_string()] },
133 );
134
135 let grid = GridSearch::new(space, 3);
136 let configs = grid.configurations();
137 assert_eq!(configs.len(), 6);
139 }
140
141 #[test]
142 fn test_grid_search_discrete() {
143 let mut space = HyperparameterSpace::new();
144 space.add("batch_size", ParameterDomain::Discrete { low: 8, high: 10 });
145
146 let grid = GridSearch::new(space, 5);
147 let configs = grid.configurations();
148 assert_eq!(configs.len(), 3);
150 }
151
152 #[test]
153 fn test_grid_search_log_scale() {
154 let mut space = HyperparameterSpace::new();
155 space.add("lr", ParameterDomain::Continuous { low: 1e-4, high: 1e-1, log_scale: true });
156
157 let grid = GridSearch::new(space, 4);
158 let configs = grid.configurations();
159
160 let values: Vec<f64> = configs
161 .iter()
162 .map(|c| c.get("lr").expect("key should exist").as_float().expect("key should exist"))
163 .collect();
164
165 assert!(values[0] < 1e-3);
167 assert!(values[3] > 1e-2);
168 }
169
170 #[test]
171 fn test_grid_search_min_n_points() {
172 let space = HyperparameterSpace::new();
173 let grid = GridSearch::new(space, 1); assert_eq!(grid.n_points, 2);
175 }
176}
177
178#[cfg(test)]
179mod property_tests {
180 use super::*;
181 use proptest::prelude::*;
182
183 proptest! {
184 #![proptest_config(ProptestConfig::with_cases(200))]
185
186 #[test]
187 fn prop_grid_search_size(n_points in 2usize..10) {
188 let mut space = HyperparameterSpace::new();
189 space.add("x", ParameterDomain::Continuous {
190 low: 0.0,
191 high: 1.0,
192 log_scale: false,
193 });
194
195 let grid = GridSearch::new(space, n_points);
196 let configs = grid.configurations();
197 prop_assert_eq!(configs.len(), n_points);
198 }
199 }
200}