use std::marker::PhantomData;
use crate::core::math::{
ComponentMulAssign, MatTransposeVec, MatVec, MatrixFromDiagonal, MatrixIdentity, NormSquared,
RankOneUpdate, SampleStandardNormal, Scalar, ScaleInPlace, ScaledAdd, SymmetricEigen,
VectorLen,
};
use crate::core::problem::{CostFunction, Problem};
use crate::core::rng::{ChaCha8Rng, SeedableRng};
use crate::core::solver::Solver;
use crate::core::state::CmaEsState;
use crate::core::termination::TerminationReason;
pub struct CmaEs<V, M, F = f64> {
lambda_override: Option<usize>,
constants: Option<CmaConstants<F>>,
rng: ChaCha8Rng,
_marker: PhantomData<(V, M)>,
}
pub(crate) struct CmaConstants<F = f64> {
pub(crate) n: usize,
pub(crate) lambda: usize,
pub(crate) mu: usize,
pub(crate) weights: Vec<F>,
pub(crate) mu_eff: F,
pub(crate) sum_w: F,
pub(crate) c_sigma: F,
pub(crate) d_sigma: F,
pub(crate) c_c: F,
pub(crate) c_1: F,
pub(crate) c_mu: F,
pub(crate) expected_norm: F,
pub(crate) h_sigma_threshold: F,
}
impl<V, M, F: Scalar> CmaEs<V, M, F> {
pub fn new(seed: u64) -> Self {
Self {
lambda_override: None,
constants: None,
rng: ChaCha8Rng::seed_from_u64(seed),
_marker: PhantomData,
}
}
pub fn with_lambda(mut self, lambda: usize) -> Self {
assert!(
lambda >= 4,
"CmaEs requires lambda >= 4, got {} (Hansen 2016 footnote 30: \
smaller populations have strong adverse effects on performance)",
lambda
);
self.lambda_override = Some(lambda);
self
}
pub fn default_lambda(n: usize) -> usize {
4 + (3.0 * (n as f64).ln()).floor() as usize
}
}
pub(crate) fn expected_norm_n01<F: Scalar>(n: usize) -> F {
let n_f = F::from_usize(n).unwrap();
let one = F::one();
let four = F::from_f64(4.0).unwrap();
let twenty_one = F::from_f64(21.0).unwrap();
n_f.sqrt() * (one - one / (four * n_f) + one / (twenty_one * n_f * n_f))
}
pub(crate) fn compute_weights<F: Scalar>(
n: usize,
lambda: usize,
c_1: F,
c_mu: F,
) -> (Vec<F>, F, F) {
let mu = lambda / 2;
let one = F::one();
let two = F::from_f64(2.0).unwrap();
let zero = F::zero();
let lambda_f = F::from_usize(lambda).unwrap();
let raw: Vec<F> = (1..=lambda)
.map(|i| ((lambda_f + one) / two).ln() - F::from_usize(i).unwrap().ln())
.collect();
let sum_pos: F = raw[..mu].iter().copied().sum();
let raw_pos_norm_sq: F = raw[..mu].iter().map(|w| *w * *w).sum();
let mu_eff = sum_pos * sum_pos / raw_pos_norm_sq;
let sum_neg: F = raw[mu..].iter().copied().sum();
let raw_neg_norm_sq: F = raw[mu..].iter().map(|w| *w * *w).sum();
let mu_eff_neg = if raw_neg_norm_sq > zero {
sum_neg * sum_neg / raw_neg_norm_sq
} else {
zero
};
let alpha_mu_minus = one + c_1 / c_mu;
let alpha_mu_eff_minus = one + two * mu_eff_neg / (mu_eff + two);
let alpha_pos_def_minus = (one - c_1 - c_mu) / (F::from_usize(n).unwrap() * c_mu);
let alpha_neg = alpha_mu_minus
.min(alpha_mu_eff_minus)
.min(alpha_pos_def_minus);
let sum_abs_neg: F = raw[mu..].iter().map(|w| -*w).sum();
let mut weights = Vec::with_capacity(lambda);
for (i, &raw_i) in raw.iter().enumerate() {
let w = if i < mu {
raw_i / sum_pos
} else if sum_abs_neg > zero {
alpha_neg * raw_i / sum_abs_neg
} else {
zero
};
weights.push(w);
}
let sum_w: F = weights.iter().copied().sum();
(weights, mu_eff, sum_w)
}
pub(crate) fn compute_constants<F: Scalar>(n: usize, lambda: usize) -> CmaConstants<F> {
let mu = lambda / 2;
let one = F::one();
let two = F::from_f64(2.0).unwrap();
let zero = F::zero();
let n_f = F::from_usize(n).unwrap();
let lambda_f = F::from_usize(lambda).unwrap();
let alpha_cov = two;
let raw: Vec<F> = (1..=lambda)
.map(|i| ((lambda_f + one) / two).ln() - F::from_usize(i).unwrap().ln())
.collect();
let sum_pos: F = raw[..mu].iter().copied().sum();
let mu_eff_provisional = sum_pos * sum_pos / raw[..mu].iter().map(|w| *w * *w).sum::<F>();
let c_1 = alpha_cov
/ ((n_f + F::from_f64(1.3).unwrap()) * (n_f + F::from_f64(1.3).unwrap())
+ mu_eff_provisional);
let c_mu_unbounded = alpha_cov * (mu_eff_provisional - two + one / mu_eff_provisional)
/ ((n_f + two) * (n_f + two) + alpha_cov * mu_eff_provisional / two);
let c_mu = (one - c_1).min(c_mu_unbounded);
let (weights, mu_eff, sum_w) = compute_weights::<F>(n, lambda, c_1, c_mu);
let c_sigma = (mu_eff + two) / (n_f + mu_eff + F::from_f64(5.0).unwrap());
let d_sigma = {
let inner = ((mu_eff - one) / (n_f + one)).sqrt() - one;
one + two * inner.max(zero) + c_sigma
};
let c_c = (F::from_f64(4.0).unwrap() + mu_eff / n_f)
/ (n_f + F::from_f64(4.0).unwrap() + two * mu_eff / n_f);
let expected_norm = expected_norm_n01::<F>(n);
let h_sigma_threshold = (F::from_f64(1.4).unwrap() + two / (n_f + one)) * expected_norm;
CmaConstants {
n,
lambda,
mu,
weights,
mu_eff,
sum_w,
c_sigma,
d_sigma,
c_c,
c_1,
c_mu,
expected_norm,
h_sigma_threshold,
}
}
pub(crate) fn sample_generation<V, M, F>(
state: &mut CmaEsState<V, M, F>,
lambda: usize,
rng: &mut ChaCha8Rng,
) where
F: Scalar,
V: VectorLen + Clone + ScaledAdd<F> + ComponentMulAssign + SampleStandardNormal,
M: MatVec<V>,
{
state.candidates.clear();
for _ in 0..lambda {
let z_k = V::sample_standard_normal(&state.m, rng);
let mut bd_z = z_k;
bd_z.component_mul_assign(&state.d);
let bd_z = state.b.matvec(&bd_z);
let mut x_k = state.m.clone();
x_k.scaled_add(state.sigma, &bd_z);
state.candidates.push(x_k);
}
}
pub(crate) fn sort_population_ascending<V, F: PartialOrd>(candidates: &mut [V], costs: &mut [F]) {
let n = candidates.len();
debug_assert_eq!(n, costs.len());
let mut idx: Vec<usize> = (0..n).collect();
idx.sort_by(|&i, &j| {
costs[i]
.partial_cmp(&costs[j])
.unwrap_or(std::cmp::Ordering::Equal)
});
apply_permutation(candidates, &idx);
apply_permutation(costs, &idx);
}
fn apply_permutation<T>(slice: &mut [T], idx: &[usize]) {
let mut visited = vec![false; slice.len()];
for start in 0..slice.len() {
if visited[start] || idx[start] == start {
visited[start] = true;
continue;
}
let mut current = start;
loop {
let next = idx[current];
visited[current] = true;
if next == start {
break;
}
slice.swap(current, next);
current = next;
}
}
}
impl<P, V, M, F> Solver<P, CmaEsState<V, M, F>> for CmaEs<V, M, F>
where
F: Scalar + crate::core::parallel::MaybeSend,
P: CostFunction<Param = V, Output = F> + crate::core::parallel::MaybeSync,
P::Error: crate::core::parallel::MaybeSend,
V: VectorLen
+ Clone
+ ScaledAdd<F>
+ ScaleInPlace<F>
+ ComponentMulAssign
+ NormSquared<F>
+ SampleStandardNormal
+ crate::core::parallel::MaybeSync
+ std::ops::Index<usize, Output = F>
+ std::ops::IndexMut<usize, Output = F>,
M: MatrixIdentity
+ MatrixFromDiagonal<V>
+ MatVec<V>
+ MatTransposeVec<V>
+ ScaleInPlace<F>
+ RankOneUpdate<V, F>
+ SymmetricEigen<V>
+ Clone,
{
type Error = P::Error;
fn init(
&mut self,
problem: &mut Problem<P>,
mut state: CmaEsState<V, M, F>,
) -> Result<CmaEsState<V, M, F>, Self::Error> {
if self.constants.is_none() {
let n = state.m.vec_len();
assert!(n >= 1, "CmaEs requires a non-empty mean");
let lambda = self
.lambda_override
.unwrap_or_else(|| Self::default_lambda(n));
self.constants = Some(compute_constants::<F>(n, lambda));
}
let lambda = self.constants.as_ref().unwrap().lambda;
if state.candidates.is_empty() {
sample_generation(&mut state, lambda, &mut self.rng);
state.costs = problem.cost_batch(&state.candidates)?;
sort_population_ascending(&mut state.candidates, &mut state.costs);
let m_cost = problem.cost(&state.m)?;
state.m_cost = Some(m_cost);
}
Ok(state)
}
fn next_iter(
&mut self,
problem: &mut Problem<P>,
mut state: CmaEsState<V, M, F>,
) -> Result<(CmaEsState<V, M, F>, Option<TerminationReason>), Self::Error> {
let k = self
.constants
.as_ref()
.expect("CmaEs::init must run before next_iter");
state.generation += 1;
let one = F::one();
let two = F::from_f64(2.0).unwrap();
let zero = F::zero();
let mut y_sorted: Vec<V> = state
.candidates
.iter()
.map(|x| {
let mut y = x.clone();
y.scaled_add(-one, &state.m);
y.scale_in_place(one / state.sigma);
y
})
.collect();
let mut y_w = state.m.clone();
y_w.scale_in_place(zero);
for (i, y_i) in y_sorted.iter().enumerate().take(k.mu) {
y_w.scaled_add(k.weights[i], y_i);
}
state.m.scaled_add(state.sigma, &y_w);
let mut bt_y_w = state.b.mat_transpose_vec(&y_w);
bt_y_w.component_mul_assign(&state.d_inv);
let c_invsqrt_y_w = state.b.matvec(&bt_y_w);
state.p_sigma.scale_in_place(one - k.c_sigma);
let coef_sigma = (k.c_sigma * (two - k.c_sigma) * k.mu_eff).sqrt();
state.p_sigma.scaled_add(coef_sigma, &c_invsqrt_y_w);
let p_sigma_norm = state.p_sigma.norm_squared().sqrt();
let log_factor = (k.c_sigma / k.d_sigma) * (p_sigma_norm / k.expected_norm - one);
state.sigma = state.sigma * log_factor.exp();
let g_for_h = (state.generation + 1) as i32;
let exponent = 2 * g_for_h;
let denom = (one - (one - k.c_sigma).powi(exponent)).sqrt();
let h_sigma = if p_sigma_norm / denom < k.h_sigma_threshold {
one
} else {
zero
};
state.p_c.scale_in_place(one - k.c_c);
let coef_c = h_sigma * (k.c_c * (two - k.c_c) * k.mu_eff).sqrt();
state.p_c.scaled_add(coef_c, &y_w);
let delta_h = (one - h_sigma) * k.c_c * (two - k.c_c);
let c_scale = one + k.c_1 * delta_h - k.c_1 - k.c_mu * k.sum_w;
state.c.scale_in_place(c_scale);
state.c.rank_one_update(k.c_1, &state.p_c);
let n_f = F::from_usize(k.n).unwrap();
for (i, y_i) in y_sorted.iter().enumerate() {
let w_i = k.weights[i];
let w_i_o = if w_i >= zero {
w_i
} else {
let mut bt_y = state.b.mat_transpose_vec(y_i);
bt_y.component_mul_assign(&state.d_inv);
let cinv_norm_sq = bt_y.norm_squared();
if cinv_norm_sq > zero {
w_i * n_f / cinv_norm_sq
} else {
zero
}
};
if w_i_o != zero {
state.c.rank_one_update(k.c_mu * w_i_o, y_i);
}
}
drop(std::mem::take(&mut y_sorted));
let (b_new, eigs) = match state.c.try_eigh() {
Ok(pair) => pair,
Err(_) => return Ok((state, Some(TerminationReason::SolverFailed))),
};
state.b = b_new;
let eig_floor = F::from_f64(1e-30).unwrap();
for i in 0..k.n {
let lam = eigs[i].max(eig_floor);
let s = lam.sqrt();
state.d[i] = s;
state.d_inv[i] = one / s;
}
let lambda = k.lambda;
sample_generation(&mut state, lambda, &mut self.rng);
state.costs = problem.cost_batch(&state.candidates)?;
sort_population_ascending(&mut state.candidates, &mut state.costs);
let m_cost = problem.cost(&state.m)?;
state.m_cost = Some(m_cost);
Ok((state, None))
}
}