use super::Pruner;
use super::percentile::compute_percentile;
use crate::sampler::CompletedTrial;
use crate::types::{Direction, TrialState};
pub struct MedianPruner {
direction: Direction,
n_warmup_steps: u64,
n_min_trials: usize,
}
impl MedianPruner {
#[must_use]
pub fn new(direction: Direction) -> Self {
Self {
direction,
n_warmup_steps: 0,
n_min_trials: 1,
}
}
#[must_use]
pub fn n_warmup_steps(mut self, n: u64) -> Self {
self.n_warmup_steps = n;
self
}
#[must_use]
pub fn n_min_trials(mut self, n: usize) -> Self {
assert!(n >= 1, "n_min_trials must be >= 1, got {n}");
self.n_min_trials = n;
self
}
}
impl Pruner for MedianPruner {
fn should_prune(
&self,
_trial_id: u64,
step: u64,
intermediate_values: &[(u64, f64)],
completed_trials: &[CompletedTrial],
) -> bool {
if step < self.n_warmup_steps {
return false;
}
let Some(&(_, current_value)) = intermediate_values.last() else {
return false;
};
let mut values_at_step: Vec<f64> = completed_trials
.iter()
.filter(|t| t.state == TrialState::Complete)
.filter_map(|t| {
t.intermediate_values
.iter()
.find(|(s, _)| *s == step)
.map(|(_, v)| *v)
})
.collect();
if values_at_step.len() < self.n_min_trials {
return false;
}
let median = compute_percentile(&mut values_at_step, 50.0);
match self.direction {
Direction::Minimize => current_value > median,
Direction::Maximize => current_value < median,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn compute_median_odd() {
assert!((compute_percentile(&mut [3.0, 1.0, 2.0], 50.0) - 2.0).abs() < f64::EPSILON);
}
#[test]
fn compute_median_even() {
assert!((compute_percentile(&mut [4.0, 1.0, 3.0, 2.0], 50.0) - 2.5).abs() < f64::EPSILON);
}
#[test]
fn compute_median_single() {
assert!((compute_percentile(&mut [5.0], 50.0) - 5.0).abs() < f64::EPSILON);
}
}