use core::cmp::Ordering;
use super::Pruner;
use crate::sampler::CompletedTrial;
use crate::types::{Direction, TrialState};
pub struct WilcoxonPruner {
p_value_threshold: f64,
n_warmup_steps: u64,
n_min_trials: usize,
direction: Direction,
}
impl WilcoxonPruner {
#[must_use]
pub fn new(direction: Direction) -> Self {
Self {
p_value_threshold: 0.05,
n_warmup_steps: 0,
n_min_trials: 1,
direction,
}
}
#[must_use]
pub fn p_value_threshold(mut self, p: f64) -> Self {
assert!(
p > 0.0 && p < 1.0,
"p_value_threshold must be in (0.0, 1.0)"
);
self.p_value_threshold = p;
self
}
#[must_use]
pub fn n_warmup_steps(mut self, n: u64) -> Self {
self.n_warmup_steps = n;
self
}
#[must_use]
pub fn n_min_trials(mut self, n: usize) -> Self {
self.n_min_trials = n;
self
}
}
impl Pruner for WilcoxonPruner {
fn should_prune(
&self,
_trial_id: u64,
step: u64,
intermediate_values: &[(u64, f64)],
completed_trials: &[CompletedTrial],
) -> bool {
if step < self.n_warmup_steps {
return false;
}
let completed: Vec<&CompletedTrial> = completed_trials
.iter()
.filter(|t| t.state == TrialState::Complete)
.collect();
if completed.len() < self.n_min_trials {
return false;
}
let best = match self.direction {
Direction::Minimize => completed
.iter()
.min_by(|a, b| a.value.partial_cmp(&b.value).unwrap_or(Ordering::Equal)),
Direction::Maximize => completed
.iter()
.max_by(|a, b| a.value.partial_cmp(&b.value).unwrap_or(Ordering::Equal)),
};
let Some(best) = best else {
return false;
};
let pairs: Vec<(f64, f64)> = intermediate_values
.iter()
.filter_map(|&(s, current_v)| {
best.intermediate_values
.iter()
.find(|(bs, _)| *bs == s)
.map(|&(_, best_v)| (current_v, best_v))
})
.collect();
if pairs.len() < 6 {
return false;
}
let differences: Vec<f64> = pairs
.iter()
.map(|&(current, best_v)| current - best_v)
.collect();
let p_value = wilcoxon_signed_rank_test(&differences, self.direction);
p_value < self.p_value_threshold
}
}
fn wilcoxon_signed_rank_test(differences: &[f64], direction: Direction) -> f64 {
let nonzero: Vec<f64> = differences.iter().copied().filter(|d| *d != 0.0).collect();
let n = nonzero.len();
if n < 6 {
return 1.0; }
let mut abs_ranked: Vec<(usize, f64, f64)> = nonzero
.iter()
.enumerate()
.map(|(i, &d)| (i, d.abs(), d))
.collect();
abs_ranked.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
let ranks = assign_ranks(&abs_ranked);
let mut w_plus = 0.0;
let mut w_minus = 0.0;
for (i, &(_, _, orig)) in abs_ranked.iter().enumerate() {
if orig > 0.0 {
w_plus += ranks[i];
} else {
w_minus += ranks[i];
}
}
let w = match direction {
Direction::Minimize => w_minus,
Direction::Maximize => w_plus,
};
#[allow(clippy::cast_precision_loss)]
let n_f = n as f64;
let mean = n_f * (n_f + 1.0) / 4.0;
let variance = n_f * (n_f + 1.0) * (2.0 * n_f + 1.0) / 24.0;
let tie_correction = compute_tie_correction(&ranks);
let adjusted_variance = variance - tie_correction;
if adjusted_variance <= 0.0 {
return 1.0;
}
let std_dev = adjusted_variance.sqrt();
let continuity = if w < mean { 0.5 } else { -0.5 };
let z = (w + continuity - mean) / std_dev;
normal_cdf(z)
}
fn assign_ranks(sorted: &[(usize, f64, f64)]) -> Vec<f64> {
let n = sorted.len();
let mut ranks = vec![0.0; n];
let mut i = 0;
while i < n {
let mut j = i;
while j < n
&& (sorted[j].1 - sorted[i].1).abs() < f64::EPSILON * sorted[i].1.max(1.0) * 100.0
{
j += 1;
}
#[allow(clippy::cast_precision_loss)]
let avg_rank = (i + 1 + j) as f64 / 2.0;
for rank in ranks.iter_mut().take(j).skip(i) {
*rank = avg_rank;
}
i = j;
}
ranks
}
fn compute_tie_correction(ranks: &[f64]) -> f64 {
let mut correction = 0.0;
let mut i = 0;
while i < ranks.len() {
let mut j = i;
while j < ranks.len() && (ranks[j] - ranks[i]).abs() < f64::EPSILON {
j += 1;
}
#[allow(clippy::cast_precision_loss)]
let t = (j - i) as f64;
if t > 1.0 {
correction += t * t * t - t;
}
i = j;
}
correction / 48.0
}
fn normal_cdf(x: f64) -> f64 {
0.5 * erfc(-x / core::f64::consts::SQRT_2)
}
fn erfc(x: f64) -> f64 {
let t = 1.0 / (1.0 + 0.327_591_1 * x.abs());
let poly = t
* (0.254_829_592
+ t * (-0.284_496_736
+ t * (1.421_413_741 + t * (-1.453_152_027 + t * 1.061_405_429))));
let result = poly * (-x * x).exp();
if x >= 0.0 { result } else { 2.0 - result }
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use super::*;
fn trial_with_values(
id: u64,
value: f64,
intermediate_values: Vec<(u64, f64)>,
) -> CompletedTrial {
CompletedTrial::with_intermediate_values(
id,
HashMap::new(),
HashMap::new(),
HashMap::new(),
value,
intermediate_values,
HashMap::new(),
)
}
#[test]
fn no_prune_during_warmup() {
let pruner = WilcoxonPruner::new(Direction::Minimize).n_warmup_steps(10);
let completed = vec![trial_with_values(
0,
0.1,
(0..20).map(|s| (s, 0.1)).collect(),
)];
let current: Vec<(u64, f64)> = (0..8).map(|s| (s, 100.0)).collect();
assert!(!pruner.should_prune(1, 7, ¤t, &completed));
}
#[test]
fn no_prune_with_insufficient_trials() {
let pruner = WilcoxonPruner::new(Direction::Minimize).n_min_trials(5);
let completed = vec![trial_with_values(
0,
0.1,
(0..20).map(|s| (s, 0.1)).collect(),
)];
let current: Vec<(u64, f64)> = (0..10).map(|s| (s, 100.0)).collect();
assert!(!pruner.should_prune(1, 9, ¤t, &completed));
}
#[test]
fn no_prune_with_fewer_than_6_pairs() {
let pruner = WilcoxonPruner::new(Direction::Minimize);
let completed = vec![trial_with_values(
0,
0.1,
(0..5).map(|s| (s, 0.1)).collect(),
)];
let current: Vec<(u64, f64)> = (0..5).map(|s| (s, 100.0)).collect();
assert!(!pruner.should_prune(1, 4, ¤t, &completed));
}
#[test]
fn prune_when_consistently_worse_minimize() {
let pruner = WilcoxonPruner::new(Direction::Minimize);
let best_values: Vec<(u64, f64)> = (0..20).map(|s| (s, 0.1)).collect();
let completed = vec![trial_with_values(0, 0.1, best_values)];
let current: Vec<(u64, f64)> = (0..20).map(|s| (s, 10.0)).collect();
assert!(pruner.should_prune(1, 19, ¤t, &completed));
}
#[test]
fn prune_when_consistently_worse_maximize() {
let pruner = WilcoxonPruner::new(Direction::Maximize);
let best_values: Vec<(u64, f64)> = (0..20).map(|s| (s, 10.0)).collect();
let completed = vec![trial_with_values(0, 10.0, best_values)];
let current: Vec<(u64, f64)> = (0..20).map(|s| (s, 0.1)).collect();
assert!(pruner.should_prune(1, 19, ¤t, &completed));
}
#[test]
fn no_prune_when_statistically_similar() {
let pruner = WilcoxonPruner::new(Direction::Minimize);
let best_values: Vec<(u64, f64)> = (0..20_u64)
.map(|s| {
let noise = if s.is_multiple_of(2) { 0.01 } else { -0.01 };
(s, 1.0 + noise)
})
.collect();
let completed = vec![trial_with_values(0, 1.0, best_values)];
let current: Vec<(u64, f64)> = (0..20_u64)
.map(|s| {
let noise = if s.is_multiple_of(2) { -0.01 } else { 0.01 };
(s, 1.0 + noise)
})
.collect();
assert!(!pruner.should_prune(1, 19, ¤t, &completed));
}
#[test]
fn selects_best_trial_minimize() {
let pruner = WilcoxonPruner::new(Direction::Minimize);
let completed = vec![
trial_with_values(0, 0.1, (0..20).map(|s| (s, 0.1)).collect()),
trial_with_values(1, 5.0, (0..20).map(|s| (s, 5.0)).collect()),
];
let current: Vec<(u64, f64)> = (0..20).map(|s| (s, 5.0)).collect();
assert!(pruner.should_prune(2, 19, ¤t, &completed));
}
#[test]
fn selects_best_trial_maximize() {
let pruner = WilcoxonPruner::new(Direction::Maximize);
let completed = vec![
trial_with_values(0, 0.1, (0..20).map(|s| (s, 0.1)).collect()),
trial_with_values(1, 10.0, (0..20).map(|s| (s, 10.0)).collect()),
];
let current: Vec<(u64, f64)> = (0..20).map(|s| (s, 0.1)).collect();
assert!(pruner.should_prune(2, 19, ¤t, &completed));
}
#[test]
fn ignores_pruned_trials() {
let pruner = WilcoxonPruner::new(Direction::Minimize);
let mut trial = trial_with_values(0, 0.1, (0..20).map(|s| (s, 0.1)).collect());
trial.state = TrialState::Pruned;
let completed = vec![trial];
let current: Vec<(u64, f64)> = (0..20).map(|s| (s, 100.0)).collect();
assert!(!pruner.should_prune(1, 19, ¤t, &completed));
}
#[test]
fn lower_p_value_is_more_conservative() {
let strict = WilcoxonPruner::new(Direction::Minimize).p_value_threshold(0.001);
let lenient = WilcoxonPruner::new(Direction::Minimize).p_value_threshold(0.1);
let completed = vec![trial_with_values(
0,
0.1,
(0..20).map(|s| (s, 0.1)).collect(),
)];
let current: Vec<(u64, f64)> = (0..20)
.map(|s| if s < 15 { (s, 0.2) } else { (s, 0.15) })
.collect();
let lenient_prunes = lenient.should_prune(1, 19, ¤t, &completed);
let strict_prunes = strict.should_prune(1, 19, ¤t, &completed);
if !lenient_prunes {
assert!(!strict_prunes);
}
}
#[test]
#[should_panic(expected = "p_value_threshold must be in (0.0, 1.0)")]
fn panics_on_zero_p_value() {
let _ = WilcoxonPruner::new(Direction::Minimize).p_value_threshold(0.0);
}
#[test]
#[should_panic(expected = "p_value_threshold must be in (0.0, 1.0)")]
fn panics_on_one_p_value() {
let _ = WilcoxonPruner::new(Direction::Minimize).p_value_threshold(1.0);
}
#[test]
fn correct_signed_rank_statistic() {
let diffs = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let p = wilcoxon_signed_rank_test(&diffs, Direction::Minimize);
assert!(
p < 0.05,
"p-value {p} should be < 0.05 for all-positive diffs"
);
}
#[test]
fn symmetric_differences_not_significant() {
let diffs = vec![1.0, -1.0, 2.0, -2.0, 3.0, -3.0, 4.0, -4.0];
let p = wilcoxon_signed_rank_test(&diffs, Direction::Minimize);
assert!(p > 0.05, "p-value {p} should be > 0.05 for symmetric diffs");
}
#[test]
fn normal_cdf_known_values() {
assert!((normal_cdf(0.0) - 0.5).abs() < 1e-6);
assert!(normal_cdf(-10.0) < 1e-6);
assert!((normal_cdf(10.0) - 1.0).abs() < 1e-6);
assert!((normal_cdf(-1.96) - 0.025).abs() < 0.001);
}
#[test]
fn no_intermediate_values() {
let pruner = WilcoxonPruner::new(Direction::Minimize);
let completed = vec![trial_with_values(
0,
0.1,
(0..20).map(|s| (s, 0.1)).collect(),
)];
assert!(!pruner.should_prune(1, 0, &[], &completed));
}
#[test]
fn no_completed_trials() {
let pruner = WilcoxonPruner::new(Direction::Minimize);
let current: Vec<(u64, f64)> = (0..20).map(|s| (s, 1.0)).collect();
assert!(!pruner.should_prune(1, 19, ¤t, &[]));
}
}