mod bipop;
mod ipop;
mod local;
pub mod options;
mod strategy;
pub use bipop::BIPOP;
pub use ipop::IPOP;
pub use local::Local;
pub use options::RestartOptions;
pub use strategy::RestartStrategy;
use nalgebra::DVector;
use rand::{Rng, SeedableRng};
use rand_chacha::ChaChaRng;
use std::fmt::{self, Debug};
use std::ops::RangeInclusive;
use std::time::{Duration, Instant};
use crate::parameters::Parameters;
use crate::{
utils, CMAESOptions, Individual, Mode, ObjectiveFunction, ParallelObjectiveFunction,
TerminationData, TerminationReason, CMAES,
};
use options::{InvalidRestartOptionsError, InvalidRestartStrategyOptionsError};
use strategy::{RestartControl, Strategy};
const DEFAULT_INITIAL_STEP_SIZE: f64 = 0.5;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum RestartTerminationReason {
MaxRuns,
FunTarget,
MaxFunctionEvals,
MaxTime,
InvalidFunctionValue,
}
impl fmt::Display for RestartTerminationReason {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
Debug::fmt(self, f)
}
}
#[derive(Clone, Debug)]
pub struct RestartResults {
pub best: Option<Individual>,
pub reason: RestartTerminationReason,
pub function_evals: usize,
pub runs: usize,
}
impl RestartResults {
fn print_results(&self) {
println!(
"Terminated in {} f-evals with reason: `{}`",
self.function_evals, self.reason
);
if let Some(ref best) = self.best {
println!("Best function value: {:e}", best.value);
println!("Best point: {}", best.point);
}
}
}
#[derive(Clone, Debug)]
pub struct Restarter {
strategy: RestartStrategy,
dimensions: usize,
mode: Mode,
parallel_update: bool,
search_range: RangeInclusive<f64>,
fun_target: Option<f64>,
max_function_evals: Option<usize>,
max_time: Option<Duration>,
max_function_evals_per_run: Option<usize>,
max_generations_per_run: Option<usize>,
print_info: bool,
seed: u64,
rng: ChaChaRng,
overall_best: Option<Individual>,
}
impl Restarter {
pub fn new(options: RestartOptions) -> Result<Self, InvalidRestartOptionsError> {
let seed = options.seed.unwrap_or_else(rand::random);
if options.dimensions == 0 {
Err(InvalidRestartOptionsError::Dimensions)
} else if options.search_range.end() - options.search_range.start() == 0.0 {
Err(InvalidRestartOptionsError::SearchRange)
} else {
Ok(Self {
strategy: options.strategy,
dimensions: options.dimensions,
mode: options.mode,
parallel_update: options.parallel_update,
search_range: options.search_range,
fun_target: options.fun_target,
max_function_evals: options.max_function_evals,
max_time: options.max_time,
max_function_evals_per_run: options.max_function_evals_per_run,
max_generations_per_run: options.max_generations_per_run,
print_info: options.enable_printing,
seed,
rng: ChaChaRng::seed_from_u64(seed),
overall_best: None,
})
}
}
pub fn seed(&self) -> u64 {
self.seed
}
pub fn run<F, G>(self, get_objective_function: G) -> RestartResults
where
F: ObjectiveFunction,
G: FnMut() -> F,
{
self.run_internal(get_objective_function, false, |state| state.run())
}
pub fn run_with_reuse<F: ObjectiveFunction>(self, objective_function: F) -> RestartResults {
let mut function = Some(objective_function);
self.run_internal(|| function.take().unwrap(), true, |state| state.run())
}
pub fn run_parallel<F, G>(self, get_objective_function: G) -> RestartResults
where
F: ParallelObjectiveFunction,
G: FnMut() -> F,
{
self.run_internal(get_objective_function, false, |state| state.run_parallel())
}
pub fn run_parallel_with_reuse<F: ParallelObjectiveFunction>(
self,
objective_function: F,
) -> RestartResults {
let mut function = Some(objective_function);
self.run_internal(
|| function.take().unwrap(),
true,
|state| state.run_parallel(),
)
}
fn run_internal<F, G, R>(
mut self,
mut get_objective_function: G,
reuse_objective_function: bool,
run: R,
) -> RestartResults
where
G: FnMut() -> F,
R: Copy + Fn(&mut CMAES<F>) -> TerminationData,
{
if self.print_info {
self.print_initial_info();
}
let time_started = Instant::now();
let reason;
let mut function_evals = 0;
let mut runs = 0;
let mut objective_function = None;
loop {
if self.strategy.has_zero_max_runs() {
reason = RestartTerminationReason::MaxRuns;
break;
}
if let Some(max_function_evals) = self.max_function_evals {
if function_evals >= max_function_evals {
reason = RestartTerminationReason::MaxFunctionEvals;
break;
}
}
if let Some(max_time) = self.max_time {
if time_started.elapsed() >= max_time {
reason = RestartTerminationReason::MaxTime;
break;
}
}
let initial_mean = self.generate_initial_mean();
let seed = self.rng.gen();
let mut options = CMAESOptions::new(initial_mean, DEFAULT_INITIAL_STEP_SIZE)
.mode(self.mode)
.parallel_update(self.parallel_update)
.seed(seed);
options.max_function_evals = self.max_function_evals_per_run;
options.max_generations = self.max_generations_per_run;
options.fun_target = self.fun_target;
let search_range_size = (self.search_range.end() - self.search_range.start()).abs();
let function = objective_function
.take()
.unwrap_or_else(&mut get_objective_function);
let run_with_print = |cmaes: &mut CMAES<F>| {
if self.print_info {
print_run_info(runs + 1, cmaes.parameters());
}
run(cmaes)
};
let (final_state, reasons, control) = self.strategy.next_run(
options,
search_range_size,
function,
run_with_print,
&mut self.rng,
);
if self.print_info {
print_run_results(runs + 1, &final_state, &reasons);
}
function_evals += final_state.function_evals();
runs += 1;
if reasons
.iter()
.any(|&r| r == TerminationReason::InvalidFunctionValue)
{
reason = RestartTerminationReason::InvalidFunctionValue;
break;
}
if let Some(best) = final_state.overall_best_individual().cloned() {
self.update_best_individual(best);
}
if reasons.iter().any(|&r| r == TerminationReason::FunTarget) {
reason = RestartTerminationReason::FunTarget;
break;
}
match control {
RestartControl::Continue => (),
RestartControl::MaxRunsReached => {
reason = RestartTerminationReason::MaxRuns;
break;
}
}
if reuse_objective_function {
objective_function = Some(final_state.into_objective_function());
}
}
let results = RestartResults {
best: self.overall_best,
reason,
function_evals,
runs,
};
if self.print_info {
results.print_results();
}
results
}
fn update_best_individual(&mut self, individual: Individual) {
match self.overall_best {
Some(ref mut current_best) => {
if self.mode.is_better(individual.value, current_best.value) {
*current_best = individual;
}
}
None => self.overall_best = Some(individual),
}
}
fn print_initial_info(&self) {
let algorithm_name = self.strategy.get_algorithm_name();
let parameters_str = self
.strategy
.get_parameters_as_strings()
.iter()
.map(|(name, value)| format!("{}={}", name, value))
.collect::<Vec<_>>()
.join(", ");
let search_range_str = format!(
"[{}, {}]",
self.search_range.start(),
self.search_range.end()
);
println!(
"{} with dimension={}, search_range={}, {}, seed={}",
algorithm_name, self.dimensions, search_range_str, parameters_str, self.seed
);
}
fn generate_initial_mean(&mut self) -> DVector<f64> {
DVector::from_iterator(
self.dimensions,
(0..self.dimensions).map(|_| self.rng.gen_range(self.search_range.clone())),
)
}
}
fn print_run_info(run: usize, params: &Parameters) {
println!(
"Run {} parameters: lambda={}, sigma0={}, seed={}",
run,
params.lambda(),
utils::format_num(params.initial_sigma(), 12),
params.seed()
);
}
fn print_run_results<F>(run: usize, cmaes: &CMAES<F>, termination_reasons: &[TerminationReason]) {
let best_value_str = cmaes
.overall_best_individual()
.map(|ind| utils::format_num(ind.value, 19))
.unwrap_or_else(|| "None".into());
let reasons_str = termination_reasons
.iter()
.map(|r| format!("`{}`", r))
.collect::<Vec<_>>()
.join(", ");
println!(
"Run {} results: best f-val={}, termination_reasons=[{}], f-evals={}",
run,
best_value_str,
reasons_str,
cmaes.function_evals()
);
}
#[cfg(test)]
mod tests {
use assert_approx_eq::assert_approx_eq;
use std::thread;
use super::*;
fn dummy_function(x: &DVector<f64>) -> f64 {
x.magnitude()
}
#[test]
fn test_run() {
let strategies = [
RestartStrategy::Local(Local::new(10, None).unwrap()),
RestartStrategy::IPOP(Default::default()),
RestartStrategy::BIPOP(Default::default()),
];
for s in strategies {
let results = RestartOptions::new(1, -1.0..=1.0, s)
.build()
.unwrap()
.run(|| dummy_function);
assert!(results.runs > 0);
assert!(results.function_evals > 0);
assert!(results.best.is_some());
assert_eq!(RestartTerminationReason::MaxRuns, results.reason);
}
}
#[test]
fn test_run_with_reuse() {
let mut counter = 0;
let function = |x: &DVector<f64>| {
counter += 1;
x.magnitude()
};
let max_runs = 10;
let strategy = RestartStrategy::Local(Local::new(max_runs, None).unwrap());
let dim = 1;
let results = RestartOptions::new(dim, -1.0..=1.0, strategy)
.build()
.unwrap()
.run_with_reuse(function);
assert_eq!(results.function_evals, counter);
}
#[test]
fn test_run_no_reuse() {
let mut counter = 0;
let max_runs = 10;
let strategy = RestartStrategy::Local(Local::new(max_runs, None).unwrap());
let get_objective_function = || {
counter += 1;
dummy_function
};
let dim = 1;
let _ = RestartOptions::new(dim, -1.0..=1.0, strategy)
.max_function_evals_per_run(0)
.build()
.unwrap()
.run(get_objective_function);
assert_eq!(max_runs, counter);
}
#[test]
fn test_zero_max_runs() {
let strategy = RestartStrategy::Local(Local::new(0, None).unwrap());
let results = RestartOptions::new(1, -1.0..=1.0, strategy)
.build()
.unwrap()
.run(|| dummy_function);
assert_eq!(0, results.runs);
assert_eq!(0, results.function_evals);
assert_eq!(RestartTerminationReason::MaxRuns, results.reason);
assert!(results.best.is_none());
}
#[test]
fn test_zero_max_function_evals() {
let strategy = RestartStrategy::Local(Local::new(10, None).unwrap());
let results = RestartOptions::new(1, -1.0..=1.0, strategy)
.max_function_evals(0)
.build()
.unwrap()
.run(|| dummy_function);
assert_eq!(0, results.runs);
assert_eq!(0, results.function_evals);
assert_eq!(RestartTerminationReason::MaxFunctionEvals, results.reason);
assert!(results.best.is_none());
}
#[test]
fn test_max_function_evals() {
let strategy = RestartStrategy::Local(Local::new(10, None).unwrap());
let results = RestartOptions::new(1, -1.0..=1.0, strategy)
.max_function_evals_per_run(101)
.max_function_evals(500)
.build()
.unwrap()
.run(|| dummy_function);
assert_eq!(5, results.runs);
assert!(results.function_evals >= 500);
assert_eq!(RestartTerminationReason::MaxFunctionEvals, results.reason);
assert!(results.best.is_some());
}
#[test]
fn test_max_time() {
let function = |_: &DVector<f64>| {
thread::sleep(Duration::from_millis(10));
0.0
};
let strategy = RestartStrategy::Local(Local::new(10, None).unwrap());
let results = RestartOptions::new(1, -1.0..=1.0, strategy)
.max_time(Duration::from_millis(100))
.build()
.unwrap()
.run(|| function);
assert_eq!(RestartTerminationReason::MaxTime, results.reason);
}
#[test]
fn test_invalid_function_value() {
let function = |_: &DVector<f64>| f64::NAN;
let strategy = RestartStrategy::Local(Local::new(10, None).unwrap());
let results = RestartOptions::new(1, -1.0..=1.0, strategy)
.max_time(Duration::from_millis(100))
.build()
.unwrap()
.run(|| function);
assert!(results.best.is_none());
assert_eq!(
RestartTerminationReason::InvalidFunctionValue,
results.reason
);
}
#[test]
fn test_fun_target_maximize() {
let function = |_: &DVector<f64>| 1.0;
let strategy = RestartStrategy::BIPOP(Default::default());
let results = RestartOptions::new(1, -1.0..=1.0, strategy)
.mode(Mode::Maximize)
.fun_target(-1.0)
.build()
.unwrap()
.run(|| function);
assert_eq!(RestartTerminationReason::FunTarget, results.reason,);
}
fn update_and_test(restarter: &mut Restarter, new_value: f64, expected: f64) {
restarter.update_best_individual(Individual::new(vec![0.0; 4].into(), new_value));
assert_eq!(expected, restarter.overall_best.clone().unwrap().value);
}
#[test]
fn test_update_best_individual_minimize() {
let strategy = RestartStrategy::Local(Local::new(5, None).unwrap());
let mut restarter = RestartOptions::new(1, -1.0..=1.0, strategy)
.mode(Mode::Minimize)
.build()
.unwrap();
assert!(restarter.overall_best.is_none());
update_and_test(&mut restarter, 1.0, 1.0);
update_and_test(&mut restarter, 2.0, 1.0);
update_and_test(&mut restarter, 0.0, 0.0);
}
#[test]
fn test_update_best_individual_maximize() {
let strategy = RestartStrategy::Local(Local::new(5, None).unwrap());
let mut restarter = RestartOptions::new(1, -1.0..=1.0, strategy)
.mode(Mode::Maximize)
.build()
.unwrap();
assert!(restarter.overall_best.is_none());
update_and_test(&mut restarter, 1.0, 1.0);
update_and_test(&mut restarter, 0.0, 1.0);
update_and_test(&mut restarter, 2.0, 2.0);
}
#[test]
fn test_fixed_seed() {
let function = |x: &DVector<f64>| 1e-8 + (x[0] - 2.0).powi(2) + (x[1] - 1.0).powi(2);
let strategy = RestartStrategy::Local(Local::new(10, None).unwrap());
let seed = 96674803299116567;
let results = RestartOptions::new(2, -1.0..=1.0, strategy)
.seed(seed)
.build()
.unwrap()
.run(|| function);
assert_eq!(10, results.runs);
assert_eq!(5790, results.function_evals);
assert_eq!(RestartTerminationReason::MaxRuns, results.reason);
let best = results.best.unwrap();
let eps = 1e-12;
assert_approx_eq!(1.0000000002890303e-8, best.value, eps);
assert_approx_eq!(2.00000000075140, best.point[0], eps);
assert_approx_eq!(0.9999999989799438, best.point[1], eps);
}
}