Skip to main content

optimizer/pruner/
mod.rs

1//! Pruner trait and implementations for trial pruning.
2//!
3//! Pruners decide whether to stop (prune) a trial early based on its
4//! intermediate values compared to other trials. This is useful for
5//! discarding unpromising trials before they complete, saving compute.
6//!
7//! # How pruning works
8//!
9//! During optimization, each trial reports intermediate values at discrete
10//! steps (e.g., validation loss after each training epoch). A pruner inspects
11//! these values and compares them against completed trials to decide whether
12//! the current trial should be stopped early.
13//!
14//! The typical flow is:
15//!
16//! 1. Call [`Trial::report`](crate::Trial::report) to record an intermediate value.
17//! 2. Call [`Trial::should_prune`](crate::Trial::should_prune) to check the pruner's decision.
18//! 3. If the pruner says prune, return [`TrialPruned`](crate::TrialPruned) from the objective.
19//!
20//! # Available pruners
21//!
22//! | Pruner | Algorithm | Best for |
23//! |--------|-----------|----------|
24//! | [`MedianPruner`] | Prune below median at each step | General-purpose default |
25//! | [`PercentilePruner`] | Prune below configurable percentile | Tunable aggressiveness |
26//! | [`ThresholdPruner`] | Prune outside fixed bounds | Known divergence limits |
27//! | [`PatientPruner`] | Require N consecutive prune signals | Noisy intermediate values |
28//! | [`SuccessiveHalvingPruner`] | Keep top 1/η fraction at each rung | Budget-aware pruning |
29//! | [`HyperbandPruner`] | Multiple SHA brackets with different budgets | Robust to budget choice |
30//! | [`WilcoxonPruner`] | Statistical signed-rank test vs. best trial | Rigorous noisy pruning |
31//! | [`NopPruner`] | Never prune | Disabling pruning explicitly |
32//!
33//! # When to use pruning
34//!
35//! Pruning is most beneficial when:
36//!
37//! - The objective function has a natural notion of "steps" (e.g., training epochs)
38//! - Early steps are informative about final performance
39//! - Trials are expensive enough that stopping bad ones early saves significant time
40//!
41//! Start with [`MedianPruner`] for most use cases. Switch to [`WilcoxonPruner`]
42//! if your intermediate values are noisy, or to [`HyperbandPruner`] if you want
43//! automatic budget scheduling.
44//!
45//! # Stateful vs stateless pruners
46//!
47//! **Stateless** pruners make their decision purely from the arguments passed
48//! to [`Pruner::should_prune`] — they hold no mutable per-trial state.
49//! [`MedianPruner`], [`PercentilePruner`], [`ThresholdPruner`],
50//! [`WilcoxonPruner`], and [`NopPruner`] are all stateless.
51//!
52//! **Stateful** pruners track information across calls. [`PatientPruner`]
53//! uses `Mutex<HashMap<u64, u64>>` to count consecutive prune signals per
54//! trial. [`HyperbandPruner`] uses `Mutex` and `AtomicU64` for bracket
55//! assignment state. When writing a stateful pruner, wrap mutable state in a
56//! `Mutex` and key it by `trial_id` to keep trials independent.
57//!
58//! # Cold start and warmup
59//!
60//! Two builder parameters control when pruning begins:
61//!
62//! - **`n_warmup_steps`** — skip pruning before step N *within a trial*,
63//!   giving the objective time to stabilize.
64//! - **`n_min_trials`** — require N completed trials before pruning any trial,
65//!   ensuring a meaningful comparison baseline.
66//!
67//! See [`MedianPruner`] for the canonical implementation of both parameters.
68//! Custom pruners should expose similar knobs when applicable.
69//!
70//! # Composing pruners
71//!
72//! [`PatientPruner`] demonstrates the decorator pattern: it wraps any
73//! `Box<dyn Pruner>` and adds patience logic on top. Custom pruners can use
74//! the same pattern to layer multiple pruning conditions — for example,
75//! combining a statistical test with a hard threshold.
76//!
77//! # Thread safety
78//!
79//! The [`Pruner`] trait requires `Send + Sync`.
80//! [`Study`](crate::Study) stores the pruner as `Arc<dyn Pruner>`, so
81//! multiple threads may call [`Pruner::should_prune`] concurrently.
82//!
83//! - **Stateless pruners** satisfy `Send + Sync` automatically.
84//! - **Stateful pruners** should use `std::sync::Mutex` or
85//!   `parking_lot::Mutex` to protect mutable state, keyed by `trial_id`.
86//!
87//! # Testing custom pruners
88//!
89//! Recommended test categories:
90//!
91//! 1. **Never-prune baseline** — empty history and early steps should not
92//!    prune.
93//! 2. **Known-prune scenario** — a clearly worse trial should be pruned.
94//! 3. **Known-keep scenario** — a well-performing trial should survive.
95//! 4. **Warmup respected** — pruning must be suppressed during warmup steps
96//!    and while the minimum trial count has not been reached.
97//! 5. **Per-trial independence** — stateful pruners must not leak state
98//!    between different `trial_id` values.
99
100mod hyperband;
101mod median;
102mod nop;
103mod patient;
104pub(crate) mod percentile;
105mod successive_halving;
106mod threshold;
107mod wilcoxon;
108
109pub use hyperband::HyperbandPruner;
110pub use median::MedianPruner;
111pub use nop::NopPruner;
112pub use patient::PatientPruner;
113pub use percentile::PercentilePruner;
114pub use successive_halving::SuccessiveHalvingPruner;
115pub use threshold::ThresholdPruner;
116pub use wilcoxon::WilcoxonPruner;
117
118use crate::sampler::CompletedTrial;
119
120/// Trait for pluggable trial pruning strategies.
121///
122/// Pruners are consulted after each intermediate value is reported to
123/// decide whether the trial should be stopped early. The trait requires
124/// `Send + Sync` to support concurrent and async optimization.
125///
126/// # Implementing a custom pruner
127///
128/// ```
129/// use optimizer::pruner::Pruner;
130/// use optimizer::sampler::CompletedTrial;
131///
132/// struct MyPruner {
133///     threshold: f64,
134/// }
135///
136/// impl Pruner for MyPruner {
137///     fn should_prune(
138///         &self,
139///         _trial_id: u64,
140///         _step: u64,
141///         intermediate_values: &[(u64, f64)],
142///         _completed_trials: &[CompletedTrial],
143///     ) -> bool {
144///         // Prune if the latest value exceeds the threshold
145///         intermediate_values
146///             .last()
147///             .is_some_and(|&(_, v)| v > self.threshold)
148///     }
149/// }
150/// ```
151///
152/// A stateful pruner that tracks per-trial state with a `Mutex`:
153///
154/// ```
155/// use std::collections::HashMap;
156/// use std::sync::Mutex;
157/// use optimizer::pruner::Pruner;
158/// use optimizer::sampler::CompletedTrial;
159///
160/// /// Prune after the value worsens for `max_stale` consecutive steps.
161/// struct StalePruner {
162///     max_stale: u64,
163///     // Per-trial: (previous_value, consecutive_stale_count)
164///     state: Mutex<HashMap<u64, (f64, u64)>>,
165/// }
166///
167/// impl StalePruner {
168///     fn new(max_stale: u64) -> Self {
169///         Self { max_stale, state: Mutex::new(HashMap::new()) }
170///     }
171/// }
172///
173/// impl Pruner for StalePruner {
174///     fn should_prune(
175///         &self,
176///         trial_id: u64,
177///         _step: u64,
178///         intermediate_values: &[(u64, f64)],
179///         _completed_trials: &[CompletedTrial],
180///     ) -> bool {
181///         let Some(&(_, current)) = intermediate_values.last() else {
182///             return false;
183///         };
184///         let mut state = self.state.lock().unwrap();
185///         let entry = state.entry(trial_id).or_insert((current, 0));
186///         if current >= entry.0 {
187///             entry.1 += 1;
188///         } else {
189///             entry.1 = 0;
190///         }
191///         entry.0 = current;
192///         entry.1 >= self.max_stale
193///     }
194/// }
195/// ```
196///
197/// See the [module-level documentation](self) for a comprehensive guide
198/// covering warmup, composition, thread safety, and testing.
199pub trait Pruner: Send + Sync {
200    /// Decide whether to prune a trial at the given step.
201    ///
202    /// # Arguments
203    ///
204    /// * `trial_id` - The current trial's ID.
205    /// * `step` - The step at which the intermediate value was reported.
206    /// * `intermediate_values` - All `(step, value)` pairs reported so far for this trial.
207    /// * `completed_trials` - History of all completed trials (for comparison).
208    fn should_prune(
209        &self,
210        trial_id: u64,
211        step: u64,
212        intermediate_values: &[(u64, f64)],
213        completed_trials: &[CompletedTrial],
214    ) -> bool;
215}