1use nalgebra::{DMatrix, DVector, SymmetricEigen};
9use ndarray::Array1;
10use rand::rngs::StdRng;
11use rand::{Rng, SeedableRng};
12use rayon::prelude::*;
13
14use crate::CallbackAction;
15use crate::error::{DEError, Result};
16use crate::parallel_eval::ParallelConfig;
17
18pub struct CmaEsIntermediate {
20 pub x: Array1<f64>,
22 pub fun: f64,
24 pub iter: usize,
26 pub nfev: usize,
28 pub sigma: f64,
30}
31
32pub type CmaEsCallback = Box<dyn FnMut(&CmaEsIntermediate) -> CallbackAction + Send>;
34
35pub struct CmaEsConfig {
37 pub bounds: Vec<(f64, f64)>,
39 pub x0: Option<Array1<f64>>,
41 pub sigma0: Option<f64>,
46 pub lambda: usize,
48 pub mu: usize,
50 pub maxeval: usize,
52 pub seed: Option<u64>,
54 pub stagnation_window: usize,
56 pub f_tol: f64,
58 pub target_f: f64,
60 pub callback: Option<CmaEsCallback>,
63 pub parallel: ParallelConfig,
65}
66
67impl Default for CmaEsConfig {
68 fn default() -> Self {
69 Self {
70 bounds: Vec::new(),
71 x0: None,
72 sigma0: None,
73 lambda: 0,
74 mu: 0,
75 maxeval: 10_000,
76 seed: None,
77 stagnation_window: 80,
78 f_tol: 1e-10,
79 target_f: f64::NEG_INFINITY,
80 callback: None,
81 parallel: ParallelConfig::default(),
82 }
83 }
84}
85
86#[derive(Clone)]
88pub struct CmaEsReport {
89 pub x: Array1<f64>,
91 pub fun: f64,
93 pub success: bool,
96 pub message: String,
98 pub nfev: usize,
100 pub nit: usize,
102 pub sigma: f64,
104}
105
106impl std::fmt::Debug for CmaEsReport {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 f.debug_struct("CmaEsReport")
109 .field("x_len", &self.x.len())
110 .field("fun", &self.fun)
111 .field("success", &self.success)
112 .field("message", &self.message)
113 .field("nfev", &self.nfev)
114 .field("nit", &self.nit)
115 .field("sigma", &self.sigma)
116 .finish()
117 }
118}
119
120#[derive(Clone)]
121struct Candidate {
122 y: DVector<f64>,
123 fun: f64,
124}
125
126struct Sample {
127 y: DVector<f64>,
128 x: Array1<f64>,
129}
130
131pub fn cma_es<F>(f: &F, mut config: CmaEsConfig) -> Result<CmaEsReport>
136where
137 F: Fn(&Array1<f64>) -> f64 + Sync,
138{
139 let n = config.bounds.len();
140 if n == 0 {
141 return Err(DEError::BoundsMismatch {
142 lower_len: 0,
143 upper_len: 0,
144 });
145 }
146 for (i, (lo, hi)) in config.bounds.iter().enumerate() {
147 if lo > hi {
148 return Err(DEError::InvalidBounds {
149 index: i,
150 lower: *lo,
151 upper: *hi,
152 });
153 }
154 }
155 if let Some(ref x0) = config.x0
156 && x0.len() != n
157 {
158 return Err(DEError::X0DimensionMismatch {
159 expected: n,
160 got: x0.len(),
161 });
162 }
163
164 let lambda = if config.lambda == 0 {
165 (4.0 + (3.0 * (n as f64).ln()).floor()).max(4.0) as usize
166 } else {
167 config.lambda
168 };
169 if lambda < 2 {
170 return Err(DEError::PopulationTooSmall { pop_size: lambda });
171 }
172 let mu = if config.mu == 0 {
173 lambda / 2
174 } else {
175 config.mu.min(lambda)
176 }
177 .max(1);
178
179 let weights = recombination_weights(mu);
180 let mueff = 1.0 / weights.iter().map(|w| w * w).sum::<f64>();
181 let n_f = n as f64;
182
183 let cc = (4.0 + mueff / n_f) / (n_f + 4.0 + 2.0 * mueff / n_f);
184 let cs = (mueff + 2.0) / (n_f + mueff + 5.0);
185 let c1 = 2.0 / ((n_f + 1.3).powi(2) + mueff);
186 let cmu = (1.0 - c1).min(2.0 * (mueff - 2.0 + 1.0 / mueff) / ((n_f + 2.0).powi(2) + mueff));
187 let damps = 1.0 + 2.0 * ((mueff - 1.0) / (n_f + 1.0)).sqrt().max(1.0) - 2.0 + cs;
188 let chi_n = n_f.sqrt() * (1.0 - 1.0 / (4.0 * n_f) + 1.0 / (21.0 * n_f * n_f));
189
190 let mut mean = initial_mean(&config);
191 let mut sigma = config.sigma0.unwrap_or(0.3).clamp(1e-12, 2.0);
192 let mut covariance = DMatrix::<f64>::identity(n, n);
193 let mut b = DMatrix::<f64>::identity(n, n);
194 let mut d = DVector::<f64>::from_element(n, 1.0);
195 let mut invsqrt_c = DMatrix::<f64>::identity(n, n);
196 let mut pc = DVector::<f64>::zeros(n);
197 let mut ps = DVector::<f64>::zeros(n);
198
199 let mut rng: StdRng = match config.seed {
200 Some(s) => StdRng::seed_from_u64(s),
201 None => {
202 let mut thread_rng = rand::rng();
203 StdRng::from_rng(&mut thread_rng)
204 }
205 };
206
207 let initial_x = denormalise(&mean, &config.bounds);
208 let initial_fun = finite_or_infinity(f(&initial_x));
209 let mut best_x = initial_x;
210 let mut best_fun = initial_fun;
211 let mut nfev = 1usize;
212 let mut nit = 0usize;
213 let mut last_improvement_fun = best_fun;
214 let mut stagnation_counter = 0usize;
215 let mut message = String::from("maximum evaluations reached");
216 let mut success = false;
217
218 if let Some(n) = config.parallel.num_threads {
219 let _ = rayon::ThreadPoolBuilder::new()
220 .num_threads(n)
221 .build_global();
222 }
223
224 while nfev < config.maxeval {
225 let old_mean = mean.clone();
226 let transform = &b * DMatrix::<f64>::from_diagonal(&d);
227 let eval_budget = (config.maxeval - nfev).min(lambda);
228 let mut samples: Vec<Sample> = Vec::with_capacity(eval_budget);
229
230 for _ in 0..eval_budget {
231 let z = standard_normal_vector(n, &mut rng);
232 let step = &transform * z;
233 let y = clamp_unit_vector(&(old_mean.clone() + step * sigma));
234 let x = denormalise(&y, &config.bounds);
235 samples.push(Sample { y, x });
236 }
237
238 let mut candidates: Vec<Candidate> = if config.parallel.enabled && samples.len() >= 4 {
239 samples
240 .par_iter()
241 .map(|sample| Candidate {
242 y: sample.y.clone(),
243 fun: finite_or_infinity(f(&sample.x)),
244 })
245 .collect()
246 } else {
247 samples
248 .iter()
249 .map(|sample| Candidate {
250 y: sample.y.clone(),
251 fun: finite_or_infinity(f(&sample.x)),
252 })
253 .collect()
254 };
255 nfev += candidates.len();
256
257 for (sample, candidate) in samples.iter().zip(candidates.iter()) {
258 if candidate.fun < best_fun {
259 best_fun = candidate.fun;
260 best_x = sample.x.clone();
261 }
262 }
263
264 if candidates.is_empty() {
265 break;
266 }
267 candidates.sort_by(|a, b| a.fun.total_cmp(&b.fun));
268
269 mean.fill(0.0);
270 for i in 0..mu.min(candidates.len()) {
271 mean += candidates[i].y.clone() * weights[i];
272 }
273 mean = clamp_unit_vector(&mean);
274
275 let y_w = (&mean - &old_mean) / sigma.max(1e-30);
276 ps = ps * (1.0 - cs) + (&invsqrt_c * &y_w) * (cs * (2.0 - cs) * mueff).sqrt();
277 let norm_ps = ps.norm();
278 let hsig_den = (1.0 - (1.0 - cs).powi(2 * (nit as i32 + 1))).sqrt() * chi_n;
279 let hsig = if hsig_den > 0.0 {
280 norm_ps / hsig_den < 1.4 + 2.0 / (n_f + 1.0)
281 } else {
282 true
283 };
284 pc *= 1.0 - cc;
285 if hsig {
286 pc += y_w.clone() * (cc * (2.0 - cc) * mueff).sqrt();
287 }
288
289 let mut rank_mu = DMatrix::<f64>::zeros(n, n);
290 for i in 0..mu.min(candidates.len()) {
291 let y_i = (&candidates[i].y - &old_mean) / sigma.max(1e-30);
292 rank_mu += (&y_i * y_i.transpose()) * weights[i];
293 }
294
295 let hsig_correction = if hsig { 0.0 } else { c1 * cc * (2.0 - cc) };
296 covariance = covariance * (1.0 - c1 - cmu + hsig_correction)
297 + (&pc * pc.transpose()) * c1
298 + rank_mu * cmu;
299 symmetrise_and_regularise(&mut covariance);
300
301 sigma *= ((cs / damps) * (norm_ps / chi_n - 1.0)).exp();
302 sigma = sigma.clamp(1e-14, 10.0);
303
304 let eig = SymmetricEigen::new(covariance.clone());
305 b = eig.eigenvectors;
306 d = eig.eigenvalues.map(|v| v.max(1e-30).sqrt());
307 let inv_d = d.map(|v| 1.0 / v.max(1e-30));
308 invsqrt_c = &b * DMatrix::<f64>::from_diagonal(&inv_d) * b.transpose();
309
310 nit += 1;
311 if (last_improvement_fun - best_fun).abs() <= config.f_tol {
312 stagnation_counter += 1;
313 } else {
314 stagnation_counter = 0;
315 last_improvement_fun = best_fun;
316 }
317
318 if let Some(ref mut callback) = config.callback {
319 let intermediate = CmaEsIntermediate {
320 x: best_x.clone(),
321 fun: best_fun,
322 iter: nit,
323 nfev,
324 sigma,
325 };
326 if matches!(callback(&intermediate), CallbackAction::Stop) {
327 success = true;
328 message = String::from("stopped by callback");
329 break;
330 }
331 }
332
333 if best_fun <= config.target_f {
334 success = true;
335 message = format!("target_f reached: {:.6e}", best_fun);
336 break;
337 }
338 if config.stagnation_window > 0 && stagnation_counter >= config.stagnation_window {
339 success = true;
340 message = format!(
341 "stagnated for {} generations below f_tol={:.3e}",
342 config.stagnation_window, config.f_tol
343 );
344 break;
345 }
346 if sigma < 1e-12 {
347 success = true;
348 message = String::from("step size collapsed");
349 break;
350 }
351 }
352
353 Ok(CmaEsReport {
354 x: best_x,
355 fun: best_fun,
356 success,
357 message,
358 nfev,
359 nit,
360 sigma,
361 })
362}
363
364fn recombination_weights(mu: usize) -> Vec<f64> {
365 let mu_f = mu as f64;
366 let mut weights: Vec<f64> = (1..=mu)
367 .map(|i| (mu_f + 0.5).ln() - (i as f64).ln())
368 .collect();
369 let sum = weights.iter().sum::<f64>();
370 for w in &mut weights {
371 *w /= sum;
372 }
373 weights
374}
375
376fn initial_mean(config: &CmaEsConfig) -> DVector<f64> {
377 if let Some(ref x0) = config.x0 {
378 let mut y = DVector::<f64>::zeros(config.bounds.len());
379 for (i, (lo, hi)) in config.bounds.iter().enumerate() {
380 let span = hi - lo;
381 y[i] = if span > 0.0 {
382 ((x0[i].clamp(*lo, *hi) - lo) / span).clamp(0.0, 1.0)
383 } else {
384 0.5
385 };
386 }
387 y
388 } else {
389 DVector::<f64>::from_element(config.bounds.len(), 0.5)
390 }
391}
392
393fn denormalise(y: &DVector<f64>, bounds: &[(f64, f64)]) -> Array1<f64> {
394 let mut x = Vec::with_capacity(bounds.len());
395 for (i, (lo, hi)) in bounds.iter().enumerate() {
396 x.push(lo + y[i].clamp(0.0, 1.0) * (hi - lo));
397 }
398 Array1::from(x)
399}
400
401fn clamp_unit_vector(y: &DVector<f64>) -> DVector<f64> {
402 y.map(|v| v.clamp(0.0, 1.0))
403}
404
405fn standard_normal_vector<R: Rng + ?Sized>(n: usize, rng: &mut R) -> DVector<f64> {
406 let mut out = DVector::<f64>::zeros(n);
407 let mut i = 0usize;
408 while i < n {
409 let u1 = rng.random::<f64>().max(f64::MIN_POSITIVE);
410 let u2 = rng.random::<f64>();
411 let radius = (-2.0 * u1.ln()).sqrt();
412 let theta = 2.0 * std::f64::consts::PI * u2;
413 out[i] = radius * theta.cos();
414 if i + 1 < n {
415 out[i + 1] = radius * theta.sin();
416 }
417 i += 2;
418 }
419 out
420}
421
422fn finite_or_infinity(v: f64) -> f64 {
423 if v.is_finite() { v } else { f64::INFINITY }
424}
425
426fn symmetrise_and_regularise(c: &mut DMatrix<f64>) {
427 let n = c.nrows();
428 for i in 0..n {
429 for j in 0..i {
430 let v = 0.5 * (c[(i, j)] + c[(j, i)]);
431 c[(i, j)] = v;
432 c[(j, i)] = v;
433 }
434 c[(i, i)] = c[(i, i)].max(1e-30);
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441
442 #[test]
443 fn cma_es_converges_on_sphere() {
444 let sphere = |x: &Array1<f64>| x.iter().map(|&xi| xi * xi).sum::<f64>();
445 let report = cma_es(
446 &sphere,
447 CmaEsConfig {
448 bounds: vec![(-5.0, 5.0); 4],
449 maxeval: 5_000,
450 seed: Some(42),
451 target_f: 1e-10,
452 ..Default::default()
453 },
454 )
455 .expect("CMA-ES should run");
456
457 assert!(
458 report.fun < 1e-6,
459 "CMA-ES should converge near origin, got {}",
460 report.fun
461 );
462 }
463
464 #[test]
465 fn cma_es_handles_coupled_rotated_quadratic() {
466 let rotated = |x: &Array1<f64>| {
467 let u = (x[0] + x[1]) / 2.0_f64.sqrt();
468 let v = (x[0] - x[1]) / 2.0_f64.sqrt();
469 1_000.0 * u * u + v * v
470 };
471 let report = cma_es(
472 &rotated,
473 CmaEsConfig {
474 bounds: vec![(-3.0, 3.0), (-3.0, 3.0)],
475 maxeval: 4_000,
476 seed: Some(7),
477 target_f: 1e-9,
478 ..Default::default()
479 },
480 )
481 .expect("CMA-ES should run");
482
483 assert!(
484 report.fun < 1e-5,
485 "CMA-ES should solve rotated ill-conditioned quadratic, got {}",
486 report.fun
487 );
488 }
489}