use crate::error::{NumRs2Error, Result};
use super::config::CMAESConfig;
use super::state::CmaEsState;
use super::types::{CMAESResult, RngSource, TerminationReason};
pub fn cma_es<F>(f: F, x0: &[f64], config: CMAESConfig) -> Result<CMAESResult>
where
F: Fn(&[f64]) -> f64,
{
if config.enable_restarts {
cma_es_ipop(&f, x0, &config)
} else {
cma_es_single_run(&f, x0, &config)
}
}
fn cma_es_single_run<F>(f: &F, x0: &[f64], config: &CMAESConfig) -> Result<CMAESResult>
where
F: Fn(&[f64]) -> f64,
{
let mut state = CmaEsState::new(x0, config)?;
let f0 = state.evaluate(f, x0, &config.bounds, config.penalty_coefficient);
state.best_f = f0;
state.best_x = x0.to_vec();
state.history.push(f0);
loop {
let mut population = state.sample_population();
if let Some(ref bounds) = config.bounds {
for x in &mut population {
CmaEsState::repair_bounds(x, bounds);
}
}
let fitness_values: Vec<f64> = population
.iter()
.map(|x| state.evaluate(f, x, &config.bounds, config.penalty_coefficient))
.collect();
let mut indices: Vec<usize> = (0..state.lambda).collect();
indices.sort_by(|&a, &b| {
fitness_values[a]
.partial_cmp(&fitness_values[b])
.unwrap_or(std::cmp::Ordering::Equal)
});
let sorted_pop: Vec<Vec<f64>> = indices.iter().map(|&i| population[i].clone()).collect();
let sorted_fitness: Vec<f64> = indices.iter().map(|&i| fitness_values[i]).collect();
if sorted_fitness[0] < state.best_f {
state.best_f = sorted_fitness[0];
state.best_x = sorted_pop[0].clone();
}
state.history.push(state.best_f);
if let Some(reason) = state.check_termination(config, &sorted_fitness) {
let success = matches!(
reason,
TerminationReason::FunctionTolerance | TerminationReason::ParameterTolerance
);
let final_condition = state.condition_number();
let history = state.history;
return Ok(CMAESResult {
x: state.best_x,
fun: state.best_f,
nit: state.generation,
nfev: state.function_evaluations,
success,
message: format!("{}", reason),
history,
final_sigma: state.sigma,
final_condition_number: final_condition,
restarts: 0,
termination_reason: reason,
});
}
let old_mean = state.update_mean(&sorted_pop);
let y_w = state.compute_weighted_step(&old_mean);
state.update_step_size(&y_w);
state.update_covariance(&y_w, &sorted_pop, &old_mean);
state.update_eigendecomposition()?;
state.generation += 1;
}
}
fn cma_es_ipop<F>(f: &F, x0: &[f64], config: &CMAESConfig) -> Result<CMAESResult>
where
F: Fn(&[f64]) -> f64,
{
let mut best_result: Option<CMAESResult> = None;
let mut total_fevals = 0;
let mut total_restarts = 0;
let mut current_lambda = config.effective_lambda(x0.len());
let mut restart_rng = RngSource::create(config.seed.map(|s| s.wrapping_add(42)));
for restart in 0..=config.max_restarts {
let mut run_config = config.clone();
run_config.population_size = current_lambda;
run_config.enable_restarts = false;
if let Some(base_seed) = config.seed {
run_config.seed = Some(base_seed.wrapping_add(restart as u64 * 1000));
}
let start_point = if restart == 0 {
x0.to_vec()
} else {
perturb_start_point(x0, config.sigma0, &config.bounds, &mut restart_rng)
};
let result = cma_es_single_run(f, &start_point, &run_config)?;
total_fevals += result.nfev;
let is_better = match &best_result {
None => true,
Some(prev) => result.fun < prev.fun,
};
if is_better {
best_result = Some(result.clone());
}
if result.success {
let mut final_result = best_result.ok_or_else(|| {
NumRs2Error::ComputationError("No result available after convergence".to_string())
})?;
final_result.nfev = total_fevals;
final_result.restarts = restart;
return Ok(final_result);
}
total_restarts += 1;
current_lambda = (current_lambda as f64 * config.restart_pop_multiplier).ceil() as usize;
}
match best_result {
Some(mut result) => {
result.nfev = total_fevals;
result.restarts = total_restarts;
result.termination_reason = TerminationReason::NoImprovementAfterRestarts;
result.success = false;
result.message = format!("{}", TerminationReason::NoImprovementAfterRestarts);
Ok(result)
}
None => Err(NumRs2Error::ComputationError(
"CMA-ES IPOP produced no results".to_string(),
)),
}
}
fn perturb_start_point(
x0: &[f64],
sigma: f64,
bounds: &Option<Vec<(f64, f64)>>,
rng: &mut RngSource,
) -> Vec<f64> {
let mut point: Vec<f64> = x0
.iter()
.map(|&xi| xi + rng.sample_normal_with_std(sigma))
.collect();
if let Some(ref b) = bounds {
for (i, &(lo, hi)) in b.iter().enumerate() {
if i < point.len() {
point[i] = point[i].clamp(lo, hi);
}
}
}
point
}