Skip to main content

entrenar/optim/hpo/
hyperband.rs

1//! Hyperband scheduler for efficient hyperparameter search
2//!
3//! Based on Li et al. (2018) - Hyperband: A Novel Bandit-Based Approach
4
5use std::collections::HashMap;
6
7use super::types::{HyperparameterSpace, ParameterValue};
8
9/// Hyperband scheduler for efficient hyperparameter search
10///
11/// # Toyota Way: Muda (Waste Elimination)
12///
13/// Aggressive early stopping eliminates poorly performing configurations,
14/// focusing resources on promising candidates.
15#[derive(Debug, Clone)]
16pub struct HyperbandScheduler {
17    /// Maximum iterations per configuration
18    pub(crate) max_iter: usize,
19    /// Reduction factor (typically 3)
20    pub(crate) eta: f64,
21    /// Search space
22    space: HyperparameterSpace,
23}
24
25impl HyperbandScheduler {
26    /// Create a new Hyperband scheduler
27    pub fn new(space: HyperparameterSpace, max_iter: usize) -> Self {
28        Self { max_iter, eta: 3.0, space }
29    }
30
31    /// Set reduction factor
32    pub fn with_eta(mut self, eta: f64) -> Self {
33        self.eta = eta.max(2.0);
34        self
35    }
36
37    /// Get s_max (number of successive halving brackets)
38    pub fn s_max(&self) -> usize {
39        (self.max_iter as f64).log(self.eta).floor() as usize
40    }
41
42    /// Get total budget B
43    pub fn budget(&self) -> usize {
44        (self.s_max() + 1) * self.max_iter
45    }
46
47    /// Generate bracket configurations
48    ///
49    /// Returns Vec of (n_configs, n_iterations) for each rung in the bracket
50    pub fn bracket(&self, s: usize) -> Vec<(usize, usize)> {
51        let s_max = self.s_max();
52        if s > s_max {
53            return Vec::new();
54        }
55
56        let n = ((self.budget() as f64 / self.max_iter as f64)
57            * (self.eta.powi(s as i32) / (s + 1) as f64))
58            .ceil() as usize;
59        let r = self.max_iter / self.eta.powi(s as i32) as usize;
60
61        (0..=s)
62            .map(|i| {
63                let n_i = (n as f64 / self.eta.powi(i as i32)).floor() as usize;
64                let r_i = (r as f64 * self.eta.powi(i as i32)).floor() as usize;
65                (n_i.max(1), r_i.max(1))
66            })
67            .collect()
68    }
69
70    /// Generate all configurations for a bracket
71    pub fn generate_configs(&self, n: usize) -> Vec<HashMap<String, ParameterValue>> {
72        let mut rng = rand::rng();
73        (0..n).map(|_| self.space.sample_random(&mut rng)).collect()
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80    use crate::optim::hpo::types::ParameterDomain;
81
82    #[test]
83    fn test_hyperband_new() {
84        let space = HyperparameterSpace::new();
85        let hb = HyperbandScheduler::new(space, 81);
86        assert_eq!(hb.max_iter, 81);
87        assert!((hb.eta - 3.0).abs() < 1e-10);
88    }
89
90    #[test]
91    fn test_hyperband_s_max() {
92        let space = HyperparameterSpace::new();
93        let hb = HyperbandScheduler::new(space, 81);
94        // log_3(81) = 4
95        assert_eq!(hb.s_max(), 4);
96    }
97
98    #[test]
99    fn test_hyperband_budget() {
100        let space = HyperparameterSpace::new();
101        let hb = HyperbandScheduler::new(space, 81);
102        // B = (s_max + 1) * max_iter = 5 * 81 = 405
103        assert_eq!(hb.budget(), 405);
104    }
105
106    #[test]
107    fn test_hyperband_bracket() {
108        let space = HyperparameterSpace::new();
109        let hb = HyperbandScheduler::new(space, 81);
110
111        // Bracket s=4 should start with most configs and least resources
112        let bracket = hb.bracket(4);
113        assert!(!bracket.is_empty());
114
115        // First rung should have more configs than last
116        let (n_first, r_first) = bracket.first().expect("collection should not be empty");
117        let (n_last, r_last) = bracket.last().expect("collection should not be empty");
118        assert!(*n_first >= *n_last);
119        assert!(*r_first <= *r_last);
120    }
121
122    #[test]
123    fn test_hyperband_generate_configs() {
124        let mut space = HyperparameterSpace::new();
125        space.add("lr", ParameterDomain::Continuous { low: 0.0, high: 1.0, log_scale: false });
126
127        let hb = HyperbandScheduler::new(space, 81);
128        let configs = hb.generate_configs(10);
129        assert_eq!(configs.len(), 10);
130    }
131
132    #[test]
133    fn test_hyperband_with_eta() {
134        let space = HyperparameterSpace::new();
135        let hb = HyperbandScheduler::new(space, 81).with_eta(4.0);
136        assert!((hb.eta - 4.0).abs() < 1e-10);
137    }
138
139    #[test]
140    fn test_hyperband_bracket_invalid_s() {
141        let space = HyperparameterSpace::new();
142        let hb = HyperbandScheduler::new(space, 81);
143        let bracket = hb.bracket(100); // s > s_max
144        assert!(bracket.is_empty());
145    }
146}
147
148#[cfg(test)]
149mod property_tests {
150    use super::*;
151    use proptest::prelude::*;
152
153    proptest! {
154        #![proptest_config(ProptestConfig::with_cases(200))]
155
156        #[test]
157        fn prop_hyperband_bracket_nonempty(max_iter in 9usize..243, eta in 2.0f64..5.0) {
158            let space = HyperparameterSpace::new();
159            let hb = HyperbandScheduler::new(space, max_iter).with_eta(eta);
160            let s_max = hb.s_max();
161            for s in 0..=s_max {
162                let bracket = hb.bracket(s);
163                prop_assert!(!bracket.is_empty());
164            }
165        }
166    }
167}