optimizer/pruner/mod.rs
1//! Pruner trait and implementations for trial pruning.
2//!
3//! Pruners decide whether to stop (prune) a trial early based on its
4//! intermediate values compared to other trials. This is useful for
5//! discarding unpromising trials before they complete, saving compute.
6//!
7//! # How pruning works
8//!
9//! During optimization, each trial reports intermediate values at discrete
10//! steps (e.g., validation loss after each training epoch). A pruner inspects
11//! these values and compares them against completed trials to decide whether
12//! the current trial should be stopped early.
13//!
14//! The typical flow is:
15//!
16//! 1. Call [`Trial::report`](crate::Trial::report) to record an intermediate value.
17//! 2. Call [`Trial::should_prune`](crate::Trial::should_prune) to check the pruner's decision.
18//! 3. If the pruner says prune, return [`TrialPruned`](crate::TrialPruned) from the objective.
19//!
20//! # Available pruners
21//!
22//! | Pruner | Algorithm | Best for |
23//! |--------|-----------|----------|
24//! | [`MedianPruner`] | Prune below median at each step | General-purpose default |
25//! | [`PercentilePruner`] | Prune below configurable percentile | Tunable aggressiveness |
26//! | [`ThresholdPruner`] | Prune outside fixed bounds | Known divergence limits |
27//! | [`PatientPruner`] | Require N consecutive prune signals | Noisy intermediate values |
28//! | [`SuccessiveHalvingPruner`] | Keep top 1/η fraction at each rung | Budget-aware pruning |
29//! | [`HyperbandPruner`] | Multiple SHA brackets with different budgets | Robust to budget choice |
30//! | [`WilcoxonPruner`] | Statistical signed-rank test vs. best trial | Rigorous noisy pruning |
31//! | [`NopPruner`] | Never prune | Disabling pruning explicitly |
32//!
33//! # When to use pruning
34//!
35//! Pruning is most beneficial when:
36//!
37//! - The objective function has a natural notion of "steps" (e.g., training epochs)
38//! - Early steps are informative about final performance
39//! - Trials are expensive enough that stopping bad ones early saves significant time
40//!
41//! Start with [`MedianPruner`] for most use cases. Switch to [`WilcoxonPruner`]
42//! if your intermediate values are noisy, or to [`HyperbandPruner`] if you want
43//! automatic budget scheduling.
44//!
45//! # Stateful vs stateless pruners
46//!
47//! **Stateless** pruners make their decision purely from the arguments passed
48//! to [`Pruner::should_prune`] — they hold no mutable per-trial state.
49//! [`MedianPruner`], [`PercentilePruner`], [`ThresholdPruner`],
50//! [`WilcoxonPruner`], and [`NopPruner`] are all stateless.
51//!
52//! **Stateful** pruners track information across calls. [`PatientPruner`]
53//! uses `Mutex<HashMap<u64, u64>>` to count consecutive prune signals per
54//! trial. [`HyperbandPruner`] uses `Mutex` and `AtomicU64` for bracket
55//! assignment state. When writing a stateful pruner, wrap mutable state in a
56//! `Mutex` and key it by `trial_id` to keep trials independent.
57//!
58//! # Cold start and warmup
59//!
60//! Two builder parameters control when pruning begins:
61//!
62//! - **`n_warmup_steps`** — skip pruning before step N *within a trial*,
63//! giving the objective time to stabilize.
64//! - **`n_min_trials`** — require N completed trials before pruning any trial,
65//! ensuring a meaningful comparison baseline.
66//!
67//! See [`MedianPruner`] for the canonical implementation of both parameters.
68//! Custom pruners should expose similar knobs when applicable.
69//!
70//! # Composing pruners
71//!
72//! [`PatientPruner`] demonstrates the decorator pattern: it wraps any
73//! `Box<dyn Pruner>` and adds patience logic on top. Custom pruners can use
74//! the same pattern to layer multiple pruning conditions — for example,
75//! combining a statistical test with a hard threshold.
76//!
77//! # Thread safety
78//!
79//! The [`Pruner`] trait requires `Send + Sync`.
80//! [`Study`](crate::Study) stores the pruner as `Arc<dyn Pruner>`, so
81//! multiple threads may call [`Pruner::should_prune`] concurrently.
82//!
83//! - **Stateless pruners** satisfy `Send + Sync` automatically.
84//! - **Stateful pruners** should use `std::sync::Mutex` or
85//! `parking_lot::Mutex` to protect mutable state, keyed by `trial_id`.
86//!
87//! # Testing custom pruners
88//!
89//! Recommended test categories:
90//!
91//! 1. **Never-prune baseline** — empty history and early steps should not
92//! prune.
93//! 2. **Known-prune scenario** — a clearly worse trial should be pruned.
94//! 3. **Known-keep scenario** — a well-performing trial should survive.
95//! 4. **Warmup respected** — pruning must be suppressed during warmup steps
96//! and while the minimum trial count has not been reached.
97//! 5. **Per-trial independence** — stateful pruners must not leak state
98//! between different `trial_id` values.
99
100mod hyperband;
101mod median;
102mod nop;
103mod patient;
104pub(crate) mod percentile;
105mod successive_halving;
106mod threshold;
107mod wilcoxon;
108
109pub use hyperband::HyperbandPruner;
110pub use median::MedianPruner;
111pub use nop::NopPruner;
112pub use patient::PatientPruner;
113pub use percentile::PercentilePruner;
114pub use successive_halving::SuccessiveHalvingPruner;
115pub use threshold::ThresholdPruner;
116pub use wilcoxon::WilcoxonPruner;
117
118use crate::sampler::CompletedTrial;
119
120/// Trait for pluggable trial pruning strategies.
121///
122/// Pruners are consulted after each intermediate value is reported to
123/// decide whether the trial should be stopped early. The trait requires
124/// `Send + Sync` to support concurrent and async optimization.
125///
126/// # Implementing a custom pruner
127///
128/// ```
129/// use optimizer::pruner::Pruner;
130/// use optimizer::sampler::CompletedTrial;
131///
132/// struct MyPruner {
133/// threshold: f64,
134/// }
135///
136/// impl Pruner for MyPruner {
137/// fn should_prune(
138/// &self,
139/// _trial_id: u64,
140/// _step: u64,
141/// intermediate_values: &[(u64, f64)],
142/// _completed_trials: &[CompletedTrial],
143/// ) -> bool {
144/// // Prune if the latest value exceeds the threshold
145/// intermediate_values
146/// .last()
147/// .is_some_and(|&(_, v)| v > self.threshold)
148/// }
149/// }
150/// ```
151///
152/// A stateful pruner that tracks per-trial state with a `Mutex`:
153///
154/// ```
155/// use std::collections::HashMap;
156/// use std::sync::Mutex;
157/// use optimizer::pruner::Pruner;
158/// use optimizer::sampler::CompletedTrial;
159///
160/// /// Prune after the value worsens for `max_stale` consecutive steps.
161/// struct StalePruner {
162/// max_stale: u64,
163/// // Per-trial: (previous_value, consecutive_stale_count)
164/// state: Mutex<HashMap<u64, (f64, u64)>>,
165/// }
166///
167/// impl StalePruner {
168/// fn new(max_stale: u64) -> Self {
169/// Self { max_stale, state: Mutex::new(HashMap::new()) }
170/// }
171/// }
172///
173/// impl Pruner for StalePruner {
174/// fn should_prune(
175/// &self,
176/// trial_id: u64,
177/// _step: u64,
178/// intermediate_values: &[(u64, f64)],
179/// _completed_trials: &[CompletedTrial],
180/// ) -> bool {
181/// let Some(&(_, current)) = intermediate_values.last() else {
182/// return false;
183/// };
184/// let mut state = self.state.lock().unwrap();
185/// let entry = state.entry(trial_id).or_insert((current, 0));
186/// if current >= entry.0 {
187/// entry.1 += 1;
188/// } else {
189/// entry.1 = 0;
190/// }
191/// entry.0 = current;
192/// entry.1 >= self.max_stale
193/// }
194/// }
195/// ```
196///
197/// See the [module-level documentation](self) for a comprehensive guide
198/// covering warmup, composition, thread safety, and testing.
199pub trait Pruner: Send + Sync {
200 /// Decide whether to prune a trial at the given step.
201 ///
202 /// # Arguments
203 ///
204 /// * `trial_id` - The current trial's ID.
205 /// * `step` - The step at which the intermediate value was reported.
206 /// * `intermediate_values` - All `(step, value)` pairs reported so far for this trial.
207 /// * `completed_trials` - History of all completed trials (for comparison).
208 fn should_prune(
209 &self,
210 trial_id: u64,
211 step: u64,
212 intermediate_values: &[(u64, f64)],
213 completed_trials: &[CompletedTrial],
214 ) -> bool;
215}