Skip to main content

numra_optim/
cmaes.rs

1//! CMA-ES (Covariance Matrix Adaptation Evolution Strategy).
2//!
3//! A state-of-the-art derivative-free global optimizer for nonlinear,
4//! non-convex optimization in continuous domains.
5//!
6//! Author: Moussa Leblouba
7//! Date: 9 February 2026
8//! Modified: 2 May 2026
9
10use numra_core::Scalar;
11use numra_linalg::{DenseMatrix, Matrix};
12use rand::rngs::SmallRng;
13use rand::SeedableRng;
14
15use crate::error::OptimError;
16use crate::types::{IterationRecord, OptimResult, OptimStatus};
17
18/// Options for CMA-ES.
19#[derive(Clone, Debug)]
20pub struct CmaEsOptions<S: Scalar> {
21    /// Population size lambda. Default: `4 + floor(3 * ln(n))`.
22    pub population_size: Option<usize>,
23    /// Initial step size (sigma). Default: 0.5.
24    pub sigma0: S,
25    /// Maximum iterations (generations).
26    pub max_iter: usize,
27    /// Convergence tolerance on fitness spread.
28    pub tol_f: S,
29    /// Convergence tolerance on sigma.
30    pub tol_sigma: S,
31    /// Random seed.
32    pub seed: u64,
33    /// Print progress.
34    pub verbose: bool,
35}
36
37impl<S: Scalar> Default for CmaEsOptions<S> {
38    fn default() -> Self {
39        Self {
40            population_size: None,
41            sigma0: S::HALF,
42            max_iter: 10_000,
43            tol_f: S::from_f64(1e-12),
44            tol_sigma: S::from_f64(1e-12),
45            seed: 42,
46            verbose: false,
47        }
48    }
49}
50
51#[allow(clippy::needless_range_loop)]
52/// Minimize `f` using CMA-ES starting from `x0`.
53///
54/// CMA-ES samples a population from a multivariate normal distribution
55/// N(mean, sigma^2 * C), ranks by fitness, and updates the mean,
56/// step size sigma, and covariance matrix C.
57///
58/// # Arguments
59///
60/// * `f` - Objective function.
61/// * `x0` - Initial mean.
62/// * `opts` - Algorithm options.
63pub fn cmaes_minimize<S, F>(
64    f: F,
65    x0: &[S],
66    opts: &CmaEsOptions<S>,
67) -> Result<OptimResult<S>, OptimError>
68where
69    S: Scalar + faer::SimpleEntity + faer::Conjugate<Canonical = S> + faer::ComplexField,
70    F: Fn(&[S]) -> S,
71{
72    let start = std::time::Instant::now();
73    let n = x0.len();
74    if n == 0 {
75        return Err(OptimError::DimensionMismatch {
76            expected: 1,
77            actual: 0,
78        });
79    }
80    let nf = n as f64;
81
82    // Population size
83    let lambda = opts
84        .population_size
85        .unwrap_or((4.0 + (3.0 * nf.ln()).floor()) as usize);
86    let lambda = lambda.max(4); // at least 4
87    let mu = lambda / 2; // number of selected (parent) individuals
88
89    // Recombination weights: w_i = ln(mu + 0.5) - ln(i)  for i = 1..mu
90    let mut weights = Vec::with_capacity(mu);
91    let log_mu_half = (mu as f64 + 0.5).ln();
92    for i in 1..=mu {
93        weights.push(log_mu_half - (i as f64).ln());
94    }
95    let w_sum: f64 = weights.iter().sum();
96    for w in weights.iter_mut() {
97        *w /= w_sum;
98    }
99    let w_sq_sum: f64 = weights.iter().map(|w| w * w).sum();
100    let mu_eff = 1.0 / w_sq_sum;
101
102    // Learning rates
103    let cc = (4.0 + mu_eff / nf) / (nf + 4.0 + 2.0 * mu_eff / nf);
104    let cs = (mu_eff + 2.0) / (nf + mu_eff + 5.0);
105    let c1 = 2.0 / ((nf + 1.3).powi(2) + mu_eff);
106    let cmu_raw = 2.0 * (mu_eff - 2.0 + 1.0 / mu_eff) / ((nf + 2.0).powi(2) + mu_eff);
107    let cmu = cmu_raw.min(1.0 - c1);
108    let damps = 1.0 + 2.0 * (0.0_f64).max(((mu_eff - 1.0) / (nf + 1.0)).sqrt() - 1.0) + cs;
109    let chi_n = nf.sqrt() * (1.0 - 1.0 / (4.0 * nf) + 1.0 / (21.0 * nf * nf));
110
111    // State variables
112    let mut mean: Vec<S> = x0.to_vec();
113    let mut sigma = opts.sigma0;
114
115    // Covariance matrix C = I (stored as DenseMatrix)
116    let mut c_mat = DenseMatrix::<S>::zeros(n, n);
117    for i in 0..n {
118        c_mat.set(i, i, S::ONE);
119    }
120
121    // Evolution paths
122    let mut p_sigma = vec![S::ZERO; n]; // conjugate evolution path for sigma
123    let mut p_c = vec![S::ZERO; n]; // evolution path for C
124
125    // Eigendecomposition cache: C = B * D^2 * B^T
126    // B = eigenvectors, D = sqrt(eigenvalues)
127    let mut bd_mat = DenseMatrix::<S>::zeros(n, n);
128    for i in 0..n {
129        bd_mat.set(i, i, S::ONE);
130    }
131    let mut d_diag = vec![S::ONE; n]; // eigenvalues of C (not sqrt)
132    let mut inv_sqrt_diag = vec![S::ONE; n]; // 1/sqrt(eigenvalues)
133
134    let mut rng = SmallRng::seed_from_u64(opts.seed);
135    let mut n_feval = 0_usize;
136    let mut history: Vec<IterationRecord<S>> = Vec::new();
137    let mut converged = false;
138    let mut iterations = 0;
139    let mut best_x = x0.to_vec();
140    let mut best_f = f(x0);
141    n_feval += 1;
142
143    let mut eigen_update_gen: usize = 0;
144
145    for gen in 0..opts.max_iter {
146        iterations = gen + 1;
147
148        // Sample lambda offspring: x_k = mean + sigma * B * D * z_k
149        let mut population: Vec<Vec<S>> = Vec::with_capacity(lambda);
150        let mut z_vectors: Vec<Vec<S>> = Vec::with_capacity(lambda);
151
152        for _ in 0..lambda {
153            // Sample z ~ N(0, I)
154            let z: Vec<S> = (0..n).map(|_| sample_standard_normal(&mut rng)).collect();
155
156            // x = mean + sigma * B * D * z
157            let mut x = vec![S::ZERO; n];
158            for i in 0..n {
159                let mut val = S::ZERO;
160                for j in 0..n {
161                    val += bd_mat.get(i, j) * d_diag[j].sqrt() * z[j];
162                }
163                x[i] = mean[i] + sigma * val;
164            }
165
166            z_vectors.push(z);
167            population.push(x);
168        }
169
170        // Evaluate fitness
171        let mut fitness: Vec<(usize, S)> = population
172            .iter()
173            .enumerate()
174            .map(|(i, x)| (i, f(x)))
175            .collect();
176        n_feval += lambda;
177
178        // Sort by fitness (ascending = best first)
179        fitness.sort_by(|a, b| a.1.to_f64().partial_cmp(&b.1.to_f64()).unwrap());
180
181        // Track best ever
182        if fitness[0].1 < best_f {
183            best_f = fitness[0].1;
184            best_x = population[fitness[0].0].clone();
185        }
186
187        if opts.verbose && gen % 50 == 0 {
188            eprintln!(
189                "CMA-ES gen {}: best_f={:.6e}, sigma={:.4e}",
190                gen,
191                best_f.to_f64(),
192                sigma.to_f64()
193            );
194        }
195
196        history.push(IterationRecord {
197            iteration: gen,
198            objective: best_f,
199            gradient_norm: sigma,
200            step_size: sigma,
201            constraint_violation: S::ZERO,
202        });
203
204        // Check convergence
205        let f_best_gen = fitness[0].1;
206        let f_worst_gen = fitness[lambda - 1].1;
207        if (f_worst_gen - f_best_gen).abs() < opts.tol_f && sigma < opts.tol_sigma {
208            converged = true;
209            break;
210        }
211
212        // ─── Update mean ───
213        let old_mean = mean.clone();
214        for j in 0..n {
215            mean[j] = S::ZERO;
216        }
217        for i in 0..mu {
218            let idx = fitness[i].0;
219            let w_i = S::from_f64(weights[i]);
220            for j in 0..n {
221                mean[j] += w_i * population[idx][j];
222            }
223        }
224
225        // ─── Update evolution paths ───
226        // p_sigma = (1 - cs) * p_sigma + sqrt(cs * (2 - cs) * mu_eff) * C^{-1/2} * (mean - old_mean) / sigma
227        let mean_shift: Vec<S> = (0..n).map(|j| (mean[j] - old_mean[j]) / sigma).collect();
228
229        // C^{-1/2} * mean_shift = B * D^{-1} * B^T * mean_shift
230        let mut c_inv_sqrt_shift = vec![S::ZERO; n];
231        // temp = B^T * mean_shift
232        let mut temp = vec![S::ZERO; n];
233        for i in 0..n {
234            let mut val = S::ZERO;
235            for j in 0..n {
236                val += bd_mat.get(j, i) * mean_shift[j]; // B^T: row i = col i of B
237            }
238            temp[i] = val;
239        }
240        // temp2 = D^{-1} * temp
241        for i in 0..n {
242            temp[i] *= inv_sqrt_diag[i];
243        }
244        // c_inv_sqrt_shift = B * temp2
245        for i in 0..n {
246            let mut val = S::ZERO;
247            for j in 0..n {
248                val += bd_mat.get(i, j) * temp[j];
249            }
250            c_inv_sqrt_shift[i] = val;
251        }
252
253        let cs_factor = S::from_f64((cs * (2.0 - cs) * mu_eff).sqrt());
254        let one_minus_cs = S::from_f64(1.0 - cs);
255        for i in 0..n {
256            p_sigma[i] = one_minus_cs * p_sigma[i] + cs_factor * c_inv_sqrt_shift[i];
257        }
258
259        // ||p_sigma||
260        let ps_norm: f64 = p_sigma
261            .iter()
262            .map(|&v| v.to_f64() * v.to_f64())
263            .sum::<f64>()
264            .sqrt();
265
266        // h_sigma: stall indicator
267        let gen_factor = 1.0 - (1.0 - cs).powi((2 * (gen + 1)) as i32);
268        let threshold = (1.4 + 2.0 / (nf + 1.0)) * chi_n * gen_factor.sqrt();
269        let h_sigma: f64 = if ps_norm < threshold { 1.0 } else { 0.0 };
270
271        // p_c = (1 - cc) * p_c + h_sigma * sqrt(cc * (2 - cc) * mu_eff) * mean_shift
272        let cc_factor = S::from_f64(h_sigma * (cc * (2.0 - cc) * mu_eff).sqrt());
273        let one_minus_cc = S::from_f64(1.0 - cc);
274        for i in 0..n {
275            p_c[i] = one_minus_cc * p_c[i] + cc_factor * mean_shift[i];
276        }
277
278        // ─── Update covariance matrix ───
279        // C = (1 - c1 - cmu) * C + c1 * (p_c * p_c^T + delta(h_sigma) * C)
280        //   + cmu * sum_i w_i * (x_i - old_mean)*(x_i - old_mean)^T / sigma^2
281        let delta_h = (1.0 - h_sigma) * cc * (2.0 - cc);
282        let c_scale = S::from_f64(1.0 - c1 - cmu + c1 * delta_h);
283        let c1_s = S::from_f64(c1);
284        let cmu_s = S::from_f64(cmu);
285
286        for i in 0..n {
287            for j in 0..=i {
288                let mut val = c_scale * c_mat.get(i, j);
289                val += c1_s * p_c[i] * p_c[j];
290                // Rank-mu update
291                let mut rank_mu = S::ZERO;
292                for k in 0..mu {
293                    let idx = fitness[k].0;
294                    let di = (population[idx][i] - old_mean[i]) / sigma;
295                    let dj = (population[idx][j] - old_mean[j]) / sigma;
296                    rank_mu += S::from_f64(weights[k]) * di * dj;
297                }
298                val += cmu_s * rank_mu;
299                c_mat.set(i, j, val);
300                c_mat.set(j, i, val);
301            }
302        }
303
304        // ─── Update step size sigma ───
305        sigma *= S::from_f64(((cs / damps) * (ps_norm / chi_n - 1.0)).exp());
306
307        // ─── Eigendecomposition of C (every ~n/10 generations) ───
308        let eigen_interval = (n / 10).max(1);
309        if gen - eigen_update_gen >= eigen_interval {
310            eigen_update_gen = gen;
311            update_eigen(&c_mat, n, &mut bd_mat, &mut d_diag, &mut inv_sqrt_diag);
312        }
313    }
314
315    let (status, message) = if converged {
316        (
317            OptimStatus::GradientConverged,
318            format!("CMA-ES converged after {} generations", iterations),
319        )
320    } else {
321        (
322            OptimStatus::MaxIterations,
323            format!(
324                "CMA-ES: max generations ({}) reached, best f = {:.6e}",
325                opts.max_iter,
326                best_f.to_f64()
327            ),
328        )
329    };
330
331    Ok(OptimResult {
332        x: best_x,
333        f: best_f,
334        grad: Vec::new(),
335        iterations,
336        n_feval,
337        n_geval: 0,
338        converged,
339        message,
340        status,
341        history,
342        lambda_eq: Vec::new(),
343        lambda_ineq: Vec::new(),
344        active_bounds: Vec::new(),
345        constraint_violation: S::ZERO,
346        wall_time_secs: 0.0,
347        pareto: None,
348        sensitivity: None,
349    }
350    .with_wall_time(start))
351}
352
353/// Update eigendecomposition of the covariance matrix.
354/// C = B * diag(d) * B^T
355fn update_eigen<S>(
356    c_mat: &DenseMatrix<S>,
357    n: usize,
358    bd_mat: &mut DenseMatrix<S>,
359    d_diag: &mut [S],
360    inv_sqrt_diag: &mut [S],
361) where
362    S: Scalar + faer::SimpleEntity + faer::Conjugate<Canonical = S> + faer::ComplexField,
363{
364    // Use symmetric eigendecomposition
365    match c_mat.eigh() {
366        Ok(eig) => {
367            let eigenvalues = eig.eigenvalues();
368            let eigenvectors = eig.eigenvectors();
369
370            for i in 0..n {
371                let ev = eigenvalues[i];
372                // Clamp eigenvalues to small positive value for numerical stability
373                d_diag[i] = if ev > S::from_f64(1e-20) {
374                    ev
375                } else {
376                    S::from_f64(1e-20)
377                };
378                inv_sqrt_diag[i] = S::ONE / d_diag[i].sqrt();
379            }
380
381            // Copy eigenvectors to bd_mat
382            for i in 0..n {
383                for j in 0..n {
384                    bd_mat.set(i, j, eigenvectors.get(i, j));
385                }
386            }
387        }
388        Err(_) => {
389            // If eigendecomposition fails, reset to identity
390            for i in 0..n {
391                d_diag[i] = S::ONE;
392                inv_sqrt_diag[i] = S::ONE;
393                for j in 0..n {
394                    bd_mat.set(i, j, if i == j { S::ONE } else { S::ZERO });
395                }
396            }
397        }
398    }
399}
400
401/// Sample from standard normal using Box-Muller transform.
402fn sample_standard_normal<S: Scalar>(rng: &mut SmallRng) -> S {
403    use rand::Rng;
404    let u1: f64 = rng.gen::<f64>().max(1e-300);
405    let u2: f64 = rng.gen::<f64>();
406    S::from_f64((-2.0 * u1.ln()).sqrt() * (core::f64::consts::TAU * u2).cos())
407}
408
409#[cfg(test)]
410mod tests {
411    use super::*;
412
413    #[test]
414    fn test_cmaes_sphere() {
415        let result = cmaes_minimize(
416            |x: &[f64]| x.iter().map(|xi| xi * xi).sum::<f64>(),
417            &[5.0, 3.0, -2.0],
418            &CmaEsOptions {
419                max_iter: 2000,
420                ..Default::default()
421            },
422        )
423        .unwrap();
424        assert!(result.f < 1e-6, "f={}", result.f);
425        for &xi in &result.x {
426            assert!(xi.abs() < 1e-3, "xi={}", xi);
427        }
428    }
429
430    #[test]
431    fn test_cmaes_rosenbrock() {
432        let result = cmaes_minimize(
433            |x: &[f64]| (1.0 - x[0]).powi(2) + 100.0 * (x[1] - x[0] * x[0]).powi(2),
434            &[-1.0, 1.0],
435            &CmaEsOptions {
436                sigma0: 1.0,
437                max_iter: 5000,
438                ..Default::default()
439            },
440        )
441        .unwrap();
442        assert!(result.f < 0.01, "f={}", result.f);
443    }
444
445    #[test]
446    fn test_cmaes_rastrigin() {
447        // Global min at (0, 0) with f=0
448        let result = cmaes_minimize(
449            |x: &[f64]| {
450                let n = x.len() as f64;
451                10.0 * n
452                    + x.iter()
453                        .map(|xi| xi * xi - 10.0 * (2.0 * std::f64::consts::PI * xi).cos())
454                        .sum::<f64>()
455            },
456            &[2.0, -2.0],
457            &CmaEsOptions {
458                sigma0: 2.0,
459                max_iter: 5000,
460                ..Default::default()
461            },
462        )
463        .unwrap();
464        assert!(result.f < 2.0, "f={}", result.f);
465    }
466
467    #[test]
468    fn test_cmaes_1d() {
469        let result = cmaes_minimize(
470            |x: &[f64]| (x[0] - 7.0).powi(2),
471            &[0.0],
472            &CmaEsOptions::default(),
473        )
474        .unwrap();
475        assert!((result.x[0] - 7.0).abs() < 0.1, "x={}", result.x[0]);
476    }
477
478    #[test]
479    fn test_cmaes_deterministic() {
480        let f = |x: &[f64]| x[0] * x[0] + x[1] * x[1];
481        let r1 = cmaes_minimize(f, &[3.0, 4.0], &CmaEsOptions::default()).unwrap();
482        let r2 = cmaes_minimize(f, &[3.0, 4.0], &CmaEsOptions::default()).unwrap();
483        assert_eq!(r1.x, r2.x);
484        assert_eq!(r1.f, r2.f);
485    }
486}