use numra_core::Scalar;
use numra_linalg::{DenseMatrix, Matrix};
use rand::rngs::SmallRng;
use rand::SeedableRng;
use crate::error::OptimError;
use crate::types::{IterationRecord, OptimResult, OptimStatus};
#[derive(Clone, Debug)]
pub struct CmaEsOptions<S: Scalar> {
pub population_size: Option<usize>,
pub sigma0: S,
pub max_iter: usize,
pub tol_f: S,
pub tol_sigma: S,
pub seed: u64,
pub verbose: bool,
}
impl<S: Scalar> Default for CmaEsOptions<S> {
fn default() -> Self {
Self {
population_size: None,
sigma0: S::HALF,
max_iter: 10_000,
tol_f: S::from_f64(1e-12),
tol_sigma: S::from_f64(1e-12),
seed: 42,
verbose: false,
}
}
}
#[allow(clippy::needless_range_loop)]
pub fn cmaes_minimize<S, F>(
f: F,
x0: &[S],
opts: &CmaEsOptions<S>,
) -> Result<OptimResult<S>, OptimError>
where
S: Scalar + faer::SimpleEntity + faer::Conjugate<Canonical = S> + faer::ComplexField,
F: Fn(&[S]) -> S,
{
let start = std::time::Instant::now();
let n = x0.len();
if n == 0 {
return Err(OptimError::DimensionMismatch {
expected: 1,
actual: 0,
});
}
let nf = n as f64;
let lambda = opts
.population_size
.unwrap_or((4.0 + (3.0 * nf.ln()).floor()) as usize);
let lambda = lambda.max(4); let mu = lambda / 2;
let mut weights = Vec::with_capacity(mu);
let log_mu_half = (mu as f64 + 0.5).ln();
for i in 1..=mu {
weights.push(log_mu_half - (i as f64).ln());
}
let w_sum: f64 = weights.iter().sum();
for w in weights.iter_mut() {
*w /= w_sum;
}
let w_sq_sum: f64 = weights.iter().map(|w| w * w).sum();
let mu_eff = 1.0 / w_sq_sum;
let cc = (4.0 + mu_eff / nf) / (nf + 4.0 + 2.0 * mu_eff / nf);
let cs = (mu_eff + 2.0) / (nf + mu_eff + 5.0);
let c1 = 2.0 / ((nf + 1.3).powi(2) + mu_eff);
let cmu_raw = 2.0 * (mu_eff - 2.0 + 1.0 / mu_eff) / ((nf + 2.0).powi(2) + mu_eff);
let cmu = cmu_raw.min(1.0 - c1);
let damps = 1.0 + 2.0 * (0.0_f64).max(((mu_eff - 1.0) / (nf + 1.0)).sqrt() - 1.0) + cs;
let chi_n = nf.sqrt() * (1.0 - 1.0 / (4.0 * nf) + 1.0 / (21.0 * nf * nf));
let mut mean: Vec<S> = x0.to_vec();
let mut sigma = opts.sigma0;
let mut c_mat = DenseMatrix::<S>::zeros(n, n);
for i in 0..n {
c_mat.set(i, i, S::ONE);
}
let mut p_sigma = vec![S::ZERO; n]; let mut p_c = vec![S::ZERO; n];
let mut bd_mat = DenseMatrix::<S>::zeros(n, n);
for i in 0..n {
bd_mat.set(i, i, S::ONE);
}
let mut d_diag = vec![S::ONE; n]; let mut inv_sqrt_diag = vec![S::ONE; n];
let mut rng = SmallRng::seed_from_u64(opts.seed);
let mut n_feval = 0_usize;
let mut history: Vec<IterationRecord<S>> = Vec::new();
let mut converged = false;
let mut iterations = 0;
let mut best_x = x0.to_vec();
let mut best_f = f(x0);
n_feval += 1;
let mut eigen_update_gen: usize = 0;
for gen in 0..opts.max_iter {
iterations = gen + 1;
let mut population: Vec<Vec<S>> = Vec::with_capacity(lambda);
let mut z_vectors: Vec<Vec<S>> = Vec::with_capacity(lambda);
for _ in 0..lambda {
let z: Vec<S> = (0..n).map(|_| sample_standard_normal(&mut rng)).collect();
let mut x = vec![S::ZERO; n];
for i in 0..n {
let mut val = S::ZERO;
for j in 0..n {
val += bd_mat.get(i, j) * d_diag[j].sqrt() * z[j];
}
x[i] = mean[i] + sigma * val;
}
z_vectors.push(z);
population.push(x);
}
let mut fitness: Vec<(usize, S)> = population
.iter()
.enumerate()
.map(|(i, x)| (i, f(x)))
.collect();
n_feval += lambda;
fitness.sort_by(|a, b| a.1.to_f64().partial_cmp(&b.1.to_f64()).unwrap());
if fitness[0].1 < best_f {
best_f = fitness[0].1;
best_x = population[fitness[0].0].clone();
}
if opts.verbose && gen % 50 == 0 {
eprintln!(
"CMA-ES gen {}: best_f={:.6e}, sigma={:.4e}",
gen,
best_f.to_f64(),
sigma.to_f64()
);
}
history.push(IterationRecord {
iteration: gen,
objective: best_f,
gradient_norm: sigma,
step_size: sigma,
constraint_violation: S::ZERO,
});
let f_best_gen = fitness[0].1;
let f_worst_gen = fitness[lambda - 1].1;
if (f_worst_gen - f_best_gen).abs() < opts.tol_f && sigma < opts.tol_sigma {
converged = true;
break;
}
let old_mean = mean.clone();
for j in 0..n {
mean[j] = S::ZERO;
}
for i in 0..mu {
let idx = fitness[i].0;
let w_i = S::from_f64(weights[i]);
for j in 0..n {
mean[j] += w_i * population[idx][j];
}
}
let mean_shift: Vec<S> = (0..n).map(|j| (mean[j] - old_mean[j]) / sigma).collect();
let mut c_inv_sqrt_shift = vec![S::ZERO; n];
let mut temp = vec![S::ZERO; n];
for i in 0..n {
let mut val = S::ZERO;
for j in 0..n {
val += bd_mat.get(j, i) * mean_shift[j]; }
temp[i] = val;
}
for i in 0..n {
temp[i] *= inv_sqrt_diag[i];
}
for i in 0..n {
let mut val = S::ZERO;
for j in 0..n {
val += bd_mat.get(i, j) * temp[j];
}
c_inv_sqrt_shift[i] = val;
}
let cs_factor = S::from_f64((cs * (2.0 - cs) * mu_eff).sqrt());
let one_minus_cs = S::from_f64(1.0 - cs);
for i in 0..n {
p_sigma[i] = one_minus_cs * p_sigma[i] + cs_factor * c_inv_sqrt_shift[i];
}
let ps_norm: f64 = p_sigma
.iter()
.map(|&v| v.to_f64() * v.to_f64())
.sum::<f64>()
.sqrt();
let gen_factor = 1.0 - (1.0 - cs).powi((2 * (gen + 1)) as i32);
let threshold = (1.4 + 2.0 / (nf + 1.0)) * chi_n * gen_factor.sqrt();
let h_sigma: f64 = if ps_norm < threshold { 1.0 } else { 0.0 };
let cc_factor = S::from_f64(h_sigma * (cc * (2.0 - cc) * mu_eff).sqrt());
let one_minus_cc = S::from_f64(1.0 - cc);
for i in 0..n {
p_c[i] = one_minus_cc * p_c[i] + cc_factor * mean_shift[i];
}
let delta_h = (1.0 - h_sigma) * cc * (2.0 - cc);
let c_scale = S::from_f64(1.0 - c1 - cmu + c1 * delta_h);
let c1_s = S::from_f64(c1);
let cmu_s = S::from_f64(cmu);
for i in 0..n {
for j in 0..=i {
let mut val = c_scale * c_mat.get(i, j);
val += c1_s * p_c[i] * p_c[j];
let mut rank_mu = S::ZERO;
for k in 0..mu {
let idx = fitness[k].0;
let di = (population[idx][i] - old_mean[i]) / sigma;
let dj = (population[idx][j] - old_mean[j]) / sigma;
rank_mu += S::from_f64(weights[k]) * di * dj;
}
val += cmu_s * rank_mu;
c_mat.set(i, j, val);
c_mat.set(j, i, val);
}
}
sigma *= S::from_f64(((cs / damps) * (ps_norm / chi_n - 1.0)).exp());
let eigen_interval = (n / 10).max(1);
if gen - eigen_update_gen >= eigen_interval {
eigen_update_gen = gen;
update_eigen(&c_mat, n, &mut bd_mat, &mut d_diag, &mut inv_sqrt_diag);
}
}
let (status, message) = if converged {
(
OptimStatus::GradientConverged,
format!("CMA-ES converged after {} generations", iterations),
)
} else {
(
OptimStatus::MaxIterations,
format!(
"CMA-ES: max generations ({}) reached, best f = {:.6e}",
opts.max_iter,
best_f.to_f64()
),
)
};
Ok(OptimResult {
x: best_x,
f: best_f,
grad: Vec::new(),
iterations,
n_feval,
n_geval: 0,
converged,
message,
status,
history,
lambda_eq: Vec::new(),
lambda_ineq: Vec::new(),
active_bounds: Vec::new(),
constraint_violation: S::ZERO,
wall_time_secs: 0.0,
pareto: None,
sensitivity: None,
}
.with_wall_time(start))
}
fn update_eigen<S>(
c_mat: &DenseMatrix<S>,
n: usize,
bd_mat: &mut DenseMatrix<S>,
d_diag: &mut [S],
inv_sqrt_diag: &mut [S],
) where
S: Scalar + faer::SimpleEntity + faer::Conjugate<Canonical = S> + faer::ComplexField,
{
match c_mat.eigh() {
Ok(eig) => {
let eigenvalues = eig.eigenvalues();
let eigenvectors = eig.eigenvectors();
for i in 0..n {
let ev = eigenvalues[i];
d_diag[i] = if ev > S::from_f64(1e-20) {
ev
} else {
S::from_f64(1e-20)
};
inv_sqrt_diag[i] = S::ONE / d_diag[i].sqrt();
}
for i in 0..n {
for j in 0..n {
bd_mat.set(i, j, eigenvectors.get(i, j));
}
}
}
Err(_) => {
for i in 0..n {
d_diag[i] = S::ONE;
inv_sqrt_diag[i] = S::ONE;
for j in 0..n {
bd_mat.set(i, j, if i == j { S::ONE } else { S::ZERO });
}
}
}
}
}
fn sample_standard_normal<S: Scalar>(rng: &mut SmallRng) -> S {
use rand::Rng;
let u1: f64 = rng.gen::<f64>().max(1e-300);
let u2: f64 = rng.gen::<f64>();
S::from_f64((-2.0 * u1.ln()).sqrt() * (core::f64::consts::TAU * u2).cos())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cmaes_sphere() {
let result = cmaes_minimize(
|x: &[f64]| x.iter().map(|xi| xi * xi).sum::<f64>(),
&[5.0, 3.0, -2.0],
&CmaEsOptions {
max_iter: 2000,
..Default::default()
},
)
.unwrap();
assert!(result.f < 1e-6, "f={}", result.f);
for &xi in &result.x {
assert!(xi.abs() < 1e-3, "xi={}", xi);
}
}
#[test]
fn test_cmaes_rosenbrock() {
let result = cmaes_minimize(
|x: &[f64]| (1.0 - x[0]).powi(2) + 100.0 * (x[1] - x[0] * x[0]).powi(2),
&[-1.0, 1.0],
&CmaEsOptions {
sigma0: 1.0,
max_iter: 5000,
..Default::default()
},
)
.unwrap();
assert!(result.f < 0.01, "f={}", result.f);
}
#[test]
fn test_cmaes_rastrigin() {
let result = cmaes_minimize(
|x: &[f64]| {
let n = x.len() as f64;
10.0 * n
+ x.iter()
.map(|xi| xi * xi - 10.0 * (2.0 * std::f64::consts::PI * xi).cos())
.sum::<f64>()
},
&[2.0, -2.0],
&CmaEsOptions {
sigma0: 2.0,
max_iter: 5000,
..Default::default()
},
)
.unwrap();
assert!(result.f < 2.0, "f={}", result.f);
}
#[test]
fn test_cmaes_1d() {
let result = cmaes_minimize(
|x: &[f64]| (x[0] - 7.0).powi(2),
&[0.0],
&CmaEsOptions::default(),
)
.unwrap();
assert!((result.x[0] - 7.0).abs() < 0.1, "x={}", result.x[0]);
}
#[test]
fn test_cmaes_deterministic() {
let f = |x: &[f64]| x[0] * x[0] + x[1] * x[1];
let r1 = cmaes_minimize(f, &[3.0, 4.0], &CmaEsOptions::default()).unwrap();
let r2 = cmaes_minimize(f, &[3.0, 4.0], &CmaEsOptions::default()).unwrap();
assert_eq!(r1.x, r2.x);
assert_eq!(r1.f, r2.f);
}
}