use std::collections::VecDeque;
use std::marker::PhantomData;
use crate::core::constraint::BoxConstraints;
use crate::core::math::{
ClampInPlace, ComponentMulAssign, MatDiagonal, MatTransposeVec, MatVec, MatrixFromDiagonal,
MatrixIdentity, NormSquared, RankOneUpdate, SampleStandardNormal, Scalar, ScaleInPlace,
ScaledAdd, SymmetricEigen, VectorLen,
};
use crate::core::problem::{CostFunction, Problem};
use crate::core::rng::{ChaCha8Rng, SeedableRng};
use crate::core::solver::Solver;
use crate::core::state::CmaEsState;
use crate::core::state::cma_es::BoundPenalty;
use crate::core::termination::TerminationReason;
use super::cma_es::{CmaConstants, compute_constants, sort_population_ascending};
pub struct BoundedCmaEs<V, M, F = f64> {
lambda_override: Option<usize>,
constants: Option<BoundedCmaConstants<F>>,
rng: ChaCha8Rng,
_marker: PhantomData<(V, M)>,
}
pub(crate) struct BoundedCmaConstants<F = f64> {
cma: CmaConstants<F>,
damp: F,
edist_threshold: F,
hist_cap: usize,
}
impl<V, M, F: Scalar> BoundedCmaEs<V, M, F> {
pub fn new(seed: u64) -> Self {
Self {
lambda_override: None,
constants: None,
rng: ChaCha8Rng::seed_from_u64(seed),
_marker: PhantomData,
}
}
pub fn with_lambda(mut self, lambda: usize) -> Self {
assert!(
lambda >= 4,
"BoundedCmaEs requires lambda >= 4, got {} (Hansen 2016 footnote 30: \
smaller populations have strong adverse effects on performance)",
lambda
);
self.lambda_override = Some(lambda);
self
}
pub fn default_lambda(n: usize) -> usize {
4 + (3.0 * (n as f64).ln()).floor() as usize
}
}
fn compute_bounded_constants<F: Scalar>(n: usize, lambda: usize) -> BoundedCmaConstants<F> {
let cma = compute_constants::<F>(n, lambda);
let one = F::one();
let n_f = F::from_usize(n).unwrap();
let three = F::from_f64(3.0).unwrap();
let ten = F::from_f64(10.0).unwrap();
let damp = (cma.mu_eff / (ten * n_f)).min(one);
let edist_threshold =
three * n_f.sqrt().max(one) / cma.mu_eff.max(F::from_f64(f64::MIN_POSITIVE).unwrap());
let hist_cap = 20 + (3 * n) / lambda;
BoundedCmaConstants {
cma,
damp,
edist_threshold,
hist_cap,
}
}
pub(crate) fn evaluate_with_penalty<P, V, F>(
problem: &mut Problem<P>,
x: &V,
lower: &V,
upper: &V,
gamma: &V,
n: usize,
) -> Result<(F, F), P::Error>
where
F: Scalar,
P: CostFunction<Param = V, Output = F>,
V: Clone + ClampInPlace + std::ops::Index<usize, Output = F>,
{
let mut x_rep = x.clone();
x_rep.clamp_in_place(lower, upper);
let raw = problem.cost(&x_rep)?;
let mut penalty = F::zero();
for i in 0..n {
let dx = x[i] - x_rep[i];
penalty = penalty + gamma[i] * dx * dx;
}
penalty = penalty / F::from_usize(n).unwrap();
Ok((raw, raw + penalty))
}
fn update_gamma<P, V, M, F>(
state: &mut CmaEsState<V, M, F>,
k: &BoundedCmaConstants<F>,
problem: &P,
) where
F: Scalar,
P: BoxConstraints<Param = V>,
V: Clone
+ ClampInPlace
+ std::ops::Index<usize, Output = F>
+ std::ops::IndexMut<usize, Output = F>,
M: MatDiagonal<V>,
{
let n = k.cma.n;
let pen = state
.penalty
.as_mut()
.expect("BoundedCmaEs::init installs the penalty before next_iter");
if pen.raw_costs.is_empty() {
return;
}
let zero = F::zero();
let two = F::from_f64(2.0).unwrap();
let three = F::from_f64(3.0).unwrap();
let five = F::from_f64(5.0).unwrap();
let n_f = F::from_usize(n).unwrap();
let diag_c = state.c.diagonal();
let mut mean_varis = zero;
for i in 0..n {
mean_varis = mean_varis + state.sigma * state.sigma * diag_c[i];
}
mean_varis = mean_varis / n_f;
let mut m_rep = state.m.clone();
m_rep.clamp_in_place(problem.lower(), problem.upper());
let mut dmean: Vec<F> = Vec::with_capacity(n);
let mut any_violation = false;
for i in 0..n {
let var_i = state.sigma * state.sigma * diag_c[i];
let d = (state.m[i] - m_rep[i]) / var_i.sqrt();
if d != zero {
any_violation = true;
}
dmean.push(d);
}
let mut sorted = pen.raw_costs.clone();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let l = 1 + sorted.len();
let val = (sorted[3 * l / 4] - sorted[l / 4]) / mean_varis;
if val.is_finite() && val > zero {
pen.hist.push_front(val);
} else if val == F::infinity() && !pen.hist.is_empty() {
let max_hist = pen
.hist
.iter()
.copied()
.fold(F::neg_infinity(), |a, b| if b > a { b } else { a });
pen.hist.push_front(max_hist);
}
while pen.hist.len() > k.hist_cap {
pen.hist.pop_back();
}
if pen.hist.is_empty() {
return;
}
let mut hsorted: Vec<F> = pen.hist.iter().copied().collect();
hsorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let dfit = hsorted[hsorted.len() / 2];
if any_violation && !pen.weights_initialized {
let init_val = two * dfit;
for i in 0..n {
pen.gamma[i] = init_val;
}
pen.weights_initialized = true;
}
if pen.weights_initialized {
for (i, dmean_i) in dmean.iter().enumerate() {
let edist_i = dmean_i.abs() - k.edist_threshold;
if edist_i > zero {
let factor = ((edist_i / three).tanh() / two * k.damp).exp();
pen.gamma[i] = pen.gamma[i] * factor;
}
}
let cap = five * dfit;
let decay = (-k.damp / three).exp();
for i in 0..n {
if pen.gamma[i] > cap {
pen.gamma[i] = pen.gamma[i] * decay;
}
}
}
}
#[allow(clippy::too_many_arguments)]
fn sample_and_penalize<P, V, M, F>(
state: &mut CmaEsState<V, M, F>,
lambda: usize,
n: usize,
rng: &mut ChaCha8Rng,
problem: &mut Problem<P>,
lo: &V,
hi: &V,
) -> Result<(), P::Error>
where
F: Scalar,
P: CostFunction<Param = V, Output = F>,
V: VectorLen
+ Clone
+ ScaledAdd<F>
+ ComponentMulAssign
+ ClampInPlace
+ SampleStandardNormal
+ std::ops::Index<usize, Output = F>,
M: MatVec<V>,
{
let CmaEsState {
candidates,
costs,
m,
d,
b,
sigma,
penalty,
..
} = state;
let pen = penalty
.as_mut()
.expect("BoundedCmaEs::init installs the penalty before sampling");
candidates.clear();
costs.clear();
pen.raw_costs.clear();
for _ in 0..lambda {
let z_k = V::sample_standard_normal(m, rng);
let mut bd_z = z_k;
bd_z.component_mul_assign(d);
let bd_z = b.matvec(&bd_z);
let mut x_k = m.clone();
x_k.scaled_add(*sigma, &bd_z);
let (raw, p) = evaluate_with_penalty(problem, &x_k, lo, hi, &pen.gamma, n)?;
candidates.push(x_k);
costs.push(p);
pen.raw_costs.push(raw);
}
Ok(())
}
impl<P, V, M, F> Solver<P, CmaEsState<V, M, F>> for BoundedCmaEs<V, M, F>
where
F: Scalar,
P: CostFunction<Param = V, Output = F> + BoxConstraints,
V: VectorLen
+ Clone
+ ScaledAdd<F>
+ ScaleInPlace<F>
+ ComponentMulAssign
+ ClampInPlace
+ NormSquared<F>
+ SampleStandardNormal
+ std::ops::Index<usize, Output = F>
+ std::ops::IndexMut<usize, Output = F>,
M: MatrixIdentity
+ MatrixFromDiagonal<V>
+ MatVec<V>
+ MatTransposeVec<V>
+ MatDiagonal<V>
+ ScaleInPlace<F>
+ RankOneUpdate<V, F>
+ SymmetricEigen<V>
+ Clone,
{
type Error = P::Error;
fn init(
&mut self,
problem: &mut Problem<P>,
mut state: CmaEsState<V, M, F>,
) -> Result<CmaEsState<V, M, F>, Self::Error> {
if self.constants.is_none() {
let n = state.m.vec_len();
assert!(n >= 1, "BoundedCmaEs requires a non-empty mean");
let lambda = self
.lambda_override
.unwrap_or_else(|| Self::default_lambda(n));
self.constants = Some(compute_bounded_constants::<F>(n, lambda));
}
let n = self.constants.as_ref().unwrap().cma.n;
let lambda = self.constants.as_ref().unwrap().cma.lambda;
if state.penalty.is_none() {
let mut gamma = state.m.clone();
for i in 0..n {
gamma[i] = F::one();
}
state.penalty = Some(BoundPenalty {
gamma,
weights_initialized: false,
hist: VecDeque::new(),
raw_costs: Vec::with_capacity(lambda),
});
}
if state.candidates.is_empty() {
let lo = problem.inner().lower().clone();
let hi = problem.inner().upper().clone();
state.m.clamp_in_place(&lo, &hi);
sample_and_penalize(&mut state, lambda, n, &mut self.rng, problem, &lo, &hi)?;
sort_population_ascending(&mut state.candidates, &mut state.costs);
let gamma = &state.penalty.as_ref().unwrap().gamma;
let (_raw, pen_m) = evaluate_with_penalty(problem, &state.m, &lo, &hi, gamma, n)?;
state.m_cost = Some(pen_m);
}
Ok(state)
}
fn next_iter(
&mut self,
problem: &mut Problem<P>,
mut state: CmaEsState<V, M, F>,
) -> Result<(CmaEsState<V, M, F>, Option<TerminationReason>), Self::Error> {
let k = self
.constants
.as_ref()
.expect("BoundedCmaEs::init must run before next_iter");
let kc = &k.cma;
state.generation += 1;
let one = F::one();
let two = F::from_f64(2.0).unwrap();
let zero = F::zero();
let mut y_sorted: Vec<V> = state
.candidates
.iter()
.map(|x| {
let mut y = x.clone();
y.scaled_add(-one, &state.m);
y.scale_in_place(one / state.sigma);
y
})
.collect();
let mut y_w = state.m.clone();
y_w.scale_in_place(zero);
for (i, y_i) in y_sorted.iter().enumerate().take(kc.mu) {
y_w.scaled_add(kc.weights[i], y_i);
}
state.m.scaled_add(state.sigma, &y_w);
let mut bt_y_w = state.b.mat_transpose_vec(&y_w);
bt_y_w.component_mul_assign(&state.d_inv);
let c_invsqrt_y_w = state.b.matvec(&bt_y_w);
state.p_sigma.scale_in_place(one - kc.c_sigma);
let coef_sigma = (kc.c_sigma * (two - kc.c_sigma) * kc.mu_eff).sqrt();
state.p_sigma.scaled_add(coef_sigma, &c_invsqrt_y_w);
let p_sigma_norm = state.p_sigma.norm_squared().sqrt();
let log_factor = (kc.c_sigma / kc.d_sigma) * (p_sigma_norm / kc.expected_norm - one);
state.sigma = state.sigma * log_factor.exp();
let g_for_h = (state.generation + 1) as i32;
let exponent = 2 * g_for_h;
let denom = (one - (one - kc.c_sigma).powi(exponent)).sqrt();
let h_sigma = if p_sigma_norm / denom < kc.h_sigma_threshold {
one
} else {
zero
};
state.p_c.scale_in_place(one - kc.c_c);
let coef_c = h_sigma * (kc.c_c * (two - kc.c_c) * kc.mu_eff).sqrt();
state.p_c.scaled_add(coef_c, &y_w);
let delta_h = (one - h_sigma) * kc.c_c * (two - kc.c_c);
let c_scale = one + kc.c_1 * delta_h - kc.c_1 - kc.c_mu * kc.sum_w;
state.c.scale_in_place(c_scale);
state.c.rank_one_update(kc.c_1, &state.p_c);
let n_f = F::from_usize(kc.n).unwrap();
for (i, y_i) in y_sorted.iter().enumerate() {
let w_i = kc.weights[i];
let w_i_o = if w_i >= zero {
w_i
} else {
let mut bt_y = state.b.mat_transpose_vec(y_i);
bt_y.component_mul_assign(&state.d_inv);
let cinv_norm_sq = bt_y.norm_squared();
if cinv_norm_sq > zero {
w_i * n_f / cinv_norm_sq
} else {
zero
}
};
if w_i_o != zero {
state.c.rank_one_update(kc.c_mu * w_i_o, y_i);
}
}
drop(std::mem::take(&mut y_sorted));
let (b_new, eigs) = match state.c.try_eigh() {
Ok(pair) => pair,
Err(_) => return Ok((state, Some(TerminationReason::SolverFailed))),
};
state.b = b_new;
let eig_floor = F::from_f64(1e-30).unwrap();
for i in 0..kc.n {
let lam = eigs[i].max(eig_floor);
let s = lam.sqrt();
state.d[i] = s;
state.d_inv[i] = one / s;
}
update_gamma(&mut state, k, problem.inner());
let n = kc.n;
let lambda = kc.lambda;
let lo = problem.inner().lower().clone();
let hi = problem.inner().upper().clone();
sample_and_penalize(&mut state, lambda, n, &mut self.rng, problem, &lo, &hi)?;
sort_population_ascending(&mut state.candidates, &mut state.costs);
let gamma = &state.penalty.as_ref().unwrap().gamma;
let (_raw, pen_m) = evaluate_with_penalty(problem, &state.m, &lo, &hi, gamma, n)?;
state.m_cost = Some(pen_m);
Ok((state, None))
}
}