use cmaes::{CMAESOptions, Mode, ObjectiveFunction, TerminationReason};
use nalgebra::DVector;
use std::thread;
use std::time::Duration;
const TEST_REPETITIONS: usize = 100;
const MAX_GENERATIONS: usize = 1000;
fn run_test<F: ObjectiveFunction + Clone + 'static, R: Fn(TerminationReason) -> bool>(
objective_function: F,
options: CMAESOptions,
check_reason: R,
max_mismatches: usize,
) {
let mut mismatches = Vec::new();
for _ in 0..TEST_REPETITIONS {
let mut options = options.clone();
options.max_generations = options.max_generations.or(Some(MAX_GENERATIONS));
let mut cmaes_state = options.build(objective_function.clone()).unwrap();
let result = cmaes_state.run();
for reason in result.reasons {
if !check_reason(reason) {
mismatches.push(reason);
}
}
}
if mismatches.len() > max_mismatches {
panic!("exceeded {} mismatches: {:?}", max_mismatches, mismatches);
}
}
#[test]
fn test_max_function_evals() {
let function = |x: &DVector<f64>| x.magnitude().powi(2);
run_test(
function,
CMAESOptions::new(vec![5.0; 2], 1.0).max_function_evals(100),
|r| matches!(r, TerminationReason::MaxFunctionEvals),
0,
);
}
#[test]
fn test_max_generations() {
let function = |x: &DVector<f64>| x.magnitude().powi(2);
run_test(
function,
CMAESOptions::new(vec![5.0; 2], 1.0).max_generations(20),
|r| matches!(r, TerminationReason::MaxGenerations),
0,
);
}
#[test]
fn test_max_time() {
let function = |x: &DVector<f64>| {
thread::sleep(Duration::from_millis(1));
x.magnitude().powi(2)
};
run_test(
function,
CMAESOptions::new(vec![5.0; 2], 1.0).max_time(Duration::from_millis(1)),
|r| matches!(r, TerminationReason::MaxTime),
0,
);
}
#[test]
fn test_fun_target() {
fn function(x: &DVector<f64>) -> f64 {
x.magnitude().powi(2)
}
run_test(
function,
CMAESOptions::new(vec![5.0; 2], 1.0).fun_target(1e-12),
|r| matches!(r, TerminationReason::FunTarget),
0,
);
run_test(
(|x| -function(x)) as fn(&DVector<f64>) -> _,
CMAESOptions::new(vec![5.0; 2], 1.0)
.mode(Mode::Maximize)
.fun_target(-1e-12),
|r| matches!(r, TerminationReason::FunTarget),
0,
);
}
#[test]
fn test_tol_fun() {
let function = |x: &DVector<f64>| 1.0 + x.magnitude().powi(2);
run_test(
function,
CMAESOptions::new(vec![5.0; 2], 1.0).tol_fun_hist(0.0),
|r| matches!(r, TerminationReason::TolFun),
0,
);
}
#[test]
fn test_tol_fun_rel() {
let function = |x: &DVector<f64>| 1.0 + x.magnitude().powi(2);
run_test(
function,
CMAESOptions::new(vec![1e6; 2], 1.0).tol_fun_rel(1e-12),
|r| matches!(r, TerminationReason::TolFunRel),
0,
);
}
#[test]
fn test_tol_x() {
let function = |x: &DVector<f64>| x.magnitude().sqrt();
run_test(
function,
CMAESOptions::new(vec![5.0; 2], 1.0),
|r| matches!(r, TerminationReason::TolX),
0,
);
}
#[test]
fn test_tol_fun_hist() {
let function = |x: &DVector<f64>| x.magnitude().max(1e-6);
run_test(
function,
CMAESOptions::new(vec![5.0; 2], 1.0).tol_fun(0.0),
|r| matches!(r, TerminationReason::TolFunHist),
0,
);
}
#[test]
fn test_tol_stagnation() {
fn function(x: &DVector<f64>) -> f64 {
1.0 + x.magnitude() + rand::random::<f64>() * 1e1
}
run_test(
function,
CMAESOptions::new(vec![5.0; 2], 1.0).tol_stagnation(20),
|r| matches!(r, TerminationReason::TolStagnation),
0,
);
run_test(
(|x| -function(x)) as fn(&DVector<f64>) -> _,
CMAESOptions::new(vec![5.0; 2], 1.0)
.mode(Mode::Maximize)
.tol_stagnation(20),
|r| matches!(r, TerminationReason::TolStagnation),
0,
);
}
#[test]
fn test_tol_x_up() {
let function = |x: &DVector<f64>| x[0].powi(2) + x[1].powi(2);
run_test(
function,
CMAESOptions::new(vec![1e3; 2], 1e-9),
|r| matches!(r, TerminationReason::TolXUp),
0,
);
}
fn run_test_no_effect<F: Fn(TerminationReason) -> bool + Clone>(check_reason: F) {
let function =
|x: &DVector<f64>| 1e-8 + (2.0 * x[0] - x[1]).abs().powf(1.5) + (2.0 - x[1]).powi(2);
run_test(
function,
CMAESOptions::new(vec![5.0; 2], 4.0)
.tol_fun(0.0)
.tol_fun_hist(0.0)
.tol_x(1e-16),
check_reason,
1,
);
}
#[test]
fn test_no_effect() {
run_test_no_effect(|r| {
matches!(
r,
TerminationReason::NoEffectAxis | TerminationReason::NoEffectCoord,
)
});
}
#[test]
#[should_panic(expected = "NoEffectCoord")]
fn test_no_effect_axis() {
run_test_no_effect(|r| matches!(r, TerminationReason::NoEffectAxis));
}
#[test]
#[should_panic(expected = "NoEffectAxis")]
fn test_no_effect_coord() {
run_test_no_effect(|r| matches!(r, TerminationReason::NoEffectCoord));
}
#[test]
fn test_tol_condition_cov() {
let function = |x: &DVector<f64>| 0.1 + x[0].abs().powi(2) - (x[1] * 1e-14).abs().sqrt();
run_test(
function,
CMAESOptions::new(vec![5.0; 2], 1e3).tol_x(1e-12),
|r| matches!(r, TerminationReason::TolConditionCov),
1,
);
}
#[test]
fn test_invalid_function_value() {
let function = |x: &DVector<f64>| x[0].sqrt() + x[1].sqrt();
run_test(
function,
CMAESOptions::new(vec![5.0; 2], 1.0),
|r| matches!(r, TerminationReason::InvalidFunctionValue),
0,
);
}