use rand::Rng as _;
use crate::core::candidate::Candidate;
use crate::core::evaluation::Evaluation;
use crate::core::objective::ObjectiveSpace;
use crate::core::population::Population;
use crate::core::problem::Problem;
use crate::core::result::OptimizationResult;
use crate::core::rng::rng_from_seed;
use crate::metrics::hypervolume::hypervolume_nd_from_evaluations;
use crate::pareto::front::{best_candidate, pareto_front};
use crate::pareto::sort::non_dominated_sort;
use crate::traits::{Initializer, Optimizer, Variation};
#[derive(Debug, Clone)]
pub struct SmsEmoaConfig {
pub population_size: usize,
pub generations: usize,
pub reference_point: Vec<f64>,
pub seed: u64,
}
impl Default for SmsEmoaConfig {
fn default() -> Self {
Self {
population_size: 100,
generations: 1_000,
reference_point: vec![11.0, 11.0],
seed: 42,
}
}
}
#[derive(Debug, Clone)]
pub struct SmsEmoa<I, V> {
pub config: SmsEmoaConfig,
pub initializer: I,
pub variation: V,
}
impl<I, V> SmsEmoa<I, V> {
pub fn new(config: SmsEmoaConfig, initializer: I, variation: V) -> Self {
Self {
config,
initializer,
variation,
}
}
}
impl<P, I, V> Optimizer<P> for SmsEmoa<I, V>
where
P: Problem + Sync,
P::Decision: Send,
I: Initializer<P::Decision>,
V: Variation<P::Decision>,
{
fn run(&mut self, problem: &P) -> OptimizationResult<P::Decision> {
assert!(
self.config.population_size > 0,
"SmsEmoa population_size must be > 0"
);
let n = self.config.population_size;
let objectives = problem.objectives();
assert_eq!(
self.config.reference_point.len(),
objectives.len(),
"SmsEmoa reference_point.len() must equal number of objectives",
);
let reference = self.config.reference_point.clone();
let mut rng = rng_from_seed(self.config.seed);
let initial_decisions = self.initializer.initialize(n, &mut rng);
let mut population: Vec<Candidate<P::Decision>> = initial_decisions
.into_iter()
.map(|d| {
let e = problem.evaluate(&d);
Candidate::new(d, e)
})
.collect();
let mut evaluations = population.len();
for _ in 0..self.config.generations {
let p1 = rng.random_range(0..population.len());
let p2 = rng.random_range(0..population.len());
let parents = vec![
population[p1].decision.clone(),
population[p2].decision.clone(),
];
let children = self.variation.vary(&parents, &mut rng);
assert!(
!children.is_empty(),
"SmsEmoa variation returned no children"
);
let child_decision = children.into_iter().next().unwrap();
let child_eval = problem.evaluate(&child_decision);
evaluations += 1;
let child = Candidate::new(child_decision, child_eval);
population.push(child);
let drop_idx = pick_drop_index(&population, &objectives, &reference);
population.swap_remove(drop_idx);
}
let front = pareto_front(&population, &objectives);
let best = best_candidate(&population, &objectives);
OptimizationResult::new(
Population::new(population),
front,
best,
evaluations,
self.config.generations,
)
}
}
fn pick_drop_index<D>(
pool: &[Candidate<D>],
objectives: &ObjectiveSpace,
reference: &[f64],
) -> usize {
let fronts = non_dominated_sort(pool, objectives);
let worst_front = fronts
.last()
.expect("non_dominated_sort must return at least one front for non-empty pool");
if worst_front.len() == 1 {
return worst_front[0];
}
let evals: Vec<&Evaluation> = worst_front.iter().map(|&i| &pool[i].evaluation).collect();
let total_hv = hypervolume_nd_from_evaluations(&evals, objectives, reference);
let mut worst_idx_in_front = 0;
let mut min_contrib = f64::INFINITY;
for k in 0..worst_front.len() {
let mut without: Vec<&Evaluation> = Vec::with_capacity(worst_front.len() - 1);
for (j, &gi) in worst_front.iter().enumerate() {
if j != k {
without.push(&pool[gi].evaluation);
}
}
let hv_without = hypervolume_nd_from_evaluations(&without, objectives, reference);
let contrib = total_hv - hv_without;
if contrib < min_contrib {
min_contrib = contrib;
worst_idx_in_front = k;
}
}
worst_front[worst_idx_in_front]
}
#[cfg(test)]
mod tests {
use super::*;
use crate::operators::{
CompositeVariation, PolynomialMutation, RealBounds, SimulatedBinaryCrossover,
};
use crate::tests_support::SchafferN1;
fn make_optimizer(
seed: u64,
) -> SmsEmoa<RealBounds, CompositeVariation<SimulatedBinaryCrossover, PolynomialMutation>> {
let bounds = vec![(-5.0, 5.0)];
let initializer = RealBounds::new(bounds.clone());
let variation = CompositeVariation {
crossover: SimulatedBinaryCrossover::new(bounds.clone(), 15.0, 0.5),
mutation: PolynomialMutation::new(bounds, 20.0, 1.0),
};
SmsEmoa::new(
SmsEmoaConfig {
population_size: 20,
generations: 100,
reference_point: vec![30.0, 30.0],
seed,
},
initializer,
variation,
)
}
#[test]
fn produces_pareto_front() {
let mut opt = make_optimizer(1);
let r = opt.run(&SchafferN1);
assert_eq!(r.population.len(), 20);
assert!(!r.pareto_front.is_empty());
}
#[test]
fn deterministic_with_same_seed() {
let mut a = make_optimizer(99);
let mut b = make_optimizer(99);
let ra = a.run(&SchafferN1);
let rb = b.run(&SchafferN1);
let oa: Vec<Vec<f64>> = ra
.pareto_front
.iter()
.map(|c| c.evaluation.objectives.clone())
.collect();
let ob: Vec<Vec<f64>> = rb
.pareto_front
.iter()
.map(|c| c.evaluation.objectives.clone())
.collect();
assert_eq!(oa, ob);
}
#[test]
#[should_panic(expected = "population_size must be > 0")]
fn zero_population_size_panics() {
let bounds = vec![(0.0, 1.0)];
let initializer = RealBounds::new(bounds.clone());
let variation = CompositeVariation {
crossover: SimulatedBinaryCrossover::new(bounds.clone(), 15.0, 0.5),
mutation: PolynomialMutation::new(bounds, 20.0, 1.0),
};
let mut opt = SmsEmoa::new(
SmsEmoaConfig {
population_size: 0,
generations: 1,
reference_point: vec![1.0, 1.0],
seed: 0,
},
initializer,
variation,
);
let _ = opt.run(&SchafferN1);
}
#[test]
#[should_panic(expected = "reference_point.len() must equal number of objectives")]
fn dim_mismatch_panics() {
let bounds = vec![(0.0, 1.0)];
let initializer = RealBounds::new(bounds.clone());
let variation = CompositeVariation {
crossover: SimulatedBinaryCrossover::new(bounds.clone(), 15.0, 0.5),
mutation: PolynomialMutation::new(bounds, 20.0, 1.0),
};
let mut opt = SmsEmoa::new(
SmsEmoaConfig {
population_size: 4,
generations: 1,
reference_point: vec![1.0, 1.0, 1.0],
seed: 0,
},
initializer,
variation,
);
let _ = opt.run(&SchafferN1);
}
}