1use std::time::{Duration, Instant};
6
7use crate::automl::params::ParamKey;
8use crate::automl::search::{SearchSpace, SearchStrategy, Trial, TrialResult};
9
10pub trait Callback<P: ParamKey> {
12 fn on_start(&mut self, _space: &SearchSpace<P>) {}
14
15 fn on_trial_start(&mut self, _trial_num: usize, _trial: &Trial<P>) {}
17
18 fn on_trial_end(&mut self, _trial_num: usize, _result: &TrialResult<P>) {}
20
21 fn on_end(&mut self, _best: Option<&TrialResult<P>>) {}
23
24 fn should_stop(&self) -> bool {
26 false
27 }
28}
29
30#[derive(Debug, Default)]
32pub struct ProgressCallback {
33 verbose: bool,
34}
35
36impl ProgressCallback {
37 #[must_use]
39 pub fn verbose() -> Self {
40 Self { verbose: true }
41 }
42}
43
44impl<P: ParamKey> Callback<P> for ProgressCallback {
45 fn on_trial_end(&mut self, trial_num: usize, result: &TrialResult<P>) {
46 if self.verbose {
47 println!(
48 "Trial {:>3}: score={:.4} params={}",
49 trial_num, result.score, result.trial
50 );
51 }
52 }
53
54 fn on_end(&mut self, best: Option<&TrialResult<P>>) {
55 if self.verbose {
56 if let Some(b) = best {
57 println!("\nBest: score={:.4} params={}", b.score, b.trial);
58 }
59 }
60 }
61}
62
63#[derive(Debug)]
65pub struct EarlyStopping {
66 patience: usize,
67 min_delta: f64,
68 trials_without_improvement: usize,
69 best_score: f64,
70}
71
72impl EarlyStopping {
73 #[must_use]
75 pub fn new(patience: usize) -> Self {
76 Self {
77 patience,
78 min_delta: 1e-4,
79 trials_without_improvement: 0,
80 best_score: f64::NEG_INFINITY,
81 }
82 }
83
84 #[must_use]
86 pub fn min_delta(mut self, delta: f64) -> Self {
87 self.min_delta = delta;
88 self
89 }
90}
91
92impl<P: ParamKey> Callback<P> for EarlyStopping {
93 fn on_trial_end(&mut self, _trial_num: usize, result: &TrialResult<P>) {
94 if result.score > self.best_score + self.min_delta {
95 self.best_score = result.score;
96 self.trials_without_improvement = 0;
97 } else {
98 self.trials_without_improvement += 1;
99 }
100 }
101
102 fn should_stop(&self) -> bool {
103 self.trials_without_improvement >= self.patience
104 }
105}
106
107#[derive(Debug)]
109pub struct TimeBudget {
110 budget: Duration,
111 start: Option<Instant>,
112}
113
114impl TimeBudget {
115 #[must_use]
117 pub fn seconds(secs: u64) -> Self {
118 Self {
119 budget: Duration::from_secs(secs),
120 start: None,
121 }
122 }
123
124 #[must_use]
126 pub fn minutes(mins: u64) -> Self {
127 Self::seconds(mins * 60)
128 }
129
130 #[must_use]
132 pub fn elapsed(&self) -> Duration {
133 self.start.map_or(Duration::ZERO, |s| s.elapsed())
134 }
135
136 #[must_use]
138 pub fn remaining(&self) -> Duration {
139 self.budget.saturating_sub(self.elapsed())
140 }
141}
142
143impl<P: ParamKey> Callback<P> for TimeBudget {
144 fn on_start(&mut self, _space: &SearchSpace<P>) {
145 self.start = Some(Instant::now());
146 }
147
148 fn should_stop(&self) -> bool {
149 self.elapsed() >= self.budget
150 }
151}
152
153#[derive(Debug, Clone)]
155pub struct TuneResult<P: ParamKey> {
156 pub best_trial: Trial<P>,
158 pub best_score: f64,
160 pub history: Vec<TrialResult<P>>,
162 pub elapsed: Duration,
164 pub n_trials: usize,
166}
167
168#[allow(missing_debug_implementations)]
193pub struct AutoTuner<S, P: ParamKey> {
194 strategy: S,
195 callbacks: Vec<Box<dyn Callback<P>>>,
196 _phantom: std::marker::PhantomData<P>,
197}
198
199impl<S, P: ParamKey> AutoTuner<S, P>
200where
201 S: SearchStrategy<P>,
202{
203 pub fn new(strategy: S) -> Self {
205 Self {
206 strategy,
207 callbacks: Vec::new(),
208 _phantom: std::marker::PhantomData,
209 }
210 }
211
212 #[must_use]
214 pub fn time_limit_secs(mut self, secs: u64) -> Self {
215 self.callbacks.push(Box::new(TimeBudget::seconds(secs)));
216 self
217 }
218
219 #[must_use]
221 pub fn time_limit_mins(mut self, mins: u64) -> Self {
222 self.callbacks.push(Box::new(TimeBudget::minutes(mins)));
223 self
224 }
225
226 #[must_use]
228 pub fn early_stopping(mut self, patience: usize) -> Self {
229 self.callbacks.push(Box::new(EarlyStopping::new(patience)));
230 self
231 }
232
233 #[must_use]
235 pub fn verbose(mut self) -> Self {
236 self.callbacks.push(Box::new(ProgressCallback::verbose()));
237 self
238 }
239
240 #[must_use]
242 pub fn callback(mut self, cb: impl Callback<P> + 'static) -> Self {
243 self.callbacks.push(Box::new(cb));
244 self
245 }
246
247 pub fn maximize<F>(mut self, space: &SearchSpace<P>, mut objective: F) -> TuneResult<P>
249 where
250 F: FnMut(&Trial<P>) -> f64,
251 {
252 let start = Instant::now();
253
254 for cb in &mut self.callbacks {
256 cb.on_start(space);
257 }
258
259 let mut history = Vec::new();
260 let mut best_score = f64::NEG_INFINITY;
261 let mut best_trial: Option<Trial<P>> = None;
262 let mut trial_num = 0;
263
264 loop {
265 if self.callbacks.iter().any(|cb| cb.should_stop()) {
267 break;
268 }
269
270 let trials = self.strategy.suggest(space, 1);
272 if trials.is_empty() {
273 break;
274 }
275
276 let trial = trials.into_iter().next().expect("should have trial");
277 trial_num += 1;
278
279 for cb in &mut self.callbacks {
281 cb.on_trial_start(trial_num, &trial);
282 }
283
284 let score = objective(&trial);
286
287 let result = TrialResult {
288 trial: trial.clone(),
289 score,
290 metrics: std::collections::HashMap::new(),
291 };
292
293 if score > best_score {
295 best_score = score;
296 best_trial = Some(trial);
297 }
298
299 for cb in &mut self.callbacks {
301 cb.on_trial_end(trial_num, &result);
302 }
303
304 self.strategy.update(std::slice::from_ref(&result));
306
307 history.push(result);
308 }
309
310 let best_result = history.iter().max_by(|a, b| {
312 a.score
313 .partial_cmp(&b.score)
314 .unwrap_or(std::cmp::Ordering::Equal)
315 });
316 for cb in &mut self.callbacks {
317 cb.on_end(best_result);
318 }
319
320 TuneResult {
321 best_trial: best_trial.unwrap_or_else(|| Trial {
322 values: std::collections::HashMap::new(),
323 }),
324 best_score,
325 history,
326 elapsed: start.elapsed(),
327 n_trials: trial_num,
328 }
329 }
330
331 pub fn minimize<F>(self, space: &SearchSpace<P>, mut objective: F) -> TuneResult<P>
333 where
334 F: FnMut(&Trial<P>) -> f64,
335 {
336 self.maximize(space, move |trial| -objective(trial))
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343 use crate::automl::params::RandomForestParam as RF;
344 use crate::automl::{RandomSearch, SearchSpace};
345
346 #[test]
347 fn test_auto_tuner_basic() {
348 let space: SearchSpace<RF> = SearchSpace::new()
349 .add(RF::NEstimators, 10..100)
350 .add(RF::MaxDepth, 2..10);
351
352 let result =
353 AutoTuner::new(RandomSearch::new(10).with_seed(42)).maximize(&space, |trial| {
354 let n = trial.get_usize(&RF::NEstimators).unwrap_or(50);
355 let d = trial.get_usize(&RF::MaxDepth).unwrap_or(5);
356 (n as f64 / 100.0) + (1.0 - (d as f64 - 5.0).abs() / 5.0)
358 });
359
360 assert_eq!(result.n_trials, 10);
361 assert!(result.best_score > 0.0);
362 }
363
364 #[test]
365 fn test_early_stopping() {
366 let space: SearchSpace<RF> = SearchSpace::new().add(RF::NEstimators, 10..100);
367
368 let result = AutoTuner::new(RandomSearch::new(100))
370 .early_stopping(3)
371 .maximize(&space, |_| 0.5);
372
373 assert!(result.n_trials <= 4);
375 }
376
377 #[test]
378 fn test_time_budget() {
379 let space: SearchSpace<RF> = SearchSpace::new().add(RF::NEstimators, 10..100);
380
381 let result = AutoTuner::new(RandomSearch::new(1000))
382 .time_limit_secs(1)
383 .maximize(&space, |_| {
384 std::thread::sleep(Duration::from_millis(100));
385 1.0
386 });
387
388 assert!(result.elapsed.as_secs() <= 2);
390 assert!(result.n_trials < 1000);
391 }
392
393 #[test]
394 fn test_callbacks() {
395 use std::sync::atomic::{AtomicUsize, Ordering};
396 use std::sync::Arc;
397
398 struct CountingCallback {
399 count: Arc<AtomicUsize>,
400 }
401
402 impl<P: ParamKey> Callback<P> for CountingCallback {
403 fn on_trial_end(&mut self, _: usize, _: &TrialResult<P>) {
404 self.count.fetch_add(1, Ordering::SeqCst);
405 }
406 }
407
408 let count = Arc::new(AtomicUsize::new(0));
409 let space: SearchSpace<RF> = SearchSpace::new().add(RF::NEstimators, 10..100);
410
411 let _ = AutoTuner::new(RandomSearch::new(5))
412 .callback(CountingCallback {
413 count: Arc::clone(&count),
414 })
415 .maximize(&space, |_| 1.0);
416
417 assert_eq!(count.load(Ordering::SeqCst), 5);
418 }
419}