Skip to main content

entrenar/finetune/
tune_searchers.rs

1//! Searcher and scheduler implementations for classification HPO.
2//!
3//! Provides concrete searcher/scheduler types used by `ClassifyTuner`.
4//!
5//! - **Searchers**: `TpeSearcher`, `GridSearcher`, `RandomSearcher`
6//! - **Schedulers**: `AshaScheduler`, `MedianScheduler`, `NoScheduler`
7
8use std::collections::HashMap;
9
10use crate::optim::{
11    GridSearch, HyperparameterSpace, ParameterValue, TPEOptimizer, Trial, TrialStatus,
12};
13
14// ═══════════════════════════════════════════════════════════════════════
15// Traits
16// ═══════════════════════════════════════════════════════════════════════
17
18/// Search strategy for suggesting hyperparameter configurations.
19///
20/// Wraps existing TPEOptimizer, GridSearch, and random sampling.
21pub trait TuneSearcher {
22    /// Suggest the next trial configuration to evaluate.
23    fn suggest(&mut self) -> crate::Result<Trial>;
24
25    /// Record a completed trial's score.
26    fn record(&mut self, trial: Trial, score: f64, epochs: usize);
27
28    /// Get the best trial so far (lowest score).
29    fn best(&self) -> Option<&Trial>;
30}
31
32/// Scheduler for deciding whether to stop a trial early.
33///
34/// Wraps existing HyperbandScheduler logic for ASHA-style pruning.
35pub trait TuneScheduler {
36    /// Should this trial be stopped early?
37    ///
38    /// # Arguments
39    /// * `trial_id` - Current trial index
40    /// * `epoch` - Current epoch (0-indexed)
41    /// * `val_loss` - Current validation loss
42    fn should_stop(&self, trial_id: usize, epoch: usize, val_loss: f64) -> bool;
43}
44
45// ═══════════════════════════════════════════════════════════════════════
46// Searcher implementations
47// ═══════════════════════════════════════════════════════════════════════
48
49/// TPE-based searcher (Bayesian optimization).
50pub struct TpeSearcher {
51    optimizer: TPEOptimizer,
52}
53
54impl TpeSearcher {
55    /// Create a TPE searcher with the given search space.
56    pub fn new(space: HyperparameterSpace, n_startup: usize) -> Self {
57        let optimizer = TPEOptimizer::new(space).with_startup(n_startup);
58        Self { optimizer }
59    }
60}
61
62impl TuneSearcher for TpeSearcher {
63    fn suggest(&mut self) -> crate::Result<Trial> {
64        self.optimizer
65            .suggest()
66            .map_err(|e| crate::Error::ConfigError(format!("TPE suggest failed: {e}")))
67    }
68
69    fn record(&mut self, trial: Trial, score: f64, epochs: usize) {
70        self.optimizer.record(trial, score, epochs);
71    }
72
73    fn best(&self) -> Option<&Trial> {
74        self.optimizer.best_trial()
75    }
76}
77
78/// Grid-based searcher (exhaustive).
79pub struct GridSearcher {
80    configs: Vec<HashMap<String, ParameterValue>>,
81    trials: Vec<Trial>,
82    next_idx: usize,
83}
84
85impl GridSearcher {
86    /// Create a grid searcher with the given search space.
87    pub fn new(space: HyperparameterSpace, n_points: usize) -> Self {
88        let grid = GridSearch::new(space, n_points);
89        let configs = grid.configurations();
90        Self { configs, trials: Vec::new(), next_idx: 0 }
91    }
92}
93
94impl TuneSearcher for GridSearcher {
95    fn suggest(&mut self) -> crate::Result<Trial> {
96        if self.next_idx >= self.configs.len() {
97            return Err(crate::Error::ConfigError(
98                "Grid search exhausted all configurations".to_string(),
99            ));
100        }
101        let config = self.configs[self.next_idx].clone();
102        let trial = Trial::new(self.next_idx, config);
103        self.next_idx += 1;
104        Ok(trial)
105    }
106
107    fn record(&mut self, trial: Trial, score: f64, epochs: usize) {
108        let mut trial = trial;
109        trial.complete(score, epochs);
110        self.trials.push(trial);
111    }
112
113    fn best(&self) -> Option<&Trial> {
114        self.trials
115            .iter()
116            .filter(|t| t.status == TrialStatus::Completed)
117            .min_by(|a, b| a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal))
118    }
119}
120
121/// Random searcher (uniform sampling).
122pub struct RandomSearcher {
123    space: HyperparameterSpace,
124    trials: Vec<Trial>,
125    next_id: usize,
126}
127
128impl RandomSearcher {
129    /// Create a random searcher with the given search space.
130    pub fn new(space: HyperparameterSpace) -> Self {
131        Self { space, trials: Vec::new(), next_id: 0 }
132    }
133}
134
135impl TuneSearcher for RandomSearcher {
136    fn suggest(&mut self) -> crate::Result<Trial> {
137        if self.space.is_empty() {
138            return Err(crate::Error::ConfigError("Empty search space".to_string()));
139        }
140        let mut rng = rand::rng();
141        let config = self.space.sample_random(&mut rng);
142        let trial = Trial::new(self.next_id, config);
143        self.next_id += 1;
144        Ok(trial)
145    }
146
147    fn record(&mut self, trial: Trial, score: f64, epochs: usize) {
148        let mut trial = trial;
149        trial.complete(score, epochs);
150        self.trials.push(trial);
151    }
152
153    fn best(&self) -> Option<&Trial> {
154        self.trials
155            .iter()
156            .filter(|t| t.status == TrialStatus::Completed)
157            .min_by(|a, b| a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal))
158    }
159}
160
161// ═══════════════════════════════════════════════════════════════════════
162// Scheduler implementations
163// ═══════════════════════════════════════════════════════════════════════
164
165/// ASHA-style scheduler: stops trials whose val_loss exceeds the median at the same epoch.
166pub struct AshaScheduler {
167    /// Grace period: minimum epochs before pruning is eligible.
168    grace_period: usize,
169    /// Reduction factor (keep top 1/eta at each rung).
170    reduction_factor: f64,
171    /// Recorded metrics per trial: trial_id → vec of val_loss per epoch.
172    history: Vec<Vec<f64>>,
173}
174
175impl AshaScheduler {
176    /// Create an ASHA scheduler.
177    pub fn new(grace_period: usize, reduction_factor: f64) -> Self {
178        Self { grace_period, reduction_factor: reduction_factor.max(2.0), history: Vec::new() }
179    }
180
181    /// Record a metric for a trial at a given epoch.
182    pub fn record_metric(&mut self, trial_id: usize, _epoch: usize, val_loss: f64) {
183        while self.history.len() <= trial_id {
184            self.history.push(Vec::new());
185        }
186        self.history[trial_id].push(val_loss);
187    }
188}
189
190impl TuneScheduler for AshaScheduler {
191    fn should_stop(&self, _trial_id: usize, epoch: usize, val_loss: f64) -> bool {
192        if epoch < self.grace_period {
193            return false;
194        }
195
196        // Collect all completed trials' val_loss at this epoch
197        let mut losses_at_epoch: Vec<f64> =
198            self.history.iter().filter_map(|h| h.get(epoch).copied()).collect();
199
200        if losses_at_epoch.is_empty() {
201            return false;
202        }
203
204        losses_at_epoch.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
205
206        // Keep top 1/eta — prune if val_loss is above the cutoff
207        let keep_fraction = 1.0 / self.reduction_factor;
208        let cutoff_idx = ((losses_at_epoch.len() as f64 * keep_fraction).ceil() as usize).max(1);
209        if cutoff_idx >= losses_at_epoch.len() {
210            return false;
211        }
212        let cutoff_val = losses_at_epoch[cutoff_idx];
213        val_loss > cutoff_val
214    }
215}
216
217/// Median scheduler: prunes trials whose metric is worse than the median.
218pub struct MedianScheduler {
219    /// Minimum epochs before pruning.
220    n_warmup: usize,
221    /// All recorded metrics: trial_id → vec of val_loss per epoch.
222    history: Vec<Vec<f64>>,
223}
224
225impl MedianScheduler {
226    /// Create a median scheduler.
227    pub fn new(n_warmup: usize) -> Self {
228        Self { n_warmup, history: Vec::new() }
229    }
230
231    /// Record a metric for a trial at a given epoch.
232    pub fn record_metric(&mut self, trial_id: usize, _epoch: usize, val_loss: f64) {
233        while self.history.len() <= trial_id {
234            self.history.push(Vec::new());
235        }
236        self.history[trial_id].push(val_loss);
237    }
238}
239
240impl TuneScheduler for MedianScheduler {
241    fn should_stop(&self, _trial_id: usize, epoch: usize, val_loss: f64) -> bool {
242        if epoch < self.n_warmup {
243            return false;
244        }
245
246        let mut losses: Vec<f64> =
247            self.history.iter().filter_map(|h| h.get(epoch).copied()).collect();
248
249        if losses.len() < 2 {
250            return false;
251        }
252
253        losses.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
254        let median = losses[losses.len() / 2];
255        val_loss > median
256    }
257}
258
259/// No-op scheduler (never prunes).
260pub struct NoScheduler;
261
262impl TuneScheduler for NoScheduler {
263    fn should_stop(&self, _trial_id: usize, _epoch: usize, _val_loss: f64) -> bool {
264        false
265    }
266}