entrenar/optim/hpo/
hyperband.rs1use std::collections::HashMap;
6
7use super::types::{HyperparameterSpace, ParameterValue};
8
9#[derive(Debug, Clone)]
16pub struct HyperbandScheduler {
17 pub(crate) max_iter: usize,
19 pub(crate) eta: f64,
21 space: HyperparameterSpace,
23}
24
25impl HyperbandScheduler {
26 pub fn new(space: HyperparameterSpace, max_iter: usize) -> Self {
28 Self { max_iter, eta: 3.0, space }
29 }
30
31 pub fn with_eta(mut self, eta: f64) -> Self {
33 self.eta = eta.max(2.0);
34 self
35 }
36
37 pub fn s_max(&self) -> usize {
39 (self.max_iter as f64).log(self.eta).floor() as usize
40 }
41
42 pub fn budget(&self) -> usize {
44 (self.s_max() + 1) * self.max_iter
45 }
46
47 pub fn bracket(&self, s: usize) -> Vec<(usize, usize)> {
51 let s_max = self.s_max();
52 if s > s_max {
53 return Vec::new();
54 }
55
56 let n = ((self.budget() as f64 / self.max_iter as f64)
57 * (self.eta.powi(s as i32) / (s + 1) as f64))
58 .ceil() as usize;
59 let r = self.max_iter / self.eta.powi(s as i32) as usize;
60
61 (0..=s)
62 .map(|i| {
63 let n_i = (n as f64 / self.eta.powi(i as i32)).floor() as usize;
64 let r_i = (r as f64 * self.eta.powi(i as i32)).floor() as usize;
65 (n_i.max(1), r_i.max(1))
66 })
67 .collect()
68 }
69
70 pub fn generate_configs(&self, n: usize) -> Vec<HashMap<String, ParameterValue>> {
72 let mut rng = rand::rng();
73 (0..n).map(|_| self.space.sample_random(&mut rng)).collect()
74 }
75}
76
77#[cfg(test)]
78mod tests {
79 use super::*;
80 use crate::optim::hpo::types::ParameterDomain;
81
82 #[test]
83 fn test_hyperband_new() {
84 let space = HyperparameterSpace::new();
85 let hb = HyperbandScheduler::new(space, 81);
86 assert_eq!(hb.max_iter, 81);
87 assert!((hb.eta - 3.0).abs() < 1e-10);
88 }
89
90 #[test]
91 fn test_hyperband_s_max() {
92 let space = HyperparameterSpace::new();
93 let hb = HyperbandScheduler::new(space, 81);
94 assert_eq!(hb.s_max(), 4);
96 }
97
98 #[test]
99 fn test_hyperband_budget() {
100 let space = HyperparameterSpace::new();
101 let hb = HyperbandScheduler::new(space, 81);
102 assert_eq!(hb.budget(), 405);
104 }
105
106 #[test]
107 fn test_hyperband_bracket() {
108 let space = HyperparameterSpace::new();
109 let hb = HyperbandScheduler::new(space, 81);
110
111 let bracket = hb.bracket(4);
113 assert!(!bracket.is_empty());
114
115 let (n_first, r_first) = bracket.first().expect("collection should not be empty");
117 let (n_last, r_last) = bracket.last().expect("collection should not be empty");
118 assert!(*n_first >= *n_last);
119 assert!(*r_first <= *r_last);
120 }
121
122 #[test]
123 fn test_hyperband_generate_configs() {
124 let mut space = HyperparameterSpace::new();
125 space.add("lr", ParameterDomain::Continuous { low: 0.0, high: 1.0, log_scale: false });
126
127 let hb = HyperbandScheduler::new(space, 81);
128 let configs = hb.generate_configs(10);
129 assert_eq!(configs.len(), 10);
130 }
131
132 #[test]
133 fn test_hyperband_with_eta() {
134 let space = HyperparameterSpace::new();
135 let hb = HyperbandScheduler::new(space, 81).with_eta(4.0);
136 assert!((hb.eta - 4.0).abs() < 1e-10);
137 }
138
139 #[test]
140 fn test_hyperband_bracket_invalid_s() {
141 let space = HyperparameterSpace::new();
142 let hb = HyperbandScheduler::new(space, 81);
143 let bracket = hb.bracket(100); assert!(bracket.is_empty());
145 }
146}
147
148#[cfg(test)]
149mod property_tests {
150 use super::*;
151 use proptest::prelude::*;
152
153 proptest! {
154 #![proptest_config(ProptestConfig::with_cases(200))]
155
156 #[test]
157 fn prop_hyperband_bracket_nonempty(max_iter in 9usize..243, eta in 2.0f64..5.0) {
158 let space = HyperparameterSpace::new();
159 let hb = HyperbandScheduler::new(space, max_iter).with_eta(eta);
160 let s_max = hb.s_max();
161 for s in 0..=s_max {
162 let bracket = hb.bracket(s);
163 prop_assert!(!bracket.is_empty());
164 }
165 }
166 }
167}