zeph_experiments/
random.rs1use std::collections::HashSet;
7use std::sync::Mutex;
8
9use ordered_float::OrderedFloat;
10use rand::Rng as _;
11use rand::SeedableRng as _;
12use rand::rngs::SmallRng;
13
14use super::generator::VariationGenerator;
15use super::search_space::SearchSpace;
16use super::snapshot::ConfigSnapshot;
17use super::types::{Variation, VariationValue};
18
19const MAX_RETRIES: usize = 1000;
21
22pub struct Random {
34 search_space: SearchSpace,
35 rng: Mutex<SmallRng>,
36}
37
38impl Random {
39 #[must_use]
41 pub fn new(search_space: SearchSpace, seed: u64) -> Self {
42 Self {
43 search_space,
44 rng: Mutex::new(SmallRng::seed_from_u64(seed)),
45 }
46 }
47}
48
49impl VariationGenerator for Random {
50 fn next(
51 &mut self,
52 _baseline: &ConfigSnapshot,
53 visited: &HashSet<Variation>,
54 ) -> Option<Variation> {
55 if self.search_space.parameters.is_empty() {
56 return None;
57 }
58 let mut rng = self.rng.lock().expect("rng mutex poisoned");
59 for _ in 0..MAX_RETRIES {
60 let idx = rng.gen_range(0..self.search_space.parameters.len());
61 let range = &self.search_space.parameters[idx];
62 let raw: f64 = rng.gen_range(range.min..=range.max);
63 let value = range.quantize(raw);
64 let variation = Variation {
65 parameter: range.kind,
66 value: VariationValue::Float(OrderedFloat(value)),
67 };
68 if !visited.contains(&variation) {
69 return Some(variation);
70 }
71 }
72 None
73 }
74
75 fn name(&self) -> &'static str {
76 "random"
77 }
78}
79
80#[cfg(test)]
81mod tests {
82 #![allow(clippy::manual_range_contains)]
83
84 use std::collections::HashSet;
85
86 use super::super::search_space::ParameterRange;
87 use super::super::types::ParameterKind;
88 use super::*;
89
90 #[test]
91 fn random_produces_values_in_range() {
92 let space = SearchSpace {
93 parameters: vec![ParameterRange {
94 kind: ParameterKind::Temperature,
95 min: 0.0,
96 max: 1.0,
97 step: Some(0.1),
98 default: 0.5,
99 }],
100 };
101 let mut generator = Random::new(space, 42);
102 let baseline = ConfigSnapshot::default();
103 let visited = HashSet::new();
104 for _ in 0..20 {
105 if let Some(v) = generator.next(&baseline, &visited) {
106 let val = v.value.as_f64();
107 assert!((0.0..=1.0).contains(&val), "out of range: {val}");
108 }
109 }
110 }
111
112 #[test]
113 fn random_skips_visited() {
114 let space = SearchSpace {
115 parameters: vec![ParameterRange {
116 kind: ParameterKind::Temperature,
117 min: 0.5,
118 max: 0.5,
119 step: Some(0.1),
120 default: 0.5,
121 }],
122 };
123 let mut generator = Random::new(space, 0);
124 let baseline = ConfigSnapshot::default();
125 let mut visited = HashSet::new();
126 visited.insert(Variation {
127 parameter: ParameterKind::Temperature,
128 value: VariationValue::Float(OrderedFloat(0.5)),
129 });
130 let result = generator.next(&baseline, &visited);
132 assert!(
133 result.is_none(),
134 "expected None when only option is already visited"
135 );
136 }
137
138 #[test]
139 fn random_empty_space_returns_none() {
140 let mut generator = Random::new(SearchSpace { parameters: vec![] }, 0);
141 let baseline = ConfigSnapshot::default();
142 let visited = HashSet::new();
143 assert!(generator.next(&baseline, &visited).is_none());
144 }
145
146 #[test]
147 fn random_is_deterministic_with_same_seed() {
148 let space = SearchSpace::default();
149 let baseline = ConfigSnapshot::default();
150 let visited = HashSet::new();
151 let mut gen1 = Random::new(space.clone(), 123);
152 let mut gen2 = Random::new(space, 123);
153 let v1 = gen1.next(&baseline, &visited);
154 let v2 = gen2.next(&baseline, &visited);
155 assert_eq!(v1, v2, "same seed must produce same first variation");
156 }
157
158 #[test]
159 fn random_quantizes_sampled_values() {
160 let space = SearchSpace {
161 parameters: vec![ParameterRange {
162 kind: ParameterKind::TopP,
163 min: 0.1,
164 max: 1.0,
165 step: Some(0.05),
166 default: 0.9,
167 }],
168 };
169 let mut generator = Random::new(space, 7);
170 let baseline = ConfigSnapshot::default();
171 let visited = HashSet::new();
172 for _ in 0..30 {
173 if let Some(v) = generator.next(&baseline, &visited) {
174 let val = v.value.as_f64();
175 let steps = (val - 0.1) / 0.05;
178 assert!(
179 (steps - steps.round()).abs() < 1e-10,
180 "value {val} is not on the 0.05-step grid anchored at 0.1"
181 );
182 }
183 }
184 }
185
186 #[test]
187 fn random_name() {
188 let generator = Random::new(SearchSpace::default(), 0);
189 assert_eq!(generator.name(), "random");
190 }
191
192 #[test]
193 fn random_is_sync() {
194 fn assert_sync<T: Sync>() {}
195 assert_sync::<Random>();
196 }
197}