use crate::{Evaluator, Evolver, Genotype, Phenotype};
use rand::prelude::SeedableRng;
use rand_pcg::Pcg64;
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
#[derive(Serialize, Deserialize)]
#[serde(bound = "G: Genotype")]
struct Nsga2Data<G: Genotype> {
population: Vec<Phenotype<G>>,
ranks: Vec<usize>,
crowding_distances: Vec<f32>,
pop_size: usize,
mutation_rate: f32,
rng: Pcg64,
}
pub struct Nsga2<G: Genotype> {
population: Vec<Phenotype<G>>,
ranks: Vec<usize>,
crowding_distances: Vec<f32>,
pop_size: usize,
mutation_rate: f32,
rng: Pcg64,
}
impl<G: Genotype> Serialize for Nsga2<G> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
let mut state = serializer.serialize_struct("Nsga2", 6)?;
state.serialize_field("population", &self.population)?;
state.serialize_field("ranks", &self.ranks)?;
state.serialize_field("crowding_distances", &self.crowding_distances)?;
state.serialize_field("pop_size", &self.pop_size)?;
state.serialize_field("mutation_rate", &self.mutation_rate)?;
state.serialize_field("rng", &self.rng)?;
state.end()
}
}
impl<'de, G: Genotype> Deserialize<'de> for Nsga2<G> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
let data = Nsga2Data::<G>::deserialize(deserializer)?;
if data.pop_size == 0 {
return Err(D::Error::custom("pop_size must be greater than 0"));
}
if data.population.len() != data.pop_size {
return Err(D::Error::custom(format!(
"population length ({}) does not match pop_size ({})",
data.population.len(),
data.pop_size
)));
}
if data.ranks.len() != data.population.len() {
return Err(D::Error::custom(format!(
"ranks length ({}) does not match population length ({})",
data.ranks.len(),
data.population.len()
)));
}
if data.crowding_distances.len() != data.population.len() {
return Err(D::Error::custom(format!(
"crowding_distances length ({}) does not match population length ({})",
data.crowding_distances.len(),
data.population.len()
)));
}
if data.mutation_rate.is_nan() || data.mutation_rate.is_infinite() {
return Err(D::Error::custom("mutation_rate must be a finite number"));
}
Ok(Self {
population: data.population,
ranks: data.ranks,
crowding_distances: data.crowding_distances,
pop_size: data.pop_size,
mutation_rate: data.mutation_rate,
rng: data.rng,
})
}
}
impl<G: Genotype> Nsga2<G> {
pub fn pop_size(&self) -> usize {
self.pop_size
}
pub fn mutation_rate(&self) -> f32 {
self.mutation_rate
}
pub fn set_mutation_rate(&mut self, rate: f32) {
self.mutation_rate = rate;
}
}
#[derive(Clone, Copy)]
pub struct SortWrapper {
pub index: usize,
pub rank: usize,
pub distance: f32,
}
impl<G: Genotype> Nsga2<G> {
pub fn new(initial_pop: Vec<G>, mutation_rate: f32, seed: u64) -> Self {
let pop_size = initial_pop.len();
let population: Vec<_> = initial_pop
.into_iter()
.map(|g| Phenotype {
genotype: g,
fitness: 0.0,
objectives: vec![],
descriptor: vec![],
})
.collect();
let (ranks, crowding_distances) = Self::calculate_ranks_and_distances(&population);
Self {
population,
ranks,
crowding_distances,
pop_size,
mutation_rate,
rng: Pcg64::seed_from_u64(seed),
}
}
fn calculate_ranks_and_distances(population: &[Phenotype<G>]) -> (Vec<usize>, Vec<f32>) {
if population.is_empty() {
return (vec![], vec![]);
}
let fronts = Self::fast_non_dominated_sort(population);
let mut ranks = vec![0; population.len()];
let mut crowding_distances = vec![0.0; population.len()];
for (rank, indices) in fronts.iter().enumerate() {
let mut front_wrappers: Vec<_> = indices
.iter()
.map(|&i| SortWrapper {
index: i,
rank,
distance: 0.0,
})
.collect();
Self::calculate_crowding_distance(&mut front_wrappers, population);
for wrapper in front_wrappers {
ranks[wrapper.index] = rank;
crowding_distances[wrapper.index] = wrapper.distance;
}
}
(ranks, crowding_distances)
}
fn binary_tournament(&mut self) -> usize {
use rand::Rng;
let n = self.population.len();
let i = self.rng.random_range(0..n);
let j = self.rng.random_range(0..n);
match self.ranks[i].cmp(&self.ranks[j]) {
Ordering::Less => i,
Ordering::Greater => j,
Ordering::Equal => {
match self.crowding_distances[i].total_cmp(&self.crowding_distances[j]) {
Ordering::Greater | Ordering::Equal => i,
Ordering::Less => j,
}
}
}
}
pub fn fast_non_dominated_sort(combined: &[Phenotype<G>]) -> Vec<Vec<usize>> {
let n = combined.len();
let mut fronts = vec![vec![]];
let mut domination_count = vec![0; n];
let mut dominated_indices = vec![vec![]; n];
for i in 0..n {
for j in 0..n {
if i == j {
continue;
}
if Self::dominates(&combined[i], &combined[j]) {
dominated_indices[i].push(j);
} else if Self::dominates(&combined[j], &combined[i]) {
domination_count[i] += 1;
}
}
if domination_count[i] == 0 {
fronts[0].push(i);
}
}
let mut curr = 0;
while curr < fronts.len() && !fronts[curr].is_empty() {
let mut next_front = vec![];
for &i in &fronts[curr] {
for &j in &dominated_indices[i] {
domination_count[j] -= 1;
if domination_count[j] == 0 {
next_front.push(j);
}
}
}
if next_front.is_empty() {
break;
}
fronts.push(next_front);
curr += 1;
}
fronts
}
pub fn calculate_crowding_distance(front: &mut [SortWrapper], combined: &[Phenotype<G>]) {
let n = front.len();
if n <= 2 {
for ind in front {
ind.distance = f32::INFINITY;
}
return;
}
let min_obj = front
.iter()
.map(|w| combined[w.index].objectives.len())
.min()
.unwrap_or(0);
let max_obj = front
.iter()
.map(|w| combined[w.index].objectives.len())
.max()
.unwrap_or(0);
if min_obj == 0 {
for ind in front {
ind.distance = f32::INFINITY;
}
return;
}
if min_obj != max_obj {
for ind in front {
ind.distance = f32::INFINITY;
}
return;
}
let obj_count = min_obj;
for m in 0..obj_count {
front.sort_by(|a, b| {
combined[a.index].objectives[m].total_cmp(&combined[b.index].objectives[m])
});
let range =
combined[front[n - 1].index].objectives[m] - combined[front[0].index].objectives[m];
front[0].distance = f32::INFINITY;
front[n - 1].distance = f32::INFINITY;
if range > 0.0 {
for i in 1..(n - 1) {
if front[i].distance != f32::INFINITY {
front[i].distance += (combined[front[i + 1].index].objectives[m]
- combined[front[i - 1].index].objectives[m])
/ range;
}
}
}
}
}
pub fn dominates(a: &Phenotype<G>, b: &Phenotype<G>) -> bool {
if a.objectives.len() != b.objectives.len() {
return false;
}
if a.objectives.iter().any(|v| v.is_nan()) {
return false;
}
if b.objectives.iter().any(|v| v.is_nan()) {
return true;
}
let mut better_in_any = false;
for (oa, ob) in a.objectives.iter().zip(b.objectives.iter()) {
if oa < ob {
return false;
}
if oa > ob {
better_in_any = true;
}
}
better_in_any
}
}
impl<G: Genotype> Evolver<G> for Nsga2<G> {
fn step<E: Evaluator<G>>(&mut self, evaluator: &E) {
if self.population.is_empty() {
return;
}
let mut offspring = vec![];
while offspring.len() < self.pop_size {
let idx1 = self.binary_tournament();
let idx2 = self.binary_tournament();
let p1 = &self.population[idx1];
let p2 = &self.population[idx2];
let mut child_dna = p1.genotype.crossover(&p2.genotype, &mut self.rng);
child_dna.mutate(&mut self.rng, self.mutation_rate);
offspring.push(Phenotype {
genotype: child_dna,
fitness: 0.0,
objectives: vec![],
descriptor: vec![],
});
}
let mut combined = self.population.clone();
combined.extend(offspring);
#[cfg(feature = "parallel")]
combined.par_iter_mut().for_each(|p| {
let (fit, obj, desc) = evaluator.evaluate(&p.genotype);
p.fitness = fit;
p.objectives = obj;
p.descriptor = desc;
});
#[cfg(not(feature = "parallel"))]
for p in &mut combined {
let (fit, obj, desc) = evaluator.evaluate(&p.genotype);
p.fitness = fit;
p.objectives = obj;
p.descriptor = desc;
}
let fronts = Self::fast_non_dominated_sort(&combined);
let mut next_gen: Vec<SortWrapper> = vec![];
for (rank, indices) in fronts.iter().enumerate() {
let mut current_front: Vec<_> = indices
.iter()
.map(|&i| SortWrapper {
index: i,
rank,
distance: 0.0,
})
.collect();
Self::calculate_crowding_distance(&mut current_front, &combined);
if next_gen.len() + current_front.len() <= self.pop_size {
next_gen.extend(current_front);
} else {
current_front.sort_by(|a, b| b.distance.total_cmp(&a.distance));
next_gen.extend(
current_front
.into_iter()
.take(self.pop_size - next_gen.len()),
);
break;
}
}
self.ranks = next_gen.iter().map(|w| w.rank).collect();
self.crowding_distances = next_gen.iter().map(|w| w.distance).collect();
self.population = next_gen
.into_iter()
.map(|w| combined[w.index].clone())
.collect();
}
fn population(&mut self) -> &[Phenotype<G>] {
&self.population
}
}