use std::time::{Duration, Instant};
use crate::automl::params::ParamKey;
use crate::automl::search::{SearchSpace, SearchStrategy, Trial, TrialResult};
pub trait Callback<P: ParamKey> {
fn on_start(&mut self, _space: &SearchSpace<P>) {}
fn on_trial_start(&mut self, _trial_num: usize, _trial: &Trial<P>) {}
fn on_trial_end(&mut self, _trial_num: usize, _result: &TrialResult<P>) {}
fn on_end(&mut self, _best: Option<&TrialResult<P>>) {}
fn should_stop(&self) -> bool {
false
}
}
#[derive(Debug, Default)]
pub struct ProgressCallback {
verbose: bool,
}
impl ProgressCallback {
#[must_use]
pub fn verbose() -> Self {
Self { verbose: true }
}
}
impl<P: ParamKey> Callback<P> for ProgressCallback {
fn on_trial_end(&mut self, trial_num: usize, result: &TrialResult<P>) {
if self.verbose {
println!(
"Trial {:>3}: score={:.4} params={}",
trial_num, result.score, result.trial
);
}
}
fn on_end(&mut self, best: Option<&TrialResult<P>>) {
if self.verbose {
if let Some(b) = best {
println!("\nBest: score={:.4} params={}", b.score, b.trial);
}
}
}
}
#[derive(Debug)]
pub struct EarlyStopping {
patience: usize,
min_delta: f64,
trials_without_improvement: usize,
best_score: f64,
}
impl EarlyStopping {
#[must_use]
pub fn new(patience: usize) -> Self {
Self {
patience,
min_delta: 1e-4,
trials_without_improvement: 0,
best_score: f64::NEG_INFINITY,
}
}
#[must_use]
pub fn min_delta(mut self, delta: f64) -> Self {
self.min_delta = delta;
self
}
}
impl<P: ParamKey> Callback<P> for EarlyStopping {
fn on_trial_end(&mut self, _trial_num: usize, result: &TrialResult<P>) {
if result.score > self.best_score + self.min_delta {
self.best_score = result.score;
self.trials_without_improvement = 0;
} else {
self.trials_without_improvement += 1;
}
}
fn should_stop(&self) -> bool {
self.trials_without_improvement >= self.patience
}
}
#[derive(Debug)]
pub struct TimeBudget {
budget: Duration,
start: Option<Instant>,
}
impl TimeBudget {
#[must_use]
pub fn seconds(secs: u64) -> Self {
Self {
budget: Duration::from_secs(secs),
start: None,
}
}
#[must_use]
pub fn minutes(mins: u64) -> Self {
Self::seconds(mins * 60)
}
#[must_use]
pub fn elapsed(&self) -> Duration {
self.start.map_or(Duration::ZERO, |s| s.elapsed())
}
#[must_use]
pub fn remaining(&self) -> Duration {
self.budget.saturating_sub(self.elapsed())
}
}
impl<P: ParamKey> Callback<P> for TimeBudget {
fn on_start(&mut self, _space: &SearchSpace<P>) {
self.start = Some(Instant::now());
}
fn should_stop(&self) -> bool {
self.elapsed() >= self.budget
}
}
#[derive(Debug, Clone)]
pub struct TuneResult<P: ParamKey> {
pub best_trial: Trial<P>,
pub best_score: f64,
pub history: Vec<TrialResult<P>>,
pub elapsed: Duration,
pub n_trials: usize,
}
#[allow(missing_debug_implementations)]
pub struct AutoTuner<S, P: ParamKey> {
strategy: S,
callbacks: Vec<Box<dyn Callback<P>>>,
_phantom: std::marker::PhantomData<P>,
}
impl<S, P: ParamKey> AutoTuner<S, P>
where
S: SearchStrategy<P>,
{
pub fn new(strategy: S) -> Self {
Self {
strategy,
callbacks: Vec::new(),
_phantom: std::marker::PhantomData,
}
}
#[must_use]
pub fn time_limit_secs(mut self, secs: u64) -> Self {
self.callbacks.push(Box::new(TimeBudget::seconds(secs)));
self
}
#[must_use]
pub fn time_limit_mins(mut self, mins: u64) -> Self {
self.callbacks.push(Box::new(TimeBudget::minutes(mins)));
self
}
#[must_use]
pub fn early_stopping(mut self, patience: usize) -> Self {
self.callbacks.push(Box::new(EarlyStopping::new(patience)));
self
}
#[must_use]
pub fn verbose(mut self) -> Self {
self.callbacks.push(Box::new(ProgressCallback::verbose()));
self
}
#[must_use]
pub fn callback(mut self, cb: impl Callback<P> + 'static) -> Self {
self.callbacks.push(Box::new(cb));
self
}
pub fn maximize<F>(mut self, space: &SearchSpace<P>, mut objective: F) -> TuneResult<P>
where
F: FnMut(&Trial<P>) -> f64,
{
let start = Instant::now();
for cb in &mut self.callbacks {
cb.on_start(space);
}
let mut history = Vec::new();
let mut best_score = f64::NEG_INFINITY;
let mut best_trial: Option<Trial<P>> = None;
let mut trial_num = 0;
loop {
if self.callbacks.iter().any(|cb| cb.should_stop()) {
break;
}
let trials = self.strategy.suggest(space, 1);
if trials.is_empty() {
break;
}
let trial = trials.into_iter().next().expect("should have trial");
trial_num += 1;
for cb in &mut self.callbacks {
cb.on_trial_start(trial_num, &trial);
}
let score = objective(&trial);
let result = TrialResult {
trial: trial.clone(),
score,
metrics: std::collections::HashMap::new(),
};
if score > best_score {
best_score = score;
best_trial = Some(trial);
}
for cb in &mut self.callbacks {
cb.on_trial_end(trial_num, &result);
}
self.strategy.update(std::slice::from_ref(&result));
history.push(result);
}
let best_result = history.iter().max_by(|a, b| {
a.score
.partial_cmp(&b.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
for cb in &mut self.callbacks {
cb.on_end(best_result);
}
TuneResult {
best_trial: best_trial.unwrap_or_else(|| Trial {
values: std::collections::HashMap::new(),
}),
best_score,
history,
elapsed: start.elapsed(),
n_trials: trial_num,
}
}
pub fn minimize<F>(self, space: &SearchSpace<P>, mut objective: F) -> TuneResult<P>
where
F: FnMut(&Trial<P>) -> f64,
{
self.maximize(space, move |trial| -objective(trial))
}
}
#[cfg(test)]
#[path = "tuner_tests.rs"]
mod tests;