use super::Pruner;
use crate::sampler::CompletedTrial;
use crate::types::{Direction, TrialState};
pub struct SuccessiveHalvingPruner {
min_resource: u64,
max_resource: u64,
reduction_factor: u64,
min_early_stopping_rate: u64,
direction: Direction,
}
impl SuccessiveHalvingPruner {
#[must_use]
pub fn new() -> Self {
Self {
min_resource: 1,
max_resource: 81,
reduction_factor: 3,
min_early_stopping_rate: 0,
direction: Direction::Minimize,
}
}
#[must_use]
pub fn min_resource(mut self, r: u64) -> Self {
assert!(r > 0, "min_resource must be > 0, got {r}");
self.min_resource = r;
self
}
#[must_use]
pub fn max_resource(mut self, r: u64) -> Self {
assert!(r > 0, "max_resource must be > 0, got {r}");
self.max_resource = r;
self
}
#[must_use]
pub fn reduction_factor(mut self, eta: u64) -> Self {
assert!(eta >= 2, "reduction_factor must be >= 2, got {eta}");
self.reduction_factor = eta;
self
}
#[must_use]
pub fn min_early_stopping_rate(mut self, n: u64) -> Self {
self.min_early_stopping_rate = n;
self
}
#[must_use]
pub fn direction(mut self, d: Direction) -> Self {
self.direction = d;
self
}
fn rung_steps(&self) -> Vec<u64> {
let eta = self.reduction_factor;
let mut steps = Vec::new();
let mut rung: u32 = 0;
while let Some(power) = eta.checked_pow(rung) {
let step = self.min_resource.saturating_mul(power);
if step > self.max_resource {
break;
}
if u64::from(rung) >= self.min_early_stopping_rate {
steps.push(step);
}
rung += 1;
}
steps
}
}
impl Default for SuccessiveHalvingPruner {
fn default() -> Self {
Self::new()
}
}
#[allow(clippy::cast_precision_loss)]
impl Pruner for SuccessiveHalvingPruner {
fn should_prune(
&self,
_trial_id: u64,
step: u64,
intermediate_values: &[(u64, f64)],
completed_trials: &[CompletedTrial],
) -> bool {
let rungs = self.rung_steps();
let Some(&rung_step) = rungs.iter().rev().find(|&&r| r <= step) else {
return false;
};
if rung_step >= self.max_resource {
return false;
}
let Some(&(_, current_value)) = intermediate_values.iter().find(|(s, _)| *s == rung_step)
else {
let Some(&(_, current_value)) = intermediate_values
.iter()
.rev()
.find(|(s, _)| *s <= rung_step)
else {
return false;
};
return self.is_pruned_at_rung(current_value, rung_step, completed_trials);
};
self.is_pruned_at_rung(current_value, rung_step, completed_trials)
}
}
impl SuccessiveHalvingPruner {
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
fn is_pruned_at_rung(
&self,
current_value: f64,
rung_step: u64,
completed_trials: &[CompletedTrial],
) -> bool {
let eta = self.reduction_factor as usize;
let mut values_at_rung: Vec<f64> = completed_trials
.iter()
.filter(|t| t.state == TrialState::Complete || t.state == TrialState::Pruned)
.filter_map(|t| {
t.intermediate_values
.iter()
.find(|(s, _)| *s == rung_step)
.map(|(_, v)| *v)
})
.collect();
if values_at_rung.len() < eta {
return false;
}
values_at_rung.push(current_value);
values_at_rung
.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
if self.direction == Direction::Maximize {
values_at_rung.reverse();
}
let n_keep = (values_at_rung.len() as f64 / eta as f64).ceil() as usize;
let threshold_idx = n_keep.max(1) - 1;
let threshold = values_at_rung[threshold_idx];
match self.direction {
Direction::Minimize => current_value > threshold,
Direction::Maximize => current_value < threshold,
}
}
}
#[cfg(test)]
#[allow(clippy::cast_precision_loss)]
mod tests {
use super::*;
fn make_trial(id: u64, values: &[(u64, f64)]) -> CompletedTrial {
use std::collections::HashMap;
use crate::parameter::ParamId;
CompletedTrial::with_intermediate_values(
id,
HashMap::<ParamId, crate::parameter::ParamValue>::new(),
HashMap::new(),
HashMap::new(),
0.0,
values.to_vec(),
HashMap::new(),
)
}
fn make_pruned_trial(id: u64, values: &[(u64, f64)]) -> CompletedTrial {
let mut t = make_trial(id, values);
t.state = TrialState::Pruned;
t
}
#[test]
fn rung_steps_default() {
let pruner = SuccessiveHalvingPruner::new();
let rungs = pruner.rung_steps();
assert_eq!(rungs, vec![1, 3, 9, 27, 81]);
}
#[test]
fn rung_steps_custom() {
let pruner = SuccessiveHalvingPruner::new()
.min_resource(2)
.max_resource(32)
.reduction_factor(2);
let rungs = pruner.rung_steps();
assert_eq!(rungs, vec![2, 4, 8, 16, 32]);
}
#[test]
fn rung_steps_with_early_stopping_rate() {
let pruner = SuccessiveHalvingPruner::new().min_early_stopping_rate(2);
let rungs = pruner.rung_steps();
assert_eq!(rungs, vec![9, 27, 81]);
}
#[test]
fn no_prune_before_first_rung() {
let pruner = SuccessiveHalvingPruner::new()
.min_resource(10)
.max_resource(100)
.reduction_factor(3);
let completed = vec![
make_trial(0, &[(5, 1.0)]),
make_trial(1, &[(5, 2.0)]),
make_trial(2, &[(5, 3.0)]),
];
assert!(!pruner.should_prune(3, 5, &[(5, 100.0)], &completed));
}
#[test]
fn no_prune_with_single_trial() {
let pruner = SuccessiveHalvingPruner::new();
let completed = vec![make_trial(0, &[(1, 5.0)]), make_trial(1, &[(1, 3.0)])];
assert!(!pruner.should_prune(2, 1, &[(1, 10.0)], &completed));
}
#[test]
fn prune_worst_trials_at_rung() {
let pruner = SuccessiveHalvingPruner::new().direction(Direction::Minimize);
let completed: Vec<_> = (0..9)
.map(|i| make_trial(i, &[(1, (i + 1) as f64)]))
.collect();
assert!(!pruner.should_prune(9, 1, &[(1, 3.0)], &completed));
assert!(pruner.should_prune(9, 1, &[(1, 5.0)], &completed));
}
#[test]
fn top_fraction_survives() {
let pruner = SuccessiveHalvingPruner::new().direction(Direction::Minimize);
let completed: Vec<_> = (0..6)
.map(|i| make_trial(i, &[(1, (i + 1) as f64)]))
.collect();
assert!(!pruner.should_prune(6, 1, &[(1, 2.0)], &completed));
assert!(!pruner.should_prune(6, 1, &[(1, 3.0)], &completed));
assert!(pruner.should_prune(6, 1, &[(1, 4.0)], &completed));
}
#[test]
fn maximize_direction() {
let pruner = SuccessiveHalvingPruner::new().direction(Direction::Maximize);
let completed: Vec<_> = (0..6)
.map(|i| make_trial(i, &[(1, (i + 1) as f64)]))
.collect();
assert!(!pruner.should_prune(6, 1, &[(1, 5.0)], &completed));
assert!(!pruner.should_prune(6, 1, &[(1, 4.0)], &completed));
assert!(pruner.should_prune(6, 1, &[(1, 3.0)], &completed));
}
#[test]
fn reduction_factor_2() {
let pruner = SuccessiveHalvingPruner::new()
.reduction_factor(2)
.min_resource(1)
.max_resource(16)
.direction(Direction::Minimize);
assert_eq!(pruner.rung_steps(), vec![1, 2, 4, 8, 16]);
let completed: Vec<_> = (0..4)
.map(|i| make_trial(i, &[(1, (i + 1) as f64)]))
.collect();
assert!(!pruner.should_prune(4, 1, &[(1, 3.0)], &completed));
assert!(pruner.should_prune(4, 1, &[(1, 4.0)], &completed));
}
#[test]
fn reduction_factor_4() {
let pruner = SuccessiveHalvingPruner::new()
.reduction_factor(4)
.min_resource(1)
.max_resource(64)
.direction(Direction::Minimize);
assert_eq!(pruner.rung_steps(), vec![1, 4, 16, 64]);
let completed: Vec<_> = (0..12)
.map(|i| make_trial(i, &[(1, (i + 1) as f64)]))
.collect();
assert!(!pruner.should_prune(12, 1, &[(1, 4.0)], &completed));
assert!(pruner.should_prune(12, 1, &[(1, 5.0)], &completed));
}
#[test]
fn non_contiguous_steps() {
let pruner = SuccessiveHalvingPruner::new().direction(Direction::Minimize);
let completed: Vec<_> = (0..6)
.map(|i| make_trial(i, &[(3, (i + 1) as f64)]))
.collect();
assert!(!pruner.should_prune(6, 5, &[(3, 2.0)], &completed));
assert!(pruner.should_prune(6, 5, &[(3, 5.0)], &completed));
}
#[test]
fn no_prune_at_max_resource() {
let pruner = SuccessiveHalvingPruner::new();
let completed: Vec<_> = (0..9)
.map(|i| make_trial(i, &[(81, (i + 1) as f64)]))
.collect();
assert!(!pruner.should_prune(9, 81, &[(81, 100.0)], &completed));
}
#[test]
fn includes_pruned_trials_in_comparison() {
let pruner = SuccessiveHalvingPruner::new().direction(Direction::Minimize);
let completed = vec![
make_trial(0, &[(1, 1.0)]),
make_trial(1, &[(1, 2.0)]),
make_pruned_trial(2, &[(1, 8.0)]),
make_pruned_trial(3, &[(1, 9.0)]),
make_pruned_trial(4, &[(1, 10.0)]),
];
assert!(!pruner.should_prune(5, 1, &[(1, 2.0)], &completed));
assert!(pruner.should_prune(5, 1, &[(1, 3.0)], &completed));
}
#[test]
#[should_panic(expected = "min_resource must be > 0")]
fn rejects_zero_min_resource() {
let _ = SuccessiveHalvingPruner::new().min_resource(0);
}
#[test]
#[should_panic(expected = "max_resource must be > 0")]
fn rejects_zero_max_resource() {
let _ = SuccessiveHalvingPruner::new().max_resource(0);
}
#[test]
#[should_panic(expected = "reduction_factor must be >= 2")]
fn rejects_reduction_factor_one() {
let _ = SuccessiveHalvingPruner::new().reduction_factor(1);
}
}