use crate::core::math::{
ComponentMulAssign, MatTransposeVec, MatVec, MatrixFromDiagonal, MatrixIdentity, NormSquared,
RankOneUpdate, SampleStandardNormal, ScaleInPlace, ScaledAdd, SymmetricEigen, VectorLen,
};
use crate::core::problem::{CostFunction, Problem};
use crate::core::rng::{ChaCha8Rng, SeedableRng};
use crate::core::solver::Solver;
use crate::core::state::BasicPopulationState;
use crate::core::termination::TerminationReason;
pub struct CmaEs<V, M> {
initial_mean: V,
initial_sigma: f64,
lambda_override: Option<usize>,
seed: u64,
tol_x_override: Option<f64>,
stds_override: Option<V>,
state: Option<Working<V, M>>,
}
pub(crate) struct Working<V, M> {
pub(crate) n: usize,
lambda: usize,
mu: usize,
weights: Vec<f64>,
mu_eff: f64,
sum_w: f64,
c_sigma: f64,
d_sigma: f64,
c_c: f64,
c_1: f64,
c_mu: f64,
expected_norm: f64,
h_sigma_threshold: f64,
tol_x: f64,
pub(crate) m: V,
pub(crate) sigma: f64,
p_sigma: V,
p_c: V,
c: M,
pub(crate) b: M,
d: V,
pub(crate) d_inv: V,
rng: ChaCha8Rng,
generation: u64,
}
impl<V, M> CmaEs<V, M> {
pub fn new(initial_mean: V, initial_sigma: f64, seed: u64) -> Self {
assert!(
initial_sigma > 0.0,
"CmaEs requires initial_sigma > 0, got {}",
initial_sigma
);
Self {
initial_mean,
initial_sigma,
lambda_override: None,
seed,
tol_x_override: None,
stds_override: None,
state: None,
}
}
pub fn with_lambda(mut self, lambda: usize) -> Self {
assert!(
lambda >= 4,
"CmaEs requires lambda >= 4, got {} (Hansen 2016 footnote 30: \
smaller populations have strong adverse effects on performance)",
lambda
);
self.lambda_override = Some(lambda);
self
}
pub fn with_tol_x(mut self, tol_x: f64) -> Self {
self.tol_x_override = Some(tol_x);
self
}
pub fn default_lambda(n: usize) -> usize {
4 + (3.0 * (n as f64).ln()).floor() as usize
}
pub(crate) fn working(&self) -> Option<&Working<V, M>> {
self.state.as_ref()
}
}
impl<V, M> CmaEs<V, M>
where
V: VectorLen + std::ops::Index<usize, Output = f64>,
{
pub fn with_stds(mut self, stds: V) -> Self {
let n = self.initial_mean.vec_len();
assert_eq!(
stds.vec_len(),
n,
"CmaEs::with_stds requires stds.len() == initial_mean.len(), got {} vs {}",
stds.vec_len(),
n
);
for i in 0..n {
assert!(
stds[i] > 0.0,
"CmaEs::with_stds requires every std > 0, got stds[{}] = {}",
i,
stds[i]
);
}
self.stds_override = Some(stds);
self
}
}
pub(crate) fn expected_norm_n01(n: usize) -> f64 {
let n = n as f64;
n.sqrt() * (1.0 - 1.0 / (4.0 * n) + 1.0 / (21.0 * n * n))
}
pub(crate) fn compute_weights(
n: usize,
lambda: usize,
c_1: f64,
c_mu: f64,
) -> (Vec<f64>, f64, f64) {
let mu = lambda / 2;
let raw: Vec<f64> = (1..=lambda)
.map(|i| ((lambda as f64 + 1.0) / 2.0).ln() - (i as f64).ln())
.collect();
let sum_pos: f64 = raw[..mu].iter().sum();
let raw_pos_norm_sq: f64 = raw[..mu].iter().map(|w| w * w).sum();
let mu_eff = sum_pos.powi(2) / raw_pos_norm_sq;
let sum_neg: f64 = raw[mu..].iter().sum();
let raw_neg_norm_sq: f64 = raw[mu..].iter().map(|w| w * w).sum();
let mu_eff_neg = if raw_neg_norm_sq > 0.0 {
sum_neg.powi(2) / raw_neg_norm_sq
} else {
0.0
};
let alpha_mu_minus = 1.0 + c_1 / c_mu;
let alpha_mu_eff_minus = 1.0 + 2.0 * mu_eff_neg / (mu_eff + 2.0);
let alpha_pos_def_minus = (1.0 - c_1 - c_mu) / (n as f64 * c_mu);
let alpha_neg = alpha_mu_minus
.min(alpha_mu_eff_minus)
.min(alpha_pos_def_minus);
let sum_abs_neg: f64 = raw[mu..].iter().map(|w| -w).sum();
let mut weights = Vec::with_capacity(lambda);
for (i, &raw_i) in raw.iter().enumerate() {
let w = if i < mu {
raw_i / sum_pos
} else if sum_abs_neg > 0.0 {
alpha_neg * raw_i / sum_abs_neg
} else {
0.0
};
weights.push(w);
}
let sum_w: f64 = weights.iter().sum();
(weights, mu_eff, sum_w)
}
impl<V, M> CmaEs<V, M>
where
V: VectorLen + Clone + ComponentMulAssign + std::ops::Index<usize, Output = f64>,
M: MatrixIdentity + MatrixFromDiagonal<V>,
{
fn build_working(&self) -> Working<V, M> {
let n = self.initial_mean.vec_len();
assert!(n >= 1, "CmaEs requires the initial mean to be non-empty");
let lambda = self
.lambda_override
.unwrap_or_else(|| Self::default_lambda(n));
let mu = lambda / 2;
let alpha_cov = 2.0;
let raw: Vec<f64> = (1..=lambda)
.map(|i| ((lambda as f64 + 1.0) / 2.0).ln() - (i as f64).ln())
.collect();
let sum_pos: f64 = raw[..mu].iter().sum();
let mu_eff_provisional = sum_pos.powi(2) / raw[..mu].iter().map(|w| w * w).sum::<f64>();
let c_1 = alpha_cov / ((n as f64 + 1.3).powi(2) + mu_eff_provisional);
let c_mu_unbounded = alpha_cov * (mu_eff_provisional - 2.0 + 1.0 / mu_eff_provisional)
/ ((n as f64 + 2.0).powi(2) + alpha_cov * mu_eff_provisional / 2.0);
let c_mu = (1.0 - c_1).min(c_mu_unbounded);
let (weights, mu_eff, sum_w) = compute_weights(n, lambda, c_1, c_mu);
let c_sigma = (mu_eff + 2.0) / (n as f64 + mu_eff + 5.0);
let d_sigma = {
let inner = ((mu_eff - 1.0) / (n as f64 + 1.0)).sqrt() - 1.0;
1.0 + 2.0 * inner.max(0.0) + c_sigma
};
let c_c = (4.0 + mu_eff / n as f64) / (n as f64 + 4.0 + 2.0 * mu_eff / n as f64);
let expected_norm = expected_norm_n01(n);
let h_sigma_threshold = (1.4 + 2.0 / (n as f64 + 1.0)) * expected_norm;
let max_std = self
.stds_override
.as_ref()
.map(|s| (0..n).map(|i| s[i]).fold(0.0_f64, f64::max))
.unwrap_or(1.0);
let tol_x = self
.tol_x_override
.unwrap_or(1e-12 * self.initial_sigma * max_std);
let c = match self.stds_override.as_ref() {
Some(stds) => {
let mut sq = stds.clone();
sq.component_mul_assign(stds);
M::from_diagonal(&sq)
}
None => M::identity(n),
};
Working {
n,
lambda,
mu,
weights,
mu_eff,
sum_w,
c_sigma,
d_sigma,
c_c,
c_1,
c_mu,
expected_norm,
h_sigma_threshold,
tol_x,
m: self.initial_mean.clone(),
sigma: self.initial_sigma,
p_sigma: self.initial_mean.clone(),
p_c: self.initial_mean.clone(),
c,
b: M::identity(n),
d: self.initial_mean.clone(),
d_inv: self.initial_mean.clone(),
rng: ChaCha8Rng::seed_from_u64(self.seed),
generation: 0,
}
}
}
pub(crate) fn sort_population_ascending<V>(candidates: &mut [V], costs: &mut [f64]) {
let n = candidates.len();
debug_assert_eq!(n, costs.len());
let mut idx: Vec<usize> = (0..n).collect();
idx.sort_by(|&i, &j| {
costs[i]
.partial_cmp(&costs[j])
.unwrap_or(std::cmp::Ordering::Equal)
});
apply_permutation(candidates, &idx);
apply_permutation(costs, &idx);
}
fn apply_permutation<T>(slice: &mut [T], idx: &[usize]) {
let mut visited = vec![false; slice.len()];
for start in 0..slice.len() {
if visited[start] || idx[start] == start {
visited[start] = true;
continue;
}
let mut current = start;
loop {
let next = idx[current];
visited[current] = true;
if next == start {
break;
}
slice.swap(current, next);
current = next;
}
}
}
impl<P, V, M> Solver<P, BasicPopulationState<V>> for CmaEs<V, M>
where
P: CostFunction<Param = V, Output = f64>,
V: VectorLen
+ Clone
+ ScaledAdd<f64>
+ ScaleInPlace
+ ComponentMulAssign
+ NormSquared
+ SampleStandardNormal
+ std::ops::Index<usize, Output = f64>
+ std::ops::IndexMut<usize, Output = f64>,
M: MatrixIdentity
+ MatrixFromDiagonal<V>
+ MatVec<V>
+ MatTransposeVec<V>
+ ScaleInPlace
+ RankOneUpdate<V>
+ SymmetricEigen<V>
+ Clone,
{
type Error = P::Error;
fn init(
&mut self,
problem: &mut Problem<P>,
mut state: BasicPopulationState<V>,
) -> Result<BasicPopulationState<V>, Self::Error> {
if self.state.is_some() {
return Ok(state);
}
let mut w = self.build_working();
w.p_sigma.scale_in_place(0.0);
w.p_c.scale_in_place(0.0);
if let Some(stds) = self.stds_override.as_ref() {
for i in 0..w.n {
w.d[i] = stds[i];
w.d_inv[i] = 1.0 / stds[i];
}
} else {
for i in 0..w.n {
w.d[i] = 1.0;
w.d_inv[i] = 1.0;
}
}
let anisotropic = self.stds_override.is_some();
state.candidates.clear();
state.costs.clear();
for _k in 0..w.lambda {
let z_k = V::sample_standard_normal(&w.m, &mut w.rng);
let mut x_k = w.m.clone();
if anisotropic {
let mut bd_z = z_k;
bd_z.component_mul_assign(&w.d);
let bd_z = w.b.matvec(&bd_z);
x_k.scaled_add(w.sigma, &bd_z);
} else {
x_k.scaled_add(w.sigma, &z_k);
}
let cost = problem.cost(&x_k)?;
state.candidates.push(x_k);
state.costs.push(cost);
}
sort_population_ascending(&mut state.candidates, &mut state.costs);
self.state = Some(w);
Ok(state)
}
fn next_iter(
&mut self,
problem: &mut Problem<P>,
mut state: BasicPopulationState<V>,
) -> Result<(BasicPopulationState<V>, Option<TerminationReason>), Self::Error> {
let w = self
.state
.as_mut()
.expect("CmaEs::init must run before next_iter");
w.generation += 1;
let mut y_sorted: Vec<V> = state
.candidates
.iter()
.map(|x| {
let mut y = x.clone();
y.scaled_add(-1.0, &w.m);
y.scale_in_place(1.0 / w.sigma);
y
})
.collect();
let mut y_w = w.m.clone();
y_w.scale_in_place(0.0);
for (i, y_i) in y_sorted.iter().enumerate().take(w.mu) {
y_w.scaled_add(w.weights[i], y_i);
}
w.m.scaled_add(w.sigma, &y_w);
let mut bt_y_w = w.b.mat_transpose_vec(&y_w);
bt_y_w.component_mul_assign(&w.d_inv);
let c_invsqrt_y_w = w.b.matvec(&bt_y_w);
w.p_sigma.scale_in_place(1.0 - w.c_sigma);
let coef_sigma = (w.c_sigma * (2.0 - w.c_sigma) * w.mu_eff).sqrt();
w.p_sigma.scaled_add(coef_sigma, &c_invsqrt_y_w);
let p_sigma_norm = w.p_sigma.norm_squared().sqrt();
let log_factor = (w.c_sigma / w.d_sigma) * (p_sigma_norm / w.expected_norm - 1.0);
w.sigma *= log_factor.exp();
let g_for_h = (w.generation + 1) as i32;
let exponent = 2 * g_for_h;
let denom = (1.0 - (1.0 - w.c_sigma).powi(exponent)).sqrt();
let h_sigma = if p_sigma_norm / denom < w.h_sigma_threshold {
1.0
} else {
0.0
};
w.p_c.scale_in_place(1.0 - w.c_c);
let coef_c = h_sigma * (w.c_c * (2.0 - w.c_c) * w.mu_eff).sqrt();
w.p_c.scaled_add(coef_c, &y_w);
let delta_h = (1.0 - h_sigma) * w.c_c * (2.0 - w.c_c);
let c_scale = 1.0 + w.c_1 * delta_h - w.c_1 - w.c_mu * w.sum_w;
w.c.scale_in_place(c_scale);
w.c.rank_one_update(w.c_1, &w.p_c);
for (i, y_i) in y_sorted.iter().enumerate() {
let w_i = w.weights[i];
let w_i_o = if w_i >= 0.0 {
w_i
} else {
let mut bt_y = w.b.mat_transpose_vec(y_i);
bt_y.component_mul_assign(&w.d_inv);
let cinv_norm_sq = bt_y.norm_squared();
if cinv_norm_sq > 0.0 {
w_i * (w.n as f64) / cinv_norm_sq
} else {
0.0
}
};
if w_i_o != 0.0 {
w.c.rank_one_update(w.c_mu * w_i_o, y_i);
}
}
drop(std::mem::take(&mut y_sorted));
let (b_new, eigs) = match w.c.try_eigh() {
Ok(pair) => pair,
Err(_) => return Ok((state, Some(TerminationReason::SolverFailed))),
};
w.b = b_new;
for i in 0..w.n {
let lam = eigs[i].max(1e-30);
let s = lam.sqrt();
w.d[i] = s;
w.d_inv[i] = 1.0 / s;
}
state.candidates.clear();
state.costs.clear();
for _k in 0..w.lambda {
let z_k = V::sample_standard_normal(&w.m, &mut w.rng);
let mut bd_z = z_k;
bd_z.component_mul_assign(&w.d);
let bd_z = w.b.matvec(&bd_z);
let mut x_k = w.m.clone();
x_k.scaled_add(w.sigma, &bd_z);
let cost = problem.cost(&x_k)?;
state.candidates.push(x_k);
state.costs.push(cost);
}
sort_population_ascending(&mut state.candidates, &mut state.costs);
Ok((state, None))
}
fn terminate(&self, _state: &BasicPopulationState<V>) -> Option<TerminationReason> {
let w = self.state.as_ref()?;
let max_d = w.d_iter_max();
if w.sigma * max_d < w.tol_x {
return Some(TerminationReason::SolverConverged);
}
None
}
}
impl<V, M> Working<V, M>
where
V: std::ops::Index<usize, Output = f64> + VectorLen,
{
fn d_iter_max(&self) -> f64 {
let mut m = 0.0_f64;
for i in 0..self.n {
let v = self.d[i];
if v > m {
m = v;
}
}
m
}
}