ghostflow_ml/
hyperparameter_optimization.rs

1//! Hyperparameter Optimization
2//!
3//! Advanced algorithms for finding optimal hyperparameters.
4
5use rand::prelude::*;
6use std::collections::HashMap;
7
8/// Parameter space definition
9#[derive(Clone, Debug)]
10pub enum ParameterSpace {
11    Continuous { min: f32, max: f32, log_scale: bool },
12    Integer { min: i32, max: i32 },
13    Categorical { choices: Vec<String> },
14}
15
16/// Hyperparameter configuration
17pub type Configuration = HashMap<String, ParameterValue>;
18
19#[derive(Clone, Debug)]
20pub enum ParameterValue {
21    Float(f32),
22    Int(i32),
23    String(String),
24}
25
26/// Bayesian Optimization using Gaussian Process
27/// 
28/// Efficiently searches hyperparameter space by building a probabilistic model
29/// of the objective function.
30pub struct BayesianOptimization {
31    pub n_iterations: usize,
32    pub n_initial_points: usize,
33    pub acquisition_function: AcquisitionFunction,
34    pub xi: f32,  // Exploration-exploitation trade-off
35    pub kappa: f32,  // For UCB acquisition
36    
37    parameter_space: HashMap<String, ParameterSpace>,
38    observations: Vec<(Configuration, f32)>,
39}
40
41#[derive(Clone, Copy)]
42pub enum AcquisitionFunction {
43    ExpectedImprovement,
44    ProbabilityOfImprovement,
45    UpperConfidenceBound,
46}
47
48impl BayesianOptimization {
49    pub fn new(parameter_space: HashMap<String, ParameterSpace>) -> Self {
50        Self {
51            n_iterations: 50,
52            n_initial_points: 10,
53            acquisition_function: AcquisitionFunction::ExpectedImprovement,
54            xi: 0.01,
55            kappa: 2.576,
56            parameter_space,
57            observations: Vec::new(),
58        }
59    }
60
61    pub fn n_iterations(mut self, n: usize) -> Self {
62        self.n_iterations = n;
63        self
64    }
65
66    pub fn n_initial_points(mut self, n: usize) -> Self {
67        self.n_initial_points = n;
68        self
69    }
70
71    /// Optimize a black-box function
72    pub fn optimize<F>(&mut self, objective: F) -> (Configuration, f32)
73    where
74        F: Fn(&Configuration) -> f32,
75    {
76        let mut rng = thread_rng();
77
78        // Initial random sampling
79        for _ in 0..self.n_initial_points {
80            let config = self.sample_random(&mut rng);
81            let score = objective(&config);
82            self.observations.push((config, score));
83        }
84
85        // Bayesian optimization loop
86        for _ in 0..self.n_iterations {
87            let next_config = self.suggest_next();
88            let score = objective(&next_config);
89            self.observations.push((next_config, score));
90        }
91
92        // Return best configuration
93        self.observations
94            .iter()
95            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
96            .unwrap()
97            .clone()
98    }
99
100    fn sample_random(&self, rng: &mut ThreadRng) -> Configuration {
101        let mut config = HashMap::new();
102
103        for (name, space) in &self.parameter_space {
104            let value = match space {
105                ParameterSpace::Continuous { min, max, log_scale } => {
106                    let val = if *log_scale {
107                        let log_min = min.ln();
108                        let log_max = max.ln();
109                        (rng.gen::<f32>() * (log_max - log_min) + log_min).exp()
110                    } else {
111                        rng.gen::<f32>() * (max - min) + min
112                    };
113                    ParameterValue::Float(val)
114                }
115                ParameterSpace::Integer { min, max } => {
116                    let val = rng.gen_range(*min..=*max);
117                    ParameterValue::Int(val)
118                }
119                ParameterSpace::Categorical { choices } => {
120                    let idx = rng.gen_range(0..choices.len());
121                    ParameterValue::String(choices[idx].clone())
122                }
123            };
124            config.insert(name.clone(), value);
125        }
126
127        config
128    }
129
130    fn suggest_next(&self) -> Configuration {
131        let mut rng = thread_rng();
132        let mut best_config = self.sample_random(&mut rng);
133        let mut best_acquisition = f32::NEG_INFINITY;
134
135        // Sample candidates and evaluate acquisition function
136        for _ in 0..100 {
137            let config = self.sample_random(&mut rng);
138            let acquisition = self.evaluate_acquisition(&config);
139
140            if acquisition > best_acquisition {
141                best_acquisition = acquisition;
142                best_config = config;
143            }
144        }
145
146        best_config
147    }
148
149    fn evaluate_acquisition(&self, config: &Configuration) -> f32 {
150        // Simplified acquisition function (in practice, would use GP)
151        let (mean, std) = self.predict_gp(config);
152        
153        match self.acquisition_function {
154            AcquisitionFunction::ExpectedImprovement => {
155                let best_y = self.observations.iter()
156                    .map(|(_, y)| *y)
157                    .max_by(|a, b| a.partial_cmp(b).unwrap())
158                    .unwrap_or(0.0);
159                
160                let z = (mean - best_y - self.xi) / (std + 1e-9);
161                let ei = (mean - best_y - self.xi) * self.normal_cdf(z) + std * self.normal_pdf(z);
162                ei
163            }
164            AcquisitionFunction::ProbabilityOfImprovement => {
165                let best_y = self.observations.iter()
166                    .map(|(_, y)| *y)
167                    .max_by(|a, b| a.partial_cmp(b).unwrap())
168                    .unwrap_or(0.0);
169                
170                let z = (mean - best_y - self.xi) / (std + 1e-9);
171                self.normal_cdf(z)
172            }
173            AcquisitionFunction::UpperConfidenceBound => {
174                mean + self.kappa * std
175            }
176        }
177    }
178
179    fn predict_gp(&self, _config: &Configuration) -> (f32, f32) {
180        // Simplified GP prediction (in practice, would use proper GP)
181        // Returns (mean, std)
182        
183        if self.observations.is_empty() {
184            return (0.0, 1.0);
185        }
186
187        // Simple average as mean, std based on variance
188        let mean: f32 = self.observations.iter().map(|(_, y)| y).sum::<f32>() / self.observations.len() as f32;
189        let variance: f32 = self.observations.iter()
190            .map(|(_, y)| (y - mean).powi(2))
191            .sum::<f32>() / self.observations.len() as f32;
192        let std = variance.sqrt();
193
194        (mean, std.max(0.1))
195    }
196
197    fn normal_cdf(&self, x: f32) -> f32 {
198        0.5 * (1.0 + self.erf(x / 2.0_f32.sqrt()))
199    }
200
201    fn normal_pdf(&self, x: f32) -> f32 {
202        (-0.5 * x * x).exp() / (2.0 * std::f32::consts::PI).sqrt()
203    }
204
205    fn erf(&self, x: f32) -> f32 {
206        // Approximation of error function
207        let a1 = 0.254829592;
208        let a2 = -0.284496736;
209        let a3 = 1.421413741;
210        let a4 = -1.453152027;
211        let a5 = 1.061405429;
212        let p = 0.3275911;
213
214        let sign = if x < 0.0 { -1.0 } else { 1.0 };
215        let x = x.abs();
216
217        let t = 1.0 / (1.0 + p * x);
218        let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
219
220        sign * y
221    }
222}
223
224/// Random Search
225/// 
226/// Simple but effective baseline for hyperparameter optimization.
227pub struct RandomSearch {
228    pub n_iterations: usize,
229    parameter_space: HashMap<String, ParameterSpace>,
230}
231
232impl RandomSearch {
233    pub fn new(parameter_space: HashMap<String, ParameterSpace>) -> Self {
234        Self {
235            n_iterations: 100,
236            parameter_space,
237        }
238    }
239
240    pub fn n_iterations(mut self, n: usize) -> Self {
241        self.n_iterations = n;
242        self
243    }
244
245    pub fn optimize<F>(&self, objective: F) -> (Configuration, f32)
246    where
247        F: Fn(&Configuration) -> f32,
248    {
249        let mut rng = thread_rng();
250        let mut best_config = self.sample_random(&mut rng);
251        let mut best_score = objective(&best_config);
252
253        for _ in 1..self.n_iterations {
254            let config = self.sample_random(&mut rng);
255            let score = objective(&config);
256
257            if score > best_score {
258                best_score = score;
259                best_config = config;
260            }
261        }
262
263        (best_config, best_score)
264    }
265
266    fn sample_random(&self, rng: &mut ThreadRng) -> Configuration {
267        let mut config = HashMap::new();
268
269        for (name, space) in &self.parameter_space {
270            let value = match space {
271                ParameterSpace::Continuous { min, max, log_scale } => {
272                    let val = if *log_scale {
273                        let log_min = min.ln();
274                        let log_max = max.ln();
275                        (rng.gen::<f32>() * (log_max - log_min) + log_min).exp()
276                    } else {
277                        rng.gen::<f32>() * (max - min) + min
278                    };
279                    ParameterValue::Float(val)
280                }
281                ParameterSpace::Integer { min, max } => {
282                    let val = rng.gen_range(*min..=*max);
283                    ParameterValue::Int(val)
284                }
285                ParameterSpace::Categorical { choices } => {
286                    let idx = rng.gen_range(0..choices.len());
287                    ParameterValue::String(choices[idx].clone())
288                }
289            };
290            config.insert(name.clone(), value);
291        }
292
293        config
294    }
295}
296
297/// Grid Search
298/// 
299/// Exhaustive search over specified parameter values.
300pub struct GridSearch {
301    parameter_grid: HashMap<String, Vec<ParameterValue>>,
302}
303
304impl GridSearch {
305    pub fn new(parameter_grid: HashMap<String, Vec<ParameterValue>>) -> Self {
306        Self { parameter_grid }
307    }
308
309    pub fn optimize<F>(&self, objective: F) -> (Configuration, f32)
310    where
311        F: Fn(&Configuration) -> f32,
312    {
313        let configurations = self.generate_configurations();
314        
315        let mut best_config = configurations[0].clone();
316        let mut best_score = objective(&best_config);
317
318        for config in configurations.iter().skip(1) {
319            let score = objective(config);
320            if score > best_score {
321                best_score = score;
322                best_config = config.clone();
323            }
324        }
325
326        (best_config, best_score)
327    }
328
329    fn generate_configurations(&self) -> Vec<Configuration> {
330        let mut configurations = vec![HashMap::new()];
331
332        for (name, values) in &self.parameter_grid {
333            let mut new_configurations = Vec::new();
334
335            for config in &configurations {
336                for value in values {
337                    let mut new_config = config.clone();
338                    new_config.insert(name.clone(), value.clone());
339                    new_configurations.push(new_config);
340                }
341            }
342
343            configurations = new_configurations;
344        }
345
346        configurations
347    }
348}
349
350/// Hyperband
351/// 
352/// Adaptive resource allocation and early-stopping algorithm.
353/// Efficiently allocates resources to promising configurations.
354pub struct Hyperband {
355    pub max_iter: usize,
356    pub eta: usize,
357    parameter_space: HashMap<String, ParameterSpace>,
358}
359
360impl Hyperband {
361    pub fn new(parameter_space: HashMap<String, ParameterSpace>) -> Self {
362        Self {
363            max_iter: 81,  // Maximum iterations per configuration
364            eta: 3,        // Downsampling rate
365            parameter_space,
366        }
367    }
368
369    pub fn max_iter(mut self, max_iter: usize) -> Self {
370        self.max_iter = max_iter;
371        self
372    }
373
374    pub fn eta(mut self, eta: usize) -> Self {
375        self.eta = eta;
376        self
377    }
378
379    /// Optimize with early stopping
380    /// 
381    /// The objective function receives (config, budget) and returns score
382    pub fn optimize<F>(&self, objective: F) -> (Configuration, f32)
383    where
384        F: Fn(&Configuration, usize) -> f32,
385    {
386        let mut rng = thread_rng();
387        let s_max = (self.max_iter as f32).log(self.eta as f32).floor() as usize;
388        let b = (s_max + 1) * self.max_iter;
389
390        let mut best_config = None;
391        let mut best_score = f32::NEG_INFINITY;
392
393        // Successive halving with different resource allocations
394        for s in (0..=s_max).rev() {
395            let n = ((b as f32 / self.max_iter as f32 / (s + 1) as f32) * (self.eta as f32).powi(s as i32)).ceil() as usize;
396            let r = self.max_iter * (self.eta as f32).powi(-(s as i32)) as usize;
397
398            // Generate n random configurations
399            let mut configs: Vec<(Configuration, f32)> = (0..n)
400                .map(|_| {
401                    let config = self.sample_random(&mut rng);
402                    let score = objective(&config, r);
403                    (config, score)
404                })
405                .collect();
406
407            // Successive halving
408            for i in 0..=s {
409                let n_i = (n as f32 * (self.eta as f32).powi(-(i as i32))).floor() as usize;
410                let r_i = r * (self.eta as f32).powi(i as i32) as usize;
411
412                // Evaluate all configurations with budget r_i
413                for (config, score) in configs.iter_mut() {
414                    *score = objective(config, r_i);
415                }
416
417                // Sort by score and keep top n_i / eta
418                configs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
419                let keep = (n_i as f32 / self.eta as f32).ceil() as usize;
420                configs.truncate(keep.min(configs.len()));
421            }
422
423            // Update best configuration
424            if let Some((config, score)) = configs.first() {
425                if *score > best_score {
426                    best_score = *score;
427                    best_config = Some(config.clone());
428                }
429            }
430        }
431
432        (best_config.unwrap(), best_score)
433    }
434
435    fn sample_random(&self, rng: &mut ThreadRng) -> Configuration {
436        let mut config = HashMap::new();
437
438        for (name, space) in &self.parameter_space {
439            let value = match space {
440                ParameterSpace::Continuous { min, max, log_scale } => {
441                    let val = if *log_scale {
442                        let log_min = min.ln();
443                        let log_max = max.ln();
444                        (rng.gen::<f32>() * (log_max - log_min) + log_min).exp()
445                    } else {
446                        rng.gen::<f32>() * (max - min) + min
447                    };
448                    ParameterValue::Float(val)
449                }
450                ParameterSpace::Integer { min, max } => {
451                    let val = rng.gen_range(*min..=*max);
452                    ParameterValue::Int(val)
453                }
454                ParameterSpace::Categorical { choices } => {
455                    let idx = rng.gen_range(0..choices.len());
456                    ParameterValue::String(choices[idx].clone())
457                }
458            };
459            config.insert(name.clone(), value);
460        }
461
462        config
463    }
464}
465
466/// BOHB (Bayesian Optimization and HyperBand)
467/// 
468/// Combines Bayesian optimization with Hyperband's adaptive resource allocation.
469/// Uses a tree-structured Parzen estimator (TPE) for configuration selection.
470pub struct BOHB {
471    pub max_iter: usize,
472    pub eta: usize,
473    pub min_points_in_model: usize,
474    pub top_n_percent: usize,
475    pub bandwidth_factor: f32,
476    parameter_space: HashMap<String, ParameterSpace>,
477    observations: Vec<(Configuration, usize, f32)>,  // (config, budget, score)
478}
479
480impl BOHB {
481    pub fn new(parameter_space: HashMap<String, ParameterSpace>) -> Self {
482        Self {
483            max_iter: 81,
484            eta: 3,
485            min_points_in_model: 10,
486            top_n_percent: 15,
487            bandwidth_factor: 3.0,
488            parameter_space,
489            observations: Vec::new(),
490        }
491    }
492
493    pub fn max_iter(mut self, max_iter: usize) -> Self {
494        self.max_iter = max_iter;
495        self
496    }
497
498    pub fn eta(mut self, eta: usize) -> Self {
499        self.eta = eta;
500        self
501    }
502
503    /// Optimize using BOHB
504    pub fn optimize<F>(&mut self, objective: F) -> (Configuration, f32)
505    where
506        F: Fn(&Configuration, usize) -> f32,
507    {
508        let mut rng = thread_rng();
509        let s_max = (self.max_iter as f32).log(self.eta as f32).floor() as usize;
510        let b = (s_max + 1) * self.max_iter;
511
512        let mut best_config = None;
513        let mut best_score = f32::NEG_INFINITY;
514
515        for s in (0..=s_max).rev() {
516            let n = ((b as f32 / self.max_iter as f32 / (s + 1) as f32) * (self.eta as f32).powi(s as i32)).ceil() as usize;
517            let r = self.max_iter * (self.eta as f32).powi(-(s as i32)) as usize;
518
519            // Generate configurations using TPE or random sampling
520            let mut configs: Vec<(Configuration, f32)> = (0..n)
521                .map(|_| {
522                    let config = if self.observations.len() >= self.min_points_in_model {
523                        self.sample_tpe(&mut rng)
524                    } else {
525                        self.sample_random(&mut rng)
526                    };
527                    let score = objective(&config, r);
528                    self.observations.push((config.clone(), r, score));
529                    (config, score)
530                })
531                .collect();
532
533            // Successive halving
534            for i in 0..=s {
535                let n_i = (n as f32 * (self.eta as f32).powi(-(i as i32))).floor() as usize;
536                let r_i = r * (self.eta as f32).powi(i as i32) as usize;
537
538                for (config, score) in configs.iter_mut() {
539                    *score = objective(config, r_i);
540                    self.observations.push((config.clone(), r_i, *score));
541                }
542
543                configs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
544                let keep = (n_i as f32 / self.eta as f32).ceil() as usize;
545                configs.truncate(keep.min(configs.len()));
546            }
547
548            if let Some((config, score)) = configs.first() {
549                if *score > best_score {
550                    best_score = *score;
551                    best_config = Some(config.clone());
552                }
553            }
554        }
555
556        (best_config.unwrap(), best_score)
557    }
558
559    fn sample_tpe(&self, rng: &mut ThreadRng) -> Configuration {
560        // Tree-structured Parzen Estimator sampling
561        // Split observations into good and bad based on top_n_percent
562        let mut sorted_obs = self.observations.clone();
563        sorted_obs.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap());
564
565        let split_idx = (sorted_obs.len() * self.top_n_percent / 100).max(1);
566        let good_obs: Vec<_> = sorted_obs.iter().take(split_idx).collect();
567        let bad_obs: Vec<_> = sorted_obs.iter().skip(split_idx).collect();
568
569        // Sample from good distribution
570        let mut config = HashMap::new();
571
572        for (name, space) in &self.parameter_space {
573            let value = match space {
574                ParameterSpace::Continuous { min, max, log_scale } => {
575                    // Build KDE from good observations
576                    let good_values: Vec<f32> = good_obs
577                        .iter()
578                        .filter_map(|(c, _, _)| {
579                            if let Some(ParameterValue::Float(v)) = c.get(name) {
580                                Some(*v)
581                            } else {
582                                None
583                            }
584                        })
585                        .collect();
586
587                    let val = if !good_values.is_empty() {
588                        // Sample from KDE
589                        let idx = rng.gen_range(0..good_values.len());
590                        let base = good_values[idx];
591                        let bandwidth = (max - min) / self.bandwidth_factor;
592                        let noise = rng.gen::<f32>() * bandwidth - bandwidth / 2.0;
593                        (base + noise).clamp(*min, *max)
594                    } else {
595                        // Fallback to random
596                        if *log_scale {
597                            let log_min = min.ln();
598                            let log_max = max.ln();
599                            (rng.gen::<f32>() * (log_max - log_min) + log_min).exp()
600                        } else {
601                            rng.gen::<f32>() * (max - min) + min
602                        }
603                    };
604                    ParameterValue::Float(val)
605                }
606                ParameterSpace::Integer { min, max } => {
607                    let good_values: Vec<i32> = good_obs
608                        .iter()
609                        .filter_map(|(c, _, _)| {
610                            if let Some(ParameterValue::Int(v)) = c.get(name) {
611                                Some(*v)
612                            } else {
613                                None
614                            }
615                        })
616                        .collect();
617
618                    let val = if !good_values.is_empty() {
619                        let idx = rng.gen_range(0..good_values.len());
620                        good_values[idx]
621                    } else {
622                        rng.gen_range(*min..=*max)
623                    };
624                    ParameterValue::Int(val)
625                }
626                ParameterSpace::Categorical { choices } => {
627                    let good_values: Vec<String> = good_obs
628                        .iter()
629                        .filter_map(|(c, _, _)| {
630                            if let Some(ParameterValue::String(v)) = c.get(name) {
631                                Some(v.clone())
632                            } else {
633                                None
634                            }
635                        })
636                        .collect();
637
638                    let val = if !good_values.is_empty() {
639                        let idx = rng.gen_range(0..good_values.len());
640                        good_values[idx].clone()
641                    } else {
642                        let idx = rng.gen_range(0..choices.len());
643                        choices[idx].clone()
644                    };
645                    ParameterValue::String(val)
646                }
647            };
648            config.insert(name.clone(), value);
649        }
650
651        config
652    }
653
654    fn sample_random(&self, rng: &mut ThreadRng) -> Configuration {
655        let mut config = HashMap::new();
656
657        for (name, space) in &self.parameter_space {
658            let value = match space {
659                ParameterSpace::Continuous { min, max, log_scale } => {
660                    let val = if *log_scale {
661                        let log_min = min.ln();
662                        let log_max = max.ln();
663                        (rng.gen::<f32>() * (log_max - log_min) + log_min).exp()
664                    } else {
665                        rng.gen::<f32>() * (max - min) + min
666                    };
667                    ParameterValue::Float(val)
668                }
669                ParameterSpace::Integer { min, max } => {
670                    let val = rng.gen_range(*min..=*max);
671                    ParameterValue::Int(val)
672                }
673                ParameterSpace::Categorical { choices } => {
674                    let idx = rng.gen_range(0..choices.len());
675                    ParameterValue::String(choices[idx].clone())
676                }
677            };
678            config.insert(name.clone(), value);
679        }
680
681        config
682    }
683}
684
685#[cfg(test)]
686mod tests {
687    use super::*;
688
689    #[test]
690    fn test_random_search() {
691        let mut param_space = HashMap::new();
692        param_space.insert(
693            "learning_rate".to_string(),
694            ParameterSpace::Continuous { min: 0.001, max: 0.1, log_scale: true },
695        );
696        param_space.insert(
697            "n_estimators".to_string(),
698            ParameterSpace::Integer { min: 10, max: 100 },
699        );
700
701        let rs = RandomSearch::new(param_space).n_iterations(10);
702
703        let (best_config, best_score) = rs.optimize(|config| {
704            // Dummy objective function
705            match config.get("learning_rate") {
706                Some(ParameterValue::Float(lr)) => *lr * 10.0,
707                _ => 0.0,
708            }
709        });
710
711        assert!(best_score > 0.0);
712        assert!(best_config.contains_key("learning_rate"));
713    }
714
715    #[test]
716    fn test_grid_search() {
717        let mut param_grid = HashMap::new();
718        param_grid.insert(
719            "param1".to_string(),
720            vec![ParameterValue::Float(0.1), ParameterValue::Float(0.2)],
721        );
722        param_grid.insert(
723            "param2".to_string(),
724            vec![ParameterValue::Int(10), ParameterValue::Int(20)],
725        );
726
727        let gs = GridSearch::new(param_grid);
728
729        let (best_config, _) = gs.optimize(|config| {
730            match (config.get("param1"), config.get("param2")) {
731                (Some(ParameterValue::Float(p1)), Some(ParameterValue::Int(p2))) => {
732                    p1 * (*p2 as f32)
733                }
734                _ => 0.0,
735            }
736        });
737
738        assert!(best_config.contains_key("param1"));
739        assert!(best_config.contains_key("param2"));
740    }
741
742    #[test]
743    fn test_hyperband() {
744        let mut param_space = HashMap::new();
745        param_space.insert(
746            "learning_rate".to_string(),
747            ParameterSpace::Continuous { min: 0.001, max: 0.1, log_scale: true },
748        );
749        param_space.insert(
750            "n_layers".to_string(),
751            ParameterSpace::Integer { min: 1, max: 5 },
752        );
753
754        let hb = Hyperband::new(param_space)
755            .max_iter(27)
756            .eta(3);
757
758        let (best_config, best_score) = hb.optimize(|config, budget| {
759            // Simulate training with budget (number of iterations)
760            let lr = match config.get("learning_rate") {
761                Some(ParameterValue::Float(v)) => *v,
762                _ => 0.01,
763            };
764            let n_layers = match config.get("n_layers") {
765                Some(ParameterValue::Int(v)) => *v,
766                _ => 2,
767            };
768
769            // Score improves with budget and depends on hyperparameters
770            let base_score = lr * 10.0 + n_layers as f32;
771            base_score * (budget as f32).sqrt() / 10.0
772        });
773
774        assert!(best_score > 0.0);
775        assert!(best_config.contains_key("learning_rate"));
776        assert!(best_config.contains_key("n_layers"));
777    }
778
779    #[test]
780    fn test_bohb() {
781        let mut param_space = HashMap::new();
782        param_space.insert(
783            "learning_rate".to_string(),
784            ParameterSpace::Continuous { min: 0.001, max: 0.1, log_scale: true },
785        );
786        param_space.insert(
787            "batch_size".to_string(),
788            ParameterSpace::Integer { min: 16, max: 128 },
789        );
790
791        let mut bohb = BOHB::new(param_space)
792            .max_iter(27)
793            .eta(3);
794
795        let (best_config, best_score) = bohb.optimize(|config, budget| {
796            let lr = match config.get("learning_rate") {
797                Some(ParameterValue::Float(v)) => *v,
798                _ => 0.01,
799            };
800            let batch_size = match config.get("batch_size") {
801                Some(ParameterValue::Int(v)) => *v,
802                _ => 32,
803            };
804
805            // Simulate validation score
806            let base_score = (lr * 100.0).ln() + (batch_size as f32 / 32.0);
807            base_score * (budget as f32).sqrt() / 5.0
808        });
809
810        assert!(best_score > 0.0);
811        assert!(best_config.contains_key("learning_rate"));
812        assert!(best_config.contains_key("batch_size"));
813    }
814
815    #[test]
816    fn test_bohb_tpe_sampling() {
817        let mut param_space = HashMap::new();
818        param_space.insert(
819            "x".to_string(),
820            ParameterSpace::Continuous { min: -5.0, max: 5.0, log_scale: false },
821        );
822
823        let mut bohb = BOHB::new(param_space)
824            .max_iter(9)
825            .eta(3);
826
827        // Optimize a simple quadratic function
828        let (best_config, best_score) = bohb.optimize(|config, _budget| {
829            let x = match config.get("x") {
830                Some(ParameterValue::Float(v)) => *v,
831                _ => 0.0,
832            };
833            // Maximize -(x-2)^2, optimum at x=2
834            -(x - 2.0).powi(2)
835        });
836
837        // Should find value close to 2
838        if let Some(ParameterValue::Float(x)) = best_config.get("x") {
839            assert!((x - 2.0).abs() < 1.0, "Expected x close to 2, got {}", x);
840        }
841        assert!(best_score > -2.0);
842    }
843}
844
845
846