1use numra_core::Scalar;
11use numra_linalg::{DenseMatrix, Matrix};
12use rand::rngs::SmallRng;
13use rand::SeedableRng;
14
15use crate::error::OptimError;
16use crate::types::{IterationRecord, OptimResult, OptimStatus};
17
18#[derive(Clone, Debug)]
20pub struct CmaEsOptions<S: Scalar> {
21 pub population_size: Option<usize>,
23 pub sigma0: S,
25 pub max_iter: usize,
27 pub tol_f: S,
29 pub tol_sigma: S,
31 pub seed: u64,
33 pub verbose: bool,
35}
36
37impl<S: Scalar> Default for CmaEsOptions<S> {
38 fn default() -> Self {
39 Self {
40 population_size: None,
41 sigma0: S::HALF,
42 max_iter: 10_000,
43 tol_f: S::from_f64(1e-12),
44 tol_sigma: S::from_f64(1e-12),
45 seed: 42,
46 verbose: false,
47 }
48 }
49}
50
51#[allow(clippy::needless_range_loop)]
52pub fn cmaes_minimize<S, F>(
64 f: F,
65 x0: &[S],
66 opts: &CmaEsOptions<S>,
67) -> Result<OptimResult<S>, OptimError>
68where
69 S: Scalar + faer::SimpleEntity + faer::Conjugate<Canonical = S> + faer::ComplexField,
70 F: Fn(&[S]) -> S,
71{
72 let start = std::time::Instant::now();
73 let n = x0.len();
74 if n == 0 {
75 return Err(OptimError::DimensionMismatch {
76 expected: 1,
77 actual: 0,
78 });
79 }
80 let nf = n as f64;
81
82 let lambda = opts
84 .population_size
85 .unwrap_or((4.0 + (3.0 * nf.ln()).floor()) as usize);
86 let lambda = lambda.max(4); let mu = lambda / 2; let mut weights = Vec::with_capacity(mu);
91 let log_mu_half = (mu as f64 + 0.5).ln();
92 for i in 1..=mu {
93 weights.push(log_mu_half - (i as f64).ln());
94 }
95 let w_sum: f64 = weights.iter().sum();
96 for w in weights.iter_mut() {
97 *w /= w_sum;
98 }
99 let w_sq_sum: f64 = weights.iter().map(|w| w * w).sum();
100 let mu_eff = 1.0 / w_sq_sum;
101
102 let cc = (4.0 + mu_eff / nf) / (nf + 4.0 + 2.0 * mu_eff / nf);
104 let cs = (mu_eff + 2.0) / (nf + mu_eff + 5.0);
105 let c1 = 2.0 / ((nf + 1.3).powi(2) + mu_eff);
106 let cmu_raw = 2.0 * (mu_eff - 2.0 + 1.0 / mu_eff) / ((nf + 2.0).powi(2) + mu_eff);
107 let cmu = cmu_raw.min(1.0 - c1);
108 let damps = 1.0 + 2.0 * (0.0_f64).max(((mu_eff - 1.0) / (nf + 1.0)).sqrt() - 1.0) + cs;
109 let chi_n = nf.sqrt() * (1.0 - 1.0 / (4.0 * nf) + 1.0 / (21.0 * nf * nf));
110
111 let mut mean: Vec<S> = x0.to_vec();
113 let mut sigma = opts.sigma0;
114
115 let mut c_mat = DenseMatrix::<S>::zeros(n, n);
117 for i in 0..n {
118 c_mat.set(i, i, S::ONE);
119 }
120
121 let mut p_sigma = vec![S::ZERO; n]; let mut p_c = vec![S::ZERO; n]; let mut bd_mat = DenseMatrix::<S>::zeros(n, n);
128 for i in 0..n {
129 bd_mat.set(i, i, S::ONE);
130 }
131 let mut d_diag = vec![S::ONE; n]; let mut inv_sqrt_diag = vec![S::ONE; n]; let mut rng = SmallRng::seed_from_u64(opts.seed);
135 let mut n_feval = 0_usize;
136 let mut history: Vec<IterationRecord<S>> = Vec::new();
137 let mut converged = false;
138 let mut iterations = 0;
139 let mut best_x = x0.to_vec();
140 let mut best_f = f(x0);
141 n_feval += 1;
142
143 let mut eigen_update_gen: usize = 0;
144
145 for gen in 0..opts.max_iter {
146 iterations = gen + 1;
147
148 let mut population: Vec<Vec<S>> = Vec::with_capacity(lambda);
150 let mut z_vectors: Vec<Vec<S>> = Vec::with_capacity(lambda);
151
152 for _ in 0..lambda {
153 let z: Vec<S> = (0..n).map(|_| sample_standard_normal(&mut rng)).collect();
155
156 let mut x = vec![S::ZERO; n];
158 for i in 0..n {
159 let mut val = S::ZERO;
160 for j in 0..n {
161 val += bd_mat.get(i, j) * d_diag[j].sqrt() * z[j];
162 }
163 x[i] = mean[i] + sigma * val;
164 }
165
166 z_vectors.push(z);
167 population.push(x);
168 }
169
170 let mut fitness: Vec<(usize, S)> = population
172 .iter()
173 .enumerate()
174 .map(|(i, x)| (i, f(x)))
175 .collect();
176 n_feval += lambda;
177
178 fitness.sort_by(|a, b| a.1.to_f64().partial_cmp(&b.1.to_f64()).unwrap());
180
181 if fitness[0].1 < best_f {
183 best_f = fitness[0].1;
184 best_x = population[fitness[0].0].clone();
185 }
186
187 if opts.verbose && gen % 50 == 0 {
188 eprintln!(
189 "CMA-ES gen {}: best_f={:.6e}, sigma={:.4e}",
190 gen,
191 best_f.to_f64(),
192 sigma.to_f64()
193 );
194 }
195
196 history.push(IterationRecord {
197 iteration: gen,
198 objective: best_f,
199 gradient_norm: sigma,
200 step_size: sigma,
201 constraint_violation: S::ZERO,
202 });
203
204 let f_best_gen = fitness[0].1;
206 let f_worst_gen = fitness[lambda - 1].1;
207 if (f_worst_gen - f_best_gen).abs() < opts.tol_f && sigma < opts.tol_sigma {
208 converged = true;
209 break;
210 }
211
212 let old_mean = mean.clone();
214 for j in 0..n {
215 mean[j] = S::ZERO;
216 }
217 for i in 0..mu {
218 let idx = fitness[i].0;
219 let w_i = S::from_f64(weights[i]);
220 for j in 0..n {
221 mean[j] += w_i * population[idx][j];
222 }
223 }
224
225 let mean_shift: Vec<S> = (0..n).map(|j| (mean[j] - old_mean[j]) / sigma).collect();
228
229 let mut c_inv_sqrt_shift = vec![S::ZERO; n];
231 let mut temp = vec![S::ZERO; n];
233 for i in 0..n {
234 let mut val = S::ZERO;
235 for j in 0..n {
236 val += bd_mat.get(j, i) * mean_shift[j]; }
238 temp[i] = val;
239 }
240 for i in 0..n {
242 temp[i] *= inv_sqrt_diag[i];
243 }
244 for i in 0..n {
246 let mut val = S::ZERO;
247 for j in 0..n {
248 val += bd_mat.get(i, j) * temp[j];
249 }
250 c_inv_sqrt_shift[i] = val;
251 }
252
253 let cs_factor = S::from_f64((cs * (2.0 - cs) * mu_eff).sqrt());
254 let one_minus_cs = S::from_f64(1.0 - cs);
255 for i in 0..n {
256 p_sigma[i] = one_minus_cs * p_sigma[i] + cs_factor * c_inv_sqrt_shift[i];
257 }
258
259 let ps_norm: f64 = p_sigma
261 .iter()
262 .map(|&v| v.to_f64() * v.to_f64())
263 .sum::<f64>()
264 .sqrt();
265
266 let gen_factor = 1.0 - (1.0 - cs).powi((2 * (gen + 1)) as i32);
268 let threshold = (1.4 + 2.0 / (nf + 1.0)) * chi_n * gen_factor.sqrt();
269 let h_sigma: f64 = if ps_norm < threshold { 1.0 } else { 0.0 };
270
271 let cc_factor = S::from_f64(h_sigma * (cc * (2.0 - cc) * mu_eff).sqrt());
273 let one_minus_cc = S::from_f64(1.0 - cc);
274 for i in 0..n {
275 p_c[i] = one_minus_cc * p_c[i] + cc_factor * mean_shift[i];
276 }
277
278 let delta_h = (1.0 - h_sigma) * cc * (2.0 - cc);
282 let c_scale = S::from_f64(1.0 - c1 - cmu + c1 * delta_h);
283 let c1_s = S::from_f64(c1);
284 let cmu_s = S::from_f64(cmu);
285
286 for i in 0..n {
287 for j in 0..=i {
288 let mut val = c_scale * c_mat.get(i, j);
289 val += c1_s * p_c[i] * p_c[j];
290 let mut rank_mu = S::ZERO;
292 for k in 0..mu {
293 let idx = fitness[k].0;
294 let di = (population[idx][i] - old_mean[i]) / sigma;
295 let dj = (population[idx][j] - old_mean[j]) / sigma;
296 rank_mu += S::from_f64(weights[k]) * di * dj;
297 }
298 val += cmu_s * rank_mu;
299 c_mat.set(i, j, val);
300 c_mat.set(j, i, val);
301 }
302 }
303
304 sigma *= S::from_f64(((cs / damps) * (ps_norm / chi_n - 1.0)).exp());
306
307 let eigen_interval = (n / 10).max(1);
309 if gen - eigen_update_gen >= eigen_interval {
310 eigen_update_gen = gen;
311 update_eigen(&c_mat, n, &mut bd_mat, &mut d_diag, &mut inv_sqrt_diag);
312 }
313 }
314
315 let (status, message) = if converged {
316 (
317 OptimStatus::GradientConverged,
318 format!("CMA-ES converged after {} generations", iterations),
319 )
320 } else {
321 (
322 OptimStatus::MaxIterations,
323 format!(
324 "CMA-ES: max generations ({}) reached, best f = {:.6e}",
325 opts.max_iter,
326 best_f.to_f64()
327 ),
328 )
329 };
330
331 Ok(OptimResult {
332 x: best_x,
333 f: best_f,
334 grad: Vec::new(),
335 iterations,
336 n_feval,
337 n_geval: 0,
338 converged,
339 message,
340 status,
341 history,
342 lambda_eq: Vec::new(),
343 lambda_ineq: Vec::new(),
344 active_bounds: Vec::new(),
345 constraint_violation: S::ZERO,
346 wall_time_secs: 0.0,
347 pareto: None,
348 sensitivity: None,
349 }
350 .with_wall_time(start))
351}
352
353fn update_eigen<S>(
356 c_mat: &DenseMatrix<S>,
357 n: usize,
358 bd_mat: &mut DenseMatrix<S>,
359 d_diag: &mut [S],
360 inv_sqrt_diag: &mut [S],
361) where
362 S: Scalar + faer::SimpleEntity + faer::Conjugate<Canonical = S> + faer::ComplexField,
363{
364 match c_mat.eigh() {
366 Ok(eig) => {
367 let eigenvalues = eig.eigenvalues();
368 let eigenvectors = eig.eigenvectors();
369
370 for i in 0..n {
371 let ev = eigenvalues[i];
372 d_diag[i] = if ev > S::from_f64(1e-20) {
374 ev
375 } else {
376 S::from_f64(1e-20)
377 };
378 inv_sqrt_diag[i] = S::ONE / d_diag[i].sqrt();
379 }
380
381 for i in 0..n {
383 for j in 0..n {
384 bd_mat.set(i, j, eigenvectors.get(i, j));
385 }
386 }
387 }
388 Err(_) => {
389 for i in 0..n {
391 d_diag[i] = S::ONE;
392 inv_sqrt_diag[i] = S::ONE;
393 for j in 0..n {
394 bd_mat.set(i, j, if i == j { S::ONE } else { S::ZERO });
395 }
396 }
397 }
398 }
399}
400
401fn sample_standard_normal<S: Scalar>(rng: &mut SmallRng) -> S {
403 use rand::Rng;
404 let u1: f64 = rng.gen::<f64>().max(1e-300);
405 let u2: f64 = rng.gen::<f64>();
406 S::from_f64((-2.0 * u1.ln()).sqrt() * (core::f64::consts::TAU * u2).cos())
407}
408
409#[cfg(test)]
410mod tests {
411 use super::*;
412
413 #[test]
414 fn test_cmaes_sphere() {
415 let result = cmaes_minimize(
416 |x: &[f64]| x.iter().map(|xi| xi * xi).sum::<f64>(),
417 &[5.0, 3.0, -2.0],
418 &CmaEsOptions {
419 max_iter: 2000,
420 ..Default::default()
421 },
422 )
423 .unwrap();
424 assert!(result.f < 1e-6, "f={}", result.f);
425 for &xi in &result.x {
426 assert!(xi.abs() < 1e-3, "xi={}", xi);
427 }
428 }
429
430 #[test]
431 fn test_cmaes_rosenbrock() {
432 let result = cmaes_minimize(
433 |x: &[f64]| (1.0 - x[0]).powi(2) + 100.0 * (x[1] - x[0] * x[0]).powi(2),
434 &[-1.0, 1.0],
435 &CmaEsOptions {
436 sigma0: 1.0,
437 max_iter: 5000,
438 ..Default::default()
439 },
440 )
441 .unwrap();
442 assert!(result.f < 0.01, "f={}", result.f);
443 }
444
445 #[test]
446 fn test_cmaes_rastrigin() {
447 let result = cmaes_minimize(
449 |x: &[f64]| {
450 let n = x.len() as f64;
451 10.0 * n
452 + x.iter()
453 .map(|xi| xi * xi - 10.0 * (2.0 * std::f64::consts::PI * xi).cos())
454 .sum::<f64>()
455 },
456 &[2.0, -2.0],
457 &CmaEsOptions {
458 sigma0: 2.0,
459 max_iter: 5000,
460 ..Default::default()
461 },
462 )
463 .unwrap();
464 assert!(result.f < 2.0, "f={}", result.f);
465 }
466
467 #[test]
468 fn test_cmaes_1d() {
469 let result = cmaes_minimize(
470 |x: &[f64]| (x[0] - 7.0).powi(2),
471 &[0.0],
472 &CmaEsOptions::default(),
473 )
474 .unwrap();
475 assert!((result.x[0] - 7.0).abs() < 0.1, "x={}", result.x[0]);
476 }
477
478 #[test]
479 fn test_cmaes_deterministic() {
480 let f = |x: &[f64]| x[0] * x[0] + x[1] * x[1];
481 let r1 = cmaes_minimize(f, &[3.0, 4.0], &CmaEsOptions::default()).unwrap();
482 let r2 = cmaes_minimize(f, &[3.0, 4.0], &CmaEsOptions::default()).unwrap();
483 assert_eq!(r1.x, r2.x);
484 assert_eq!(r1.f, r2.f);
485 }
486}