use std::collections::HashMap;
use optimizer::Direction;
use optimizer::pruner::{MedianPruner, Pruner};
use optimizer::sampler::CompletedTrial;
fn trial_with_values(id: u64, intermediate_values: Vec<(u64, f64)>) -> CompletedTrial {
CompletedTrial::with_intermediate_values(
id,
HashMap::new(),
HashMap::new(),
HashMap::new(),
0.0,
intermediate_values,
HashMap::new(),
)
}
#[test]
fn prune_when_worse_than_median_minimize() {
let pruner = MedianPruner::new(Direction::Minimize);
let completed = vec![
trial_with_values(0, vec![(0, 0.5), (1, 0.8), (2, 1.0)]),
trial_with_values(1, vec![(0, 0.6), (1, 1.5), (2, 2.0)]),
trial_with_values(2, vec![(0, 0.7), (1, 2.0), (2, 3.0)]),
];
let current = vec![(0, 0.5), (1, 1.0), (2, 2.5)];
assert!(pruner.should_prune(3, 2, ¤t, &completed));
}
#[test]
fn no_prune_when_better_than_median_minimize() {
let pruner = MedianPruner::new(Direction::Minimize);
let completed = vec![
trial_with_values(0, vec![(0, 0.5), (1, 0.8), (2, 1.0)]),
trial_with_values(1, vec![(0, 0.6), (1, 1.5), (2, 2.0)]),
trial_with_values(2, vec![(0, 0.7), (1, 2.0), (2, 3.0)]),
];
let current = vec![(0, 0.5), (1, 1.0), (2, 1.5)];
assert!(!pruner.should_prune(3, 2, ¤t, &completed));
}
#[test]
fn prune_when_worse_than_median_maximize() {
let pruner = MedianPruner::new(Direction::Maximize);
let completed = vec![
trial_with_values(0, vec![(0, 3.0), (1, 5.0)]),
trial_with_values(1, vec![(0, 4.0), (1, 7.0)]),
trial_with_values(2, vec![(0, 5.0), (1, 9.0)]),
];
let current = vec![(0, 4.0), (1, 6.0)];
assert!(pruner.should_prune(3, 1, ¤t, &completed));
}
#[test]
fn no_prune_when_better_than_median_maximize() {
let pruner = MedianPruner::new(Direction::Maximize);
let completed = vec![
trial_with_values(0, vec![(0, 3.0), (1, 5.0)]),
trial_with_values(1, vec![(0, 4.0), (1, 7.0)]),
trial_with_values(2, vec![(0, 5.0), (1, 9.0)]),
];
let current = vec![(0, 4.0), (1, 8.0)];
assert!(!pruner.should_prune(3, 1, ¤t, &completed));
}
#[test]
fn no_prune_during_warmup() {
let pruner = MedianPruner::new(Direction::Minimize).n_warmup_steps(5);
let completed = vec![trial_with_values(0, vec![(0, 1.0), (1, 1.0), (2, 1.0)])];
let current = vec![(0, 100.0), (1, 100.0), (2, 100.0)];
assert!(!pruner.should_prune(1, 2, ¤t, &completed));
}
#[test]
fn prune_after_warmup() {
let pruner = MedianPruner::new(Direction::Minimize).n_warmup_steps(2);
let completed = vec![trial_with_values(0, vec![(0, 1.0), (1, 1.0), (2, 1.0)])];
let current = vec![(0, 100.0), (1, 100.0), (2, 100.0)];
assert!(pruner.should_prune(1, 2, ¤t, &completed));
}
#[test]
fn no_prune_when_fewer_than_n_min_trials() {
let pruner = MedianPruner::new(Direction::Minimize).n_min_trials(3);
let completed = vec![
trial_with_values(0, vec![(0, 1.0)]),
trial_with_values(1, vec![(0, 2.0)]),
];
let current = vec![(0, 100.0)];
assert!(!pruner.should_prune(2, 0, ¤t, &completed));
}
#[test]
fn prune_when_at_least_n_min_trials() {
let pruner = MedianPruner::new(Direction::Minimize).n_min_trials(3);
let completed = vec![
trial_with_values(0, vec![(0, 1.0)]),
trial_with_values(1, vec![(0, 2.0)]),
trial_with_values(2, vec![(0, 3.0)]),
];
let current = vec![(0, 5.0)];
assert!(pruner.should_prune(3, 0, ¤t, &completed));
}
#[test]
fn no_prune_when_no_completed_trials_at_step() {
let pruner = MedianPruner::new(Direction::Minimize);
let completed = vec![
trial_with_values(0, vec![(0, 1.0)]),
trial_with_values(1, vec![(0, 2.0)]),
];
let current = vec![(0, 0.5), (5, 100.0)];
assert!(!pruner.should_prune(2, 5, ¤t, &completed));
}
#[test]
fn correct_median_with_even_number_of_trials() {
let pruner = MedianPruner::new(Direction::Minimize);
let completed = vec![
trial_with_values(0, vec![(0, 1.0)]),
trial_with_values(1, vec![(0, 2.0)]),
trial_with_values(2, vec![(0, 3.0)]),
trial_with_values(3, vec![(0, 4.0)]),
];
let current = vec![(0, 2.6)];
assert!(pruner.should_prune(4, 0, ¤t, &completed));
let current = vec![(0, 2.4)];
assert!(!pruner.should_prune(4, 0, ¤t, &completed));
}
#[test]
fn correct_median_with_odd_number_of_trials() {
let pruner = MedianPruner::new(Direction::Minimize);
let completed = vec![
trial_with_values(0, vec![(0, 1.0)]),
trial_with_values(1, vec![(0, 2.0)]),
trial_with_values(2, vec![(0, 3.0)]),
trial_with_values(3, vec![(0, 4.0)]),
trial_with_values(4, vec![(0, 5.0)]),
];
let current = vec![(0, 3.5)];
assert!(pruner.should_prune(5, 0, ¤t, &completed));
let current = vec![(0, 2.5)];
assert!(!pruner.should_prune(5, 0, ¤t, &completed));
}
#[test]
fn works_with_non_contiguous_steps() {
let pruner = MedianPruner::new(Direction::Minimize);
let completed = vec![
trial_with_values(0, vec![(0, 1.0), (10, 2.0), (100, 3.0)]),
trial_with_values(1, vec![(0, 1.5), (10, 2.5), (100, 4.0)]),
trial_with_values(2, vec![(0, 2.0), (10, 3.0), (100, 5.0)]),
];
let current = vec![(0, 1.0), (10, 2.0), (100, 4.5)];
assert!(pruner.should_prune(3, 100, ¤t, &completed));
let current = vec![(0, 1.0), (10, 2.0), (100, 3.5)];
assert!(!pruner.should_prune(3, 100, ¤t, &completed));
}
#[test]
fn no_prune_when_no_intermediate_values() {
let pruner = MedianPruner::new(Direction::Minimize);
let completed = vec![trial_with_values(0, vec![(0, 1.0)])];
assert!(!pruner.should_prune(1, 0, &[], &completed));
}
#[test]
fn pruned_trials_excluded_from_median() {
use optimizer::TrialState;
let pruner = MedianPruner::new(Direction::Minimize);
let mut pruned = trial_with_values(0, vec![(0, 0.1)]);
pruned.state = TrialState::Pruned;
let completed = vec![pruned, trial_with_values(1, vec![(0, 5.0)])];
let current = vec![(0, 3.0)];
assert!(!pruner.should_prune(2, 0, ¤t, &completed));
let current = vec![(0, 6.0)];
assert!(pruner.should_prune(2, 0, ¤t, &completed));
}