entrenar/finetune/
tune_searchers.rs1use std::collections::HashMap;
9
10use crate::optim::{
11 GridSearch, HyperparameterSpace, ParameterValue, TPEOptimizer, Trial, TrialStatus,
12};
13
14pub trait TuneSearcher {
22 fn suggest(&mut self) -> crate::Result<Trial>;
24
25 fn record(&mut self, trial: Trial, score: f64, epochs: usize);
27
28 fn best(&self) -> Option<&Trial>;
30}
31
32pub trait TuneScheduler {
36 fn should_stop(&self, trial_id: usize, epoch: usize, val_loss: f64) -> bool;
43}
44
45pub struct TpeSearcher {
51 optimizer: TPEOptimizer,
52}
53
54impl TpeSearcher {
55 pub fn new(space: HyperparameterSpace, n_startup: usize) -> Self {
57 let optimizer = TPEOptimizer::new(space).with_startup(n_startup);
58 Self { optimizer }
59 }
60}
61
62impl TuneSearcher for TpeSearcher {
63 fn suggest(&mut self) -> crate::Result<Trial> {
64 self.optimizer
65 .suggest()
66 .map_err(|e| crate::Error::ConfigError(format!("TPE suggest failed: {e}")))
67 }
68
69 fn record(&mut self, trial: Trial, score: f64, epochs: usize) {
70 self.optimizer.record(trial, score, epochs);
71 }
72
73 fn best(&self) -> Option<&Trial> {
74 self.optimizer.best_trial()
75 }
76}
77
78pub struct GridSearcher {
80 configs: Vec<HashMap<String, ParameterValue>>,
81 trials: Vec<Trial>,
82 next_idx: usize,
83}
84
85impl GridSearcher {
86 pub fn new(space: HyperparameterSpace, n_points: usize) -> Self {
88 let grid = GridSearch::new(space, n_points);
89 let configs = grid.configurations();
90 Self { configs, trials: Vec::new(), next_idx: 0 }
91 }
92}
93
94impl TuneSearcher for GridSearcher {
95 fn suggest(&mut self) -> crate::Result<Trial> {
96 if self.next_idx >= self.configs.len() {
97 return Err(crate::Error::ConfigError(
98 "Grid search exhausted all configurations".to_string(),
99 ));
100 }
101 let config = self.configs[self.next_idx].clone();
102 let trial = Trial::new(self.next_idx, config);
103 self.next_idx += 1;
104 Ok(trial)
105 }
106
107 fn record(&mut self, trial: Trial, score: f64, epochs: usize) {
108 let mut trial = trial;
109 trial.complete(score, epochs);
110 self.trials.push(trial);
111 }
112
113 fn best(&self) -> Option<&Trial> {
114 self.trials
115 .iter()
116 .filter(|t| t.status == TrialStatus::Completed)
117 .min_by(|a, b| a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal))
118 }
119}
120
121pub struct RandomSearcher {
123 space: HyperparameterSpace,
124 trials: Vec<Trial>,
125 next_id: usize,
126}
127
128impl RandomSearcher {
129 pub fn new(space: HyperparameterSpace) -> Self {
131 Self { space, trials: Vec::new(), next_id: 0 }
132 }
133}
134
135impl TuneSearcher for RandomSearcher {
136 fn suggest(&mut self) -> crate::Result<Trial> {
137 if self.space.is_empty() {
138 return Err(crate::Error::ConfigError("Empty search space".to_string()));
139 }
140 let mut rng = rand::rng();
141 let config = self.space.sample_random(&mut rng);
142 let trial = Trial::new(self.next_id, config);
143 self.next_id += 1;
144 Ok(trial)
145 }
146
147 fn record(&mut self, trial: Trial, score: f64, epochs: usize) {
148 let mut trial = trial;
149 trial.complete(score, epochs);
150 self.trials.push(trial);
151 }
152
153 fn best(&self) -> Option<&Trial> {
154 self.trials
155 .iter()
156 .filter(|t| t.status == TrialStatus::Completed)
157 .min_by(|a, b| a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal))
158 }
159}
160
161pub struct AshaScheduler {
167 grace_period: usize,
169 reduction_factor: f64,
171 history: Vec<Vec<f64>>,
173}
174
175impl AshaScheduler {
176 pub fn new(grace_period: usize, reduction_factor: f64) -> Self {
178 Self { grace_period, reduction_factor: reduction_factor.max(2.0), history: Vec::new() }
179 }
180
181 pub fn record_metric(&mut self, trial_id: usize, _epoch: usize, val_loss: f64) {
183 while self.history.len() <= trial_id {
184 self.history.push(Vec::new());
185 }
186 self.history[trial_id].push(val_loss);
187 }
188}
189
190impl TuneScheduler for AshaScheduler {
191 fn should_stop(&self, _trial_id: usize, epoch: usize, val_loss: f64) -> bool {
192 if epoch < self.grace_period {
193 return false;
194 }
195
196 let mut losses_at_epoch: Vec<f64> =
198 self.history.iter().filter_map(|h| h.get(epoch).copied()).collect();
199
200 if losses_at_epoch.is_empty() {
201 return false;
202 }
203
204 losses_at_epoch.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
205
206 let keep_fraction = 1.0 / self.reduction_factor;
208 let cutoff_idx = ((losses_at_epoch.len() as f64 * keep_fraction).ceil() as usize).max(1);
209 if cutoff_idx >= losses_at_epoch.len() {
210 return false;
211 }
212 let cutoff_val = losses_at_epoch[cutoff_idx];
213 val_loss > cutoff_val
214 }
215}
216
217pub struct MedianScheduler {
219 n_warmup: usize,
221 history: Vec<Vec<f64>>,
223}
224
225impl MedianScheduler {
226 pub fn new(n_warmup: usize) -> Self {
228 Self { n_warmup, history: Vec::new() }
229 }
230
231 pub fn record_metric(&mut self, trial_id: usize, _epoch: usize, val_loss: f64) {
233 while self.history.len() <= trial_id {
234 self.history.push(Vec::new());
235 }
236 self.history[trial_id].push(val_loss);
237 }
238}
239
240impl TuneScheduler for MedianScheduler {
241 fn should_stop(&self, _trial_id: usize, epoch: usize, val_loss: f64) -> bool {
242 if epoch < self.n_warmup {
243 return false;
244 }
245
246 let mut losses: Vec<f64> =
247 self.history.iter().filter_map(|h| h.get(epoch).copied()).collect();
248
249 if losses.len() < 2 {
250 return false;
251 }
252
253 losses.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
254 let median = losses[losses.len() / 2];
255 val_loss > median
256 }
257}
258
259pub struct NoScheduler;
261
262impl TuneScheduler for NoScheduler {
263 fn should_stop(&self, _trial_id: usize, _epoch: usize, _val_loss: f64) -> bool {
264 false
265 }
266}