use super::Pruner;
use crate::sampler::CompletedTrial;
use crate::types::{Direction, TrialState};
pub struct PercentilePruner {
percentile: f64,
n_warmup_steps: u64,
n_min_trials: usize,
direction: Direction,
}
impl PercentilePruner {
#[must_use]
pub fn new(percentile: f64, direction: Direction) -> Self {
assert!(
percentile > 0.0 && percentile < 100.0,
"percentile must be in (0.0, 100.0), got {percentile}"
);
Self {
percentile,
n_warmup_steps: 0,
n_min_trials: 1,
direction,
}
}
#[must_use]
pub fn n_warmup_steps(mut self, n: u64) -> Self {
self.n_warmup_steps = n;
self
}
#[must_use]
pub fn n_min_trials(mut self, n: usize) -> Self {
assert!(n >= 1, "n_min_trials must be >= 1, got {n}");
self.n_min_trials = n;
self
}
}
impl Pruner for PercentilePruner {
fn should_prune(
&self,
_trial_id: u64,
step: u64,
intermediate_values: &[(u64, f64)],
completed_trials: &[CompletedTrial],
) -> bool {
if step < self.n_warmup_steps {
return false;
}
let Some(&(_, current_value)) = intermediate_values.last() else {
return false;
};
let mut values_at_step: Vec<f64> = completed_trials
.iter()
.filter(|t| t.state == TrialState::Complete)
.filter_map(|t| {
t.intermediate_values
.iter()
.find(|(s, _)| *s == step)
.map(|(_, v)| *v)
})
.collect();
if values_at_step.len() < self.n_min_trials {
return false;
}
let threshold = compute_percentile(&mut values_at_step, self.percentile);
match self.direction {
Direction::Minimize => current_value > threshold,
Direction::Maximize => current_value < threshold,
}
}
}
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
pub(crate) fn compute_percentile(values: &mut [f64], percentile: f64) -> f64 {
assert!(!values.is_empty(), "compute_percentile: empty input");
values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
let len = values.len();
if len == 1 {
return values[0];
}
let rank = percentile / 100.0 * (len - 1) as f64;
let lower = rank.floor() as usize;
let upper = rank.ceil() as usize;
if lower == upper {
values[lower]
} else {
let frac = rank - lower as f64;
values[lower] * (1.0 - frac) + values[upper] * frac
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn compute_percentile_median_odd() {
let val = compute_percentile(&mut [3.0, 1.0, 2.0], 50.0);
assert!((val - 2.0).abs() < f64::EPSILON);
}
#[test]
fn compute_percentile_median_even() {
let val = compute_percentile(&mut [4.0, 1.0, 3.0, 2.0], 50.0);
assert!((val - 2.5).abs() < f64::EPSILON);
}
#[test]
fn compute_percentile_25() {
let val = compute_percentile(&mut [4.0, 1.0, 3.0, 2.0], 25.0);
assert!((val - 1.75).abs() < f64::EPSILON);
}
#[test]
fn compute_percentile_75() {
let val = compute_percentile(&mut [4.0, 1.0, 3.0, 2.0], 75.0);
assert!((val - 3.25).abs() < f64::EPSILON);
}
#[test]
fn compute_percentile_single() {
let val = compute_percentile(&mut [5.0], 50.0);
assert!((val - 5.0).abs() < f64::EPSILON);
}
#[test]
#[should_panic(expected = "percentile must be in (0.0, 100.0)")]
fn new_rejects_zero() {
let _ = PercentilePruner::new(0.0, Direction::Minimize);
}
#[test]
#[should_panic(expected = "percentile must be in (0.0, 100.0)")]
fn new_rejects_hundred() {
let _ = PercentilePruner::new(100.0, Direction::Minimize);
}
fn make_completed_trial(id: u64, values: &[(u64, f64)]) -> CompletedTrial {
use std::collections::HashMap;
use crate::parameter::ParamId;
CompletedTrial::with_intermediate_values(
id,
HashMap::<ParamId, crate::parameter::ParamValue>::new(),
HashMap::new(),
HashMap::new(),
0.0,
values.to_vec(),
HashMap::new(),
)
}
#[test]
fn percentile_50_matches_median_behavior() {
let pruner = PercentilePruner::new(50.0, Direction::Minimize);
let completed = vec![
make_completed_trial(0, &[(0, 1.0), (1, 2.0)]),
make_completed_trial(1, &[(0, 3.0), (1, 4.0)]),
make_completed_trial(2, &[(0, 5.0), (1, 6.0)]),
];
assert!(pruner.should_prune(3, 1, &[(0, 3.0), (1, 5.0)], &completed));
assert!(!pruner.should_prune(3, 1, &[(0, 3.0), (1, 3.0)], &completed));
}
#[test]
fn percentile_25_is_more_aggressive() {
let pruner_25 = PercentilePruner::new(25.0, Direction::Minimize);
let pruner_75 = PercentilePruner::new(75.0, Direction::Minimize);
let completed = vec![
make_completed_trial(0, &[(0, 1.0)]),
make_completed_trial(1, &[(0, 2.0)]),
make_completed_trial(2, &[(0, 3.0)]),
make_completed_trial(3, &[(0, 4.0)]),
];
assert!(pruner_25.should_prune(4, 0, &[(0, 2.5)], &completed));
assert!(!pruner_75.should_prune(4, 0, &[(0, 2.5)], &completed));
}
#[test]
fn warmup_prevents_pruning() {
let pruner = PercentilePruner::new(50.0, Direction::Minimize).n_warmup_steps(5);
let completed = vec![make_completed_trial(0, &[(0, 1.0)])];
assert!(!pruner.should_prune(1, 3, &[(3, 100.0)], &completed));
}
#[test]
fn n_min_trials_prevents_pruning() {
let pruner = PercentilePruner::new(50.0, Direction::Minimize).n_min_trials(5);
let completed = vec![
make_completed_trial(0, &[(0, 1.0)]),
make_completed_trial(1, &[(0, 2.0)]),
];
assert!(!pruner.should_prune(2, 0, &[(0, 100.0)], &completed));
}
#[test]
fn maximize_direction() {
let pruner = PercentilePruner::new(50.0, Direction::Maximize);
let completed = vec![
make_completed_trial(0, &[(0, 1.0)]),
make_completed_trial(1, &[(0, 3.0)]),
make_completed_trial(2, &[(0, 5.0)]),
];
assert!(pruner.should_prune(3, 0, &[(0, 2.0)], &completed));
assert!(!pruner.should_prune(3, 0, &[(0, 4.0)], &completed));
}
#[test]
fn near_boundary_percentiles() {
let pruner_low = PercentilePruner::new(1.0, Direction::Minimize);
let pruner_high = PercentilePruner::new(99.0, Direction::Minimize);
let completed = vec![
make_completed_trial(0, &[(0, 1.0)]),
make_completed_trial(1, &[(0, 2.0)]),
make_completed_trial(2, &[(0, 3.0)]),
make_completed_trial(3, &[(0, 100.0)]),
];
assert!(pruner_low.should_prune(4, 0, &[(0, 1.5)], &completed));
assert!(!pruner_high.should_prune(4, 0, &[(0, 50.0)], &completed));
}
}