Skip to main content

math_audio_optimisation/
cmaes.rs

1//! Covariance Matrix Adaptation Evolution Strategy (CMA-ES).
2//!
3//! This is a pure-Rust, bounded, full-covariance CMA-ES implementation for
4//! continuous black-box minimisation. Internally the search runs in a
5//! normalised `[0, 1]^n` box so parameter scales such as log-frequency, Q, and
6//! gain can share one covariance matrix without manual preconditioning.
7
8use nalgebra::{DMatrix, DVector, SymmetricEigen};
9use ndarray::Array1;
10use rand::rngs::StdRng;
11use rand::{Rng, SeedableRng};
12use rayon::prelude::*;
13
14use crate::CallbackAction;
15use crate::error::{DEError, Result};
16use crate::parallel_eval::ParallelConfig;
17
18/// Per-generation callback payload for [`cma_es`].
19pub struct CmaEsIntermediate {
20    /// Current best parameter vector in the original bounded coordinates.
21    pub x: Array1<f64>,
22    /// Current best objective value.
23    pub fun: f64,
24    /// Current generation index.
25    pub iter: usize,
26    /// Number of objective evaluations consumed so far.
27    pub nfev: usize,
28    /// Current global step size in normalised coordinates.
29    pub sigma: f64,
30}
31
32/// Callback type used by [`CmaEsConfig`].
33pub type CmaEsCallback = Box<dyn FnMut(&CmaEsIntermediate) -> CallbackAction + Send>;
34
35/// Configuration for [`cma_es`].
36pub struct CmaEsConfig {
37    /// `(lower, upper)` bounds per parameter.
38    pub bounds: Vec<(f64, f64)>,
39    /// Optional initial mean. Values outside bounds are clipped.
40    pub x0: Option<Array1<f64>>,
41    /// Initial step size in normalised `[0, 1]` coordinates.
42    ///
43    /// `None` uses `0.3`, the standard broad-search default for bounded
44    /// CMA-ES. Smaller values are appropriate for local refinement.
45    pub sigma0: Option<f64>,
46    /// Offspring population size. `0` uses `4 + floor(3 ln(n))`.
47    pub lambda: usize,
48    /// Parent count. `0` uses `lambda / 2`.
49    pub mu: usize,
50    /// Maximum objective evaluations.
51    pub maxeval: usize,
52    /// Optional RNG seed for deterministic runs.
53    pub seed: Option<u64>,
54    /// Stop after this many generations with improvement below [`Self::f_tol`].
55    pub stagnation_window: usize,
56    /// Objective-improvement tolerance for stagnation detection.
57    pub f_tol: f64,
58    /// Stop once the best objective is at or below this value.
59    pub target_f: f64,
60    /// Optional per-generation callback. Returning [`CallbackAction::Stop`]
61    /// terminates the run early and returns the best point seen so far.
62    pub callback: Option<CmaEsCallback>,
63    /// Parallel evaluation configuration for offspring fitness calls.
64    pub parallel: ParallelConfig,
65}
66
67impl Default for CmaEsConfig {
68    fn default() -> Self {
69        Self {
70            bounds: Vec::new(),
71            x0: None,
72            sigma0: None,
73            lambda: 0,
74            mu: 0,
75            maxeval: 10_000,
76            seed: None,
77            stagnation_window: 80,
78            f_tol: 1e-10,
79            target_f: f64::NEG_INFINITY,
80            callback: None,
81            parallel: ParallelConfig::default(),
82        }
83    }
84}
85
86/// Result of a [`cma_es`] run.
87#[derive(Clone)]
88pub struct CmaEsReport {
89    /// Best parameter vector found.
90    pub x: Array1<f64>,
91    /// Objective value at [`Self::x`].
92    pub fun: f64,
93    /// Whether the run met a convergence/target/callback stop condition before
94    /// exhausting the evaluation budget.
95    pub success: bool,
96    /// Human-readable termination message.
97    pub message: String,
98    /// Objective evaluations consumed.
99    pub nfev: usize,
100    /// Generations completed.
101    pub nit: usize,
102    /// Final global step size in normalised coordinates.
103    pub sigma: f64,
104}
105
106impl std::fmt::Debug for CmaEsReport {
107    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108        f.debug_struct("CmaEsReport")
109            .field("x_len", &self.x.len())
110            .field("fun", &self.fun)
111            .field("success", &self.success)
112            .field("message", &self.message)
113            .field("nfev", &self.nfev)
114            .field("nit", &self.nit)
115            .field("sigma", &self.sigma)
116            .finish()
117    }
118}
119
120#[derive(Clone)]
121struct Candidate {
122    y: DVector<f64>,
123    fun: f64,
124}
125
126struct Sample {
127    y: DVector<f64>,
128    x: Array1<f64>,
129}
130
131/// Minimise `f` with bounded full-covariance CMA-ES.
132///
133/// The objective receives parameters in the original coordinate system. Bounds
134/// are handled by clipping sampled normalised points before evaluation.
135pub fn cma_es<F>(f: &F, mut config: CmaEsConfig) -> Result<CmaEsReport>
136where
137    F: Fn(&Array1<f64>) -> f64 + Sync,
138{
139    let n = config.bounds.len();
140    if n == 0 {
141        return Err(DEError::BoundsMismatch {
142            lower_len: 0,
143            upper_len: 0,
144        });
145    }
146    for (i, (lo, hi)) in config.bounds.iter().enumerate() {
147        if lo > hi {
148            return Err(DEError::InvalidBounds {
149                index: i,
150                lower: *lo,
151                upper: *hi,
152            });
153        }
154    }
155    if let Some(ref x0) = config.x0
156        && x0.len() != n
157    {
158        return Err(DEError::X0DimensionMismatch {
159            expected: n,
160            got: x0.len(),
161        });
162    }
163
164    let lambda = if config.lambda == 0 {
165        (4.0 + (3.0 * (n as f64).ln()).floor()).max(4.0) as usize
166    } else {
167        config.lambda
168    };
169    if lambda < 2 {
170        return Err(DEError::PopulationTooSmall { pop_size: lambda });
171    }
172    let mu = if config.mu == 0 {
173        lambda / 2
174    } else {
175        config.mu.min(lambda)
176    }
177    .max(1);
178
179    let weights = recombination_weights(mu);
180    let mueff = 1.0 / weights.iter().map(|w| w * w).sum::<f64>();
181    let n_f = n as f64;
182
183    let cc = (4.0 + mueff / n_f) / (n_f + 4.0 + 2.0 * mueff / n_f);
184    let cs = (mueff + 2.0) / (n_f + mueff + 5.0);
185    let c1 = 2.0 / ((n_f + 1.3).powi(2) + mueff);
186    let cmu = (1.0 - c1).min(2.0 * (mueff - 2.0 + 1.0 / mueff) / ((n_f + 2.0).powi(2) + mueff));
187    let damps = 1.0 + 2.0 * ((mueff - 1.0) / (n_f + 1.0)).sqrt().max(1.0) - 2.0 + cs;
188    let chi_n = n_f.sqrt() * (1.0 - 1.0 / (4.0 * n_f) + 1.0 / (21.0 * n_f * n_f));
189
190    let mut mean = initial_mean(&config);
191    let mut sigma = config.sigma0.unwrap_or(0.3).clamp(1e-12, 2.0);
192    let mut covariance = DMatrix::<f64>::identity(n, n);
193    let mut b = DMatrix::<f64>::identity(n, n);
194    let mut d = DVector::<f64>::from_element(n, 1.0);
195    let mut invsqrt_c = DMatrix::<f64>::identity(n, n);
196    let mut pc = DVector::<f64>::zeros(n);
197    let mut ps = DVector::<f64>::zeros(n);
198
199    let mut rng: StdRng = match config.seed {
200        Some(s) => StdRng::seed_from_u64(s),
201        None => {
202            let mut thread_rng = rand::rng();
203            StdRng::from_rng(&mut thread_rng)
204        }
205    };
206
207    let initial_x = denormalise(&mean, &config.bounds);
208    let initial_fun = finite_or_infinity(f(&initial_x));
209    let mut best_x = initial_x;
210    let mut best_fun = initial_fun;
211    let mut nfev = 1usize;
212    let mut nit = 0usize;
213    let mut last_improvement_fun = best_fun;
214    let mut stagnation_counter = 0usize;
215    let mut message = String::from("maximum evaluations reached");
216    let mut success = false;
217
218    if let Some(n) = config.parallel.num_threads {
219        let _ = rayon::ThreadPoolBuilder::new()
220            .num_threads(n)
221            .build_global();
222    }
223
224    while nfev < config.maxeval {
225        let old_mean = mean.clone();
226        let transform = &b * DMatrix::<f64>::from_diagonal(&d);
227        let eval_budget = (config.maxeval - nfev).min(lambda);
228        let mut samples: Vec<Sample> = Vec::with_capacity(eval_budget);
229
230        for _ in 0..eval_budget {
231            let z = standard_normal_vector(n, &mut rng);
232            let step = &transform * z;
233            let y = clamp_unit_vector(&(old_mean.clone() + step * sigma));
234            let x = denormalise(&y, &config.bounds);
235            samples.push(Sample { y, x });
236        }
237
238        let mut candidates: Vec<Candidate> = if config.parallel.enabled && samples.len() >= 4 {
239            samples
240                .par_iter()
241                .map(|sample| Candidate {
242                    y: sample.y.clone(),
243                    fun: finite_or_infinity(f(&sample.x)),
244                })
245                .collect()
246        } else {
247            samples
248                .iter()
249                .map(|sample| Candidate {
250                    y: sample.y.clone(),
251                    fun: finite_or_infinity(f(&sample.x)),
252                })
253                .collect()
254        };
255        nfev += candidates.len();
256
257        for (sample, candidate) in samples.iter().zip(candidates.iter()) {
258            if candidate.fun < best_fun {
259                best_fun = candidate.fun;
260                best_x = sample.x.clone();
261            }
262        }
263
264        if candidates.is_empty() {
265            break;
266        }
267        candidates.sort_by(|a, b| a.fun.total_cmp(&b.fun));
268
269        mean.fill(0.0);
270        for i in 0..mu.min(candidates.len()) {
271            mean += candidates[i].y.clone() * weights[i];
272        }
273        mean = clamp_unit_vector(&mean);
274
275        let y_w = (&mean - &old_mean) / sigma.max(1e-30);
276        ps = ps * (1.0 - cs) + (&invsqrt_c * &y_w) * (cs * (2.0 - cs) * mueff).sqrt();
277        let norm_ps = ps.norm();
278        let hsig_den = (1.0 - (1.0 - cs).powi(2 * (nit as i32 + 1))).sqrt() * chi_n;
279        let hsig = if hsig_den > 0.0 {
280            norm_ps / hsig_den < 1.4 + 2.0 / (n_f + 1.0)
281        } else {
282            true
283        };
284        pc *= 1.0 - cc;
285        if hsig {
286            pc += y_w.clone() * (cc * (2.0 - cc) * mueff).sqrt();
287        }
288
289        let mut rank_mu = DMatrix::<f64>::zeros(n, n);
290        for i in 0..mu.min(candidates.len()) {
291            let y_i = (&candidates[i].y - &old_mean) / sigma.max(1e-30);
292            rank_mu += (&y_i * y_i.transpose()) * weights[i];
293        }
294
295        let hsig_correction = if hsig { 0.0 } else { c1 * cc * (2.0 - cc) };
296        covariance = covariance * (1.0 - c1 - cmu + hsig_correction)
297            + (&pc * pc.transpose()) * c1
298            + rank_mu * cmu;
299        symmetrise_and_regularise(&mut covariance);
300
301        sigma *= ((cs / damps) * (norm_ps / chi_n - 1.0)).exp();
302        sigma = sigma.clamp(1e-14, 10.0);
303
304        let eig = SymmetricEigen::new(covariance.clone());
305        b = eig.eigenvectors;
306        d = eig.eigenvalues.map(|v| v.max(1e-30).sqrt());
307        let inv_d = d.map(|v| 1.0 / v.max(1e-30));
308        invsqrt_c = &b * DMatrix::<f64>::from_diagonal(&inv_d) * b.transpose();
309
310        nit += 1;
311        if (last_improvement_fun - best_fun).abs() <= config.f_tol {
312            stagnation_counter += 1;
313        } else {
314            stagnation_counter = 0;
315            last_improvement_fun = best_fun;
316        }
317
318        if let Some(ref mut callback) = config.callback {
319            let intermediate = CmaEsIntermediate {
320                x: best_x.clone(),
321                fun: best_fun,
322                iter: nit,
323                nfev,
324                sigma,
325            };
326            if matches!(callback(&intermediate), CallbackAction::Stop) {
327                success = true;
328                message = String::from("stopped by callback");
329                break;
330            }
331        }
332
333        if best_fun <= config.target_f {
334            success = true;
335            message = format!("target_f reached: {:.6e}", best_fun);
336            break;
337        }
338        if config.stagnation_window > 0 && stagnation_counter >= config.stagnation_window {
339            success = true;
340            message = format!(
341                "stagnated for {} generations below f_tol={:.3e}",
342                config.stagnation_window, config.f_tol
343            );
344            break;
345        }
346        if sigma < 1e-12 {
347            success = true;
348            message = String::from("step size collapsed");
349            break;
350        }
351    }
352
353    Ok(CmaEsReport {
354        x: best_x,
355        fun: best_fun,
356        success,
357        message,
358        nfev,
359        nit,
360        sigma,
361    })
362}
363
364fn recombination_weights(mu: usize) -> Vec<f64> {
365    let mu_f = mu as f64;
366    let mut weights: Vec<f64> = (1..=mu)
367        .map(|i| (mu_f + 0.5).ln() - (i as f64).ln())
368        .collect();
369    let sum = weights.iter().sum::<f64>();
370    for w in &mut weights {
371        *w /= sum;
372    }
373    weights
374}
375
376fn initial_mean(config: &CmaEsConfig) -> DVector<f64> {
377    if let Some(ref x0) = config.x0 {
378        let mut y = DVector::<f64>::zeros(config.bounds.len());
379        for (i, (lo, hi)) in config.bounds.iter().enumerate() {
380            let span = hi - lo;
381            y[i] = if span > 0.0 {
382                ((x0[i].clamp(*lo, *hi) - lo) / span).clamp(0.0, 1.0)
383            } else {
384                0.5
385            };
386        }
387        y
388    } else {
389        DVector::<f64>::from_element(config.bounds.len(), 0.5)
390    }
391}
392
393fn denormalise(y: &DVector<f64>, bounds: &[(f64, f64)]) -> Array1<f64> {
394    let mut x = Vec::with_capacity(bounds.len());
395    for (i, (lo, hi)) in bounds.iter().enumerate() {
396        x.push(lo + y[i].clamp(0.0, 1.0) * (hi - lo));
397    }
398    Array1::from(x)
399}
400
401fn clamp_unit_vector(y: &DVector<f64>) -> DVector<f64> {
402    y.map(|v| v.clamp(0.0, 1.0))
403}
404
405fn standard_normal_vector<R: Rng + ?Sized>(n: usize, rng: &mut R) -> DVector<f64> {
406    let mut out = DVector::<f64>::zeros(n);
407    let mut i = 0usize;
408    while i < n {
409        let u1 = rng.random::<f64>().max(f64::MIN_POSITIVE);
410        let u2 = rng.random::<f64>();
411        let radius = (-2.0 * u1.ln()).sqrt();
412        let theta = 2.0 * std::f64::consts::PI * u2;
413        out[i] = radius * theta.cos();
414        if i + 1 < n {
415            out[i + 1] = radius * theta.sin();
416        }
417        i += 2;
418    }
419    out
420}
421
422fn finite_or_infinity(v: f64) -> f64 {
423    if v.is_finite() { v } else { f64::INFINITY }
424}
425
426fn symmetrise_and_regularise(c: &mut DMatrix<f64>) {
427    let n = c.nrows();
428    for i in 0..n {
429        for j in 0..i {
430            let v = 0.5 * (c[(i, j)] + c[(j, i)]);
431            c[(i, j)] = v;
432            c[(j, i)] = v;
433        }
434        c[(i, i)] = c[(i, i)].max(1e-30);
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441
442    #[test]
443    fn cma_es_converges_on_sphere() {
444        let sphere = |x: &Array1<f64>| x.iter().map(|&xi| xi * xi).sum::<f64>();
445        let report = cma_es(
446            &sphere,
447            CmaEsConfig {
448                bounds: vec![(-5.0, 5.0); 4],
449                maxeval: 5_000,
450                seed: Some(42),
451                target_f: 1e-10,
452                ..Default::default()
453            },
454        )
455        .expect("CMA-ES should run");
456
457        assert!(
458            report.fun < 1e-6,
459            "CMA-ES should converge near origin, got {}",
460            report.fun
461        );
462    }
463
464    #[test]
465    fn cma_es_handles_coupled_rotated_quadratic() {
466        let rotated = |x: &Array1<f64>| {
467            let u = (x[0] + x[1]) / 2.0_f64.sqrt();
468            let v = (x[0] - x[1]) / 2.0_f64.sqrt();
469            1_000.0 * u * u + v * v
470        };
471        let report = cma_es(
472            &rotated,
473            CmaEsConfig {
474                bounds: vec![(-3.0, 3.0), (-3.0, 3.0)],
475                maxeval: 4_000,
476                seed: Some(7),
477                target_f: 1e-9,
478                ..Default::default()
479            },
480        )
481        .expect("CMA-ES should run");
482
483        assert!(
484            report.fun < 1e-5,
485            "CMA-ES should solve rotated ill-conditioned quadratic, got {}",
486            report.fun
487        );
488    }
489}