use crate::{Evaluator, Evolver, Genotype, Phenotype};
use rand::prelude::SeedableRng;
use rand_pcg::Pcg64;
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
#[derive(Serialize, Deserialize)]
#[serde(bound = "G: Genotype")]
struct MapElitesData<G: Genotype> {
archive: BTreeMap<Vec<usize>, Phenotype<G>>,
archive_keys_vec: Vec<Vec<usize>>,
resolution: usize,
mutation_rate: f32,
batch_size: usize,
rng: Pcg64,
}
pub struct MapElites<G: Genotype> {
archive: BTreeMap<Vec<usize>, Phenotype<G>>,
archive_keys_vec: Vec<Vec<usize>>,
population_cache: Vec<Phenotype<G>>,
cache_valid: bool,
resolution: usize,
mutation_rate: f32,
batch_size: usize,
rng: Pcg64,
}
impl<G: Genotype> Serialize for MapElites<G> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
let mut state = serializer.serialize_struct("MapElites", 6)?;
state.serialize_field("archive", &self.archive)?;
state.serialize_field("archive_keys_vec", &self.archive_keys_vec)?;
state.serialize_field("resolution", &self.resolution)?;
state.serialize_field("mutation_rate", &self.mutation_rate)?;
state.serialize_field("batch_size", &self.batch_size)?;
state.serialize_field("rng", &self.rng)?;
state.end()
}
}
impl<'de, G: Genotype> Deserialize<'de> for MapElites<G> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
let data = MapElitesData::<G>::deserialize(deserializer)?;
if data.resolution == 0 {
return Err(D::Error::custom("resolution must be greater than 0"));
}
if data.batch_size == 0 {
return Err(D::Error::custom("batch_size must be greater than 0"));
}
for key in &data.archive_keys_vec {
if !data.archive.contains_key(key) {
return Err(D::Error::custom(
"archive_keys_vec contains key not present in archive",
));
}
}
if data.archive.len() != data.archive_keys_vec.len() {
return Err(D::Error::custom(
"archive_keys_vec length does not match archive size",
));
}
let population_cache: Vec<Phenotype<G>> = data.archive.values().cloned().collect();
Ok(Self {
archive: data.archive,
archive_keys_vec: data.archive_keys_vec,
population_cache,
cache_valid: true,
resolution: data.resolution,
mutation_rate: data.mutation_rate,
batch_size: data.batch_size,
rng: data.rng,
})
}
}
impl<G: Genotype> MapElites<G> {
pub fn new(resolution: usize, mutation_rate: f32, seed: u64) -> Self {
assert!(resolution > 0, "resolution must be greater than 0");
Self {
archive: BTreeMap::new(),
archive_keys_vec: Vec::new(),
population_cache: Vec::new(),
cache_valid: true,
resolution,
mutation_rate,
batch_size: 64,
rng: Pcg64::seed_from_u64(seed),
}
}
pub fn resolution(&self) -> usize {
self.resolution
}
pub fn mutation_rate(&self) -> f32 {
self.mutation_rate
}
pub fn set_mutation_rate(&mut self, rate: f32) {
self.mutation_rate = rate;
}
pub fn batch_size(&self) -> usize {
self.batch_size
}
pub fn set_batch_size(&mut self, size: usize) {
assert!(size > 0, "batch_size must be greater than 0");
self.batch_size = size;
}
pub fn archive_len(&self) -> usize {
self.archive.len()
}
pub fn archive_get(&self, key: &[usize]) -> Option<&Phenotype<G>> {
self.archive.get(key)
}
pub fn archive_keys(&self) -> impl Iterator<Item = &Vec<usize>> {
self.archive.keys()
}
pub fn archive_iter(&self) -> impl Iterator<Item = (&Vec<usize>, &Phenotype<G>)> {
self.archive.iter()
}
pub fn best_by_fitness(&self) -> Option<&Phenotype<G>> {
self.archive.values().max_by(|a, b| {
a.fitness
.partial_cmp(&b.fitness)
.unwrap_or(std::cmp::Ordering::Equal)
})
}
pub fn seed_population<E: Evaluator<G>>(&mut self, initial: Vec<G>, evaluator: &E) {
for dna in initial {
let (f, obj, desc) = evaluator.evaluate(&dna);
if f.is_nan() {
continue;
}
if desc.iter().any(|v| v.is_nan()) {
continue;
}
let idx = self.map_to_index(&desc);
let new_pheno = Phenotype {
genotype: dna,
fitness: f,
objectives: obj,
descriptor: desc,
};
if self.archive.get(&idx).is_none_or(|existing| {
new_pheno.fitness > existing.fitness || existing.fitness.is_nan()
}) {
let is_new_key = !self.archive.contains_key(&idx);
self.archive.insert(idx.clone(), new_pheno);
if is_new_key {
self.archive_keys_vec.push(idx);
}
self.cache_valid = false;
}
}
}
fn ensure_cache_valid(&mut self) {
if !self.cache_valid {
self.population_cache = self.archive.values().cloned().collect();
self.cache_valid = true;
}
}
pub fn map_to_index(&self, descriptor: &[f32]) -> Vec<usize> {
descriptor
.iter()
.map(|&v| self.map_single_dimension(v))
.collect()
}
pub fn map_to_index_into(&self, descriptor: &[f32], buffer: &mut [usize]) {
assert!(
buffer.len() >= descriptor.len(),
"buffer too small: {} < {}",
buffer.len(),
descriptor.len()
);
for (i, &v) in descriptor.iter().enumerate() {
buffer[i] = self.map_single_dimension(v);
}
}
#[inline]
fn map_single_dimension(&self, v: f32) -> usize {
let scaled = v.clamp(0.0, 1.0) * self.resolution as f32;
(scaled.floor() as usize).min(self.resolution - 1)
}
}
impl<G: Genotype> Evolver<G> for MapElites<G> {
fn step<E: Evaluator<G>>(&mut self, evaluator: &E) {
use rand::Rng;
if self.archive.is_empty() {
return;
}
let mutation_rate = self.mutation_rate;
let num_keys = self.archive_keys_vec.len();
let selections: Vec<(Vec<usize>, u64)> = (0..self.batch_size)
.map(|_| {
let key_idx = self.rng.random_range(0..num_keys);
let key = self.archive_keys_vec[key_idx].clone();
let seed = self.rng.random::<u64>();
(key, seed)
})
.collect();
let parents: Vec<G> = selections
.iter()
.map(|(key, _)| self.archive.get(key).unwrap().genotype.clone())
.collect();
#[cfg(feature = "parallel")]
let results: Vec<(G, f32, Vec<f32>, Vec<f32>)> = parents
.into_par_iter()
.zip(selections.into_par_iter())
.map(|(mut dna, (_, seed))| {
let mut rng = Pcg64::seed_from_u64(seed);
dna.mutate(&mut rng, mutation_rate);
let (f, obj, desc) = evaluator.evaluate(&dna);
(dna, f, obj, desc)
})
.collect();
#[cfg(not(feature = "parallel"))]
let results: Vec<(G, f32, Vec<f32>, Vec<f32>)> = parents
.into_iter()
.zip(selections.into_iter())
.map(|(mut dna, (_, seed))| {
let mut rng = Pcg64::seed_from_u64(seed);
dna.mutate(&mut rng, mutation_rate);
let (f, obj, desc) = evaluator.evaluate(&dna);
(dna, f, obj, desc)
})
.collect();
let mut idx_buffer: Vec<usize> = Vec::new();
for (dna, f, obj, desc) in results {
if f.is_nan() {
continue;
}
if desc.iter().any(|v| v.is_nan()) {
continue;
}
idx_buffer.resize(desc.len(), 0);
self.map_to_index_into(&desc, &mut idx_buffer);
let dominated = self
.archive
.get(&idx_buffer)
.is_some_and(|e| f <= e.fitness && !e.fitness.is_nan());
if !dominated {
let new_pheno = Phenotype {
genotype: dna,
fitness: f,
objectives: obj,
descriptor: desc,
};
let is_new_key = !self.archive.contains_key(&idx_buffer);
let idx = idx_buffer.clone();
self.archive.insert(idx.clone(), new_pheno);
if is_new_key {
self.archive_keys_vec.push(idx);
}
self.cache_valid = false;
}
}
}
fn population(&mut self) -> &[Phenotype<G>] {
self.ensure_cache_valid();
&self.population_cache
}
}