use std::collections::HashMap;
use super::types::{HyperparameterSpace, ParameterValue};
#[derive(Debug, Clone)]
pub struct HyperbandScheduler {
pub(crate) max_iter: usize,
pub(crate) eta: f64,
space: HyperparameterSpace,
}
impl HyperbandScheduler {
pub fn new(space: HyperparameterSpace, max_iter: usize) -> Self {
Self { max_iter, eta: 3.0, space }
}
pub fn with_eta(mut self, eta: f64) -> Self {
self.eta = eta.max(2.0);
self
}
pub fn s_max(&self) -> usize {
(self.max_iter as f64).log(self.eta).floor() as usize
}
pub fn budget(&self) -> usize {
(self.s_max() + 1) * self.max_iter
}
pub fn bracket(&self, s: usize) -> Vec<(usize, usize)> {
let s_max = self.s_max();
if s > s_max {
return Vec::new();
}
let n = ((self.budget() as f64 / self.max_iter as f64)
* (self.eta.powi(s as i32) / (s + 1) as f64))
.ceil() as usize;
let r = self.max_iter / self.eta.powi(s as i32) as usize;
(0..=s)
.map(|i| {
let n_i = (n as f64 / self.eta.powi(i as i32)).floor() as usize;
let r_i = (r as f64 * self.eta.powi(i as i32)).floor() as usize;
(n_i.max(1), r_i.max(1))
})
.collect()
}
pub fn generate_configs(&self, n: usize) -> Vec<HashMap<String, ParameterValue>> {
let mut rng = rand::rng();
(0..n).map(|_| self.space.sample_random(&mut rng)).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::optim::hpo::types::ParameterDomain;
#[test]
fn test_hyperband_new() {
let space = HyperparameterSpace::new();
let hb = HyperbandScheduler::new(space, 81);
assert_eq!(hb.max_iter, 81);
assert!((hb.eta - 3.0).abs() < 1e-10);
}
#[test]
fn test_hyperband_s_max() {
let space = HyperparameterSpace::new();
let hb = HyperbandScheduler::new(space, 81);
assert_eq!(hb.s_max(), 4);
}
#[test]
fn test_hyperband_budget() {
let space = HyperparameterSpace::new();
let hb = HyperbandScheduler::new(space, 81);
assert_eq!(hb.budget(), 405);
}
#[test]
fn test_hyperband_bracket() {
let space = HyperparameterSpace::new();
let hb = HyperbandScheduler::new(space, 81);
let bracket = hb.bracket(4);
assert!(!bracket.is_empty());
let (n_first, r_first) = bracket.first().expect("collection should not be empty");
let (n_last, r_last) = bracket.last().expect("collection should not be empty");
assert!(*n_first >= *n_last);
assert!(*r_first <= *r_last);
}
#[test]
fn test_hyperband_generate_configs() {
let mut space = HyperparameterSpace::new();
space.add("lr", ParameterDomain::Continuous { low: 0.0, high: 1.0, log_scale: false });
let hb = HyperbandScheduler::new(space, 81);
let configs = hb.generate_configs(10);
assert_eq!(configs.len(), 10);
}
#[test]
fn test_hyperband_with_eta() {
let space = HyperparameterSpace::new();
let hb = HyperbandScheduler::new(space, 81).with_eta(4.0);
assert!((hb.eta - 4.0).abs() < 1e-10);
}
#[test]
fn test_hyperband_bracket_invalid_s() {
let space = HyperparameterSpace::new();
let hb = HyperbandScheduler::new(space, 81);
let bracket = hb.bracket(100); assert!(bracket.is_empty());
}
}
#[cfg(test)]
mod property_tests {
use super::*;
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(200))]
#[test]
fn prop_hyperband_bracket_nonempty(max_iter in 9usize..243, eta in 2.0f64..5.0) {
let space = HyperparameterSpace::new();
let hb = HyperbandScheduler::new(space, max_iter).with_eta(eta);
let s_max = hb.s_max();
for s in 0..=s_max {
let bracket = hb.bracket(s);
prop_assert!(!bracket.is_empty());
}
}
}
}