use crate::error::{OptimizeError, OptimizeResult};
use super::types::{
sample_configs, ConfigSampler, EvaluationResult, MultiFidelityConfig, MultiFidelityResult,
};
#[derive(Debug, Clone)]
pub(crate) struct RoundSpec {
pub n_configs: usize,
pub budget: f64,
}
#[derive(Debug, Clone)]
pub struct SuccessiveHalving {
config: MultiFidelityConfig,
}
impl SuccessiveHalving {
pub fn new(config: MultiFidelityConfig) -> OptimizeResult<Self> {
config.validate()?;
Ok(Self { config })
}
pub fn compute_schedule(&self) -> Vec<RoundSpec> {
self.compute_schedule_with(
self.effective_n_initial(),
self.config.min_budget,
self.config.max_budget,
)
}
pub(crate) fn compute_schedule_with(
&self,
n_initial: usize,
start_budget: f64,
max_budget: f64,
) -> Vec<RoundSpec> {
let eta = self.config.eta;
let mut schedule = Vec::new();
let mut n = n_initial;
let mut budget = start_budget;
loop {
schedule.push(RoundSpec {
n_configs: n,
budget,
});
if n <= 1 || budget >= max_budget {
break;
}
n = (n / eta).max(1);
budget = (budget * eta as f64).min(max_budget);
}
schedule
}
fn effective_n_initial(&self) -> usize {
if self.config.n_initial > 0 {
return self.config.n_initial;
}
let s_max = self.config.s_max();
let eta = self.config.eta;
(eta as f64).powi(s_max as i32) as usize
}
pub fn run<F>(
&self,
objective: &F,
bounds: &[(f64, f64)],
sampler: &ConfigSampler,
rng_state: &mut u64,
) -> OptimizeResult<MultiFidelityResult>
where
F: Fn(&[f64], f64) -> OptimizeResult<f64>,
{
self.run_with(
objective,
bounds,
sampler,
rng_state,
self.effective_n_initial(),
self.config.min_budget,
)
}
pub(crate) fn run_with<F>(
&self,
objective: &F,
bounds: &[(f64, f64)],
sampler: &ConfigSampler,
rng_state: &mut u64,
n_initial: usize,
start_budget: f64,
) -> OptimizeResult<MultiFidelityResult>
where
F: Fn(&[f64], f64) -> OptimizeResult<f64>,
{
if bounds.is_empty() {
return Err(OptimizeError::InvalidParameter(
"bounds must not be empty".into(),
));
}
let schedule = self.compute_schedule_with(n_initial, start_budget, self.config.max_budget);
let initial_n = schedule.first().map(|r| r.n_configs).unwrap_or(n_initial);
let mut configs: Vec<(usize, Vec<f64>)> =
sample_configs(initial_n, bounds, sampler, rng_state)
.into_iter()
.enumerate()
.collect();
let mut all_evals: Vec<EvaluationResult> = Vec::new();
let mut total_budget = 0.0;
let mut next_id = configs.len();
for round in &schedule {
configs.truncate(round.n_configs);
if configs.is_empty() {
break;
}
let mut scored: Vec<(usize, Vec<f64>, f64)> = Vec::with_capacity(configs.len());
for (id, cfg) in &configs {
let obj = objective(cfg, round.budget)?;
all_evals.push(EvaluationResult {
config_id: *id,
config: cfg.clone(),
budget: round.budget,
objective: obj,
});
total_budget += round.budget;
scored.push((*id, cfg.clone(), obj));
}
scored.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
let keep = (scored.len() / self.config.eta).max(1);
scored.truncate(keep);
configs = scored.into_iter().map(|(id, cfg, _)| (id, cfg)).collect();
let _ = next_id; }
let best = all_evals
.iter()
.min_by(|a, b| {
a.objective
.partial_cmp(&b.objective)
.unwrap_or(std::cmp::Ordering::Equal)
})
.ok_or_else(|| OptimizeError::ComputationError("no evaluations performed".into()))?;
Ok(MultiFidelityResult {
best_config: best.config.clone(),
best_objective: best.objective,
total_budget_used: total_budget,
evaluations: all_evals,
n_brackets: 1,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn quadratic(x: &[f64], _budget: f64) -> OptimizeResult<f64> {
Ok(x.iter().map(|xi| xi * xi).sum())
}
#[test]
fn test_sh_finds_minimum() {
let cfg = MultiFidelityConfig {
max_budget: 27.0,
min_budget: 1.0,
eta: 3,
n_initial: 27,
};
let sh = SuccessiveHalving::new(cfg).expect("valid config");
let bounds = vec![(-5.0, 5.0), (-5.0, 5.0)];
let mut rng = 42u64;
let result = sh
.run(&quadratic, &bounds, &ConfigSampler::Random, &mut rng)
.expect("run succeeds");
assert!(
result.best_objective < 10.0,
"best objective {} should be < 10",
result.best_objective
);
}
#[test]
fn test_budget_monotonically_increases() {
let cfg = MultiFidelityConfig {
max_budget: 81.0,
min_budget: 1.0,
eta: 3,
n_initial: 81,
};
let sh = SuccessiveHalving::new(cfg).expect("valid config");
let schedule = sh.compute_schedule();
for window in schedule.windows(2) {
assert!(
window[1].budget >= window[0].budget,
"budget must not decrease: {} -> {}",
window[0].budget,
window[1].budget
);
}
}
#[test]
fn test_configs_decrease_each_round() {
let cfg = MultiFidelityConfig {
max_budget: 81.0,
min_budget: 1.0,
eta: 3,
n_initial: 81,
};
let sh = SuccessiveHalving::new(cfg).expect("valid config");
let schedule = sh.compute_schedule();
for window in schedule.windows(2) {
assert!(
window[1].n_configs <= window[0].n_configs,
"n_configs must not increase: {} -> {}",
window[0].n_configs,
window[1].n_configs
);
}
}
#[test]
fn test_best_config_survives() {
let cfg = MultiFidelityConfig {
max_budget: 9.0,
min_budget: 1.0,
eta: 3,
n_initial: 9,
};
let sh = SuccessiveHalving::new(cfg).expect("valid config");
let bounds = vec![(-5.0, 5.0)];
let mut rng = 7u64;
let result = sh
.run(&quadratic, &bounds, &ConfigSampler::Random, &mut rng)
.expect("run ok");
let best_id = result
.evaluations
.iter()
.min_by(|a, b| {
a.objective
.partial_cmp(&b.objective)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|e| e.config_id);
let max_budget_evals: Vec<_> = result
.evaluations
.iter()
.filter(|e| (e.budget - 9.0).abs() < 1e-9)
.collect();
assert!(
!max_budget_evals.is_empty(),
"should have evaluations at max budget"
);
let final_ids: Vec<usize> = max_budget_evals.iter().map(|e| e.config_id).collect();
assert!(best_id.is_some(), "should have at least one evaluation");
}
#[test]
fn test_schedule_correct_pairs() {
let cfg = MultiFidelityConfig {
max_budget: 27.0,
min_budget: 1.0,
eta: 3,
n_initial: 27,
};
let sh = SuccessiveHalving::new(cfg).expect("valid config");
let schedule = sh.compute_schedule();
assert_eq!(schedule.len(), 4);
assert_eq!(schedule[0].n_configs, 27);
assert!((schedule[0].budget - 1.0).abs() < 1e-9);
assert_eq!(schedule[1].n_configs, 9);
assert!((schedule[1].budget - 3.0).abs() < 1e-9);
assert_eq!(schedule[2].n_configs, 3);
assert!((schedule[2].budget - 9.0).abs() < 1e-9);
assert_eq!(schedule[3].n_configs, 1);
assert!((schedule[3].budget - 27.0).abs() < 1e-9);
}
#[test]
fn test_sh_with_lhs() {
let cfg = MultiFidelityConfig {
max_budget: 9.0,
min_budget: 1.0,
eta: 3,
n_initial: 9,
};
let sh = SuccessiveHalving::new(cfg).expect("valid config");
let bounds = vec![(-1.0, 1.0), (-1.0, 1.0)];
let mut rng = 99u64;
let result = sh
.run(
&quadratic,
&bounds,
&ConfigSampler::LatinHypercube,
&mut rng,
)
.expect("lhs run ok");
assert!(result.best_objective < 2.0);
}
#[test]
fn test_empty_bounds_error() {
let cfg = MultiFidelityConfig::default();
let sh = SuccessiveHalving::new(cfg).expect("valid config");
let result = sh.run(&quadratic, &[], &ConfigSampler::Random, &mut 1u64);
assert!(result.is_err());
}
}