pub mod functions;
mod history;
mod matrix;
mod mode;
pub mod objective_function;
pub mod options;
pub mod parameters;
#[cfg(feature = "plotters")]
pub mod plotting;
pub mod restart;
mod sampling;
mod state;
pub mod termination;
mod utils;
pub use nalgebra::DVector;
pub use crate::functions::*;
pub use crate::history::MAX_HISTORY_LENGTH;
pub use crate::mode::Mode;
pub use crate::objective_function::{ObjectiveFunction, ParallelObjectiveFunction};
pub use crate::options::CMAESOptions;
pub use crate::parameters::Weights;
#[cfg(feature = "plotters")]
pub use crate::plotting::PlotOptions;
pub use crate::termination::TerminationReason;
use std::f64;
use std::time::{Duration, Instant};
use crate::history::History;
use crate::matrix::SquareMatrix;
use crate::options::InvalidOptionsError;
use crate::parameters::Parameters;
#[cfg(feature = "plotters")]
use crate::plotting::Plot;
use crate::sampling::{EvaluatedPoint, InvalidFunctionValueError, Sampler};
use crate::state::State;
use crate::termination::TerminationCheck;
#[derive(Clone, Debug)]
pub struct Individual {
pub point: DVector<f64>,
pub value: f64,
}
impl Individual {
fn new(point: DVector<f64>, value: f64) -> Self {
Self { point, value }
}
}
#[derive(Clone, Debug)]
pub struct TerminationData {
pub current_best: Option<Individual>,
pub overall_best: Option<Individual>,
pub final_mean: DVector<f64>,
pub reasons: Vec<TerminationReason>,
}
pub struct CMAES<F> {
sampler: Sampler<F>,
parameters: Parameters,
state: State,
history: History,
#[cfg(feature = "plotters")]
plot: Option<Plot>,
print_gap_evals: Option<usize>,
last_print_evals: usize,
time_created: Instant,
}
impl<F> CMAES<F> {
pub fn new(objective_function: F, options: CMAESOptions) -> Result<Self, InvalidOptionsError> {
let dimensions = options.initial_mean.len();
if dimensions == 0 {
return Err(InvalidOptionsError::Dimensions);
}
if options.population_size < 2 {
return Err(InvalidOptionsError::PopulationSize);
}
if !options::is_initial_step_size_valid(options.initial_step_size) {
return Err(InvalidOptionsError::InitialStepSize);
}
if !options.cm.is_normal() || options.cm <= 0.0 || options.cm > 1.0 {
return Err(InvalidOptionsError::Cm);
}
let seed = options.seed.unwrap_or_else(rand::random);
let sampler = Sampler::new(
dimensions,
options.population_size,
objective_function,
seed,
);
let parameters = Parameters::from_options(&options, seed);
let state = State::new(options.initial_mean, options.initial_step_size);
let history = History::new();
#[cfg(feature = "plotters")]
let plot = options
.plot_options
.map(|o| Plot::new(dimensions, o, options.mode));
let cmaes = Self {
sampler,
parameters,
state,
history,
#[cfg(feature = "plotters")]
plot,
print_gap_evals: options.print_gap_evals,
last_print_evals: 0,
time_created: Instant::now(),
};
#[cfg(feature = "plotters")]
let mut cmaes = cmaes;
#[cfg(feature = "plotters")]
cmaes.add_plot_point();
if cmaes.print_gap_evals.is_some() {
cmaes.print_initial_info();
}
Ok(cmaes)
}
fn run_internal(&mut self, result: &TerminationData) {
#[cfg(feature = "plotters")]
self.add_plot_point();
if self.print_gap_evals.is_some() {
self.print_final_info(&result.reasons);
}
}
fn sample_internal(&mut self, individuals: &[EvaluatedPoint]) {
self.history.update(self.parameters.mode(), individuals);
}
fn next_internal(&mut self, individuals: &[EvaluatedPoint]) -> Option<TerminationData> {
if self
.state
.update(self.sampler.function_evals(), &self.parameters, individuals)
.is_err()
{
return Some(self.get_termination_data(vec![TerminationReason::PosDefCov]));
}
#[cfg(feature = "plotters")]
if let Some(ref plot) = self.plot {
if self.state.generation() <= 1
|| self.sampler.function_evals() >= plot.get_next_data_point_evals()
{
self.add_plot_point();
}
}
if let Some(gap_evals) = self.print_gap_evals {
if self.sampler.function_evals() >= self.last_print_evals + gap_evals {
self.print_info();
self.last_print_evals = self.sampler.function_evals();
} else if self.state.generation() < 4 {
self.print_info();
}
}
let termination_reasons = TerminationCheck {
current_function_evals: self.sampler.function_evals(),
time_created: self.time_created,
parameters: &self.parameters,
state: &self.state,
history: &self.history,
individuals,
}
.check_termination_criteria();
if !termination_reasons.is_empty() {
Some(self.get_termination_data(termination_reasons))
} else {
None
}
}
pub fn into_objective_function(self) -> F {
self.sampler.into_objective_function()
}
pub fn parameters(&self) -> &Parameters {
&self.parameters
}
pub fn generation(&self) -> usize {
self.state.generation()
}
pub fn function_evals(&self) -> usize {
self.sampler.function_evals()
}
pub fn mean(&self) -> &DVector<f64> {
self.state.mean()
}
pub fn covariance_matrix(&self) -> &SquareMatrix<f64> {
self.state.cov()
}
pub fn eigenvalues(&self) -> DVector<f64> {
self.state
.cov_sqrt_eigenvalues()
.diagonal()
.map(|x| x.powi(2))
}
pub fn sigma(&self) -> f64 {
self.state.sigma()
}
pub fn axis_ratio(&self) -> f64 {
self.state.axis_ratio()
}
pub fn current_best_individual(&self) -> Option<&Individual> {
self.history.current_best_individual()
}
pub fn overall_best_individual(&self) -> Option<&Individual> {
self.history.overall_best_individual()
}
pub fn time_created(&self) -> Instant {
self.time_created
}
pub fn elapsed(&self) -> Duration {
self.time_created.elapsed()
}
#[cfg(feature = "plotters")]
pub fn get_plot(&self) -> Option<&Plot> {
self.plot.as_ref()
}
#[cfg(feature = "plotters")]
pub fn get_mut_plot(&mut self) -> Option<&mut Plot> {
self.plot.as_mut()
}
#[doc(hidden)]
pub fn generations_per_eigen_update(&self) -> usize {
(self.state.evals_per_eigen_update(&self.parameters) as f64
/ self.parameters.lambda() as f64)
.ceil() as usize
}
fn get_termination_data(&self, reasons: Vec<TerminationReason>) -> TerminationData {
return TerminationData {
current_best: self.current_best_individual().cloned(),
overall_best: self.overall_best_individual().cloned(),
final_mean: self.state.mean().clone(),
reasons,
};
}
#[cfg(feature = "plotters")]
pub fn add_plot_point(&mut self) {
if let Some(ref mut plot) = self.plot {
plot.add_data_point(self.sampler.function_evals(), &self.state, &self.history);
}
}
pub fn print_initial_info(&self) {
let params = &self.parameters;
let variant = match params.weights_setting() {
Weights::Positive | Weights::Uniform => "CMA-ES",
Weights::Negative => "aCMA-ES",
};
println!(
"{} with dimension={}, lambda={}, seed={}",
variant,
params.dim(),
params.lambda(),
params.seed()
);
let title_string = format!(
"{:^7} | {:^7} | {:^19} | {:^10} | {:^10} | {:^10} | {:^10}",
"Gen #", "f evals", "Best function value", "Axis Ratio", "Sigma", "Min std", "Max std",
);
println!("{}", title_string);
println!("{}", "-".repeat(title_string.chars().count()));
}
pub fn print_info(&self) {
let generations = format!("{:7}", self.state.generation());
let evals = format!("{:7}", self.sampler.function_evals());
let best_function_value = self
.current_best_individual()
.map(|x| utils::format_num(x.value, 19))
.unwrap_or(format!("{:19}", ""));
let axis_ratio = utils::format_num(self.axis_ratio(), 11);
let sigma = utils::format_num(self.state.sigma(), 11);
let cov_diag = self.state.cov().diagonal();
let min_std = utils::format_num(self.state.sigma() * cov_diag.min().sqrt(), 11);
let max_std = utils::format_num(self.state.sigma() * cov_diag.max().sqrt(), 11);
println!(
"{} | {} | {} |{} |{} |{} |{}",
generations, evals, best_function_value, axis_ratio, sigma, min_std, max_std
);
}
pub fn print_final_info(&self, termination_reasons: &[TerminationReason]) {
if self.sampler.function_evals() != self.last_print_evals {
self.print_info();
}
let reasons_str = termination_reasons
.iter()
.map(|r| format!("`{}`", r))
.collect::<Vec<_>>()
.join(", ");
println!("Terminated with reason(s): {}", reasons_str);
let current_best = self.current_best_individual();
let overall_best = self.overall_best_individual();
if let (Some(current), Some(overall)) = (current_best, overall_best) {
println!("Current best function value: {:e}", current.value);
println!("Overall best function value: {:e}", overall.value);
}
println!("Final mean: {}", self.state.mean());
}
}
impl<F: ObjectiveFunction> CMAES<F> {
pub fn run(&mut self) -> TerminationData {
let result = loop {
if let Some(data) = self.next() {
break data;
}
};
self.run_internal(&result);
result
}
fn sample(&mut self) -> Result<Vec<EvaluatedPoint>, InvalidFunctionValueError> {
let individuals = self.sampler.sample(
&self.state,
self.parameters.mode(),
self.parameters.parallel_update(),
)?;
self.sample_internal(&individuals);
Ok(individuals)
}
#[allow(clippy::should_implement_trait)]
#[must_use]
pub fn next(&mut self) -> Option<TerminationData> {
let individuals = match self.sample() {
Ok(x) => x,
Err(_) => {
return Some(
self.get_termination_data(vec![TerminationReason::InvalidFunctionValue]),
);
}
};
self.next_internal(&individuals)
}
}
impl<F: ParallelObjectiveFunction> CMAES<F> {
pub fn run_parallel(&mut self) -> TerminationData {
let result = loop {
if let Some(data) = self.next_parallel() {
break data;
}
};
self.run_internal(&result);
result
}
fn sample_parallel(&mut self) -> Result<Vec<EvaluatedPoint>, InvalidFunctionValueError> {
let individuals = self.sampler.sample_parallel(
&self.state,
self.parameters.mode(),
self.parameters.parallel_update(),
)?;
self.sample_internal(&individuals);
Ok(individuals)
}
pub fn next_parallel(&mut self) -> Option<TerminationData> {
let individuals = match self.sample_parallel() {
Ok(x) => x,
Err(_) => {
return Some(
self.get_termination_data(vec![TerminationReason::InvalidFunctionValue]),
);
}
};
self.next_internal(&individuals)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn dummy_function(_: &DVector<f64>) -> f64 {
0.0
}
#[test]
fn test_get_best_individuals() {
let mut cmaes = CMAESOptions::new(vec![0.0; 10], 1.0)
.build(dummy_function)
.unwrap();
assert!(cmaes.current_best_individual().is_none());
assert!(cmaes.overall_best_individual().is_none());
let _ = cmaes.next();
assert!(cmaes.current_best_individual().is_some());
assert!(cmaes.overall_best_individual().is_some());
}
#[test]
fn test_immediate_termination() {
let function = |_: &DVector<f64>| f64::NAN;
let mut cmaes = CMAESOptions::new(vec![0.0; 10], 1.0)
.build(function)
.unwrap();
let result = cmaes.run();
assert_eq!(
vec![TerminationReason::InvalidFunctionValue],
result.reasons
);
}
#[test]
fn test_run_final_plot() {
let evals_per_plot_point = 100;
let mut cmaes = CMAESOptions::new(vec![0.0; 10], 1.0)
.enable_plot(PlotOptions::new(evals_per_plot_point, false))
.max_generations(1)
.build(dummy_function)
.unwrap();
assert_eq!(cmaes.get_plot().unwrap().len(), 1);
let _ = cmaes.run();
assert_eq!(cmaes.get_plot().unwrap().len(), 2);
}
#[test]
fn test_generations_per_eigen_update() {
let cmaes_3 = CMAESOptions::new(vec![0.0; 3], 1.0)
.build(dummy_function)
.unwrap();
let cmaes_10 = CMAESOptions::new(vec![0.0; 10], 1.0)
.build(dummy_function)
.unwrap();
let cmaes_30 = CMAESOptions::new(vec![0.0; 30], 1.0)
.build(dummy_function)
.unwrap();
assert_eq!(2, cmaes_3.generations_per_eigen_update());
assert_eq!(2, cmaes_10.generations_per_eigen_update());
assert_eq!(3, cmaes_30.generations_per_eigen_update());
}
#[test]
fn test_evals_per_eigen_update() {
let cmaes_3 = CMAESOptions::new(vec![0.0; 3], 1.0)
.build(dummy_function)
.unwrap();
let cmaes_10 = CMAESOptions::new(vec![0.0; 10], 1.0)
.build(dummy_function)
.unwrap();
let cmaes_30 = CMAESOptions::new(vec![0.0; 30], 1.0)
.build(dummy_function)
.unwrap();
assert_eq!(
8,
cmaes_3.state.evals_per_eigen_update(cmaes_3.parameters())
);
assert_eq!(
15,
cmaes_10.state.evals_per_eigen_update(cmaes_10.parameters())
);
assert_eq!(
34,
cmaes_30.state.evals_per_eigen_update(cmaes_30.parameters())
);
}
}