aprender/automl/
tuner.rs

1//! AutoTuner for automatic hyperparameter optimization.
2//!
3//! Provides high-level API for tuning any model that implements the Estimator trait.
4
5use std::time::{Duration, Instant};
6
7use crate::automl::params::ParamKey;
8use crate::automl::search::{SearchSpace, SearchStrategy, Trial, TrialResult};
9
10/// Callback trait for monitoring optimization progress.
11pub trait Callback<P: ParamKey> {
12    /// Called at the start of optimization.
13    fn on_start(&mut self, _space: &SearchSpace<P>) {}
14
15    /// Called before each trial.
16    fn on_trial_start(&mut self, _trial_num: usize, _trial: &Trial<P>) {}
17
18    /// Called after each trial with results.
19    fn on_trial_end(&mut self, _trial_num: usize, _result: &TrialResult<P>) {}
20
21    /// Called at the end of optimization.
22    fn on_end(&mut self, _best: Option<&TrialResult<P>>) {}
23
24    /// Return true to stop optimization early.
25    fn should_stop(&self) -> bool {
26        false
27    }
28}
29
30/// Progress logging callback.
31#[derive(Debug, Default)]
32pub struct ProgressCallback {
33    verbose: bool,
34}
35
36impl ProgressCallback {
37    /// Create verbose progress callback.
38    #[must_use]
39    pub fn verbose() -> Self {
40        Self { verbose: true }
41    }
42}
43
44impl<P: ParamKey> Callback<P> for ProgressCallback {
45    fn on_trial_end(&mut self, trial_num: usize, result: &TrialResult<P>) {
46        if self.verbose {
47            println!(
48                "Trial {:>3}: score={:.4} params={}",
49                trial_num, result.score, result.trial
50            );
51        }
52    }
53
54    fn on_end(&mut self, best: Option<&TrialResult<P>>) {
55        if self.verbose {
56            if let Some(b) = best {
57                println!("\nBest: score={:.4} params={}", b.score, b.trial);
58            }
59        }
60    }
61}
62
63/// Early stopping callback.
64#[derive(Debug)]
65pub struct EarlyStopping {
66    patience: usize,
67    min_delta: f64,
68    trials_without_improvement: usize,
69    best_score: f64,
70}
71
72impl EarlyStopping {
73    /// Create early stopping with patience (number of trials without improvement).
74    #[must_use]
75    pub fn new(patience: usize) -> Self {
76        Self {
77            patience,
78            min_delta: 1e-4,
79            trials_without_improvement: 0,
80            best_score: f64::NEG_INFINITY,
81        }
82    }
83
84    /// Set minimum improvement threshold.
85    #[must_use]
86    pub fn min_delta(mut self, delta: f64) -> Self {
87        self.min_delta = delta;
88        self
89    }
90}
91
92impl<P: ParamKey> Callback<P> for EarlyStopping {
93    fn on_trial_end(&mut self, _trial_num: usize, result: &TrialResult<P>) {
94        if result.score > self.best_score + self.min_delta {
95            self.best_score = result.score;
96            self.trials_without_improvement = 0;
97        } else {
98            self.trials_without_improvement += 1;
99        }
100    }
101
102    fn should_stop(&self) -> bool {
103        self.trials_without_improvement >= self.patience
104    }
105}
106
107/// Time budget constraint.
108#[derive(Debug)]
109pub struct TimeBudget {
110    budget: Duration,
111    start: Option<Instant>,
112}
113
114impl TimeBudget {
115    /// Create time budget in seconds.
116    #[must_use]
117    pub fn seconds(secs: u64) -> Self {
118        Self {
119            budget: Duration::from_secs(secs),
120            start: None,
121        }
122    }
123
124    /// Create time budget in minutes.
125    #[must_use]
126    pub fn minutes(mins: u64) -> Self {
127        Self::seconds(mins * 60)
128    }
129
130    /// Elapsed time since start.
131    #[must_use]
132    pub fn elapsed(&self) -> Duration {
133        self.start.map_or(Duration::ZERO, |s| s.elapsed())
134    }
135
136    /// Remaining time.
137    #[must_use]
138    pub fn remaining(&self) -> Duration {
139        self.budget.saturating_sub(self.elapsed())
140    }
141}
142
143impl<P: ParamKey> Callback<P> for TimeBudget {
144    fn on_start(&mut self, _space: &SearchSpace<P>) {
145        self.start = Some(Instant::now());
146    }
147
148    fn should_stop(&self) -> bool {
149        self.elapsed() >= self.budget
150    }
151}
152
153/// Result of hyperparameter optimization.
154#[derive(Debug, Clone)]
155pub struct TuneResult<P: ParamKey> {
156    /// Best trial found.
157    pub best_trial: Trial<P>,
158    /// Best score achieved.
159    pub best_score: f64,
160    /// All trial results.
161    pub history: Vec<TrialResult<P>>,
162    /// Total optimization time.
163    pub elapsed: Duration,
164    /// Number of trials run.
165    pub n_trials: usize,
166}
167
168/// AutoTuner for hyperparameter optimization.
169///
170/// # Example
171///
172/// ```ignore
173/// use aprender::automl::{AutoTuner, RandomSearch, SearchSpace};
174/// use aprender::automl::params::RandomForestParam as RF;
175///
176/// let space = SearchSpace::new()
177///     .add(RF::NEstimators, 10..500)
178///     .add(RF::MaxDepth, 2..20);
179///
180/// let result = AutoTuner::new(RandomSearch::new(100))
181///     .time_limit_secs(60)
182///     .early_stopping(20)
183///     .maximize(space, |trial| {
184///         let n = trial.get_usize(&RF::NEstimators).unwrap_or(100);
185///         let d = trial.get_usize(&RF::MaxDepth).unwrap_or(5);
186///         // Return cross-validation score
187///         evaluate_model(n, d)
188///     });
189///
190/// println!("Best: {:?}", result.best_trial);
191/// ```
192#[allow(missing_debug_implementations)]
193pub struct AutoTuner<S, P: ParamKey> {
194    strategy: S,
195    callbacks: Vec<Box<dyn Callback<P>>>,
196    _phantom: std::marker::PhantomData<P>,
197}
198
199impl<S, P: ParamKey> AutoTuner<S, P>
200where
201    S: SearchStrategy<P>,
202{
203    /// Create new tuner with search strategy.
204    pub fn new(strategy: S) -> Self {
205        Self {
206            strategy,
207            callbacks: Vec::new(),
208            _phantom: std::marker::PhantomData,
209        }
210    }
211
212    /// Add time limit in seconds.
213    #[must_use]
214    pub fn time_limit_secs(mut self, secs: u64) -> Self {
215        self.callbacks.push(Box::new(TimeBudget::seconds(secs)));
216        self
217    }
218
219    /// Add time limit in minutes.
220    #[must_use]
221    pub fn time_limit_mins(mut self, mins: u64) -> Self {
222        self.callbacks.push(Box::new(TimeBudget::minutes(mins)));
223        self
224    }
225
226    /// Add early stopping with patience.
227    #[must_use]
228    pub fn early_stopping(mut self, patience: usize) -> Self {
229        self.callbacks.push(Box::new(EarlyStopping::new(patience)));
230        self
231    }
232
233    /// Add verbose progress logging.
234    #[must_use]
235    pub fn verbose(mut self) -> Self {
236        self.callbacks.push(Box::new(ProgressCallback::verbose()));
237        self
238    }
239
240    /// Add custom callback.
241    #[must_use]
242    pub fn callback(mut self, cb: impl Callback<P> + 'static) -> Self {
243        self.callbacks.push(Box::new(cb));
244        self
245    }
246
247    /// Run optimization to maximize objective.
248    pub fn maximize<F>(mut self, space: &SearchSpace<P>, mut objective: F) -> TuneResult<P>
249    where
250        F: FnMut(&Trial<P>) -> f64,
251    {
252        let start = Instant::now();
253
254        // Notify callbacks of start
255        for cb in &mut self.callbacks {
256            cb.on_start(space);
257        }
258
259        let mut history = Vec::new();
260        let mut best_score = f64::NEG_INFINITY;
261        let mut best_trial: Option<Trial<P>> = None;
262        let mut trial_num = 0;
263
264        loop {
265            // Check stopping conditions
266            if self.callbacks.iter().any(|cb| cb.should_stop()) {
267                break;
268            }
269
270            // Get next trial(s)
271            let trials = self.strategy.suggest(space, 1);
272            if trials.is_empty() {
273                break;
274            }
275
276            let trial = trials.into_iter().next().expect("should have trial");
277            trial_num += 1;
278
279            // Notify trial start
280            for cb in &mut self.callbacks {
281                cb.on_trial_start(trial_num, &trial);
282            }
283
284            // Evaluate
285            let score = objective(&trial);
286
287            let result = TrialResult {
288                trial: trial.clone(),
289                score,
290                metrics: std::collections::HashMap::new(),
291            };
292
293            // Update best
294            if score > best_score {
295                best_score = score;
296                best_trial = Some(trial);
297            }
298
299            // Notify trial end
300            for cb in &mut self.callbacks {
301                cb.on_trial_end(trial_num, &result);
302            }
303
304            // Update strategy with result
305            self.strategy.update(std::slice::from_ref(&result));
306
307            history.push(result);
308        }
309
310        // Notify end
311        let best_result = history.iter().max_by(|a, b| {
312            a.score
313                .partial_cmp(&b.score)
314                .unwrap_or(std::cmp::Ordering::Equal)
315        });
316        for cb in &mut self.callbacks {
317            cb.on_end(best_result);
318        }
319
320        TuneResult {
321            best_trial: best_trial.unwrap_or_else(|| Trial {
322                values: std::collections::HashMap::new(),
323            }),
324            best_score,
325            history,
326            elapsed: start.elapsed(),
327            n_trials: trial_num,
328        }
329    }
330
331    /// Run optimization to minimize objective.
332    pub fn minimize<F>(self, space: &SearchSpace<P>, mut objective: F) -> TuneResult<P>
333    where
334        F: FnMut(&Trial<P>) -> f64,
335    {
336        self.maximize(space, move |trial| -objective(trial))
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343    use crate::automl::params::RandomForestParam as RF;
344    use crate::automl::{RandomSearch, SearchSpace};
345
346    #[test]
347    fn test_auto_tuner_basic() {
348        let space: SearchSpace<RF> = SearchSpace::new()
349            .add(RF::NEstimators, 10..100)
350            .add(RF::MaxDepth, 2..10);
351
352        let result =
353            AutoTuner::new(RandomSearch::new(10).with_seed(42)).maximize(&space, |trial| {
354                let n = trial.get_usize(&RF::NEstimators).unwrap_or(50);
355                let d = trial.get_usize(&RF::MaxDepth).unwrap_or(5);
356                // Simple objective: prefer more trees and moderate depth
357                (n as f64 / 100.0) + (1.0 - (d as f64 - 5.0).abs() / 5.0)
358            });
359
360        assert_eq!(result.n_trials, 10);
361        assert!(result.best_score > 0.0);
362    }
363
364    #[test]
365    fn test_early_stopping() {
366        let space: SearchSpace<RF> = SearchSpace::new().add(RF::NEstimators, 10..100);
367
368        // Constant objective should trigger early stopping
369        let result = AutoTuner::new(RandomSearch::new(100))
370            .early_stopping(3)
371            .maximize(&space, |_| 0.5);
372
373        // Should stop after patience + 1 trials
374        assert!(result.n_trials <= 4);
375    }
376
377    #[test]
378    fn test_time_budget() {
379        let space: SearchSpace<RF> = SearchSpace::new().add(RF::NEstimators, 10..100);
380
381        let result = AutoTuner::new(RandomSearch::new(1000))
382            .time_limit_secs(1)
383            .maximize(&space, |_| {
384                std::thread::sleep(Duration::from_millis(100));
385                1.0
386            });
387
388        // Should complete within ~1 second
389        assert!(result.elapsed.as_secs() <= 2);
390        assert!(result.n_trials < 1000);
391    }
392
393    #[test]
394    fn test_callbacks() {
395        use std::sync::atomic::{AtomicUsize, Ordering};
396        use std::sync::Arc;
397
398        struct CountingCallback {
399            count: Arc<AtomicUsize>,
400        }
401
402        impl<P: ParamKey> Callback<P> for CountingCallback {
403            fn on_trial_end(&mut self, _: usize, _: &TrialResult<P>) {
404                self.count.fetch_add(1, Ordering::SeqCst);
405            }
406        }
407
408        let count = Arc::new(AtomicUsize::new(0));
409        let space: SearchSpace<RF> = SearchSpace::new().add(RF::NEstimators, 10..100);
410
411        let _ = AutoTuner::new(RandomSearch::new(5))
412            .callback(CountingCallback {
413                count: Arc::clone(&count),
414            })
415            .maximize(&space, |_| 1.0);
416
417        assert_eq!(count.load(Ordering::SeqCst), 5);
418    }
419}