use crate::error::{from_yamakan, into_yamakan};
use crate::yamakan_utils::YamakanIdGen;
use kurobako_core::json::JsonRecipe;
use kurobako_core::num::OrderedFloat;
use kurobako_core::problem::ProblemSpec;
use kurobako_core::registry::FactoryRegistry;
use kurobako_core::rng::{ArcRng, Rng};
use kurobako_core::solver::{
BoxSolver, BoxSolverFactory, Capability, Solver, SolverFactory, SolverRecipe, SolverSpec,
SolverSpecBuilder,
};
use kurobako_core::trial::{EvaluatedTrial, IdGen, NextTrial, TrialId, Values};
use kurobako_core::{ErrorKind, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::f64;
use structopt::StructOpt;
use yamakan::optimizers::asha::{AshaOptimizer, AshaOptimizerBuilder};
use yamakan::{self, Budget, MfObs, MultiFidelityOptimizer, Obs, ObsId, Optimizer, Ranked};
#[derive(Debug, Clone, StructOpt, Serialize, Deserialize)]
#[structopt(rename_all = "kebab-case")]
pub struct AshaSolverRecipe {
#[structopt(long, default_value = "0.01")]
pub min_step_rate: f64,
#[structopt(long)]
pub min_step: Option<u64>,
#[structopt(long, default_value = "2")]
pub reduction_factor: usize,
#[structopt(long)]
pub without_checkpoint: bool,
pub base_solver: JsonRecipe,
}
impl SolverRecipe for AshaSolverRecipe {
type Factory = AshaSolverFactory;
fn create_factory(&self, registry: &FactoryRegistry) -> Result<Self::Factory> {
let base = track!(registry.create_solver_factory_from_json(&self.base_solver))?;
Ok(AshaSolverFactory {
min_step_rate: self.min_step_rate,
min_step: self.min_step,
reduction_factor: self.reduction_factor,
without_checkpoint: self.without_checkpoint,
base,
})
}
}
#[derive(Debug)]
pub struct AshaSolverFactory {
min_step_rate: f64,
min_step: Option<u64>,
reduction_factor: usize,
without_checkpoint: bool,
base: BoxSolverFactory,
}
impl SolverFactory for AshaSolverFactory {
type Solver = AshaSolver;
fn specification(&self) -> Result<SolverSpec> {
let mut base = track!(self.base.specification())?;
base.capabilities
.remove_capability(Capability::MultiObjective);
let spec = SolverSpecBuilder::new(&format!("ASHA with {}", base.name))
.attr(
"version",
&format!("kurobako_solvers={}", env!("CARGO_PKG_VERSION")),
)
.attr(
"paper",
"Li, Liam, et al. \"Massively parallel hyperparameter tuning.\" \
arXiv preprint arXiv:1810.05934 (2018).",
)
.capabilities(base.capabilities);
Ok(spec.finish())
}
fn create_solver(&self, rng: ArcRng, problem: &ProblemSpec) -> Result<Self::Solver> {
let max_budget = problem.steps.last();
let min_budget = if let Some(v) = self.min_step {
v
} else {
(max_budget as f64 * self.min_step_rate) as u64
};
let base = track!(self.base.create_solver(rng.clone(), problem))?;
let mut builder = AshaOptimizerBuilder::new();
track!(builder
.reduction_factor(self.reduction_factor)
.map_err(from_yamakan))?;
if self.without_checkpoint {
builder.without_checkpoint();
}
let optimizer = track!(builder
.finish(BaseOptimizer::new(max_budget, base), min_budget, max_budget)
.map_err(from_yamakan))?;
Ok(AshaSolver {
optimizer,
rng,
trials: HashMap::new(),
max_budget,
})
}
}
#[derive(Debug)]
pub struct AshaSolver {
optimizer: AshaOptimizer<OrderedFloat<f64>, BaseOptimizer>,
rng: ArcRng,
trials: HashMap<TrialId, NextTrial>,
max_budget: u64,
}
impl Solver for AshaSolver {
fn ask(&mut self, idg: &mut IdGen) -> Result<NextTrial> {
let mut idg = YamakanIdGen(idg);
let obs = track!(self
.optimizer
.ask(&mut self.rng, &mut idg)
.map_err(from_yamakan))?;
let mut trial = obs.param.clone();
trial.id = TrialId::new(obs.id.get());
self.trials.insert(trial.id, obs.param);
Ok(trial)
}
fn tell(&mut self, trial: EvaluatedTrial) -> Result<()> {
let param = track_assert_some!(self.trials.remove(&trial.id), ErrorKind::Bug);
let value = if trial.values.is_empty() {
OrderedFloat(f64::NAN)
} else {
OrderedFloat(trial.values[0])
};
let obs = MfObs {
id: ObsId::new(trial.id.get()),
budget: Budget {
amount: self.max_budget,
consumption: trial.current_step,
},
param,
value,
};
track!(self.optimizer.tell(obs).map_err(from_yamakan))
}
}
#[derive(Debug)]
struct BaseOptimizer {
max_budget: u64,
solver: BoxSolver,
idg: IdGen,
idmap: HashMap<TrialId, ObsId>,
}
impl BaseOptimizer {
fn new(max_budget: u64, solver: BoxSolver) -> Self {
Self {
max_budget,
solver,
idg: IdGen::new(),
idmap: HashMap::new(),
}
}
}
impl Optimizer for BaseOptimizer {
type Param = NextTrial;
type Value = Ranked<OrderedFloat<f64>>;
#[allow(clippy::map_entry)]
fn ask<R: Rng, G: yamakan::IdGen>(
&mut self,
_rng: R,
mut idg: G,
) -> Result<Obs<Self::Param>, yamakan::Error> {
let trial = track!(self.solver.ask(&mut self.idg).map_err(into_yamakan))?;
if !self.idmap.contains_key(&trial.id) {
self.idmap.insert(trial.id, track!(idg.generate())?);
}
Ok(Obs {
id: self.idmap[&trial.id],
param: trial,
value: (),
})
}
fn tell(&mut self, obs: Obs<Self::Param, Self::Value>) -> Result<(), yamakan::Error> {
let value = obs.value.value.0;
let values = if value.is_nan() {
Values::new(Vec::new())
} else {
Values::new(vec![value])
};
let trial = EvaluatedTrial {
id: obs.param.id,
values,
current_step: self.max_budget - obs.value.rank,
};
track!(self.solver.tell(trial).map_err(into_yamakan))?;
Ok(())
}
}