use core::sync::atomic::{AtomicU64, Ordering};
use std::collections::HashMap;
use std::sync::Mutex;
use super::Pruner;
use crate::sampler::CompletedTrial;
use crate::types::{Direction, TrialState};
pub struct HyperbandPruner {
min_resource: u64,
max_resource: u64,
reduction_factor: u64,
direction: Direction,
trial_brackets: Mutex<HashMap<u64, usize>>,
next_bracket: AtomicU64,
}
impl HyperbandPruner {
#[must_use]
pub fn new() -> Self {
Self {
min_resource: 1,
max_resource: 81,
reduction_factor: 3,
direction: Direction::Minimize,
trial_brackets: Mutex::new(HashMap::new()),
next_bracket: AtomicU64::new(0),
}
}
#[must_use]
pub fn min_resource(mut self, r: u64) -> Self {
assert!(r > 0, "min_resource must be > 0, got {r}");
self.min_resource = r;
self
}
#[must_use]
pub fn max_resource(mut self, r: u64) -> Self {
assert!(r > 0, "max_resource must be > 0, got {r}");
self.max_resource = r;
self
}
#[must_use]
pub fn reduction_factor(mut self, eta: u64) -> Self {
assert!(eta >= 2, "reduction_factor must be >= 2, got {eta}");
self.reduction_factor = eta;
self
}
#[must_use]
pub fn direction(mut self, d: Direction) -> Self {
self.direction = d;
self
}
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
fn s_max(&self) -> u64 {
let eta = self.reduction_factor as f64;
let ratio = self.max_resource as f64 / self.min_resource as f64;
(ratio.ln() / eta.ln()).floor() as u64
}
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
fn rung_steps_for_bracket(&self, bracket: usize) -> Vec<u64> {
let s_max = self.s_max();
let eta = self.reduction_factor as f64;
let exponent = s_max.saturating_sub(bracket as u64);
let min_resource_bracket =
(self.max_resource as f64 / eta.powi(exponent as i32)).ceil() as u64;
let mut steps = Vec::new();
let mut rung: u32 = 0;
while let Some(power) = self.reduction_factor.checked_pow(rung) {
let step = min_resource_bracket.saturating_mul(power);
if step > self.max_resource {
break;
}
steps.push(step);
rung += 1;
}
steps
}
#[allow(clippy::cast_possible_truncation)]
fn assign_bracket(&self, trial_id: u64) -> usize {
let n_brackets = (self.s_max() + 1) as usize;
let mut map = self.trial_brackets.lock().expect("lock poisoned");
*map.entry(trial_id).or_insert_with(|| {
let idx = self.next_bracket.fetch_add(1, Ordering::Relaxed);
(idx as usize) % n_brackets
})
}
}
impl Default for HyperbandPruner {
fn default() -> Self {
Self::new()
}
}
#[allow(clippy::cast_precision_loss)]
impl Pruner for HyperbandPruner {
fn should_prune(
&self,
trial_id: u64,
step: u64,
intermediate_values: &[(u64, f64)],
completed_trials: &[CompletedTrial],
) -> bool {
let bracket = self.assign_bracket(trial_id);
let rungs = self.rung_steps_for_bracket(bracket);
let Some(&rung_step) = rungs.iter().rev().find(|&&r| r <= step) else {
return false;
};
if rung_step >= self.max_resource {
return false;
}
let current_value =
if let Some(&(_, v)) = intermediate_values.iter().find(|(s, _)| *s == rung_step) {
v
} else if let Some(&(_, v)) = intermediate_values
.iter()
.rev()
.find(|(s, _)| *s <= rung_step)
{
v
} else {
return false;
};
self.is_pruned_at_rung(current_value, rung_step, bracket, completed_trials)
}
}
impl HyperbandPruner {
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
fn is_pruned_at_rung(
&self,
current_value: f64,
rung_step: u64,
bracket: usize,
completed_trials: &[CompletedTrial],
) -> bool {
let eta = self.reduction_factor as usize;
let map = self.trial_brackets.lock().expect("lock poisoned");
let mut values_at_rung: Vec<f64> = completed_trials
.iter()
.filter(|t| t.state == TrialState::Complete || t.state == TrialState::Pruned)
.filter(|t| map.get(&t.id).copied() == Some(bracket))
.filter_map(|t| {
t.intermediate_values
.iter()
.find(|(s, _)| *s == rung_step)
.map(|(_, v)| *v)
})
.collect();
drop(map);
if values_at_rung.len() < eta {
return false;
}
values_at_rung.push(current_value);
values_at_rung
.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
if self.direction == Direction::Maximize {
values_at_rung.reverse();
}
let n_keep = (values_at_rung.len() as f64 / eta as f64).ceil() as usize;
let threshold_idx = n_keep.max(1) - 1;
let threshold = values_at_rung[threshold_idx];
match self.direction {
Direction::Minimize => current_value > threshold,
Direction::Maximize => current_value < threshold,
}
}
}
#[cfg(test)]
#[allow(clippy::cast_precision_loss)]
mod tests {
use super::*;
fn make_trial(id: u64, values: &[(u64, f64)]) -> CompletedTrial {
use std::collections::HashMap;
use crate::parameter::ParamId;
CompletedTrial::with_intermediate_values(
id,
HashMap::<ParamId, crate::parameter::ParamValue>::new(),
HashMap::new(),
HashMap::new(),
0.0,
values.to_vec(),
HashMap::new(),
)
}
fn make_pruned_trial(id: u64, values: &[(u64, f64)]) -> CompletedTrial {
let mut t = make_trial(id, values);
t.state = TrialState::Pruned;
t
}
#[test]
fn s_max_default() {
let pruner = HyperbandPruner::new();
assert_eq!(pruner.s_max(), 4);
}
#[test]
fn s_max_custom() {
let pruner = HyperbandPruner::new()
.min_resource(1)
.max_resource(16)
.reduction_factor(2);
assert_eq!(pruner.s_max(), 4);
}
#[test]
fn bracket_count() {
let pruner = HyperbandPruner::new();
assert_eq!(pruner.s_max() + 1, 5);
}
#[test]
fn rung_steps_bracket_0_default() {
let pruner = HyperbandPruner::new();
assert_eq!(pruner.rung_steps_for_bracket(0), vec![1, 3, 9, 27, 81]);
}
#[test]
fn rung_steps_bracket_2_default() {
let pruner = HyperbandPruner::new();
assert_eq!(pruner.rung_steps_for_bracket(2), vec![9, 27, 81]);
}
#[test]
fn rung_steps_bracket_4_default() {
let pruner = HyperbandPruner::new();
assert_eq!(pruner.rung_steps_for_bracket(4), vec![81]);
}
#[test]
fn rung_steps_eta2() {
let pruner = HyperbandPruner::new()
.min_resource(1)
.max_resource(16)
.reduction_factor(2);
assert_eq!(pruner.rung_steps_for_bracket(0), vec![1, 2, 4, 8, 16]);
assert_eq!(pruner.rung_steps_for_bracket(2), vec![4, 8, 16]);
assert_eq!(pruner.rung_steps_for_bracket(4), vec![16]);
}
#[test]
fn round_robin_bracket_assignment() {
let pruner = HyperbandPruner::new(); assert_eq!(pruner.assign_bracket(100), 0);
assert_eq!(pruner.assign_bracket(101), 1);
assert_eq!(pruner.assign_bracket(102), 2);
assert_eq!(pruner.assign_bracket(103), 3);
assert_eq!(pruner.assign_bracket(104), 4);
assert_eq!(pruner.assign_bracket(105), 0);
assert_eq!(pruner.assign_bracket(100), 0);
assert_eq!(pruner.assign_bracket(103), 3);
}
#[test]
fn no_prune_before_first_rung() {
let pruner = HyperbandPruner::new().direction(Direction::Minimize);
pruner.assign_bracket(0);
let mut completed = Vec::new();
for i in 1..=9 {
pruner.assign_bracket(i);
completed.push(make_trial(i, &[(1, i as f64)]));
}
assert!(!pruner.should_prune(0, 0, &[(0, 100.0)], &completed));
}
#[test]
fn no_prune_at_max_resource() {
let pruner = HyperbandPruner::new().direction(Direction::Minimize);
let mut completed = Vec::new();
for i in 0..9 {
pruner.assign_bracket(i);
completed.push(make_trial(i, &[(81, (i + 1) as f64)]));
}
let trial_id = 9;
pruner.assign_bracket(trial_id);
assert!(!pruner.should_prune(trial_id, 81, &[(81, 100.0)], &completed));
}
#[test]
fn prune_worst_in_bracket_minimize() {
let pruner = HyperbandPruner::new().direction(Direction::Minimize);
let bracket_0_ids: Vec<u64> = (0..5).map(|i| i * 5).collect();
for i in 0..25 {
pruner.assign_bracket(i);
}
let completed: Vec<_> = bracket_0_ids
.iter()
.take(3)
.enumerate()
.map(|(idx, &id)| make_trial(id, &[(1, (idx + 1) as f64)]))
.collect();
let test_id = 25;
pruner.assign_bracket(test_id);
assert_eq!(pruner.assign_bracket(test_id), 0);
assert!(!pruner.should_prune(test_id, 1, &[(1, 2.0)], &completed));
assert!(pruner.should_prune(test_id, 1, &[(1, 3.0)], &completed));
}
#[test]
fn prune_worst_in_bracket_maximize() {
let pruner = HyperbandPruner::new().direction(Direction::Maximize);
for i in 0..25 {
pruner.assign_bracket(i);
}
let completed: Vec<_> = [0u64, 5, 10]
.iter()
.enumerate()
.map(|(idx, &id)| make_trial(id, &[(1, (idx + 1) as f64)]))
.collect();
let test_id = 25;
pruner.assign_bracket(test_id);
assert!(!pruner.should_prune(test_id, 1, &[(1, 2.0)], &completed));
assert!(pruner.should_prune(test_id, 1, &[(1, 0.5)], &completed));
}
#[test]
fn different_brackets_have_different_aggressiveness() {
let pruner = HyperbandPruner::new()
.min_resource(1)
.max_resource(81)
.reduction_factor(3)
.direction(Direction::Minimize);
let rungs_0 = pruner.rung_steps_for_bracket(0);
let rungs_2 = pruner.rung_steps_for_bracket(2);
let rungs_4 = pruner.rung_steps_for_bracket(4);
assert!(rungs_0.len() > rungs_2.len());
assert_eq!(rungs_4.len(), 1);
assert!(rungs_0[0] < rungs_2[0]);
}
#[test]
fn trials_in_different_brackets_independent() {
let pruner = HyperbandPruner::new().direction(Direction::Minimize);
for i in 0..25 {
pruner.assign_bracket(i);
}
let bracket_0_trials: Vec<_> = [0u64, 5, 10]
.iter()
.map(|&id| make_trial(id, &[(1, 100.0)]))
.collect();
let bracket_1_trials: Vec<_> = [1u64, 6, 11]
.iter()
.map(|&id| make_trial(id, &[(1, 1.0)]))
.collect();
let mut all_trials = bracket_0_trials;
all_trials.extend(bracket_1_trials);
let test_id = 25; pruner.assign_bracket(test_id);
assert!(!pruner.should_prune(test_id, 1, &[(1, 50.0)], &all_trials));
}
#[test]
fn includes_pruned_trials() {
let pruner = HyperbandPruner::new().direction(Direction::Minimize);
for i in 0..25 {
pruner.assign_bracket(i);
}
let completed = vec![
make_trial(0, &[(1, 1.0)]),
make_pruned_trial(5, &[(1, 8.0)]),
make_pruned_trial(10, &[(1, 9.0)]),
];
let test_id = 25;
pruner.assign_bracket(test_id);
assert!(!pruner.should_prune(test_id, 1, &[(1, 1.0)], &completed));
assert!(!pruner.should_prune(test_id, 1, &[(1, 5.0)], &completed));
assert!(!pruner.should_prune(test_id, 1, &[(1, 6.0)], &completed));
assert!(pruner.should_prune(test_id, 1, &[(1, 9.5)], &completed));
}
#[test]
#[should_panic(expected = "min_resource must be > 0")]
fn rejects_zero_min_resource() {
let _ = HyperbandPruner::new().min_resource(0);
}
#[test]
#[should_panic(expected = "max_resource must be > 0")]
fn rejects_zero_max_resource() {
let _ = HyperbandPruner::new().max_resource(0);
}
#[test]
#[should_panic(expected = "reduction_factor must be >= 2")]
fn rejects_reduction_factor_one() {
let _ = HyperbandPruner::new().reduction_factor(1);
}
}