use crate::core::constraint::BoxConstrained;
use crate::core::math::{NormSquared, SampleUniformBox, ScaledAdd, VectorLen};
use crate::core::problem::CostFunction;
use crate::core::rng::{ChaCha8Rng, Rng, RngExt, SeedableRng};
use crate::core::solver::Solver;
use crate::core::state::BasicPopulationState;
use crate::core::termination::TerminationReason;
use crate::solver::cma_es::sort_population_ascending;
pub struct Ssga {
pop_size: usize,
blx_alpha: f64,
nam_pool: usize,
mutation_prob: f64,
bga_range_fraction: f64,
offspring_per_step: usize,
seed: u64,
rng: Option<ChaCha8Rng>,
}
impl Ssga {
pub fn new(seed: u64) -> Self {
Self {
pop_size: 60,
blx_alpha: 0.5,
nam_pool: 4,
mutation_prob: 0.125,
bga_range_fraction: 0.1,
offspring_per_step: 2,
seed,
rng: None,
}
}
pub fn with_pop_size(mut self, pop_size: usize) -> Self {
assert!(
pop_size >= self.nam_pool,
"Ssga requires pop_size >= nam_pool (got pop_size={}, nam_pool={})",
pop_size,
self.nam_pool
);
self.pop_size = pop_size;
self
}
pub fn with_blx_alpha(mut self, alpha: f64) -> Self {
assert!(alpha >= 0.0, "Ssga requires blx_alpha >= 0, got {}", alpha);
self.blx_alpha = alpha;
self
}
pub fn with_nam_pool(mut self, pool: usize) -> Self {
assert!(pool >= 2, "Ssga requires nam_pool >= 2, got {}", pool);
self.nam_pool = pool;
self
}
pub fn with_mutation_prob(mut self, p: f64) -> Self {
assert!(
(0.0..=1.0).contains(&p),
"Ssga requires mutation_prob in [0, 1], got {}",
p
);
self.mutation_prob = p;
self
}
pub fn with_bga_range_fraction(mut self, f: f64) -> Self {
assert!(f > 0.0, "Ssga requires bga_range_fraction > 0, got {}", f);
self.bga_range_fraction = f;
self
}
pub fn with_offspring_per_step(mut self, n: usize) -> Self {
assert!(n >= 1, "Ssga requires offspring_per_step >= 1");
self.offspring_per_step = n;
self
}
}
pub(crate) fn blx_alpha_crossover<V, R>(
p1: &V,
p2: &V,
alpha: f64,
lower: &V,
upper: &V,
rng: &mut R,
) -> V
where
V: VectorLen
+ Clone
+ SampleUniformBox
+ std::ops::Index<usize, Output = f64>
+ std::ops::IndexMut<usize, Output = f64>,
R: Rng + ?Sized,
{
let n = p1.vec_len();
let mut blx_lo = lower.clone();
let mut blx_hi = upper.clone();
for i in 0..n {
let a = p1[i];
let b = p2[i];
let (mn, mx) = if a < b { (a, b) } else { (b, a) };
let d = mx - mn;
let lo = (mn - alpha * d).max(lower[i]);
let hi = (mx + alpha * d).min(upper[i]);
if hi >= lo {
blx_lo[i] = lo;
blx_hi[i] = hi;
} else {
let mid = 0.5 * (lower[i] + upper[i]);
blx_lo[i] = mid;
blx_hi[i] = mid;
}
}
V::sample_uniform_box(&blx_lo, &blx_hi, rng)
}
pub(crate) fn nam_select<V, R>(pop: &[V], pool: usize, rng: &mut R) -> (usize, usize)
where
V: Clone + ScaledAdd<f64> + NormSquared,
R: Rng + ?Sized,
{
debug_assert!(pop.len() >= 2, "nam_select needs at least 2 individuals");
debug_assert!(pool >= 2, "nam_select needs nam_pool >= 2");
let n = pop.len();
let p1 = rng.random_range(0..n);
let first = rng.random_range(0..n);
let mut best = first;
let mut best_d_sq = {
let mut diff = pop[p1].clone();
diff.scaled_add(-1.0, &pop[first]);
diff.norm_squared()
};
for _ in 1..(pool - 1) {
let c = rng.random_range(0..n);
let mut diff = pop[p1].clone();
diff.scaled_add(-1.0, &pop[c]);
let d_sq = diff.norm_squared();
if d_sq > best_d_sq {
best_d_sq = d_sq;
best = c;
}
}
(p1, best)
}
pub(crate) fn bga_mutate_in_place<V, R>(
child: &mut V,
lower: &V,
upper: &V,
prob: f64,
range_fraction: f64,
rng: &mut R,
) where
V: VectorLen + std::ops::Index<usize, Output = f64> + std::ops::IndexMut<usize, Output = f64>,
R: Rng + ?Sized,
{
let n = child.vec_len();
for i in 0..n {
if rng.random::<f64>() >= prob {
continue;
}
let sign = if rng.random::<f64>() < 0.5 { 1.0 } else { -1.0 };
let rang = range_fraction * (upper[i] - lower[i]);
let mut s = 0.0;
for k in 0..16 {
if rng.random::<f64>() < 1.0 / 16.0 {
s += (-(k as f64)).exp2();
}
}
let v = child[i] + sign * rang * s;
child[i] = v.clamp(lower[i], upper[i]);
}
}
pub(crate) fn replace_worst_if_better<V>(
pop: &mut [V],
costs: &mut [f64],
child: V,
c_child: f64,
) -> Option<usize> {
let mut worst_idx = 0;
let mut worst_cost = costs[0];
for (i, &c) in costs.iter().enumerate().skip(1) {
let is_worse = if worst_cost.is_nan() {
false
} else if c.is_nan() {
true
} else {
c > worst_cost
};
if is_worse {
worst_idx = i;
worst_cost = c;
}
}
let replace = if worst_cost.is_nan() {
!c_child.is_nan()
} else {
c_child < worst_cost
};
if replace {
pop[worst_idx] = child;
costs[worst_idx] = c_child;
Some(worst_idx)
} else {
None
}
}
impl<P, V> Solver<P, BasicPopulationState<V>> for Ssga
where
P: CostFunction<Param = V, Output = f64> + BoxConstrained<Param = V>,
V: VectorLen
+ Clone
+ SampleUniformBox
+ ScaledAdd<f64>
+ NormSquared
+ std::ops::Index<usize, Output = f64>
+ std::ops::IndexMut<usize, Output = f64>,
{
fn init(&mut self, problem: &P, mut state: BasicPopulationState<V>) -> BasicPopulationState<V> {
let lo = problem.lower();
let hi = problem.upper();
let mut rng = ChaCha8Rng::seed_from_u64(self.seed);
state.candidates.clear();
state.costs.clear();
for _ in 0..self.pop_size {
let x = V::sample_uniform_box(lo, hi, &mut rng);
let c = problem.cost(&x);
state.candidates.push(x);
state.costs.push(c);
}
state.cost_evals += self.pop_size as u64;
sort_population_ascending(&mut state.candidates, &mut state.costs);
self.rng = Some(rng);
state
}
fn next_iter(
&mut self,
problem: &P,
mut state: BasicPopulationState<V>,
) -> (BasicPopulationState<V>, Option<TerminationReason>) {
let rng = self
.rng
.as_mut()
.expect("Ssga::init must run before next_iter");
let lo = problem.lower();
let hi = problem.upper();
for _ in 0..self.offspring_per_step {
let (p1, p2) = nam_select(&state.candidates, self.nam_pool, rng);
let mut child = blx_alpha_crossover(
&state.candidates[p1],
&state.candidates[p2],
self.blx_alpha,
lo,
hi,
rng,
);
bga_mutate_in_place(
&mut child,
lo,
hi,
self.mutation_prob,
self.bga_range_fraction,
rng,
);
let c_child = problem.cost(&child);
state.cost_evals += 1;
replace_worst_if_better(&mut state.candidates, &mut state.costs, child, c_child);
}
sort_population_ascending(&mut state.candidates, &mut state.costs);
(state, None)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand_chacha::ChaCha8Rng;
#[test]
fn blx_alpha_samples_lie_in_expected_interval_unclipped() {
let mut rng = ChaCha8Rng::seed_from_u64(42);
let p1 = vec![0.0];
let p2 = vec![1.0];
let lo = vec![-10.0];
let hi = vec![10.0];
let mut min_seen = f64::INFINITY;
let mut max_seen = f64::NEG_INFINITY;
for _ in 0..10_000 {
let c = blx_alpha_crossover(&p1, &p2, 0.5, &lo, &hi, &mut rng);
assert!(
(-0.5..=1.5).contains(&c[0]),
"child {} out of [-0.5, 1.5]",
c[0]
);
min_seen = min_seen.min(c[0]);
max_seen = max_seen.max(c[0]);
}
assert!(min_seen < -0.4, "min {} not near -0.5", min_seen);
assert!(max_seen > 1.4, "max {} not near 1.5", max_seen);
}
#[test]
fn blx_alpha_clips_to_global_bounds() {
let mut rng = ChaCha8Rng::seed_from_u64(7);
let p1 = vec![0.0];
let p2 = vec![1.0];
let lo = vec![0.0];
let hi = vec![1.0];
for _ in 0..2_000 {
let c = blx_alpha_crossover(&p1, &p2, 0.5, &lo, &hi, &mut rng);
assert!(
(0.0..=1.0).contains(&c[0]),
"child {} outside global box [0, 1]",
c[0]
);
}
}
#[test]
fn nam_picks_farthest_in_pool_deterministically_when_pool_covers_population() {
let pop = vec![vec![0.0], vec![1.0], vec![2.0], vec![10.0]];
let mut hits_farthest = 0;
for seed in 0..200 {
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let (p1, p2) = nam_select(&pop, 4, &mut rng);
if p1 != 3 && p2 == 3 {
hits_farthest += 1;
}
}
assert!(
hits_farthest > 80,
"NAM rarely picks the farthest: {} / 200",
hits_farthest
);
}
#[test]
fn replace_worst_only_when_strictly_better() {
let mut pop = vec![vec![0.0], vec![1.0], vec![2.0]];
let mut costs = vec![0.0, 1.0, 2.0];
let r = replace_worst_if_better(&mut pop, &mut costs, vec![5.0], 2.0);
assert!(r.is_none());
assert_eq!(pop[2], vec![2.0]);
let r = replace_worst_if_better(&mut pop, &mut costs, vec![5.0], 1.5);
assert_eq!(r, Some(2));
assert_eq!(pop[2], vec![5.0]);
assert_eq!(costs[2], 1.5);
}
#[test]
fn replace_worst_treats_nan_as_worse_than_any_finite() {
let mut pop = vec![vec![0.0], vec![1.0]];
let mut costs = vec![0.0, f64::NAN];
let r = replace_worst_if_better(&mut pop, &mut costs, vec![5.0], 100.0);
assert_eq!(r, Some(1));
assert_eq!(costs[1], 100.0);
}
#[test]
fn bga_mutation_prob_zero_leaves_unchanged() {
let mut rng = ChaCha8Rng::seed_from_u64(99);
let mut child = vec![0.5, 0.5, 0.5];
let lo = vec![0.0, 0.0, 0.0];
let hi = vec![1.0, 1.0, 1.0];
bga_mutate_in_place(&mut child, &lo, &hi, 0.0, 0.1, &mut rng);
assert_eq!(child, vec![0.5, 0.5, 0.5]);
}
#[test]
fn bga_mutation_prob_one_respects_bounds() {
let mut rng = ChaCha8Rng::seed_from_u64(99);
let lo = vec![0.0; 8];
let hi = vec![1.0; 8];
for _ in 0..200 {
let mut child = vec![0.5; 8];
bga_mutate_in_place(&mut child, &lo, &hi, 1.0, 0.1, &mut rng);
for (i, &v) in child.iter().enumerate() {
assert!(
v >= lo[i] && v <= hi[i],
"child[{}] = {} outside [0, 1]",
i,
v
);
}
}
}
}