somatize_runtime/sampler/
bayesian.rs1use crate::sampler::{Sampler, hash_u64, pseudo_random, sample_float};
2use somatize_core::error::Result;
3use somatize_core::search::{SearchDimension, SearchSpace};
4use std::collections::HashMap;
5
6pub struct BayesianSampler {
15 n_trials: usize,
16 n_startup: usize,
17 seed: u64,
18 history: Vec<(HashMap<String, serde_json::Value>, f64)>,
20 gamma: f64,
22}
23
24impl BayesianSampler {
25 pub fn new(n_trials: usize, n_startup: usize, seed: Option<u64>) -> Self {
26 Self {
27 n_trials,
28 n_startup: n_startup.max(2),
29 seed: seed.unwrap_or(42),
30 history: Vec::new(),
31 gamma: 0.25, }
33 }
34
35 pub fn record(&mut self, params: HashMap<String, serde_json::Value>, metric: f64) {
37 self.history.push((params, metric));
38 }
39
40 fn sample_tpe(
42 &self,
43 space: &SearchSpace,
44 trial_index: usize,
45 ) -> HashMap<String, serde_json::Value> {
46 let mut sorted_history: Vec<(usize, f64)> = self
48 .history
49 .iter()
50 .enumerate()
51 .map(|(i, (_, v))| (i, *v))
52 .collect();
53 sorted_history.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
54
55 let n_good = (self.history.len() as f64 * self.gamma).ceil() as usize;
56 let n_good = n_good.max(1).min(self.history.len());
57 let good_indices: Vec<usize> = sorted_history[..n_good].iter().map(|(i, _)| *i).collect();
58
59 let mut params = HashMap::new();
60 for (dim_idx, dim) in space.active_dimensions().iter().enumerate() {
61 let rng_state = hash_u64(self.seed, trial_index as u64, dim_idx as u64);
62 let t = pseudo_random(rng_state);
63
64 let explore_prob = pseudo_random(hash_u64(
67 self.seed,
68 trial_index as u64,
69 dim_idx as u64 + 1000,
70 ));
71
72 let value = if explore_prob < 0.2 || good_indices.is_empty() {
73 self.sample_uniform(dim, t)
75 } else {
76 let good_idx = good_indices
78 [((t * good_indices.len() as f64) as usize).min(good_indices.len() - 1)];
79 let good_params = &self.history[good_idx].0;
80
81 if let Some(good_val) = good_params.get(dim.name()) {
82 self.sample_near(dim, good_val, rng_state)
83 } else {
84 self.sample_uniform(dim, t)
85 }
86 };
87
88 params.insert(dim.name().to_string(), value);
89 }
90
91 params
92 }
93
94 fn sample_uniform(&self, dim: &SearchDimension, t: f64) -> serde_json::Value {
95 match dim {
96 SearchDimension::Float {
97 low, high, scale, ..
98 } => {
99 serde_json::json!(sample_float(*low, *high, *scale, t))
100 }
101 SearchDimension::Int { low, high, .. } => {
102 let range = (*high - *low + 1) as f64;
103 let val = *low + (t * range).floor() as i64;
104 serde_json::json!(val.min(*high))
105 }
106 SearchDimension::Categorical { choices, .. } => {
107 let idx = (t * choices.len() as f64).floor() as usize;
108 choices[idx.min(choices.len() - 1)].clone()
109 }
110 _ => serde_json::Value::Null,
111 }
112 }
113
114 fn sample_near(
116 &self,
117 dim: &SearchDimension,
118 center: &serde_json::Value,
119 rng_state: u64,
120 ) -> serde_json::Value {
121 let t = pseudo_random(hash_u64(rng_state, 777, 0));
122 let perturbation = (pseudo_random(hash_u64(rng_state, 888, 0)) - 0.5) * 0.3;
123
124 match dim {
125 SearchDimension::Float { low, high, .. } => {
126 if let Some(center_val) = center.as_f64() {
127 let range = *high - *low;
128 let new_val = (center_val + perturbation * range).clamp(*low, *high);
129 serde_json::json!(new_val)
130 } else {
131 self.sample_uniform(dim, t)
132 }
133 }
134 SearchDimension::Int { low, high, .. } => {
135 if let Some(center_val) = center.as_i64() {
136 let range = (*high - *low) as f64;
137 let new_val = (center_val as f64 + perturbation * range).round() as i64;
138 serde_json::json!(new_val.clamp(*low, *high))
139 } else {
140 self.sample_uniform(dim, t)
141 }
142 }
143 SearchDimension::Categorical { choices, .. } => {
144 if perturbation.abs() < 0.1 {
146 center.clone()
147 } else {
148 let idx = (t * choices.len() as f64).floor() as usize;
149 choices[idx.min(choices.len() - 1)].clone()
150 }
151 }
152 _ => serde_json::Value::Null,
153 }
154 }
155}
156
157impl Sampler for BayesianSampler {
158 fn sample(
159 &mut self,
160 space: &SearchSpace,
161 trial_index: usize,
162 ) -> Result<Option<HashMap<String, serde_json::Value>>> {
163 if trial_index >= self.n_trials {
164 return Ok(None);
165 }
166
167 if trial_index < self.n_startup || self.history.is_empty() {
168 let mut params = HashMap::new();
170 for (i, dim) in space.active_dimensions().iter().enumerate() {
171 let rng_state = hash_u64(self.seed, trial_index as u64, i as u64);
172 let t = pseudo_random(rng_state);
173 params.insert(dim.name().to_string(), self.sample_uniform(dim, t));
174 }
175 Ok(Some(params))
176 } else {
177 Ok(Some(self.sample_tpe(space, trial_index)))
178 }
179 }
180
181 fn n_trials(&self) -> Option<usize> {
182 Some(self.n_trials)
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use somatize_core::search::Scale;
190
191 fn sample_space() -> SearchSpace {
192 let mut space = SearchSpace::new();
193 space.add(SearchDimension::Float {
194 name: "lr".into(),
195 low: 0.001,
196 high: 0.1,
197 scale: Scale::Log,
198 default: None,
199 });
200 space.add(SearchDimension::Categorical {
201 name: "kernel".into(),
202 choices: vec![serde_json::json!("rbf"), serde_json::json!("linear")],
203 });
204 space
205 }
206
207 #[test]
208 fn startup_phase_is_random() {
209 let mut sampler = BayesianSampler::new(20, 5, Some(42));
210 let space = sample_space();
211
212 let mut samples = Vec::new();
214 for i in 0..5 {
215 let params = sampler.sample(&space, i).unwrap().unwrap();
216 assert!(params.contains_key("lr"));
217 assert!(params.contains_key("kernel"));
218 samples.push(params);
219 }
220
221 let lrs: Vec<f64> = samples.iter().map(|p| p["lr"].as_f64().unwrap()).collect();
223 assert!(lrs.windows(2).any(|w| (w[0] - w[1]).abs() > 1e-10));
224 }
225
226 #[test]
227 fn tpe_phase_after_recording_history() {
228 let mut sampler = BayesianSampler::new(20, 3, Some(42));
229 let space = sample_space();
230
231 for i in 0..5 {
233 let params = sampler.sample(&space, i).unwrap().unwrap();
234 let lr = params["lr"].as_f64().unwrap();
235 let metric = 1.0 - (lr - 0.01).abs() * 10.0; sampler.record(params, metric);
237 }
238
239 let params = sampler.sample(&space, 5).unwrap().unwrap();
241 assert!(params.contains_key("lr"));
242 let lr = params["lr"].as_f64().unwrap();
243 assert!((0.001..=0.1).contains(&lr));
244 }
245
246 #[test]
247 fn respects_n_trials_limit() {
248 let mut sampler = BayesianSampler::new(10, 3, Some(42));
249 let space = sample_space();
250
251 for i in 0..15 {
252 let result = sampler.sample(&space, i).unwrap();
253 if i < 10 {
254 assert!(result.is_some());
255 } else {
256 assert!(result.is_none());
257 }
258 }
259 }
260
261 #[test]
262 fn deterministic_with_seed() {
263 let space = sample_space();
264
265 let mut s1 = BayesianSampler::new(10, 3, Some(42));
266 let mut s2 = BayesianSampler::new(10, 3, Some(42));
267
268 for i in 0..5 {
269 let p1 = s1.sample(&space, i).unwrap().unwrap();
270 let p2 = s2.sample(&space, i).unwrap().unwrap();
271 assert_eq!(p1, p2);
272 }
273 }
274
275 #[test]
276 fn different_seeds_differ() {
277 let space = sample_space();
278
279 let mut s1 = BayesianSampler::new(10, 3, Some(42));
280 let mut s2 = BayesianSampler::new(10, 3, Some(99));
281
282 let p1 = s1.sample(&space, 0).unwrap().unwrap();
283 let p2 = s2.sample(&space, 0).unwrap().unwrap();
284 assert_ne!(p1["lr"], p2["lr"]);
285 }
286}