use std::marker::PhantomData;
use crate::core::constraint::BoxConstrained;
use crate::core::executor::run_loop;
use crate::core::math::{
ComponentMulAssign, MatTransposeVec, MatVec, MatrixIdentity, NormSquared, RankOneUpdate,
SampleStandardNormal, SampleUniformBox, ScaleInPlace, ScaledAdd, SymmetricEigen, VectorLen,
};
use crate::core::problem::CostFunction;
use crate::core::rng::{ChaCha8Rng, Rng, SeedableRng};
use crate::core::solver::Solver;
use crate::core::state::{BasicPopulationState, PopulationState, State};
use crate::core::termination::{MaxCostEvals, TerminationCriterion, TerminationReason};
use crate::solver::cma_es::CmaEs;
use crate::solver::ssga::{
bga_mutate_in_place, blx_alpha_crossover, nam_select, replace_worst_if_better,
};
type ChainSlot<V, M> = (CmaEs<V, M>, BasicPopulationState<V>);
pub struct MaLsChState<V, M> {
pub(crate) candidates: Vec<V>,
pub(crate) costs: Vec<f64>,
pub(crate) cma_chains: Vec<Option<ChainSlot<V, M>>>,
pub(crate) last_ls_cost: Vec<f64>,
pub(crate) ls_application_count: Vec<u32>,
iter: u64,
cost_evals: u64,
}
impl<V, M> MaLsChState<V, M> {
pub fn ls_application_count(&self, i: usize) -> u32 {
self.ls_application_count[i]
}
}
impl<V, M> State for MaLsChState<V, M> {
type Param = V;
type Float = f64;
fn iter(&self) -> u64 {
self.iter
}
fn increment_iter(&mut self) {
self.iter += 1;
}
fn cost_evals(&self) -> u64 {
self.cost_evals
}
fn increment_cost_evals(&mut self, by: u64) {
self.cost_evals += by;
}
fn param(&self) -> &V {
&self.candidates[0]
}
fn cost(&self) -> f64 {
self.costs[0]
}
}
impl<V, M> PopulationState for MaLsChState<V, M> {
fn candidates(&self) -> &[V] {
&self.candidates
}
fn costs(&self) -> &[f64] {
&self.costs
}
}
impl<V, M> MaLsChState<V, M> {
pub fn new() -> Self {
Self {
candidates: Vec::new(),
costs: Vec::new(),
cma_chains: Vec::new(),
last_ls_cost: Vec::new(),
ls_application_count: Vec::new(),
iter: 0,
cost_evals: 0,
}
}
}
impl<V, M> Default for MaLsChState<V, M> {
fn default() -> Self {
Self::new()
}
}
pub struct MaLsChCma<V, M> {
pop_size: usize,
blx_alpha: f64,
nam_pool: usize,
mutation_prob: f64,
bga_range_fraction: f64,
ls_intensity: u64,
ls_improvement_threshold: f64,
nfrec: Option<u64>,
inner_lambda: Option<usize>,
initial_sigma_fallback: f64,
seed: u64,
rng: Option<ChaCha8Rng>,
_phantom: PhantomData<(V, M)>,
}
impl<V, M> MaLsChCma<V, M> {
pub fn new(seed: u64) -> Self {
Self {
pop_size: 60,
blx_alpha: 0.5,
nam_pool: 4,
mutation_prob: 0.125,
bga_range_fraction: 0.1,
ls_intensity: 300,
ls_improvement_threshold: 1e-8,
nfrec: None,
inner_lambda: None,
initial_sigma_fallback: 1.0,
seed,
rng: None,
_phantom: PhantomData,
}
}
pub fn with_pop_size(mut self, pop_size: usize) -> Self {
assert!(
pop_size >= self.nam_pool,
"MaLsChCma requires pop_size >= nam_pool (got pop_size={}, nam_pool={})",
pop_size,
self.nam_pool
);
self.pop_size = pop_size;
self
}
pub fn with_blx_alpha(mut self, alpha: f64) -> Self {
assert!(alpha >= 0.0, "blx_alpha must be >= 0, got {}", alpha);
self.blx_alpha = alpha;
self
}
pub fn with_nam_pool(mut self, pool: usize) -> Self {
assert!(pool >= 2, "nam_pool must be >= 2, got {}", pool);
self.nam_pool = pool;
self
}
pub fn with_mutation_prob(mut self, p: f64) -> Self {
assert!(
(0.0..=1.0).contains(&p),
"mutation_prob must be in [0, 1], got {}",
p
);
self.mutation_prob = p;
self
}
pub fn with_bga_range_fraction(mut self, f: f64) -> Self {
assert!(f > 0.0, "bga_range_fraction must be > 0, got {}", f);
self.bga_range_fraction = f;
self
}
pub fn with_ls_intensity(mut self, istr: u64) -> Self {
assert!(istr >= 1, "ls_intensity must be >= 1, got {}", istr);
self.ls_intensity = istr;
self
}
pub fn with_ls_improvement_threshold(mut self, delta: f64) -> Self {
assert!(
delta >= 0.0,
"ls_improvement_threshold must be >= 0, got {}",
delta
);
self.ls_improvement_threshold = delta;
self
}
pub fn with_nfrec(mut self, n: u64) -> Self {
assert!(n >= 1, "nfrec must be >= 1, got {}", n);
self.nfrec = Some(n);
self
}
pub fn with_inner_lambda(mut self, lambda: usize) -> Self {
assert!(lambda >= 4, "inner_lambda must be >= 4, got {}", lambda);
self.inner_lambda = Some(lambda);
self
}
pub fn with_initial_sigma_fallback(mut self, sigma: f64) -> Self {
assert!(
sigma > 0.0,
"initial_sigma_fallback must be > 0, got {}",
sigma
);
self.initial_sigma_fallback = sigma;
self
}
}
fn sigma_init_for<V>(candidates: &[V], i: usize) -> Option<f64>
where
V: Clone + ScaledAdd<f64> + NormSquared,
{
if candidates.len() < 2 {
return None;
}
let mut best_sq = f64::INFINITY;
for (j, x) in candidates.iter().enumerate() {
if j == i {
continue;
}
let mut diff = candidates[i].clone();
diff.scaled_add(-1.0, x);
let d_sq = diff.norm_squared();
if d_sq < best_sq {
best_sq = d_sq;
}
}
Some(0.5 * best_sq.sqrt())
}
impl<P, V, M> Solver<P, MaLsChState<V, M>> for MaLsChCma<V, M>
where
P: CostFunction<Param = V, Output = f64> + BoxConstrained<Param = V>,
V: VectorLen
+ Clone
+ SampleUniformBox
+ SampleStandardNormal
+ ScaledAdd<f64>
+ ScaleInPlace
+ ComponentMulAssign
+ NormSquared
+ std::ops::Index<usize, Output = f64>
+ std::ops::IndexMut<usize, Output = f64>,
M: MatrixIdentity
+ MatVec<V>
+ MatTransposeVec<V>
+ ScaleInPlace
+ RankOneUpdate<V>
+ SymmetricEigen<V>
+ Clone,
{
fn init(&mut self, problem: &P, mut state: MaLsChState<V, M>) -> MaLsChState<V, M> {
let lo = problem.lower();
let hi = problem.upper();
let mut rng = ChaCha8Rng::seed_from_u64(self.seed);
state.candidates.clear();
state.costs.clear();
state.cma_chains.clear();
state.last_ls_cost.clear();
state.ls_application_count.clear();
for _ in 0..self.pop_size {
let x = V::sample_uniform_box(lo, hi, &mut rng);
let c = problem.cost(&x);
state.candidates.push(x);
state.costs.push(c);
state.cma_chains.push(None);
state.last_ls_cost.push(f64::INFINITY);
state.ls_application_count.push(0);
}
state.cost_evals += self.pop_size as u64;
sort_parallel_arrays(&mut state);
self.rng = Some(rng);
state
}
fn next_iter(
&mut self,
problem: &P,
mut state: MaLsChState<V, M>,
) -> (MaLsChState<V, M>, Option<TerminationReason>) {
let rng = self
.rng
.as_mut()
.expect("MaLsChCma::init must run before next_iter");
let lo = problem.lower();
let hi = problem.upper();
let nfrec = self.nfrec.unwrap_or(self.ls_intensity);
let evals_at_phase_start = state.cost_evals;
while state.cost_evals - evals_at_phase_start < nfrec {
let (p1, p2) = nam_select(&state.candidates, self.nam_pool, rng);
let mut child = blx_alpha_crossover(
&state.candidates[p1],
&state.candidates[p2],
self.blx_alpha,
lo,
hi,
rng,
);
bga_mutate_in_place(
&mut child,
lo,
hi,
self.mutation_prob,
self.bga_range_fraction,
rng,
);
let c_child = problem.cost(&child);
state.cost_evals += 1;
if let Some(replaced_idx) =
replace_worst_if_better(&mut state.candidates, &mut state.costs, child, c_child)
{
state.cma_chains[replaced_idx] = None;
state.last_ls_cost[replaced_idx] = f64::INFINITY;
state.ls_application_count[replaced_idx] = 0;
}
}
sort_parallel_arrays(&mut state);
let mut c_ls: Option<usize> = None;
let mut best_cost_in_s_ls = f64::INFINITY;
for i in 0..state.candidates.len() {
let eligible = state.cma_chains[i].is_none()
|| (state.last_ls_cost[i] - state.costs[i] >= self.ls_improvement_threshold);
if eligible && state.costs[i] < best_cost_in_s_ls {
best_cost_in_s_ls = state.costs[i];
c_ls = Some(i);
}
}
let c_ls = c_ls.unwrap_or(0);
let (mut cma, inner_state) = match state.cma_chains[c_ls].take() {
Some((cma, inner_state)) => {
let mut s = inner_state;
s.cost_evals = 0;
s.iter = 0;
(cma, s)
}
None => {
let n = state.candidates[c_ls].vec_len();
let sigma_init = sigma_init_for(&state.candidates, c_ls)
.filter(|s| *s > 0.0)
.unwrap_or(self.initial_sigma_fallback);
let derived_seed = rng.random::<u64>();
let mut cma =
CmaEs::<V, M>::new(state.candidates[c_ls].clone(), sigma_init, derived_seed);
if let Some(lam) = self.inner_lambda {
cma = cma.with_lambda(lam);
}
let lambda = self
.inner_lambda
.unwrap_or(CmaEs::<V, M>::default_lambda(n));
let inner_state = BasicPopulationState::<V>::with_size(lambda);
(cma, inner_state)
}
};
let mut criteria: Vec<Box<dyn TerminationCriterion<BasicPopulationState<V>>>> =
vec![Box::new(MaxCostEvals(self.ls_intensity))];
let inner_result = run_loop(problem, inner_state, &mut cma, &mut criteria, u64::MAX);
state.cost_evals += inner_result.state.cost_evals();
if inner_result.reason.is_failure() {
return (state, Some(inner_result.reason));
}
let new_cost = inner_result.cost();
let new_param = inner_result.param().clone();
if new_cost < state.costs[c_ls] {
state.candidates[c_ls] = new_param;
state.costs[c_ls] = new_cost;
}
state.last_ls_cost[c_ls] = state.costs[c_ls];
state.ls_application_count[c_ls] = state.ls_application_count[c_ls].saturating_add(1);
state.cma_chains[c_ls] = Some((cma, inner_result.state));
sort_parallel_arrays(&mut state);
(state, None)
}
}
fn sort_parallel_arrays<V, M>(state: &mut MaLsChState<V, M>) {
let n = state.candidates.len();
debug_assert_eq!(n, state.costs.len());
debug_assert_eq!(n, state.cma_chains.len());
debug_assert_eq!(n, state.last_ls_cost.len());
debug_assert_eq!(n, state.ls_application_count.len());
let mut idx: Vec<usize> = (0..n).collect();
idx.sort_by(|&i, &j| {
state.costs[i]
.partial_cmp(&state.costs[j])
.unwrap_or(std::cmp::Ordering::Equal)
});
apply_permutation::<V>(&mut state.candidates, &idx);
apply_permutation::<f64>(&mut state.costs, &idx);
apply_permutation::<Option<ChainSlot<V, M>>>(&mut state.cma_chains, &idx);
apply_permutation::<f64>(&mut state.last_ls_cost, &idx);
apply_permutation::<u32>(&mut state.ls_application_count, &idx);
}
fn apply_permutation<T>(slice: &mut [T], idx: &[usize]) {
let mut visited = vec![false; slice.len()];
for start in 0..slice.len() {
if visited[start] || idx[start] == start {
visited[start] = true;
continue;
}
let mut current = start;
loop {
let next = idx[current];
visited[current] = true;
if next == start {
break;
}
slice.swap(current, next);
current = next;
}
}
}