use std::collections::HashMap;
use nalgebra::{DMatrix, DVector};
use parking_lot::Mutex;
use crate::distribution::Distribution;
use crate::param::ParamValue;
use crate::rng_util;
use crate::sampler::{CompletedTrial, Sampler};
use super::common::{from_internal, internal_bounds, sample_random};
pub struct CmaEsSampler {
state: Mutex<CmaEsState>,
}
impl CmaEsSampler {
#[must_use]
pub fn new() -> Self {
Self {
state: Mutex::new(CmaEsState::new(None, None, None)),
}
}
#[must_use]
pub fn with_seed(seed: u64) -> Self {
Self {
state: Mutex::new(CmaEsState::new(None, None, Some(seed))),
}
}
#[must_use]
pub fn builder() -> CmaEsSamplerBuilder {
CmaEsSamplerBuilder::new()
}
}
impl Default for CmaEsSampler {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Default)]
pub struct CmaEsSamplerBuilder {
sigma0: Option<f64>,
population_size: Option<usize>,
seed: Option<u64>,
}
impl CmaEsSamplerBuilder {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn sigma0(mut self, sigma0: f64) -> Self {
self.sigma0 = Some(sigma0);
self
}
#[must_use]
pub fn population_size(mut self, population_size: usize) -> Self {
self.population_size = Some(population_size);
self
}
#[must_use]
pub fn seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
#[must_use]
pub fn build(self) -> CmaEsSampler {
CmaEsSampler {
state: Mutex::new(CmaEsState::new(
self.sigma0,
self.population_size,
self.seed,
)),
}
}
}
#[derive(Clone, Debug)]
struct DimensionInfo {
distribution: Distribution,
is_continuous: bool,
bounds: Option<(f64, f64)>,
}
#[derive(Clone, Debug)]
struct Candidate {
x: DVector<f64>,
categorical_values: HashMap<usize, usize>,
}
#[derive(Clone, Debug)]
struct TrialProgress {
candidate_idx: usize,
next_dim: usize,
}
#[derive(Clone, Debug)]
struct CmaEsConstants {
n: usize,
lambda: usize,
mu: usize,
weights: Vec<f64>,
mu_eff: f64,
c_sigma: f64,
d_sigma: f64,
c_c: f64,
c_1: f64,
c_mu: f64,
chi_n: f64,
}
struct CmaEsAlgorithm {
mean: DVector<f64>,
sigma: f64,
c: DMatrix<f64>,
p_sigma: DVector<f64>,
p_c: DVector<f64>,
b: DMatrix<f64>,
d: DVector<f64>,
inv_sqrt_c: DMatrix<f64>,
generation: usize,
last_eigen_generation: usize,
constants: CmaEsConstants,
}
enum Phase {
Discovery,
Active(Box<CmaEsAlgorithm>),
}
struct CmaEsState {
rng: fastrand::Rng,
sigma0: Option<f64>,
user_lambda: Option<usize>,
phase: Phase,
dimensions: Vec<DimensionInfo>,
candidates: Vec<Candidate>,
trial_progress: HashMap<u64, TrialProgress>,
assigned_count: usize,
generation_trial_ids: Vec<u64>,
discovery_trial_id: Option<u64>,
}
impl CmaEsState {
fn new(sigma0: Option<f64>, user_lambda: Option<usize>, seed: Option<u64>) -> Self {
let rng = seed.map_or_else(fastrand::Rng::new, fastrand::Rng::with_seed);
Self {
rng,
sigma0,
user_lambda,
phase: Phase::Discovery,
dimensions: Vec::new(),
candidates: Vec::new(),
trial_progress: HashMap::new(),
assigned_count: 0,
generation_trial_ids: Vec::new(),
discovery_trial_id: None,
}
}
}
impl CmaEsConstants {
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
fn new(n: usize, user_lambda: Option<usize>) -> Self {
let n_f = n as f64;
let lambda = user_lambda.unwrap_or_else(|| 4 + (3.0 * n_f.ln()).max(0.0).floor() as usize);
let lambda = lambda.max(4);
let mu = lambda / 2;
let log_half_lambda = f64::midpoint(lambda as f64, 1.0).ln();
let raw_weights: Vec<f64> = (0..mu)
.map(|i| log_half_lambda - ((i + 1) as f64).ln())
.collect();
let w_sum: f64 = raw_weights.iter().sum();
let weights: Vec<f64> = raw_weights.iter().map(|w| w / w_sum).collect();
let w_sq_sum: f64 = weights.iter().map(|w| w * w).sum();
let mu_eff = 1.0 / w_sq_sum;
let c_sigma = (mu_eff + 2.0) / (n_f + mu_eff + 5.0);
let d_sigma = 1.0 + 2.0 * (((mu_eff - 1.0) / (n_f + 1.0)).sqrt() - 1.0).max(0.0) + c_sigma;
let c_c = (4.0 + mu_eff / n_f) / (n_f + 4.0 + 2.0 * mu_eff / n_f);
let c_1 = 2.0 / ((n_f + 1.3).powi(2) + mu_eff);
let c_mu_raw = (2.0 * (mu_eff - 2.0 + 1.0 / mu_eff)) / ((n_f + 2.0).powi(2) + mu_eff);
let c_mu = c_mu_raw.min(1.0 - c_1);
let chi_n = n_f.sqrt() * (1.0 - 1.0 / (4.0 * n_f) + 1.0 / (21.0 * n_f * n_f));
Self {
n,
lambda,
mu,
weights,
mu_eff,
c_sigma,
d_sigma,
c_c,
c_1,
c_mu,
chi_n,
}
}
}
impl CmaEsAlgorithm {
#[allow(clippy::cast_precision_loss)]
fn new(dimensions: &[DimensionInfo], sigma0: Option<f64>, user_lambda: Option<usize>) -> Self {
let n = dimensions.iter().filter(|d| d.is_continuous).count();
let constants = CmaEsConstants::new(n, user_lambda);
let mut mean = DVector::zeros(n);
let mut total_range = 0.0;
let mut ci = 0;
for dim in dimensions {
if dim.is_continuous {
if let Some((lo, hi)) = dim.bounds {
mean[ci] = f64::midpoint(lo, hi);
total_range += hi - lo;
}
ci += 1;
}
}
let sigma = sigma0.unwrap_or_else(|| {
if n > 0 {
(total_range / n as f64) / 4.0
} else {
1.0
}
});
let c = DMatrix::identity(n, n);
let p_sigma = DVector::zeros(n);
let p_c = DVector::zeros(n);
let b = DMatrix::identity(n, n);
let d = DVector::from_element(n, 1.0);
let inv_sqrt_c = DMatrix::identity(n, n);
Self {
mean,
sigma,
c,
p_sigma,
p_c,
b,
d,
inv_sqrt_c,
generation: 0,
last_eigen_generation: 0,
constants,
}
}
fn generate_candidates(
&self,
rng: &mut fastrand::Rng,
dimensions: &[DimensionInfo],
) -> Vec<Candidate> {
let n = self.constants.n;
let lambda = self.constants.lambda;
let mut candidates = Vec::with_capacity(lambda);
for _ in 0..lambda {
let candidate = self.generate_single_candidate(rng, dimensions, n);
candidates.push(candidate);
}
candidates
}
fn generate_single_candidate(
&self,
rng: &mut fastrand::Rng,
dimensions: &[DimensionInfo],
n: usize,
) -> Candidate {
let x = self.sample_with_rejection(rng, dimensions, n);
let mut categorical_values = HashMap::new();
for (i, dim) in dimensions.iter().enumerate() {
if !dim.is_continuous
&& let Distribution::Categorical(cat) = &dim.distribution
{
categorical_values.insert(i, rng.usize(0..cat.n_choices));
}
}
Candidate {
x,
categorical_values,
}
}
fn sample_with_rejection(
&self,
rng: &mut fastrand::Rng,
dimensions: &[DimensionInfo],
n: usize,
) -> DVector<f64> {
let max_attempts = 100;
for _ in 0..max_attempts {
let z = DVector::from_fn(n, |_, _| sample_standard_normal(rng));
let x = &self.mean + self.sigma * (&self.b * self.d.component_mul(&z));
if is_within_bounds(&x, dimensions) {
return x;
}
}
let z = DVector::from_fn(n, |_, _| sample_standard_normal(rng));
let mut x = &self.mean + self.sigma * (&self.b * self.d.component_mul(&z));
clip_to_bounds(&mut x, dimensions);
x
}
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_possible_wrap
)]
fn update(&mut self, ranked_candidates: &[&DVector<f64>]) {
let n = self.constants.n;
let mu = self.constants.mu;
let sigma = self.sigma;
let mut new_mean = DVector::zeros(n);
for (i, &x) in ranked_candidates.iter().take(mu).enumerate() {
new_mean += self.constants.weights[i] * x;
}
let mean_diff = &new_mean - &self.mean;
let inv_sqrt_c_times_diff = &self.inv_sqrt_c * &mean_diff / sigma;
self.p_sigma = (1.0 - self.constants.c_sigma) * &self.p_sigma
+ (self.constants.c_sigma * (2.0 - self.constants.c_sigma) * self.constants.mu_eff)
.sqrt()
* &inv_sqrt_c_times_diff;
let p_sigma_norm = self.p_sigma.norm();
let threshold =
(1.0 - (1.0 - self.constants.c_sigma).powi(2 * (self.generation as i32 + 1))).sqrt()
* (1.4 + 2.0 / (n as f64 + 1.0))
* self.constants.chi_n;
let h_sigma = if p_sigma_norm < threshold { 1.0 } else { 0.0 };
self.p_c = (1.0 - self.constants.c_c) * &self.p_c
+ h_sigma
* (self.constants.c_c * (2.0 - self.constants.c_c) * self.constants.mu_eff).sqrt()
* &mean_diff
/ sigma;
let delta_h = (1.0 - h_sigma) * self.constants.c_c * (2.0 - self.constants.c_c);
let old_c_weight =
1.0 - self.constants.c_1 - self.constants.c_mu + self.constants.c_1 * delta_h;
let rank_one = self.constants.c_1 * &self.p_c * self.p_c.transpose();
let mut rank_mu = DMatrix::zeros(n, n);
for (i, &x) in ranked_candidates.iter().take(mu).enumerate() {
let y = (x - &self.mean) / sigma;
rank_mu += self.constants.weights[i] * &y * y.transpose();
}
let rank_mu = self.constants.c_mu * rank_mu;
self.c = old_c_weight * &self.c + rank_one + rank_mu;
self.sigma *= ((self.constants.c_sigma / self.constants.d_sigma)
* (p_sigma_norm / self.constants.chi_n - 1.0))
.exp();
self.sigma = self.sigma.clamp(1e-20, 1e10);
self.mean = new_mean;
self.generation += 1;
let eigen_interval = (n / 10).max(1);
if self.generation - self.last_eigen_generation >= eigen_interval {
self.update_eigen();
}
}
fn update_eigen(&mut self) {
let n = self.constants.n;
self.c = (&self.c + self.c.transpose()) / 2.0;
let eigen = self.c.clone().symmetric_eigen();
let eigenvalues = &eigen.eigenvalues;
let eigenvectors = &eigen.eigenvectors;
let mut d_vec = DVector::zeros(n);
for i in 0..n {
d_vec[i] = eigenvalues[i].max(1e-20).sqrt();
}
self.b = eigenvectors.clone();
self.d = d_vec;
let d_inv = DVector::from_fn(n, |i, _| 1.0 / self.d[i]);
let d_inv_diag = DMatrix::from_diagonal(&d_inv);
self.inv_sqrt_c = &self.b * d_inv_diag * self.b.transpose();
self.last_eigen_generation = self.generation;
}
}
fn is_within_bounds(x: &DVector<f64>, dimensions: &[DimensionInfo]) -> bool {
let mut ci = 0;
for dim in dimensions {
if dim.is_continuous {
if let Some((lo, hi)) = dim.bounds
&& (x[ci] < lo || x[ci] > hi)
{
return false;
}
ci += 1;
}
}
true
}
fn clip_to_bounds(x: &mut DVector<f64>, dimensions: &[DimensionInfo]) {
let mut ci = 0;
for dim in dimensions {
if dim.is_continuous {
if let Some((lo, hi)) = dim.bounds {
x[ci] = x[ci].clamp(lo, hi);
}
ci += 1;
}
}
}
fn sample_standard_normal(rng: &mut fastrand::Rng) -> f64 {
let u1: f64 = rng_util::f64_range(rng, f64::EPSILON, 1.0);
let u2: f64 = rng_util::f64_range(rng, 0.0_f64, core::f64::consts::TAU);
(-2.0 * u1.ln()).sqrt() * u2.cos()
}
impl Sampler for CmaEsSampler {
#[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
fn sample(
&self,
distribution: &Distribution,
trial_id: u64,
history: &[CompletedTrial],
) -> ParamValue {
let mut state = self.state.lock();
match &state.phase {
Phase::Discovery => sample_discovery(&mut state, distribution, trial_id),
Phase::Active(_) => sample_active(&mut state, distribution, trial_id, history),
}
}
}
fn sample_discovery(
state: &mut CmaEsState,
distribution: &Distribution,
trial_id: u64,
) -> ParamValue {
if let Some(prev_id) = state.discovery_trial_id
&& trial_id != prev_id
{
finalize_discovery(state);
return sample_active(state, distribution, trial_id, &[]);
}
state.discovery_trial_id = Some(trial_id);
let is_continuous = !matches!(distribution, Distribution::Categorical(_));
let bounds = internal_bounds(distribution);
state.dimensions.push(DimensionInfo {
distribution: distribution.clone(),
is_continuous,
bounds,
});
sample_random(&mut state.rng, distribution)
}
fn finalize_discovery(state: &mut CmaEsState) {
let n_continuous = state.dimensions.iter().filter(|d| d.is_continuous).count();
let algo = CmaEsAlgorithm::new(&state.dimensions, state.sigma0, state.user_lambda);
let candidates = if n_continuous > 0 {
algo.generate_candidates(&mut state.rng, &state.dimensions)
} else {
generate_pure_categorical_candidates(
&mut state.rng,
&state.dimensions,
algo.constants.lambda,
)
};
state.candidates = candidates;
state.assigned_count = 0;
state.generation_trial_ids.clear();
state.trial_progress.clear();
state.phase = Phase::Active(Box::new(algo));
}
fn generate_pure_categorical_candidates(
rng: &mut fastrand::Rng,
dimensions: &[DimensionInfo],
lambda: usize,
) -> Vec<Candidate> {
(0..lambda)
.map(|_| {
let mut categorical_values = HashMap::new();
for (i, dim) in dimensions.iter().enumerate() {
if let Distribution::Categorical(cat) = &dim.distribution {
categorical_values.insert(i, rng.usize(0..cat.n_choices));
}
}
Candidate {
x: DVector::zeros(0),
categorical_values,
}
})
.collect()
}
fn sample_active(
state: &mut CmaEsState,
distribution: &Distribution,
trial_id: u64,
history: &[CompletedTrial],
) -> ParamValue {
maybe_update_generation(state, history);
if !state.trial_progress.contains_key(&trial_id) {
assign_candidate(state, trial_id);
}
let progress = state.trial_progress.get_mut(&trial_id).unwrap();
let dim_idx = progress.next_dim;
progress.next_dim += 1;
if dim_idx >= state.dimensions.len() {
return sample_random(&mut state.rng, distribution);
}
let candidate = &state.candidates[progress.candidate_idx];
let dim_info = &state.dimensions[dim_idx];
if dim_info.is_continuous {
let ci = state.dimensions[..dim_idx]
.iter()
.filter(|d| d.is_continuous)
.count();
if ci < candidate.x.len() {
from_internal(candidate.x[ci], &dim_info.distribution)
} else {
sample_random(&mut state.rng, distribution)
}
} else {
if let Some(&cat_idx) = candidate.categorical_values.get(&dim_idx) {
ParamValue::Categorical(cat_idx)
} else {
sample_random(&mut state.rng, distribution)
}
}
}
fn assign_candidate(state: &mut CmaEsState, trial_id: u64) {
let candidate_idx = if state.assigned_count < state.candidates.len() {
let idx = state.assigned_count;
state.assigned_count += 1;
idx
} else {
let extra = generate_overflow_candidate(state);
state.candidates.push(extra);
let idx = state.candidates.len() - 1;
state.assigned_count = state.candidates.len();
idx
};
state.trial_progress.insert(
trial_id,
TrialProgress {
candidate_idx,
next_dim: 0,
},
);
state.generation_trial_ids.push(trial_id);
}
fn generate_overflow_candidate(state: &mut CmaEsState) -> Candidate {
let n_continuous = state.dimensions.iter().filter(|d| d.is_continuous).count();
if n_continuous == 0 {
let mut categorical_values = HashMap::new();
for (i, dim) in state.dimensions.iter().enumerate() {
if let Distribution::Categorical(cat) = &dim.distribution {
categorical_values.insert(i, state.rng.usize(0..cat.n_choices));
}
}
return Candidate {
x: DVector::zeros(0),
categorical_values,
};
}
match &state.phase {
Phase::Active(algo) => {
algo.generate_single_candidate(&mut state.rng, &state.dimensions, n_continuous)
}
Phase::Discovery => {
Candidate {
x: DVector::zeros(n_continuous),
categorical_values: HashMap::new(),
}
}
}
}
fn maybe_update_generation(state: &mut CmaEsState, history: &[CompletedTrial]) {
let Phase::Active(algo) = &state.phase else {
return;
};
let lambda = algo.constants.lambda;
let n_continuous = algo.constants.n;
if state.generation_trial_ids.len() < lambda {
return;
}
let trial_ids: Vec<u64> = state
.generation_trial_ids
.iter()
.take(lambda)
.copied()
.collect();
let history_ids: HashMap<u64, f64> = history.iter().map(|t| (t.id, t.value)).collect();
let all_completed = trial_ids.iter().all(|id| history_ids.contains_key(id));
if !all_completed {
return;
}
if n_continuous == 0 {
state.candidates =
generate_pure_categorical_candidates(&mut state.rng, &state.dimensions, lambda);
state.assigned_count = 0;
state.generation_trial_ids.clear();
state.trial_progress.clear();
return;
}
let mut ranked: Vec<(&DVector<f64>, f64)> = trial_ids
.iter()
.filter_map(|id| {
let progress = state.trial_progress.get(id)?;
let value = *history_ids.get(id)?;
Some((&state.candidates[progress.candidate_idx].x, value))
})
.collect();
ranked.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(core::cmp::Ordering::Equal));
let ranked_xs: Vec<&DVector<f64>> = ranked.iter().map(|(x, _)| *x).collect();
let Phase::Active(algo) = &mut state.phase else {
return;
};
algo.update(&ranked_xs);
let new_candidates = algo.generate_candidates(&mut state.rng, &state.dimensions);
state.candidates = new_candidates;
state.assigned_count = 0;
state.generation_trial_ids.clear();
state.trial_progress.clear();
}