use parking_lot::Mutex;
use super::genetic::{
self, Candidate, EvolutionaryState, Phase, advance_generation, auto_divisions,
collect_evaluated_generation, crossover, das_dennis, extract_trial_params,
generate_random_candidates, mutate, sample_from_candidate, sample_random,
};
use crate::distribution::Distribution;
use crate::multi_objective::MultiObjectiveTrial;
use crate::param::ParamValue;
use crate::pareto;
use crate::types::Direction;
pub struct Nsga3Sampler {
state: Mutex<Nsga3State>,
}
impl Nsga3Sampler {
#[must_use]
pub fn new() -> Self {
Self {
state: Mutex::new(Nsga3State::new(Nsga3Config::default(), None)),
}
}
#[must_use]
pub fn with_seed(seed: u64) -> Self {
Self {
state: Mutex::new(Nsga3State::new(Nsga3Config::default(), Some(seed))),
}
}
#[must_use]
pub fn builder() -> Nsga3SamplerBuilder {
Nsga3SamplerBuilder::default()
}
}
impl Default for Nsga3Sampler {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Default)]
pub struct Nsga3SamplerBuilder {
population_size: Option<usize>,
n_divisions: Option<usize>,
crossover_prob: Option<f64>,
crossover_eta: Option<f64>,
mutation_eta: Option<f64>,
seed: Option<u64>,
}
impl Nsga3SamplerBuilder {
#[must_use]
pub fn population_size(mut self, size: usize) -> Self {
self.population_size = Some(size);
self
}
#[must_use]
pub fn n_divisions(mut self, h: usize) -> Self {
self.n_divisions = Some(h);
self
}
#[must_use]
pub fn crossover_prob(mut self, prob: f64) -> Self {
self.crossover_prob = Some(prob);
self
}
#[must_use]
pub fn crossover_eta(mut self, eta: f64) -> Self {
self.crossover_eta = Some(eta);
self
}
#[must_use]
pub fn mutation_eta(mut self, eta: f64) -> Self {
self.mutation_eta = Some(eta);
self
}
#[must_use]
pub fn seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
#[must_use]
pub fn build(self) -> Nsga3Sampler {
let config = Nsga3Config {
user_population_size: self.population_size,
n_divisions: self.n_divisions,
crossover_prob: self.crossover_prob.unwrap_or(1.0),
crossover_eta: self.crossover_eta.unwrap_or(30.0),
mutation_eta: self.mutation_eta.unwrap_or(20.0),
};
Nsga3Sampler {
state: Mutex::new(Nsga3State::new(config, self.seed)),
}
}
}
#[derive(Clone, Debug)]
struct Nsga3Config {
user_population_size: Option<usize>,
n_divisions: Option<usize>,
crossover_prob: f64,
crossover_eta: f64,
mutation_eta: f64,
}
impl Default for Nsga3Config {
fn default() -> Self {
Self {
user_population_size: None,
n_divisions: None,
crossover_prob: 1.0,
crossover_eta: 30.0,
mutation_eta: 20.0,
}
}
}
struct Nsga3State {
evo: EvolutionaryState,
config: Nsga3Config,
reference_points: Vec<Vec<f64>>,
ideal_point: Vec<f64>,
initialized: bool,
}
impl Nsga3State {
fn new(config: Nsga3Config, seed: Option<u64>) -> Self {
Self {
evo: EvolutionaryState::new(seed),
config,
reference_points: Vec::new(),
ideal_point: Vec::new(),
initialized: false,
}
}
}
impl crate::multi_objective::MultiObjectiveSampler for Nsga3Sampler {
fn sample(
&self,
distribution: &Distribution,
trial_id: u64,
history: &[MultiObjectiveTrial],
directions: &[Direction],
) -> ParamValue {
let mut state = self.state.lock();
match &state.evo.phase {
Phase::Discovery => {
if let Some(value) =
genetic::sample_discovery(&mut state.evo, distribution, trial_id)
{
return value;
}
initialize_nsga3(&mut state, directions);
generate_random_candidates(&mut state.evo);
sample_from_candidate(&mut state.evo, trial_id)
}
Phase::Active => {
maybe_generate_new_generation(&mut state, history, directions);
sample_from_candidate(&mut state.evo, trial_id)
}
}
}
}
fn initialize_nsga3(state: &mut Nsga3State, directions: &[Direction]) {
let n_obj = directions.len();
let divisions = state
.config
.n_divisions
.unwrap_or_else(|| auto_divisions(n_obj, state.config.user_population_size.unwrap_or(100)));
state.reference_points = das_dennis(n_obj, divisions);
let n_ref = state.reference_points.len();
let pop_size = state.config.user_population_size.unwrap_or(n_ref).max(4);
state.evo.population_size = pop_size;
state.evo.phase = Phase::Active;
state.ideal_point = vec![f64::INFINITY; n_obj];
state.initialized = true;
}
fn maybe_generate_new_generation(
state: &mut Nsga3State,
history: &[MultiObjectiveTrial],
directions: &[Direction],
) {
if state.evo.candidates.is_empty() {
generate_random_candidates(&mut state.evo);
return;
}
if let Some(evaluated) = collect_evaluated_generation(&state.evo, history) {
let offspring = nsga3_generate_offspring(state, &evaluated, directions);
advance_generation(&mut state.evo, offspring);
}
}
fn to_minimize_space(values: &[f64], directions: &[Direction]) -> Vec<f64> {
values
.iter()
.zip(directions)
.map(|(&v, d)| match d {
Direction::Minimize => v,
Direction::Maximize => -v,
})
.collect()
}
fn update_ideal_point(ideal: &mut [f64], normalized_values: &[Vec<f64>]) {
for vals in normalized_values {
for (i, &v) in vals.iter().enumerate() {
if v < ideal[i] {
ideal[i] = v;
}
}
}
}
fn asf(point: &[f64], weight: &[f64], ideal: &[f64]) -> f64 {
point
.iter()
.zip(weight)
.zip(ideal)
.map(|((&p, &w), &z)| {
let w = if w < 1e-6 { 1e-6 } else { w };
(p - z) / w
})
.fold(f64::NEG_INFINITY, f64::max)
}
fn find_intercepts(normalized_values: &[Vec<f64>], ideal: &[f64]) -> Vec<f64> {
let n_obj = ideal.len();
let n = normalized_values.len();
if n == 0 || n_obj == 0 {
return vec![1.0; n_obj];
}
let mut extreme_indices = Vec::with_capacity(n_obj);
for obj in 0..n_obj {
let mut weight = vec![1e-6; n_obj];
weight[obj] = 1.0;
let mut best_idx = 0;
let mut best_asf = f64::INFINITY;
for (i, vals) in normalized_values.iter().enumerate() {
let a = asf(vals, &weight, ideal);
if a < best_asf {
best_asf = a;
best_idx = i;
}
}
extreme_indices.push(best_idx);
}
let mut intercepts = Vec::with_capacity(n_obj);
for obj in 0..n_obj {
let max_val = normalized_values
.iter()
.map(|v| v[obj])
.fold(f64::NEG_INFINITY, f64::max);
let intercept = max_val - ideal[obj];
intercepts.push(if intercept > 1e-10 { intercept } else { 1.0 });
}
intercepts
}
fn normalize_objectives(values: &[Vec<f64>], ideal: &[f64], intercepts: &[f64]) -> Vec<Vec<f64>> {
values
.iter()
.map(|v| {
v.iter()
.zip(ideal)
.zip(intercepts)
.map(|((&val, &z), &a)| {
let norm = if a > 1e-10 { a } else { 1.0 };
(val - z) / norm
})
.collect()
})
.collect()
}
fn perpendicular_distance(point: &[f64], reference: &[f64]) -> f64 {
let dot: f64 = point.iter().zip(reference).map(|(&p, &r)| p * r).sum();
let ref_norm_sq: f64 = reference.iter().map(|&r| r * r).sum();
if ref_norm_sq < 1e-30 {
return f64::INFINITY;
}
let proj_scalar = dot / ref_norm_sq;
let dist_sq: f64 = point
.iter()
.zip(reference)
.map(|(&p, &r)| {
let proj = proj_scalar * r;
(p - proj).powi(2)
})
.sum();
dist_sq.sqrt()
}
fn associate_to_reference_points(
normalized: &[Vec<f64>],
reference_points: &[Vec<f64>],
) -> Vec<(usize, f64)> {
normalized
.iter()
.map(|point| {
let mut best_ref = 0;
let mut best_dist = f64::INFINITY;
for (j, rp) in reference_points.iter().enumerate() {
let d = perpendicular_distance(point, rp);
if d < best_dist {
best_dist = d;
best_ref = j;
}
}
(best_ref, best_dist)
})
.collect()
}
fn niching_select(
rng: &mut fastrand::Rng,
associations: &[(usize, f64)],
already_selected: &[usize],
last_front: &[usize],
n_reference_points: usize,
remaining: usize,
) -> Vec<usize> {
let mut niche_count = vec![0_usize; n_reference_points];
for &idx in already_selected {
niche_count[associations[idx].0] += 1;
}
let mut ref_candidates: Vec<Vec<(usize, f64)>> = vec![Vec::new(); n_reference_points];
for &idx in last_front {
let (ref_idx, dist) = associations[idx];
ref_candidates[ref_idx].push((idx, dist));
}
let mut selected = Vec::with_capacity(remaining);
let mut excluded = vec![false; associations.len()];
for _ in 0..remaining {
let min_count = (0..n_reference_points)
.filter(|&j| ref_candidates[j].iter().any(|&(idx, _)| !excluded[idx]))
.map(|j| niche_count[j])
.min();
let Some(min_count) = min_count else {
break;
};
let min_refs: Vec<usize> = (0..n_reference_points)
.filter(|&j| {
niche_count[j] == min_count
&& ref_candidates[j].iter().any(|&(idx, _)| !excluded[idx])
})
.collect();
if min_refs.is_empty() {
break;
}
let chosen_ref = min_refs[rng.usize(0..min_refs.len())];
let available: Vec<(usize, f64)> = ref_candidates[chosen_ref]
.iter()
.filter(|&&(idx, _)| !excluded[idx])
.copied()
.collect();
if available.is_empty() {
continue;
}
let chosen_idx = if min_count == 0 {
available
.iter()
.min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(core::cmp::Ordering::Equal))
.unwrap()
.0
} else {
available[rng.usize(0..available.len())].0
};
selected.push(chosen_idx);
excluded[chosen_idx] = true;
niche_count[chosen_ref] += 1;
}
selected
}
fn nsga3_select(
state: &mut Nsga3State,
population: &[&MultiObjectiveTrial],
directions: &[Direction],
) -> (Vec<Vec<ParamValue>>, Vec<usize>) {
let pop_size = state.evo.population_size;
let n_obj = directions.len();
let min_values: Vec<Vec<f64>> = population
.iter()
.map(|t| to_minimize_space(&t.values, directions))
.collect();
let constraints: Vec<Vec<f64>> = population.iter().map(|t| t.constraints.clone()).collect();
let has_constraints = constraints.iter().any(|c| !c.is_empty());
let fronts = if has_constraints {
pareto::fast_non_dominated_sort_constrained(
&min_values,
&vec![Direction::Minimize; n_obj],
&constraints,
)
} else {
pareto::fast_non_dominated_sort(&min_values, &vec![Direction::Minimize; n_obj])
};
let mut selected: Vec<usize> = Vec::with_capacity(pop_size);
let mut last_front_idx = None;
for (fi, front) in fronts.iter().enumerate() {
if selected.len() + front.len() <= pop_size {
selected.extend_from_slice(front);
} else {
last_front_idx = Some(fi);
break;
}
}
if selected.len() < pop_size
&& let Some(lf_idx) = last_front_idx
{
let remaining = pop_size - selected.len();
update_ideal_point(&mut state.ideal_point, &min_values);
let intercepts = find_intercepts(&min_values, &state.ideal_point);
let normalized = normalize_objectives(&min_values, &state.ideal_point, &intercepts);
let associations = associate_to_reference_points(&normalized, &state.reference_points);
let last_front = &fronts[lf_idx];
let additional = niching_select(
&mut state.evo.rng,
&associations,
&selected,
last_front,
state.reference_points.len(),
remaining,
);
selected.extend(additional);
}
let n = population.len();
while selected.len() < pop_size {
selected.push(state.evo.rng.usize(0..n));
}
let params = selected
.iter()
.map(|&idx| {
extract_trial_params(population[idx], &state.evo.dimensions, &mut state.evo.rng)
})
.collect();
(params, selected)
}
fn tournament_select_rank(rng: &mut fastrand::Rng, ranks: &[usize], n: usize) -> usize {
let a = rng.usize(0..n);
let b = rng.usize(0..n);
if ranks[a] <= ranks[b] { a } else { b }
}
fn nsga3_generate_offspring(
state: &mut Nsga3State,
population: &[&MultiObjectiveTrial],
directions: &[Direction],
) -> Vec<Candidate> {
let pop_size = state.evo.population_size;
if population.len() < 2 {
return (0..pop_size)
.map(|_| {
let params = state
.evo
.dimensions
.iter()
.map(|d| sample_random(&mut state.evo.rng, &d.distribution))
.collect();
Candidate { params }
})
.collect();
}
if !state.initialized {
initialize_nsga3(state, directions);
}
let (parents, selected_indices) = nsga3_select(state, population, directions);
let n_obj = directions.len();
let min_values: Vec<Vec<f64>> = population
.iter()
.map(|t| to_minimize_space(&t.values, directions))
.collect();
let fronts = pareto::fast_non_dominated_sort(&min_values, &vec![Direction::Minimize; n_obj]);
let mut pop_rank = vec![0_usize; population.len()];
for (front_rank, front) in fronts.iter().enumerate() {
for &idx in front {
if idx < pop_rank.len() {
pop_rank[idx] = front_rank;
}
}
}
let parent_ranks: Vec<usize> = selected_indices
.iter()
.map(|&idx| {
if idx < pop_rank.len() {
pop_rank[idx]
} else {
0
}
})
.collect();
let mut offspring = Vec::with_capacity(pop_size);
while offspring.len() < pop_size {
let p1 = tournament_select_rank(&mut state.evo.rng, &parent_ranks, parents.len());
let p2 = tournament_select_rank(&mut state.evo.rng, &parent_ranks, parents.len());
let (mut child1, mut child2) = crossover(
&mut state.evo.rng,
&parents[p1],
&parents[p2],
&state.evo.dimensions,
state.config.crossover_prob,
state.config.crossover_eta,
);
mutate(
&mut state.evo.rng,
&mut child1,
&state.evo.dimensions,
state.config.mutation_eta,
);
mutate(
&mut state.evo.rng,
&mut child2,
&state.evo.dimensions,
state.config.mutation_eta,
);
offspring.push(Candidate { params: child1 });
if offspring.len() < pop_size {
offspring.push(Candidate { params: child2 });
}
}
offspring
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_perpendicular_distance() {
let d = perpendicular_distance(&[1.0, 0.0], &[1.0, 1.0]);
assert!((d - (0.5_f64).sqrt()).abs() < 1e-10);
}
#[test]
fn test_perpendicular_distance_on_line() {
let d = perpendicular_distance(&[2.0, 2.0], &[1.0, 1.0]);
assert!(d < 1e-10);
}
#[test]
fn test_normalize_objectives() {
let values = vec![vec![2.0, 4.0], vec![4.0, 2.0]];
let ideal = vec![1.0, 1.0];
let intercepts = vec![3.0, 3.0];
let normalized = normalize_objectives(&values, &ideal, &intercepts);
assert!((normalized[0][0] - 1.0 / 3.0).abs() < 1e-10);
assert!((normalized[0][1] - 1.0).abs() < 1e-10);
}
}