1use std::collections::HashSet;
15
16use ordered_float::OrderedFloat;
17
18use super::generator::VariationGenerator;
19use super::search_space::SearchSpace;
20use super::snapshot::ConfigSnapshot;
21use super::types::{Variation, VariationValue};
22
23pub struct GridStep {
59 search_space: SearchSpace,
60 current_param: usize,
61 current_step: usize,
62}
63
64impl GridStep {
65 #[must_use]
76 pub fn new(search_space: SearchSpace) -> Self {
77 Self {
78 search_space,
79 current_param: 0,
80 current_step: 0,
81 }
82 }
83}
84
85impl VariationGenerator for GridStep {
86 fn next(
87 &mut self,
88 _baseline: &ConfigSnapshot,
89 visited: &HashSet<Variation>,
90 ) -> Option<Variation> {
91 while self.current_param < self.search_space.parameters.len() {
92 let range = &self.search_space.parameters[self.current_param];
93 let step = range
94 .step()
95 .unwrap_or_else(|| (range.max() - range.min()) / 20.0);
96 if step <= 0.0 {
97 self.current_param += 1;
98 self.current_step = 0;
99 continue;
100 }
101
102 #[allow(clippy::cast_precision_loss)]
103 let raw = range.min() + step * self.current_step as f64;
104
105 if raw > range.max() + f64::EPSILON {
106 self.current_param += 1;
107 self.current_step = 0;
108 continue;
109 }
110
111 self.current_step += 1;
112
113 let value = range.quantize(raw);
115
116 let variation = Variation {
117 parameter: range.kind(),
118 value: VariationValue::Float(OrderedFloat(value)),
119 };
120
121 if !visited.contains(&variation) {
122 return Some(variation);
123 }
124 }
125 None
126 }
127
128 fn name(&self) -> &'static str {
129 "grid"
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use std::collections::HashSet;
136
137 use super::super::search_space::ParameterRange;
138 use super::super::types::ParameterKind;
139 use super::*;
140
141 fn single_param_space(min: f64, max: f64, step: f64) -> SearchSpace {
142 let default = (min + max) / 2.0;
144 SearchSpace {
145 parameters: vec![
146 ParameterRange::new(ParameterKind::Temperature, min, max, Some(step), default)
147 .unwrap(),
148 ],
149 }
150 }
151
152 #[test]
153 fn grid_step_produces_values_in_range() {
154 let mut generator = GridStep::new(single_param_space(0.0, 1.0, 0.5));
155 let baseline = ConfigSnapshot::default();
156 let mut visited = HashSet::new();
157 let mut values = vec![];
158 while let Some(v) = generator.next(&baseline, &visited) {
159 visited.insert(v.clone());
160 values.push(v.value.as_f64());
161 }
162 assert_eq!(values.len(), 3, "0.0, 0.5, 1.0");
163 for v in &values {
164 assert!(*v >= 0.0 && *v <= 1.0);
165 }
166 }
167
168 #[test]
169 fn grid_step_skips_visited() {
170 let mut generator = GridStep::new(single_param_space(0.0, 1.0, 0.5));
171 let baseline = ConfigSnapshot::default();
172 let mut visited = HashSet::new();
173 visited.insert(Variation {
174 parameter: ParameterKind::Temperature,
175 value: VariationValue::Float(OrderedFloat(0.0)),
176 });
177 let first = generator.next(&baseline, &visited).unwrap();
178 assert!(
179 (first.value.as_f64() - 0.5).abs() < 1e-10,
180 "expected 0.5, got {}",
181 first.value.as_f64()
182 );
183 }
184
185 #[test]
186 fn grid_step_returns_none_when_exhausted() {
187 let mut generator = GridStep::new(single_param_space(0.0, 0.5, 1.0));
189 let baseline = ConfigSnapshot::default();
190 let mut visited = HashSet::new();
191 generator.next(&baseline, &visited).unwrap();
193 visited.insert(Variation {
194 parameter: ParameterKind::Temperature,
195 value: VariationValue::Float(OrderedFloat(0.0)),
196 });
197 assert!(generator.next(&baseline, &visited).is_none());
198 }
199
200 #[test]
201 fn grid_step_multiple_params() {
202 let space = SearchSpace {
203 parameters: vec![
204 ParameterRange::new(ParameterKind::Temperature, 0.0, 0.5, Some(0.5), 0.0).unwrap(),
205 ParameterRange::new(ParameterKind::TopP, 0.5, 1.0, Some(0.5), 0.5).unwrap(),
206 ],
207 };
208 let mut generator = GridStep::new(space);
209 let baseline = ConfigSnapshot::default();
210 let mut visited = HashSet::new();
211 let mut results = vec![];
212 while let Some(v) = generator.next(&baseline, &visited) {
213 visited.insert(v.clone());
214 results.push(v);
215 }
216 assert_eq!(results.len(), 4);
218 let temp_count = results
219 .iter()
220 .filter(|v| v.parameter == ParameterKind::Temperature)
221 .count();
222 let top_p_count = results
223 .iter()
224 .filter(|v| v.parameter == ParameterKind::TopP)
225 .count();
226 assert_eq!(temp_count, 2);
227 assert_eq!(top_p_count, 2);
228 }
229
230 #[test]
231 fn grid_step_quantizes_to_avoid_fp_drift() {
232 let mut generator = GridStep::new(single_param_space(0.0, 1.0, 0.1));
235 let baseline = ConfigSnapshot::default();
236 let mut visited = HashSet::new();
237 let mut values = vec![];
238 while let Some(v) = generator.next(&baseline, &visited) {
239 visited.insert(v.clone());
240 values.push(v.value.as_f64());
241 }
242 for v in &values {
244 let rounded = (v * 10.0).round() / 10.0;
245 assert!(
246 (v - rounded).abs() < 1e-10,
247 "value {v} is not a clean multiple of 0.1"
248 );
249 }
250 }
251
252 #[test]
253 fn grid_step_empty_space_returns_none() {
254 let mut generator = GridStep::new(SearchSpace { parameters: vec![] });
255 let baseline = ConfigSnapshot::default();
256 let visited = HashSet::new();
257 assert!(generator.next(&baseline, &visited).is_none());
258 }
259
260 #[test]
261 fn grid_step_none_step_uses_fallback() {
262 let space = SearchSpace {
264 parameters: vec![
265 ParameterRange::new(ParameterKind::Temperature, 0.0, 1.0, None, 0.5).unwrap(),
266 ],
267 };
268 let mut generator = GridStep::new(space);
269 let baseline = ConfigSnapshot::default();
270 let mut visited = HashSet::new();
271 let mut count = 0;
272 while let Some(v) = generator.next(&baseline, &visited) {
273 visited.insert(v.clone());
274 count += 1;
275 }
276 assert_eq!(
278 count, 21,
279 "expected 21 steps for step=None with DEFAULT_STEPS=20"
280 );
281 }
282
283 #[test]
284 fn grid_step_name() {
285 let generator = GridStep::new(SearchSpace::default());
286 assert_eq!(generator.name(), "grid");
287 }
288}