use std::collections::HashMap;
use std::sync::Mutex;
use super::Pruner;
use crate::sampler::CompletedTrial;
pub struct PatientPruner {
inner: Box<dyn Pruner>,
patience: u64,
consecutive_counts: Mutex<HashMap<u64, u64>>,
}
impl PatientPruner {
pub fn new(inner: impl Pruner + 'static, patience: u64) -> Self {
Self {
inner: Box::new(inner),
patience,
consecutive_counts: Mutex::new(HashMap::new()),
}
}
}
impl Pruner for PatientPruner {
fn should_prune(
&self,
trial_id: u64,
step: u64,
intermediate_values: &[(u64, f64)],
completed_trials: &[CompletedTrial],
) -> bool {
let inner_says_prune =
self.inner
.should_prune(trial_id, step, intermediate_values, completed_trials);
let mut counts = self.consecutive_counts.lock().expect("lock poisoned");
let count = counts.entry(trial_id).or_insert(0);
if inner_says_prune {
*count += 1;
*count >= self.patience
} else {
*count = 0;
false
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pruner::ThresholdPruner;
struct ConstPruner(bool);
impl Pruner for ConstPruner {
fn should_prune(
&self,
_trial_id: u64,
_step: u64,
_intermediate_values: &[(u64, f64)],
_completed_trials: &[CompletedTrial],
) -> bool {
self.0
}
}
struct SequencePruner(Mutex<Vec<bool>>);
impl Pruner for SequencePruner {
fn should_prune(
&self,
_trial_id: u64,
_step: u64,
_intermediate_values: &[(u64, f64)],
_completed_trials: &[CompletedTrial],
) -> bool {
self.0.lock().expect("lock poisoned").remove(0)
}
}
fn call(pruner: &PatientPruner, trial_id: u64, step: u64) -> bool {
pruner.should_prune(trial_id, step, &[(step, 0.0)], &[])
}
#[test]
fn patience_1_behaves_like_inner() {
let pruner = PatientPruner::new(ConstPruner(true), 1);
assert!(call(&pruner, 0, 0));
assert!(call(&pruner, 0, 1));
let pruner = PatientPruner::new(ConstPruner(false), 1);
assert!(!call(&pruner, 0, 0));
assert!(!call(&pruner, 0, 1));
}
#[test]
fn patience_3_requires_consecutive_recommendations() {
let pruner = PatientPruner::new(ConstPruner(true), 3);
assert!(!call(&pruner, 0, 0)); assert!(!call(&pruner, 0, 1)); assert!(call(&pruner, 0, 2)); }
#[test]
fn counter_resets_on_no_prune() {
let seq = vec![true, true, false, true, true, true];
let pruner = PatientPruner::new(SequencePruner(Mutex::new(seq)), 3);
assert!(!call(&pruner, 0, 0)); assert!(!call(&pruner, 0, 1)); assert!(!call(&pruner, 0, 2)); assert!(!call(&pruner, 0, 3)); assert!(!call(&pruner, 0, 4)); assert!(call(&pruner, 0, 5)); }
#[test]
fn independent_per_trial() {
let pruner = PatientPruner::new(ConstPruner(true), 2);
assert!(!call(&pruner, 0, 0)); assert!(!call(&pruner, 1, 0)); assert!(call(&pruner, 0, 1)); assert!(!call(&pruner, 2, 0)); assert!(call(&pruner, 1, 1)); }
#[test]
fn works_with_threshold_pruner() {
let inner = ThresholdPruner::new().upper(10.0);
let pruner = PatientPruner::new(inner, 2);
assert!(!pruner.should_prune(0, 0, &[(0, 5.0)], &[]));
assert!(!pruner.should_prune(0, 1, &[(0, 5.0), (1, 15.0)], &[]));
assert!(pruner.should_prune(0, 2, &[(0, 5.0), (1, 15.0), (2, 20.0)], &[]));
}
}