1use std::collections::HashSet;
7
8use ordered_float::OrderedFloat;
9use rand::Rng as _;
10use rand::SeedableRng as _;
11use rand::rngs::SmallRng;
12
13use super::error::EvalError;
14use super::generator::VariationGenerator;
15use super::search_space::SearchSpace;
16use super::snapshot::ConfigSnapshot;
17use super::types::{Variation, VariationValue};
18
19const MAX_RETRIES: usize = 1000;
21
22const DEFAULT_STEPS: f64 = 20.0;
27
28pub struct Neighborhood {
37 search_space: SearchSpace,
38 radius: f64,
39 rng: SmallRng,
40}
41
42impl Neighborhood {
43 pub fn new(search_space: SearchSpace, radius: f64, seed: u64) -> Result<Self, EvalError> {
49 if !radius.is_finite() || radius <= 0.0 {
50 return Err(EvalError::InvalidRadius { radius });
51 }
52 Ok(Self {
53 search_space,
54 radius,
55 rng: SmallRng::seed_from_u64(seed),
56 })
57 }
58}
59
60impl VariationGenerator for Neighborhood {
61 fn next(
62 &mut self,
63 baseline: &ConfigSnapshot,
64 visited: &HashSet<Variation>,
65 ) -> Option<Variation> {
66 if self.search_space.parameters.is_empty() {
67 return None;
68 }
69 for _ in 0..MAX_RETRIES {
70 let idx = self.rng.gen_range(0..self.search_space.parameters.len());
71 let range = &self.search_space.parameters[idx];
72 let current = baseline.get(range.kind);
73 let step = range
75 .step
76 .unwrap_or_else(|| (range.max - range.min) / DEFAULT_STEPS);
77 let delta = self.rng.gen_range(-self.radius..=self.radius) * step;
78 if delta.abs() < f64::EPSILON {
80 continue;
81 }
82 let raw = current + delta;
83 let value = range.quantize(range.clamp(raw));
84 if (value - current).abs() < f64::EPSILON {
86 continue;
87 }
88 let variation = Variation {
89 parameter: range.kind,
90 value: VariationValue::Float(OrderedFloat(value)),
91 };
92 if !visited.contains(&variation) {
93 return Some(variation);
94 }
95 }
96 None
97 }
98
99 fn name(&self) -> &'static str {
100 "neighborhood"
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 #![allow(
107 clippy::collapsible_if,
108 clippy::field_reassign_with_default,
109 clippy::manual_midpoint,
110 clippy::manual_range_contains
111 )]
112
113 use std::collections::HashSet;
114
115 use super::super::search_space::ParameterRange;
116 use super::super::types::ParameterKind;
117 use super::*;
118
119 fn make_space(kind: ParameterKind, min: f64, max: f64, step: f64) -> SearchSpace {
120 SearchSpace {
121 parameters: vec![ParameterRange {
122 kind,
123 min,
124 max,
125 step: Some(step),
126 default: f64::midpoint(min, max),
127 }],
128 }
129 }
130
131 #[test]
132 fn neighborhood_produces_values_in_range() {
133 let space = make_space(ParameterKind::Temperature, 0.0, 2.0, 0.1);
134 let mut generator = Neighborhood::new(space, 1.0, 42).unwrap();
135 let baseline = ConfigSnapshot::default();
136 let visited = HashSet::new();
137 for _ in 0..20 {
138 if let Some(v) = generator.next(&baseline, &visited) {
139 let val = v.value.as_f64();
140 assert!((0.0..=2.0).contains(&val), "out of range: {val}");
141 }
142 }
143 }
144
145 #[test]
146 fn neighborhood_is_deterministic_with_same_seed() {
147 let space = SearchSpace::default();
148 let baseline = ConfigSnapshot::default();
149 let visited = HashSet::new();
150 let mut gen1 = Neighborhood::new(space.clone(), 1.0, 99).unwrap();
151 let mut gen2 = Neighborhood::new(space, 1.0, 99).unwrap();
152 let v1 = gen1.next(&baseline, &visited);
153 let v2 = gen2.next(&baseline, &visited);
154 assert_eq!(v1, v2, "same seed must produce same first variation");
155 }
156
157 #[test]
158 fn neighborhood_skips_visited() {
159 let space = make_space(ParameterKind::Temperature, 0.5, 0.5, 0.1);
161 let mut generator = Neighborhood::new(space, 1.0, 0).unwrap();
162 let baseline = ConfigSnapshot::default();
163 let mut visited = HashSet::new();
164 visited.insert(Variation {
165 parameter: ParameterKind::Temperature,
166 value: VariationValue::Float(OrderedFloat(0.5)),
167 });
168 assert!(generator.next(&baseline, &visited).is_none());
169 }
170
171 #[test]
172 fn neighborhood_empty_space_returns_none() {
173 let mut generator = Neighborhood::new(SearchSpace { parameters: vec![] }, 1.0, 0).unwrap();
174 let baseline = ConfigSnapshot::default();
175 let visited = HashSet::new();
176 assert!(generator.next(&baseline, &visited).is_none());
177 }
178
179 #[test]
180 fn neighborhood_zero_radius_returns_error() {
181 let result = Neighborhood::new(SearchSpace::default(), 0.0, 0);
182 assert!(result.is_err(), "zero radius must be rejected");
183 }
184
185 #[test]
186 fn neighborhood_negative_radius_returns_error() {
187 let result = Neighborhood::new(SearchSpace::default(), -1.0, 0);
188 assert!(result.is_err(), "negative radius must be rejected");
189 }
190
191 #[test]
192 fn neighborhood_nan_radius_returns_error() {
193 let result = Neighborhood::new(SearchSpace::default(), f64::NAN, 0);
194 assert!(result.is_err(), "NaN radius must be rejected");
195 }
196
197 #[test]
198 fn neighborhood_step_none_uses_default_steps() {
199 let space = SearchSpace {
201 parameters: vec![super::super::search_space::ParameterRange {
202 kind: ParameterKind::Temperature,
203 min: 0.0,
204 max: 2.0,
205 step: None,
206 default: 1.0,
207 }],
208 };
209 let mut generator = Neighborhood::new(space, 1.0, 77).unwrap();
210 let baseline = ConfigSnapshot::default();
211 let visited = HashSet::new();
212 let mut got_any = false;
214 for _ in 0..50 {
215 if generator.next(&baseline, &visited).is_some() {
216 got_any = true;
217 break;
218 }
219 }
220 assert!(
221 got_any,
222 "should produce at least one variation for continuous parameter"
223 );
224 }
225
226 #[test]
227 fn neighborhood_quantizes_perturbed_values() {
228 let space = make_space(ParameterKind::TopP, 0.1, 1.0, 0.05);
229 let mut generator = Neighborhood::new(space, 2.0, 11).unwrap();
230 let mut baseline = ConfigSnapshot::default();
231 baseline.top_p = 0.5;
232 let visited = HashSet::new();
233 for _ in 0..30 {
234 if let Some(v) = generator.next(&baseline, &visited) {
235 let val = v.value.as_f64();
236 let steps = (val - 0.1) / 0.05;
239 assert!(
240 (steps - steps.round()).abs() < 1e-10,
241 "value {val} is not on the 0.05-step grid anchored at 0.1"
242 );
243 }
244 }
245 }
246
247 #[test]
248 fn neighborhood_name() {
249 let generator = Neighborhood::new(SearchSpace::default(), 1.0, 0).unwrap();
250 assert_eq!(generator.name(), "neighborhood");
251 }
252
253 #[test]
254 fn neighborhood_perturbs_around_baseline() {
255 let space = make_space(ParameterKind::Temperature, 0.0, 2.0, 0.1);
258 let mut generator = Neighborhood::new(space, 1.0, 55).unwrap();
259 let baseline = ConfigSnapshot::default(); let visited = HashSet::new();
261 let mut temp_values = vec![];
262 for _ in 0..50 {
263 if let Some(v) = generator.next(&baseline, &visited)
264 && v.parameter == ParameterKind::Temperature
265 {
266 temp_values.push(v.value.as_f64());
267 }
268 }
269 assert!(
270 !temp_values.is_empty(),
271 "should produce temperature variations"
272 );
273 for val in &temp_values {
275 assert!(
276 *val >= 0.6 - 1e-10 && *val <= 0.8 + 1e-10,
277 "value {val} not within ±1 step of 0.7"
278 );
279 }
280 }
281}