use std::collections::HashMap;
use crate::optim::{
GridSearch, HyperparameterSpace, ParameterValue, TPEOptimizer, Trial, TrialStatus,
};
pub trait TuneSearcher {
fn suggest(&mut self) -> crate::Result<Trial>;
fn record(&mut self, trial: Trial, score: f64, epochs: usize);
fn best(&self) -> Option<&Trial>;
}
pub trait TuneScheduler {
fn should_stop(&self, trial_id: usize, epoch: usize, val_loss: f64) -> bool;
}
pub struct TpeSearcher {
optimizer: TPEOptimizer,
}
impl TpeSearcher {
pub fn new(space: HyperparameterSpace, n_startup: usize) -> Self {
let optimizer = TPEOptimizer::new(space).with_startup(n_startup);
Self { optimizer }
}
}
impl TuneSearcher for TpeSearcher {
fn suggest(&mut self) -> crate::Result<Trial> {
self.optimizer
.suggest()
.map_err(|e| crate::Error::ConfigError(format!("TPE suggest failed: {e}")))
}
fn record(&mut self, trial: Trial, score: f64, epochs: usize) {
self.optimizer.record(trial, score, epochs);
}
fn best(&self) -> Option<&Trial> {
self.optimizer.best_trial()
}
}
pub struct GridSearcher {
configs: Vec<HashMap<String, ParameterValue>>,
trials: Vec<Trial>,
next_idx: usize,
}
impl GridSearcher {
pub fn new(space: HyperparameterSpace, n_points: usize) -> Self {
let grid = GridSearch::new(space, n_points);
let configs = grid.configurations();
Self { configs, trials: Vec::new(), next_idx: 0 }
}
}
impl TuneSearcher for GridSearcher {
fn suggest(&mut self) -> crate::Result<Trial> {
if self.next_idx >= self.configs.len() {
return Err(crate::Error::ConfigError(
"Grid search exhausted all configurations".to_string(),
));
}
let config = self.configs[self.next_idx].clone();
let trial = Trial::new(self.next_idx, config);
self.next_idx += 1;
Ok(trial)
}
fn record(&mut self, trial: Trial, score: f64, epochs: usize) {
let mut trial = trial;
trial.complete(score, epochs);
self.trials.push(trial);
}
fn best(&self) -> Option<&Trial> {
self.trials
.iter()
.filter(|t| t.status == TrialStatus::Completed)
.min_by(|a, b| a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal))
}
}
pub struct RandomSearcher {
space: HyperparameterSpace,
trials: Vec<Trial>,
next_id: usize,
}
impl RandomSearcher {
pub fn new(space: HyperparameterSpace) -> Self {
Self { space, trials: Vec::new(), next_id: 0 }
}
}
impl TuneSearcher for RandomSearcher {
fn suggest(&mut self) -> crate::Result<Trial> {
if self.space.is_empty() {
return Err(crate::Error::ConfigError("Empty search space".to_string()));
}
let mut rng = rand::rng();
let config = self.space.sample_random(&mut rng);
let trial = Trial::new(self.next_id, config);
self.next_id += 1;
Ok(trial)
}
fn record(&mut self, trial: Trial, score: f64, epochs: usize) {
let mut trial = trial;
trial.complete(score, epochs);
self.trials.push(trial);
}
fn best(&self) -> Option<&Trial> {
self.trials
.iter()
.filter(|t| t.status == TrialStatus::Completed)
.min_by(|a, b| a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal))
}
}
pub struct AshaScheduler {
grace_period: usize,
reduction_factor: f64,
history: Vec<Vec<f64>>,
}
impl AshaScheduler {
pub fn new(grace_period: usize, reduction_factor: f64) -> Self {
Self { grace_period, reduction_factor: reduction_factor.max(2.0), history: Vec::new() }
}
pub fn record_metric(&mut self, trial_id: usize, _epoch: usize, val_loss: f64) {
while self.history.len() <= trial_id {
self.history.push(Vec::new());
}
self.history[trial_id].push(val_loss);
}
}
impl TuneScheduler for AshaScheduler {
fn should_stop(&self, _trial_id: usize, epoch: usize, val_loss: f64) -> bool {
if epoch < self.grace_period {
return false;
}
let mut losses_at_epoch: Vec<f64> =
self.history.iter().filter_map(|h| h.get(epoch).copied()).collect();
if losses_at_epoch.is_empty() {
return false;
}
losses_at_epoch.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let keep_fraction = 1.0 / self.reduction_factor;
let cutoff_idx = ((losses_at_epoch.len() as f64 * keep_fraction).ceil() as usize).max(1);
if cutoff_idx >= losses_at_epoch.len() {
return false;
}
let cutoff_val = losses_at_epoch[cutoff_idx];
val_loss > cutoff_val
}
}
pub struct MedianScheduler {
n_warmup: usize,
history: Vec<Vec<f64>>,
}
impl MedianScheduler {
pub fn new(n_warmup: usize) -> Self {
Self { n_warmup, history: Vec::new() }
}
pub fn record_metric(&mut self, trial_id: usize, _epoch: usize, val_loss: f64) {
while self.history.len() <= trial_id {
self.history.push(Vec::new());
}
self.history[trial_id].push(val_loss);
}
}
impl TuneScheduler for MedianScheduler {
fn should_stop(&self, _trial_id: usize, epoch: usize, val_loss: f64) -> bool {
if epoch < self.n_warmup {
return false;
}
let mut losses: Vec<f64> =
self.history.iter().filter_map(|h| h.get(epoch).copied()).collect();
if losses.len() < 2 {
return false;
}
losses.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let median = losses[losses.len() / 2];
val_loss > median
}
}
pub struct NoScheduler;
impl TuneScheduler for NoScheduler {
fn should_stop(&self, _trial_id: usize, _epoch: usize, _val_loss: f64) -> bool {
false
}
}