1use rand::prelude::*;
6use std::collections::HashMap;
7
8#[derive(Clone, Debug)]
10pub enum ParameterSpace {
11 Continuous { min: f32, max: f32, log_scale: bool },
12 Integer { min: i32, max: i32 },
13 Categorical { choices: Vec<String> },
14}
15
16pub type Configuration = HashMap<String, ParameterValue>;
18
19#[derive(Clone, Debug)]
20pub enum ParameterValue {
21 Float(f32),
22 Int(i32),
23 String(String),
24}
25
26pub struct BayesianOptimization {
31 pub n_iterations: usize,
32 pub n_initial_points: usize,
33 pub acquisition_function: AcquisitionFunction,
34 pub xi: f32, pub kappa: f32, parameter_space: HashMap<String, ParameterSpace>,
38 observations: Vec<(Configuration, f32)>,
39}
40
41#[derive(Clone, Copy)]
42pub enum AcquisitionFunction {
43 ExpectedImprovement,
44 ProbabilityOfImprovement,
45 UpperConfidenceBound,
46}
47
48impl BayesianOptimization {
49 pub fn new(parameter_space: HashMap<String, ParameterSpace>) -> Self {
50 Self {
51 n_iterations: 50,
52 n_initial_points: 10,
53 acquisition_function: AcquisitionFunction::ExpectedImprovement,
54 xi: 0.01,
55 kappa: 2.576,
56 parameter_space,
57 observations: Vec::new(),
58 }
59 }
60
61 pub fn n_iterations(mut self, n: usize) -> Self {
62 self.n_iterations = n;
63 self
64 }
65
66 pub fn n_initial_points(mut self, n: usize) -> Self {
67 self.n_initial_points = n;
68 self
69 }
70
71 pub fn optimize<F>(&mut self, objective: F) -> (Configuration, f32)
73 where
74 F: Fn(&Configuration) -> f32,
75 {
76 let mut rng = thread_rng();
77
78 for _ in 0..self.n_initial_points {
80 let config = self.sample_random(&mut rng);
81 let score = objective(&config);
82 self.observations.push((config, score));
83 }
84
85 for _ in 0..self.n_iterations {
87 let next_config = self.suggest_next();
88 let score = objective(&next_config);
89 self.observations.push((next_config, score));
90 }
91
92 self.observations
94 .iter()
95 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
96 .unwrap()
97 .clone()
98 }
99
100 fn sample_random(&self, rng: &mut ThreadRng) -> Configuration {
101 let mut config = HashMap::new();
102
103 for (name, space) in &self.parameter_space {
104 let value = match space {
105 ParameterSpace::Continuous { min, max, log_scale } => {
106 let val = if *log_scale {
107 let log_min = min.ln();
108 let log_max = max.ln();
109 (rng.gen::<f32>() * (log_max - log_min) + log_min).exp()
110 } else {
111 rng.gen::<f32>() * (max - min) + min
112 };
113 ParameterValue::Float(val)
114 }
115 ParameterSpace::Integer { min, max } => {
116 let val = rng.gen_range(*min..=*max);
117 ParameterValue::Int(val)
118 }
119 ParameterSpace::Categorical { choices } => {
120 let idx = rng.gen_range(0..choices.len());
121 ParameterValue::String(choices[idx].clone())
122 }
123 };
124 config.insert(name.clone(), value);
125 }
126
127 config
128 }
129
130 fn suggest_next(&self) -> Configuration {
131 let mut rng = thread_rng();
132 let mut best_config = self.sample_random(&mut rng);
133 let mut best_acquisition = f32::NEG_INFINITY;
134
135 for _ in 0..100 {
137 let config = self.sample_random(&mut rng);
138 let acquisition = self.evaluate_acquisition(&config);
139
140 if acquisition > best_acquisition {
141 best_acquisition = acquisition;
142 best_config = config;
143 }
144 }
145
146 best_config
147 }
148
149 fn evaluate_acquisition(&self, config: &Configuration) -> f32 {
150 let (mean, std) = self.predict_gp(config);
152
153 match self.acquisition_function {
154 AcquisitionFunction::ExpectedImprovement => {
155 let best_y = self.observations.iter()
156 .map(|(_, y)| *y)
157 .max_by(|a, b| a.partial_cmp(b).unwrap())
158 .unwrap_or(0.0);
159
160 let z = (mean - best_y - self.xi) / (std + 1e-9);
161 let ei = (mean - best_y - self.xi) * self.normal_cdf(z) + std * self.normal_pdf(z);
162 ei
163 }
164 AcquisitionFunction::ProbabilityOfImprovement => {
165 let best_y = self.observations.iter()
166 .map(|(_, y)| *y)
167 .max_by(|a, b| a.partial_cmp(b).unwrap())
168 .unwrap_or(0.0);
169
170 let z = (mean - best_y - self.xi) / (std + 1e-9);
171 self.normal_cdf(z)
172 }
173 AcquisitionFunction::UpperConfidenceBound => {
174 mean + self.kappa * std
175 }
176 }
177 }
178
179 fn predict_gp(&self, _config: &Configuration) -> (f32, f32) {
180 if self.observations.is_empty() {
184 return (0.0, 1.0);
185 }
186
187 let mean: f32 = self.observations.iter().map(|(_, y)| y).sum::<f32>() / self.observations.len() as f32;
189 let variance: f32 = self.observations.iter()
190 .map(|(_, y)| (y - mean).powi(2))
191 .sum::<f32>() / self.observations.len() as f32;
192 let std = variance.sqrt();
193
194 (mean, std.max(0.1))
195 }
196
197 fn normal_cdf(&self, x: f32) -> f32 {
198 0.5 * (1.0 + self.erf(x / 2.0_f32.sqrt()))
199 }
200
201 fn normal_pdf(&self, x: f32) -> f32 {
202 (-0.5 * x * x).exp() / (2.0 * std::f32::consts::PI).sqrt()
203 }
204
205 fn erf(&self, x: f32) -> f32 {
206 let a1 = 0.254829592;
208 let a2 = -0.284496736;
209 let a3 = 1.421413741;
210 let a4 = -1.453152027;
211 let a5 = 1.061405429;
212 let p = 0.3275911;
213
214 let sign = if x < 0.0 { -1.0 } else { 1.0 };
215 let x = x.abs();
216
217 let t = 1.0 / (1.0 + p * x);
218 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
219
220 sign * y
221 }
222}
223
224pub struct RandomSearch {
228 pub n_iterations: usize,
229 parameter_space: HashMap<String, ParameterSpace>,
230}
231
232impl RandomSearch {
233 pub fn new(parameter_space: HashMap<String, ParameterSpace>) -> Self {
234 Self {
235 n_iterations: 100,
236 parameter_space,
237 }
238 }
239
240 pub fn n_iterations(mut self, n: usize) -> Self {
241 self.n_iterations = n;
242 self
243 }
244
245 pub fn optimize<F>(&self, objective: F) -> (Configuration, f32)
246 where
247 F: Fn(&Configuration) -> f32,
248 {
249 let mut rng = thread_rng();
250 let mut best_config = self.sample_random(&mut rng);
251 let mut best_score = objective(&best_config);
252
253 for _ in 1..self.n_iterations {
254 let config = self.sample_random(&mut rng);
255 let score = objective(&config);
256
257 if score > best_score {
258 best_score = score;
259 best_config = config;
260 }
261 }
262
263 (best_config, best_score)
264 }
265
266 fn sample_random(&self, rng: &mut ThreadRng) -> Configuration {
267 let mut config = HashMap::new();
268
269 for (name, space) in &self.parameter_space {
270 let value = match space {
271 ParameterSpace::Continuous { min, max, log_scale } => {
272 let val = if *log_scale {
273 let log_min = min.ln();
274 let log_max = max.ln();
275 (rng.gen::<f32>() * (log_max - log_min) + log_min).exp()
276 } else {
277 rng.gen::<f32>() * (max - min) + min
278 };
279 ParameterValue::Float(val)
280 }
281 ParameterSpace::Integer { min, max } => {
282 let val = rng.gen_range(*min..=*max);
283 ParameterValue::Int(val)
284 }
285 ParameterSpace::Categorical { choices } => {
286 let idx = rng.gen_range(0..choices.len());
287 ParameterValue::String(choices[idx].clone())
288 }
289 };
290 config.insert(name.clone(), value);
291 }
292
293 config
294 }
295}
296
297pub struct GridSearch {
301 parameter_grid: HashMap<String, Vec<ParameterValue>>,
302}
303
304impl GridSearch {
305 pub fn new(parameter_grid: HashMap<String, Vec<ParameterValue>>) -> Self {
306 Self { parameter_grid }
307 }
308
309 pub fn optimize<F>(&self, objective: F) -> (Configuration, f32)
310 where
311 F: Fn(&Configuration) -> f32,
312 {
313 let configurations = self.generate_configurations();
314
315 let mut best_config = configurations[0].clone();
316 let mut best_score = objective(&best_config);
317
318 for config in configurations.iter().skip(1) {
319 let score = objective(config);
320 if score > best_score {
321 best_score = score;
322 best_config = config.clone();
323 }
324 }
325
326 (best_config, best_score)
327 }
328
329 fn generate_configurations(&self) -> Vec<Configuration> {
330 let mut configurations = vec![HashMap::new()];
331
332 for (name, values) in &self.parameter_grid {
333 let mut new_configurations = Vec::new();
334
335 for config in &configurations {
336 for value in values {
337 let mut new_config = config.clone();
338 new_config.insert(name.clone(), value.clone());
339 new_configurations.push(new_config);
340 }
341 }
342
343 configurations = new_configurations;
344 }
345
346 configurations
347 }
348}
349
350pub struct Hyperband {
355 pub max_iter: usize,
356 pub eta: usize,
357 parameter_space: HashMap<String, ParameterSpace>,
358}
359
360impl Hyperband {
361 pub fn new(parameter_space: HashMap<String, ParameterSpace>) -> Self {
362 Self {
363 max_iter: 81, eta: 3, parameter_space,
366 }
367 }
368
369 pub fn max_iter(mut self, max_iter: usize) -> Self {
370 self.max_iter = max_iter;
371 self
372 }
373
374 pub fn eta(mut self, eta: usize) -> Self {
375 self.eta = eta;
376 self
377 }
378
379 pub fn optimize<F>(&self, objective: F) -> (Configuration, f32)
383 where
384 F: Fn(&Configuration, usize) -> f32,
385 {
386 let mut rng = thread_rng();
387 let s_max = (self.max_iter as f32).log(self.eta as f32).floor() as usize;
388 let b = (s_max + 1) * self.max_iter;
389
390 let mut best_config = None;
391 let mut best_score = f32::NEG_INFINITY;
392
393 for s in (0..=s_max).rev() {
395 let n = ((b as f32 / self.max_iter as f32 / (s + 1) as f32) * (self.eta as f32).powi(s as i32)).ceil() as usize;
396 let r = self.max_iter * (self.eta as f32).powi(-(s as i32)) as usize;
397
398 let mut configs: Vec<(Configuration, f32)> = (0..n)
400 .map(|_| {
401 let config = self.sample_random(&mut rng);
402 let score = objective(&config, r);
403 (config, score)
404 })
405 .collect();
406
407 for i in 0..=s {
409 let n_i = (n as f32 * (self.eta as f32).powi(-(i as i32))).floor() as usize;
410 let r_i = r * (self.eta as f32).powi(i as i32) as usize;
411
412 for (config, score) in configs.iter_mut() {
414 *score = objective(config, r_i);
415 }
416
417 configs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
419 let keep = (n_i as f32 / self.eta as f32).ceil() as usize;
420 configs.truncate(keep.min(configs.len()));
421 }
422
423 if let Some((config, score)) = configs.first() {
425 if *score > best_score {
426 best_score = *score;
427 best_config = Some(config.clone());
428 }
429 }
430 }
431
432 (best_config.unwrap(), best_score)
433 }
434
435 fn sample_random(&self, rng: &mut ThreadRng) -> Configuration {
436 let mut config = HashMap::new();
437
438 for (name, space) in &self.parameter_space {
439 let value = match space {
440 ParameterSpace::Continuous { min, max, log_scale } => {
441 let val = if *log_scale {
442 let log_min = min.ln();
443 let log_max = max.ln();
444 (rng.gen::<f32>() * (log_max - log_min) + log_min).exp()
445 } else {
446 rng.gen::<f32>() * (max - min) + min
447 };
448 ParameterValue::Float(val)
449 }
450 ParameterSpace::Integer { min, max } => {
451 let val = rng.gen_range(*min..=*max);
452 ParameterValue::Int(val)
453 }
454 ParameterSpace::Categorical { choices } => {
455 let idx = rng.gen_range(0..choices.len());
456 ParameterValue::String(choices[idx].clone())
457 }
458 };
459 config.insert(name.clone(), value);
460 }
461
462 config
463 }
464}
465
466pub struct BOHB {
471 pub max_iter: usize,
472 pub eta: usize,
473 pub min_points_in_model: usize,
474 pub top_n_percent: usize,
475 pub bandwidth_factor: f32,
476 parameter_space: HashMap<String, ParameterSpace>,
477 observations: Vec<(Configuration, usize, f32)>, }
479
480impl BOHB {
481 pub fn new(parameter_space: HashMap<String, ParameterSpace>) -> Self {
482 Self {
483 max_iter: 81,
484 eta: 3,
485 min_points_in_model: 10,
486 top_n_percent: 15,
487 bandwidth_factor: 3.0,
488 parameter_space,
489 observations: Vec::new(),
490 }
491 }
492
493 pub fn max_iter(mut self, max_iter: usize) -> Self {
494 self.max_iter = max_iter;
495 self
496 }
497
498 pub fn eta(mut self, eta: usize) -> Self {
499 self.eta = eta;
500 self
501 }
502
503 pub fn optimize<F>(&mut self, objective: F) -> (Configuration, f32)
505 where
506 F: Fn(&Configuration, usize) -> f32,
507 {
508 let mut rng = thread_rng();
509 let s_max = (self.max_iter as f32).log(self.eta as f32).floor() as usize;
510 let b = (s_max + 1) * self.max_iter;
511
512 let mut best_config = None;
513 let mut best_score = f32::NEG_INFINITY;
514
515 for s in (0..=s_max).rev() {
516 let n = ((b as f32 / self.max_iter as f32 / (s + 1) as f32) * (self.eta as f32).powi(s as i32)).ceil() as usize;
517 let r = self.max_iter * (self.eta as f32).powi(-(s as i32)) as usize;
518
519 let mut configs: Vec<(Configuration, f32)> = (0..n)
521 .map(|_| {
522 let config = if self.observations.len() >= self.min_points_in_model {
523 self.sample_tpe(&mut rng)
524 } else {
525 self.sample_random(&mut rng)
526 };
527 let score = objective(&config, r);
528 self.observations.push((config.clone(), r, score));
529 (config, score)
530 })
531 .collect();
532
533 for i in 0..=s {
535 let n_i = (n as f32 * (self.eta as f32).powi(-(i as i32))).floor() as usize;
536 let r_i = r * (self.eta as f32).powi(i as i32) as usize;
537
538 for (config, score) in configs.iter_mut() {
539 *score = objective(config, r_i);
540 self.observations.push((config.clone(), r_i, *score));
541 }
542
543 configs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
544 let keep = (n_i as f32 / self.eta as f32).ceil() as usize;
545 configs.truncate(keep.min(configs.len()));
546 }
547
548 if let Some((config, score)) = configs.first() {
549 if *score > best_score {
550 best_score = *score;
551 best_config = Some(config.clone());
552 }
553 }
554 }
555
556 (best_config.unwrap(), best_score)
557 }
558
559 fn sample_tpe(&self, rng: &mut ThreadRng) -> Configuration {
560 let mut sorted_obs = self.observations.clone();
563 sorted_obs.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap());
564
565 let split_idx = (sorted_obs.len() * self.top_n_percent / 100).max(1);
566 let good_obs: Vec<_> = sorted_obs.iter().take(split_idx).collect();
567 let bad_obs: Vec<_> = sorted_obs.iter().skip(split_idx).collect();
568
569 let mut config = HashMap::new();
571
572 for (name, space) in &self.parameter_space {
573 let value = match space {
574 ParameterSpace::Continuous { min, max, log_scale } => {
575 let good_values: Vec<f32> = good_obs
577 .iter()
578 .filter_map(|(c, _, _)| {
579 if let Some(ParameterValue::Float(v)) = c.get(name) {
580 Some(*v)
581 } else {
582 None
583 }
584 })
585 .collect();
586
587 let val = if !good_values.is_empty() {
588 let idx = rng.gen_range(0..good_values.len());
590 let base = good_values[idx];
591 let bandwidth = (max - min) / self.bandwidth_factor;
592 let noise = rng.gen::<f32>() * bandwidth - bandwidth / 2.0;
593 (base + noise).clamp(*min, *max)
594 } else {
595 if *log_scale {
597 let log_min = min.ln();
598 let log_max = max.ln();
599 (rng.gen::<f32>() * (log_max - log_min) + log_min).exp()
600 } else {
601 rng.gen::<f32>() * (max - min) + min
602 }
603 };
604 ParameterValue::Float(val)
605 }
606 ParameterSpace::Integer { min, max } => {
607 let good_values: Vec<i32> = good_obs
608 .iter()
609 .filter_map(|(c, _, _)| {
610 if let Some(ParameterValue::Int(v)) = c.get(name) {
611 Some(*v)
612 } else {
613 None
614 }
615 })
616 .collect();
617
618 let val = if !good_values.is_empty() {
619 let idx = rng.gen_range(0..good_values.len());
620 good_values[idx]
621 } else {
622 rng.gen_range(*min..=*max)
623 };
624 ParameterValue::Int(val)
625 }
626 ParameterSpace::Categorical { choices } => {
627 let good_values: Vec<String> = good_obs
628 .iter()
629 .filter_map(|(c, _, _)| {
630 if let Some(ParameterValue::String(v)) = c.get(name) {
631 Some(v.clone())
632 } else {
633 None
634 }
635 })
636 .collect();
637
638 let val = if !good_values.is_empty() {
639 let idx = rng.gen_range(0..good_values.len());
640 good_values[idx].clone()
641 } else {
642 let idx = rng.gen_range(0..choices.len());
643 choices[idx].clone()
644 };
645 ParameterValue::String(val)
646 }
647 };
648 config.insert(name.clone(), value);
649 }
650
651 config
652 }
653
654 fn sample_random(&self, rng: &mut ThreadRng) -> Configuration {
655 let mut config = HashMap::new();
656
657 for (name, space) in &self.parameter_space {
658 let value = match space {
659 ParameterSpace::Continuous { min, max, log_scale } => {
660 let val = if *log_scale {
661 let log_min = min.ln();
662 let log_max = max.ln();
663 (rng.gen::<f32>() * (log_max - log_min) + log_min).exp()
664 } else {
665 rng.gen::<f32>() * (max - min) + min
666 };
667 ParameterValue::Float(val)
668 }
669 ParameterSpace::Integer { min, max } => {
670 let val = rng.gen_range(*min..=*max);
671 ParameterValue::Int(val)
672 }
673 ParameterSpace::Categorical { choices } => {
674 let idx = rng.gen_range(0..choices.len());
675 ParameterValue::String(choices[idx].clone())
676 }
677 };
678 config.insert(name.clone(), value);
679 }
680
681 config
682 }
683}
684
685#[cfg(test)]
686mod tests {
687 use super::*;
688
689 #[test]
690 fn test_random_search() {
691 let mut param_space = HashMap::new();
692 param_space.insert(
693 "learning_rate".to_string(),
694 ParameterSpace::Continuous { min: 0.001, max: 0.1, log_scale: true },
695 );
696 param_space.insert(
697 "n_estimators".to_string(),
698 ParameterSpace::Integer { min: 10, max: 100 },
699 );
700
701 let rs = RandomSearch::new(param_space).n_iterations(10);
702
703 let (best_config, best_score) = rs.optimize(|config| {
704 match config.get("learning_rate") {
706 Some(ParameterValue::Float(lr)) => *lr * 10.0,
707 _ => 0.0,
708 }
709 });
710
711 assert!(best_score > 0.0);
712 assert!(best_config.contains_key("learning_rate"));
713 }
714
715 #[test]
716 fn test_grid_search() {
717 let mut param_grid = HashMap::new();
718 param_grid.insert(
719 "param1".to_string(),
720 vec![ParameterValue::Float(0.1), ParameterValue::Float(0.2)],
721 );
722 param_grid.insert(
723 "param2".to_string(),
724 vec![ParameterValue::Int(10), ParameterValue::Int(20)],
725 );
726
727 let gs = GridSearch::new(param_grid);
728
729 let (best_config, _) = gs.optimize(|config| {
730 match (config.get("param1"), config.get("param2")) {
731 (Some(ParameterValue::Float(p1)), Some(ParameterValue::Int(p2))) => {
732 p1 * (*p2 as f32)
733 }
734 _ => 0.0,
735 }
736 });
737
738 assert!(best_config.contains_key("param1"));
739 assert!(best_config.contains_key("param2"));
740 }
741
742 #[test]
743 fn test_hyperband() {
744 let mut param_space = HashMap::new();
745 param_space.insert(
746 "learning_rate".to_string(),
747 ParameterSpace::Continuous { min: 0.001, max: 0.1, log_scale: true },
748 );
749 param_space.insert(
750 "n_layers".to_string(),
751 ParameterSpace::Integer { min: 1, max: 5 },
752 );
753
754 let hb = Hyperband::new(param_space)
755 .max_iter(27)
756 .eta(3);
757
758 let (best_config, best_score) = hb.optimize(|config, budget| {
759 let lr = match config.get("learning_rate") {
761 Some(ParameterValue::Float(v)) => *v,
762 _ => 0.01,
763 };
764 let n_layers = match config.get("n_layers") {
765 Some(ParameterValue::Int(v)) => *v,
766 _ => 2,
767 };
768
769 let base_score = lr * 10.0 + n_layers as f32;
771 base_score * (budget as f32).sqrt() / 10.0
772 });
773
774 assert!(best_score > 0.0);
775 assert!(best_config.contains_key("learning_rate"));
776 assert!(best_config.contains_key("n_layers"));
777 }
778
779 #[test]
780 fn test_bohb() {
781 let mut param_space = HashMap::new();
782 param_space.insert(
783 "learning_rate".to_string(),
784 ParameterSpace::Continuous { min: 0.001, max: 0.1, log_scale: true },
785 );
786 param_space.insert(
787 "batch_size".to_string(),
788 ParameterSpace::Integer { min: 16, max: 128 },
789 );
790
791 let mut bohb = BOHB::new(param_space)
792 .max_iter(27)
793 .eta(3);
794
795 let (best_config, best_score) = bohb.optimize(|config, budget| {
796 let lr = match config.get("learning_rate") {
797 Some(ParameterValue::Float(v)) => *v,
798 _ => 0.01,
799 };
800 let batch_size = match config.get("batch_size") {
801 Some(ParameterValue::Int(v)) => *v,
802 _ => 32,
803 };
804
805 let base_score = (lr * 100.0).ln() + (batch_size as f32 / 32.0);
807 base_score * (budget as f32).sqrt() / 5.0
808 });
809
810 assert!(best_score > 0.0);
811 assert!(best_config.contains_key("learning_rate"));
812 assert!(best_config.contains_key("batch_size"));
813 }
814
815 #[test]
816 fn test_bohb_tpe_sampling() {
817 let mut param_space = HashMap::new();
818 param_space.insert(
819 "x".to_string(),
820 ParameterSpace::Continuous { min: -5.0, max: 5.0, log_scale: false },
821 );
822
823 let mut bohb = BOHB::new(param_space)
824 .max_iter(9)
825 .eta(3);
826
827 let (best_config, best_score) = bohb.optimize(|config, _budget| {
829 let x = match config.get("x") {
830 Some(ParameterValue::Float(v)) => *v,
831 _ => 0.0,
832 };
833 -(x - 2.0).powi(2)
835 });
836
837 if let Some(ParameterValue::Float(x)) = best_config.get("x") {
839 assert!((x - 2.0).abs() < 1.0, "Expected x close to 2, got {}", x);
840 }
841 assert!(best_score > -2.0);
842 }
843}
844
845
846