Skip to main content

zeph_experiments/
grid.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Systematic grid sweep strategy for parameter variation.
5//!
6//! [`GridStep`] iterates each parameter through its discrete steps in order,
7//! skipping variations that have already been visited. This gives exhaustive
8//! coverage of the search space and is well-suited as a first-pass exploration
9//! before switching to a [`Neighborhood`] or [`Random`] strategy.
10//!
11//! [`Neighborhood`]: crate::Neighborhood
12//! [`Random`]: crate::Random
13
14use std::collections::HashSet;
15
16use ordered_float::OrderedFloat;
17
18use super::generator::VariationGenerator;
19use super::search_space::SearchSpace;
20use super::snapshot::ConfigSnapshot;
21use super::types::{Variation, VariationValue};
22
23/// Systematic grid sweep: iterate each parameter through its discrete steps, skip visited.
24///
25/// Parameters are swept one at a time. For each parameter, all grid points from
26/// `min` to `max` (with the configured `step`) are enumerated in order. Already-visited
27/// variations are skipped. When all steps for a parameter are exhausted, the next
28/// parameter is tried. Returns `None` when the full grid has been visited.
29///
30/// When a parameter has no discrete `step`, [`GridStep`] falls back to
31/// `(max - min) / 20` as the step size.
32///
33/// # Examples
34///
35/// ```rust
36/// use std::collections::HashSet;
37/// use zeph_experiments::{
38///     ConfigSnapshot, GridStep, ParameterKind, ParameterRange, SearchSpace, VariationGenerator,
39/// };
40///
41/// let space = SearchSpace {
42///     parameters: vec![
43///         ParameterRange::new(ParameterKind::Temperature, 0.0, 1.0, Some(0.5), 0.5).unwrap(),
44///     ],
45/// };
46/// let mut generator = GridStep::new(space);
47/// let baseline = ConfigSnapshot::default();
48/// let mut visited = HashSet::new();
49///
50/// // Produces 0.0, 0.5, 1.0 in order.
51/// let mut count = 0;
52/// while let Some(v) = generator.next(&baseline, &visited) {
53///     visited.insert(v);
54///     count += 1;
55/// }
56/// assert_eq!(count, 3);
57/// ```
58pub struct GridStep {
59    search_space: SearchSpace,
60    current_param: usize,
61    current_step: usize,
62}
63
64impl GridStep {
65    /// Create a new [`GridStep`] generator starting at the first grid point.
66    ///
67    /// # Examples
68    ///
69    /// ```rust
70    /// use zeph_experiments::{GridStep, SearchSpace, VariationGenerator};
71    ///
72    /// let generator = GridStep::new(SearchSpace::default());
73    /// assert_eq!(generator.name(), "grid");
74    /// ```
75    #[must_use]
76    pub fn new(search_space: SearchSpace) -> Self {
77        Self {
78            search_space,
79            current_param: 0,
80            current_step: 0,
81        }
82    }
83}
84
85impl VariationGenerator for GridStep {
86    fn next(
87        &mut self,
88        _baseline: &ConfigSnapshot,
89        visited: &HashSet<Variation>,
90    ) -> Option<Variation> {
91        while self.current_param < self.search_space.parameters.len() {
92            let range = &self.search_space.parameters[self.current_param];
93            let step = range
94                .step()
95                .unwrap_or_else(|| (range.max() - range.min()) / 20.0);
96            if step <= 0.0 {
97                self.current_param += 1;
98                self.current_step = 0;
99                continue;
100            }
101
102            #[allow(clippy::cast_precision_loss)]
103            let raw = range.min() + step * self.current_step as f64;
104
105            if raw > range.max() + f64::EPSILON {
106                self.current_param += 1;
107                self.current_step = 0;
108                continue;
109            }
110
111            self.current_step += 1;
112
113            // Quantize to avoid floating-point accumulation before deduplication.
114            let value = range.quantize(raw);
115
116            let variation = Variation {
117                parameter: range.kind(),
118                value: VariationValue::Float(OrderedFloat(value)),
119            };
120
121            if !visited.contains(&variation) {
122                return Some(variation);
123            }
124        }
125        None
126    }
127
128    fn name(&self) -> &'static str {
129        "grid"
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use std::collections::HashSet;
136
137    use super::super::search_space::ParameterRange;
138    use super::super::types::ParameterKind;
139    use super::*;
140
141    fn single_param_space(min: f64, max: f64, step: f64) -> SearchSpace {
142        // default = midpoint so it satisfies min <= default <= max
143        let default = (min + max) / 2.0;
144        SearchSpace {
145            parameters: vec![
146                ParameterRange::new(ParameterKind::Temperature, min, max, Some(step), default)
147                    .unwrap(),
148            ],
149        }
150    }
151
152    #[test]
153    fn grid_step_produces_values_in_range() {
154        let mut generator = GridStep::new(single_param_space(0.0, 1.0, 0.5));
155        let baseline = ConfigSnapshot::default();
156        let mut visited = HashSet::new();
157        let mut values = vec![];
158        while let Some(v) = generator.next(&baseline, &visited) {
159            visited.insert(v.clone());
160            values.push(v.value.as_f64());
161        }
162        assert_eq!(values.len(), 3, "0.0, 0.5, 1.0");
163        for v in &values {
164            assert!(*v >= 0.0 && *v <= 1.0);
165        }
166    }
167
168    #[test]
169    fn grid_step_skips_visited() {
170        let mut generator = GridStep::new(single_param_space(0.0, 1.0, 0.5));
171        let baseline = ConfigSnapshot::default();
172        let mut visited = HashSet::new();
173        visited.insert(Variation {
174            parameter: ParameterKind::Temperature,
175            value: VariationValue::Float(OrderedFloat(0.0)),
176        });
177        let first = generator.next(&baseline, &visited).unwrap();
178        assert!(
179            (first.value.as_f64() - 0.5).abs() < 1e-10,
180            "expected 0.5, got {}",
181            first.value.as_f64()
182        );
183    }
184
185    #[test]
186    fn grid_step_returns_none_when_exhausted() {
187        // step=1.0 over [0.0, 0.5]: only one grid point (0.0), then exhausted.
188        let mut generator = GridStep::new(single_param_space(0.0, 0.5, 1.0));
189        let baseline = ConfigSnapshot::default();
190        let mut visited = HashSet::new();
191        // Only one point: 0.0
192        generator.next(&baseline, &visited).unwrap();
193        visited.insert(Variation {
194            parameter: ParameterKind::Temperature,
195            value: VariationValue::Float(OrderedFloat(0.0)),
196        });
197        assert!(generator.next(&baseline, &visited).is_none());
198    }
199
200    #[test]
201    fn grid_step_multiple_params() {
202        let space = SearchSpace {
203            parameters: vec![
204                ParameterRange::new(ParameterKind::Temperature, 0.0, 0.5, Some(0.5), 0.0).unwrap(),
205                ParameterRange::new(ParameterKind::TopP, 0.5, 1.0, Some(0.5), 0.5).unwrap(),
206            ],
207        };
208        let mut generator = GridStep::new(space);
209        let baseline = ConfigSnapshot::default();
210        let mut visited = HashSet::new();
211        let mut results = vec![];
212        while let Some(v) = generator.next(&baseline, &visited) {
213            visited.insert(v.clone());
214            results.push(v);
215        }
216        // Temperature: 0.0, 0.5 — TopP: 0.5, 1.0
217        assert_eq!(results.len(), 4);
218        let temp_count = results
219            .iter()
220            .filter(|v| v.parameter == ParameterKind::Temperature)
221            .count();
222        let top_p_count = results
223            .iter()
224            .filter(|v| v.parameter == ParameterKind::TopP)
225            .count();
226        assert_eq!(temp_count, 2);
227        assert_eq!(top_p_count, 2);
228    }
229
230    #[test]
231    fn grid_step_quantizes_to_avoid_fp_drift() {
232        // 0.1 * 7 via accumulation = 0.7000000000000001
233        // quantize must snap to 0.7
234        let mut generator = GridStep::new(single_param_space(0.0, 1.0, 0.1));
235        let baseline = ConfigSnapshot::default();
236        let mut visited = HashSet::new();
237        let mut values = vec![];
238        while let Some(v) = generator.next(&baseline, &visited) {
239            visited.insert(v.clone());
240            values.push(v.value.as_f64());
241        }
242        // All values should be clean multiples of 0.1
243        for v in &values {
244            let rounded = (v * 10.0).round() / 10.0;
245            assert!(
246                (v - rounded).abs() < 1e-10,
247                "value {v} is not a clean multiple of 0.1"
248            );
249        }
250    }
251
252    #[test]
253    fn grid_step_empty_space_returns_none() {
254        let mut generator = GridStep::new(SearchSpace { parameters: vec![] });
255        let baseline = ConfigSnapshot::default();
256        let visited = HashSet::new();
257        assert!(generator.next(&baseline, &visited).is_none());
258    }
259
260    #[test]
261    fn grid_step_none_step_uses_fallback() {
262        // Parameter with step=None — GridStep falls back to (max-min)/20.0 as step size.
263        let space = SearchSpace {
264            parameters: vec![
265                ParameterRange::new(ParameterKind::Temperature, 0.0, 1.0, None, 0.5).unwrap(),
266            ],
267        };
268        let mut generator = GridStep::new(space);
269        let baseline = ConfigSnapshot::default();
270        let mut visited = HashSet::new();
271        let mut count = 0;
272        while let Some(v) = generator.next(&baseline, &visited) {
273            visited.insert(v.clone());
274            count += 1;
275        }
276        // With step = 1.0/20.0, there should be 21 steps (0.0, 0.05, ..., 1.0)
277        assert_eq!(
278            count, 21,
279            "expected 21 steps for step=None with DEFAULT_STEPS=20"
280        );
281    }
282
283    #[test]
284    fn grid_step_name() {
285        let generator = GridStep::new(SearchSpace::default());
286        assert_eq!(generator.name(), "grid");
287    }
288}