use std::collections::VecDeque;
use crate::core::constraint::BoxConstrained;
use crate::core::math::{
ClampInPlace, ComponentMulAssign, MatDiagonal, MatTransposeVec, MatVec, MatrixIdentity,
NormSquared, RankOneUpdate, SampleStandardNormal, ScaleInPlace, ScaledAdd, SymmetricEigen,
VectorLen,
};
use crate::core::problem::CostFunction;
use crate::core::rng::{ChaCha8Rng, SeedableRng};
use crate::core::solver::Solver;
use crate::core::state::BasicPopulationState;
use crate::core::termination::TerminationReason;
use super::cma_es::{compute_weights, expected_norm_n01, sort_population_ascending};
pub struct BoundedCmaEs<V, M> {
initial_mean: V,
initial_sigma: f64,
lambda_override: Option<usize>,
seed: u64,
tol_x_override: Option<f64>,
state: Option<Working<V, M>>,
}
pub(crate) struct Working<V, M> {
pub(crate) n: usize,
lambda: usize,
mu: usize,
weights: Vec<f64>,
mu_eff: f64,
sum_w: f64,
c_sigma: f64,
d_sigma: f64,
c_c: f64,
c_1: f64,
c_mu: f64,
expected_norm: f64,
h_sigma_threshold: f64,
tol_x: f64,
damp: f64,
edist_threshold: f64,
hist_cap: usize,
pub(crate) m: V,
pub(crate) sigma: f64,
p_sigma: V,
p_c: V,
c: M,
pub(crate) b: M,
d: V,
pub(crate) d_inv: V,
rng: ChaCha8Rng,
generation: u64,
pub(crate) gamma: V,
weights_initialized: bool,
hist: VecDeque<f64>,
raw_costs: Vec<f64>,
}
impl<V, M> BoundedCmaEs<V, M> {
pub fn new(initial_mean: V, initial_sigma: f64, seed: u64) -> Self {
assert!(
initial_sigma > 0.0,
"BoundedCmaEs requires initial_sigma > 0, got {}",
initial_sigma
);
Self {
initial_mean,
initial_sigma,
lambda_override: None,
seed,
tol_x_override: None,
state: None,
}
}
pub fn with_lambda(mut self, lambda: usize) -> Self {
assert!(
lambda >= 4,
"BoundedCmaEs requires lambda >= 4, got {} (Hansen 2016 footnote 30: \
smaller populations have strong adverse effects on performance)",
lambda
);
self.lambda_override = Some(lambda);
self
}
pub fn with_tol_x(mut self, tol_x: f64) -> Self {
self.tol_x_override = Some(tol_x);
self
}
pub fn default_lambda(n: usize) -> usize {
4 + (3.0 * (n as f64).ln()).floor() as usize
}
pub(crate) fn working(&self) -> Option<&Working<V, M>> {
self.state.as_ref()
}
}
impl<V, M> BoundedCmaEs<V, M>
where
V: VectorLen + Clone + std::ops::IndexMut<usize, Output = f64>,
M: MatrixIdentity,
{
fn build_working(&self) -> Working<V, M> {
let n = self.initial_mean.vec_len();
assert!(
n >= 1,
"BoundedCmaEs requires the initial mean to be non-empty"
);
let lambda = self
.lambda_override
.unwrap_or_else(|| Self::default_lambda(n));
let mu = lambda / 2;
let alpha_cov = 2.0;
let raw: Vec<f64> = (1..=lambda)
.map(|i| ((lambda as f64 + 1.0) / 2.0).ln() - (i as f64).ln())
.collect();
let sum_pos: f64 = raw[..mu].iter().sum();
let mu_eff_provisional = sum_pos.powi(2) / raw[..mu].iter().map(|w| w * w).sum::<f64>();
let c_1 = alpha_cov / ((n as f64 + 1.3).powi(2) + mu_eff_provisional);
let c_mu_unbounded = alpha_cov * (mu_eff_provisional - 2.0 + 1.0 / mu_eff_provisional)
/ ((n as f64 + 2.0).powi(2) + alpha_cov * mu_eff_provisional / 2.0);
let c_mu = (1.0 - c_1).min(c_mu_unbounded);
let (weights, mu_eff, sum_w) = compute_weights(n, lambda, c_1, c_mu);
let c_sigma = (mu_eff + 2.0) / (n as f64 + mu_eff + 5.0);
let d_sigma = {
let inner = ((mu_eff - 1.0) / (n as f64 + 1.0)).sqrt() - 1.0;
1.0 + 2.0 * inner.max(0.0) + c_sigma
};
let c_c = (4.0 + mu_eff / n as f64) / (n as f64 + 4.0 + 2.0 * mu_eff / n as f64);
let expected_norm = expected_norm_n01(n);
let h_sigma_threshold = (1.4 + 2.0 / (n as f64 + 1.0)) * expected_norm;
let tol_x = self.tol_x_override.unwrap_or(1e-12 * self.initial_sigma);
let damp = (mu_eff / (10.0 * n as f64)).min(1.0);
let edist_threshold = 3.0 * (n as f64).sqrt().max(1.0) / mu_eff.max(f64::MIN_POSITIVE);
let hist_cap = 20 + (3 * n) / lambda;
let mut gamma = self.initial_mean.clone();
for i in 0..n {
gamma[i] = 1.0;
}
Working {
n,
lambda,
mu,
weights,
mu_eff,
sum_w,
c_sigma,
d_sigma,
c_c,
c_1,
c_mu,
expected_norm,
h_sigma_threshold,
tol_x,
damp,
edist_threshold,
hist_cap,
m: self.initial_mean.clone(),
sigma: self.initial_sigma,
p_sigma: self.initial_mean.clone(),
p_c: self.initial_mean.clone(),
c: M::identity(n),
b: M::identity(n),
d: self.initial_mean.clone(),
d_inv: self.initial_mean.clone(),
rng: ChaCha8Rng::seed_from_u64(self.seed),
generation: 0,
gamma,
weights_initialized: false,
hist: VecDeque::new(),
raw_costs: Vec::with_capacity(lambda),
}
}
}
pub(crate) fn evaluate_with_penalty<P, V>(
problem: &P,
x: &V,
lower: &V,
upper: &V,
gamma: &V,
n: usize,
) -> (f64, f64)
where
P: CostFunction<Param = V, Output = f64>,
V: Clone + ClampInPlace + std::ops::Index<usize, Output = f64>,
{
let mut x_rep = x.clone();
x_rep.clamp_in_place(lower, upper);
let raw = problem.cost(&x_rep);
let mut penalty = 0.0;
for i in 0..n {
let dx = x[i] - x_rep[i];
penalty += gamma[i] * dx * dx;
}
penalty /= n as f64;
(raw, raw + penalty)
}
fn update_gamma<P, V, M>(w: &mut Working<V, M>, problem: &P)
where
P: BoxConstrained<Param = V>,
V: Clone
+ ClampInPlace
+ std::ops::Index<usize, Output = f64>
+ std::ops::IndexMut<usize, Output = f64>,
M: MatDiagonal<V>,
{
if w.raw_costs.is_empty() {
return;
}
let diag_c = w.c.diagonal();
let mut mean_varis = 0.0;
for i in 0..w.n {
mean_varis += w.sigma * w.sigma * diag_c[i];
}
mean_varis /= w.n as f64;
let mut m_rep = w.m.clone();
m_rep.clamp_in_place(problem.lower(), problem.upper());
let mut dmean: Vec<f64> = Vec::with_capacity(w.n);
let mut any_violation = false;
for i in 0..w.n {
let var_i = w.sigma * w.sigma * diag_c[i];
let d = (w.m[i] - m_rep[i]) / var_i.sqrt();
if d != 0.0 {
any_violation = true;
}
dmean.push(d);
}
let mut sorted = w.raw_costs.clone();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let l = 1 + sorted.len();
let val = (sorted[3 * l / 4] - sorted[l / 4]) / mean_varis;
if val.is_finite() && val > 0.0 {
w.hist.push_front(val);
} else if val == f64::INFINITY && !w.hist.is_empty() {
let max_hist = w.hist.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
w.hist.push_front(max_hist);
}
while w.hist.len() > w.hist_cap {
w.hist.pop_back();
}
if w.hist.is_empty() {
return;
}
let mut hsorted: Vec<f64> = w.hist.iter().cloned().collect();
hsorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let dfit = hsorted[hsorted.len() / 2];
if any_violation && !w.weights_initialized {
let init_val = 2.0 * dfit;
for i in 0..w.n {
w.gamma[i] = init_val;
}
w.weights_initialized = true;
}
if w.weights_initialized {
for (i, dmean_i) in dmean.iter().enumerate() {
let edist_i = dmean_i.abs() - w.edist_threshold;
if edist_i > 0.0 {
let factor = ((edist_i / 3.0).tanh() / 2.0 * w.damp).exp();
w.gamma[i] *= factor;
}
}
let cap = 5.0 * dfit;
let decay = (-w.damp / 3.0).exp();
for i in 0..w.n {
if w.gamma[i] > cap {
w.gamma[i] *= decay;
}
}
}
}
impl<P, V, M> Solver<P, BasicPopulationState<V>> for BoundedCmaEs<V, M>
where
P: CostFunction<Param = V, Output = f64> + BoxConstrained,
V: VectorLen
+ Clone
+ ScaledAdd<f64>
+ ScaleInPlace
+ ComponentMulAssign
+ ClampInPlace
+ NormSquared
+ SampleStandardNormal
+ std::ops::Index<usize, Output = f64>
+ std::ops::IndexMut<usize, Output = f64>,
M: MatrixIdentity
+ MatVec<V>
+ MatTransposeVec<V>
+ MatDiagonal<V>
+ ScaleInPlace
+ RankOneUpdate<V>
+ SymmetricEigen<V>
+ Clone,
{
fn init(&mut self, problem: &P, mut state: BasicPopulationState<V>) -> BasicPopulationState<V> {
if self.state.is_some() {
return state;
}
let mut w = self.build_working();
w.p_sigma.scale_in_place(0.0);
w.p_c.scale_in_place(0.0);
for i in 0..w.n {
w.d[i] = 1.0;
w.d_inv[i] = 1.0;
}
w.m.clamp_in_place(problem.lower(), problem.upper());
state.candidates.clear();
state.costs.clear();
w.raw_costs.clear();
for _k in 0..w.lambda {
let z_k = V::sample_standard_normal(&w.m, &mut w.rng);
let mut x_k = w.m.clone();
x_k.scaled_add(w.sigma, &z_k);
let (raw, pen) = evaluate_with_penalty(
problem,
&x_k,
problem.lower(),
problem.upper(),
&w.gamma,
w.n,
);
state.candidates.push(x_k);
state.costs.push(pen);
w.raw_costs.push(raw);
}
state.cost_evals += w.lambda as u64;
sort_population_ascending(&mut state.candidates, &mut state.costs);
self.state = Some(w);
state
}
fn next_iter(
&mut self,
problem: &P,
mut state: BasicPopulationState<V>,
) -> (BasicPopulationState<V>, Option<TerminationReason>) {
let w = self
.state
.as_mut()
.expect("BoundedCmaEs::init must run before next_iter");
w.generation += 1;
let mut y_sorted: Vec<V> = state
.candidates
.iter()
.map(|x| {
let mut y = x.clone();
y.scaled_add(-1.0, &w.m);
y.scale_in_place(1.0 / w.sigma);
y
})
.collect();
let mut y_w = w.m.clone();
y_w.scale_in_place(0.0);
for (i, y_i) in y_sorted.iter().enumerate().take(w.mu) {
y_w.scaled_add(w.weights[i], y_i);
}
w.m.scaled_add(w.sigma, &y_w);
let mut bt_y_w = w.b.mat_transpose_vec(&y_w);
bt_y_w.component_mul_assign(&w.d_inv);
let c_invsqrt_y_w = w.b.matvec(&bt_y_w);
w.p_sigma.scale_in_place(1.0 - w.c_sigma);
let coef_sigma = (w.c_sigma * (2.0 - w.c_sigma) * w.mu_eff).sqrt();
w.p_sigma.scaled_add(coef_sigma, &c_invsqrt_y_w);
let p_sigma_norm = w.p_sigma.norm_squared().sqrt();
let log_factor = (w.c_sigma / w.d_sigma) * (p_sigma_norm / w.expected_norm - 1.0);
w.sigma *= log_factor.exp();
let g_for_h = (w.generation + 1) as i32;
let exponent = 2 * g_for_h;
let denom = (1.0 - (1.0 - w.c_sigma).powi(exponent)).sqrt();
let h_sigma = if p_sigma_norm / denom < w.h_sigma_threshold {
1.0
} else {
0.0
};
w.p_c.scale_in_place(1.0 - w.c_c);
let coef_c = h_sigma * (w.c_c * (2.0 - w.c_c) * w.mu_eff).sqrt();
w.p_c.scaled_add(coef_c, &y_w);
let delta_h = (1.0 - h_sigma) * w.c_c * (2.0 - w.c_c);
let c_scale = 1.0 + w.c_1 * delta_h - w.c_1 - w.c_mu * w.sum_w;
w.c.scale_in_place(c_scale);
w.c.rank_one_update(w.c_1, &w.p_c);
for (i, y_i) in y_sorted.iter().enumerate() {
let w_i = w.weights[i];
let w_i_o = if w_i >= 0.0 {
w_i
} else {
let mut bt_y = w.b.mat_transpose_vec(y_i);
bt_y.component_mul_assign(&w.d_inv);
let cinv_norm_sq = bt_y.norm_squared();
if cinv_norm_sq > 0.0 {
w_i * (w.n as f64) / cinv_norm_sq
} else {
0.0
}
};
if w_i_o != 0.0 {
w.c.rank_one_update(w.c_mu * w_i_o, y_i);
}
}
drop(std::mem::take(&mut y_sorted));
let (b_new, eigs) = match w.c.try_eigh() {
Ok(pair) => pair,
Err(_) => return (state, Some(TerminationReason::SolverFailed)),
};
w.b = b_new;
for i in 0..w.n {
let lam = eigs[i].max(1e-30);
let s = lam.sqrt();
w.d[i] = s;
w.d_inv[i] = 1.0 / s;
}
update_gamma(w, problem);
state.candidates.clear();
state.costs.clear();
w.raw_costs.clear();
for _k in 0..w.lambda {
let z_k = V::sample_standard_normal(&w.m, &mut w.rng);
let mut bd_z = z_k;
bd_z.component_mul_assign(&w.d);
let bd_z = w.b.matvec(&bd_z);
let mut x_k = w.m.clone();
x_k.scaled_add(w.sigma, &bd_z);
let (raw, pen) = evaluate_with_penalty(
problem,
&x_k,
problem.lower(),
problem.upper(),
&w.gamma,
w.n,
);
state.candidates.push(x_k);
state.costs.push(pen);
w.raw_costs.push(raw);
}
state.cost_evals += w.lambda as u64;
sort_population_ascending(&mut state.candidates, &mut state.costs);
(state, None)
}
fn terminate(&self, _state: &BasicPopulationState<V>) -> Option<TerminationReason> {
let w = self.state.as_ref()?;
let mut max_d = 0.0_f64;
for i in 0..w.n {
let v = w.d[i];
if v > max_d {
max_d = v;
}
}
if w.sigma * max_d < w.tol_x {
return Some(TerminationReason::SolverConverged);
}
None
}
}