use derive_more::Display;
use eyre::{eyre, Result};
use crate::cfg::{Cfg, Crossover, Duplicates, Mutation, Replacement, Selection, Survival};
use crate::eval::{Evaluator, Genome, Mem};
use crate::gen::species::SpeciesId;
use crate::gen::unevaluated::UnevaluatedGen;
use crate::ops::mutation::{mutate_lognorm, mutate_normal, mutate_rate};
use crate::ops::sampling::{multi_rws, rws, sus};
use crate::run::runner::RandGenome;
#[derive(Display, Clone, PartialOrd, PartialEq)]
#[display(fmt = "pop: {}, best: {}", "mems.len()", "self.mems[0]")]
pub struct EvaluatedGen<G: Genome> {
pub mems: Vec<Mem<G>>,
}
impl<G: Genome> EvaluatedGen<G> {
#[must_use]
pub fn new(mut mems: Vec<Mem<G>>) -> Self {
mems.sort_unstable_by(|a, b| b.base_fitness.partial_cmp(&a.base_fitness).unwrap());
Self { mems }
}
#[must_use]
pub fn mems(&self) -> &[Mem<G>] {
&self.mems
}
#[must_use]
pub fn species_mems(&self, n: SpeciesId) -> Vec<Mem<G>> {
self.mems.iter().filter(|v| v.species == n).cloned().collect()
}
#[must_use]
pub fn species(&self) -> Vec<SpeciesId> {
let mut species: Vec<_> = self.mems.iter().map(|mem| mem.species).collect();
species.sort_unstable();
species.dedup();
species
}
fn survivors(&self, survival: Survival, cfg: &Cfg) -> Vec<Mem<G>> {
match survival {
Survival::TopProportion(prop) => {
let num = (cfg.pop_size as f64 * prop).ceil() as usize;
self.mems.iter().take(num).cloned().collect()
}
Survival::SpeciesTopProportion(prop) => {
let mut survivors = Vec::new();
let species = self.species();
let num = (cfg.pop_size as f64 * prop / species.len() as f64).ceil() as usize;
for id in species {
survivors.extend(self.species_mems(id).into_iter().take(num));
}
survivors
}
}
}
fn selection(&self, selection: Selection) -> [Mem<G>; 2] {
let fitnesses = self.mems.iter().map(|v| v.selection_fitness).collect::<Vec<_>>();
let idxs = match selection {
Selection::Sus => sus(&fitnesses, 2),
Selection::Roulette => multi_rws(&fitnesses, 2),
};
[self.mems[idxs[0]].clone(), self.mems[idxs[1]].clone()]
}
fn check_weights(weights: &[f64], l: usize) -> Result<()> {
if weights.len() != l {
return Err(eyre!("number of fixed weights {} doesn't match {}", weights.len(), l));
}
for &v in weights.iter() {
if v < 0.0 {
return Err(eyre!("weights must all be non-negative: {}", v));
}
}
Ok(())
}
fn crossover<E: Evaluator<Genome = G>>(
&self,
crossover: &Crossover,
eval: &E,
s1: &mut Mem<G>,
s2: &mut Mem<G>,
) -> Result<()> {
match crossover {
Crossover::Fixed(rates) => {
s1.params.crossover = rates.clone();
s2.params.crossover = rates.clone();
}
Crossover::Adaptive => {
let lrate = 1.0 / (self.mems.len() as f64).sqrt();
mutate_rate(&mut s1.params.crossover, 1.0, |v| mutate_normal(v, lrate).max(0.0));
mutate_rate(&mut s2.params.crossover, 1.0, |v| mutate_normal(v, lrate).max(0.0));
}
};
Self::check_weights(&s1.params.crossover, E::NUM_CROSSOVER)?;
Self::check_weights(&s2.params.crossover, E::NUM_CROSSOVER)?;
let idx = rws(&s1.params.crossover).unwrap();
eval.crossover(&mut s1.genome, &mut s2.genome, idx);
Ok(())
}
fn mutation<E: Evaluator<Genome = G>>(
&self,
mutation: &Mutation,
eval: &E,
s: &mut Mem<G>,
) -> Result<()> {
match mutation {
Mutation::Fixed(rates) => {
s.params.mutation = rates.clone();
}
Mutation::Adaptive => {
let lrate = 1.0 / (self.mems.len() as f64).sqrt();
mutate_rate(&mut s.params.mutation, 1.0, |v| {
mutate_lognorm(v, lrate).clamp(0.0, 1.0)
});
}
};
Self::check_weights(&s.params.mutation, E::NUM_MUTATION)?;
for (idx, &rate) in s.params.mutation.iter().enumerate() {
eval.mutate(&mut s.genome, rate, idx);
}
Ok(())
}
pub fn next_gen<E: Evaluator<Genome = G>>(
&self,
genfn: &mut (dyn RandGenome<G> + '_),
stagnant: bool,
cfg: &Cfg,
eval: &E,
) -> Result<UnevaluatedGen<G>> {
let mut new_mems = self.survivors(cfg.survival, cfg);
new_mems.reserve(cfg.pop_size);
if stagnant {
let num = match cfg.replacement {
Replacement::ReplaceChildren(prop) => {
let remaining = cfg.pop_size as f64 - new_mems.len() as f64;
(prop * remaining).ceil().max(0.0) as usize
}
};
for _ in 0..num {
new_mems.push(Mem::new::<E>((*genfn)(), cfg));
}
}
const NUM_TRIES: usize = 3;
for _ in 0..NUM_TRIES {
while new_mems.len() < cfg.pop_size {
let [mut s1, mut s2] = self.selection(cfg.selection);
self.crossover(&cfg.crossover, eval, &mut s1, &mut s2).unwrap();
self.mutation(&cfg.mutation, eval, &mut s1).unwrap();
self.mutation(&cfg.mutation, eval, &mut s2).unwrap();
new_mems.push(s1);
new_mems.push(s2);
}
if cfg.duplicates == Duplicates::DisallowDuplicates {
new_mems.sort_unstable_by(|a, b| a.genome.partial_cmp(&b.genome).unwrap());
new_mems.dedup_by(|a, b| a.genome.eq(&b.genome));
}
}
Ok(UnevaluatedGen::new(new_mems))
}
}