use crate::algorithms::archive::Archive;
use crate::{Evaluator, Evolver, Genotype, Phenotype};
use rand::Rng;
use rand::prelude::SeedableRng;
use rand_pcg::Pcg64;
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
#[derive(Serialize, Deserialize)]
#[serde(bound = "G: Genotype")]
struct CvtMapElitesData<G: Genotype> {
archive: BTreeMap<Vec<usize>, Phenotype<G>>,
archive_keys_vec: Vec<Vec<usize>>,
centroids: Vec<Vec<f32>>,
mutation_rate: f32,
batch_size: usize,
rng: Pcg64,
}
pub struct CvtMapElites<G: Genotype> {
archive: Archive<G>,
centroids: Vec<Vec<f32>>,
mutation_rate: f32,
batch_size: usize,
rng: Pcg64,
}
impl<G: Genotype> Serialize for CvtMapElites<G> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
let mut state = serializer.serialize_struct("CvtMapElites", 6)?;
state.serialize_field("archive", self.archive.cells())?;
state.serialize_field("archive_keys_vec", self.archive.keys_vec())?;
state.serialize_field("centroids", &self.centroids)?;
state.serialize_field("mutation_rate", &self.mutation_rate)?;
state.serialize_field("batch_size", &self.batch_size)?;
state.serialize_field("rng", &self.rng)?;
state.end()
}
}
impl<'de, G: Genotype> Deserialize<'de> for CvtMapElites<G> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
let data = CvtMapElitesData::<G>::deserialize(deserializer)?;
validate_centroids(&data.centroids).map_err(D::Error::custom)?;
if data.batch_size == 0 {
return Err(D::Error::custom("batch_size must be greater than 0"));
}
let archive = Archive::from_raw(data.archive, data.archive_keys_vec);
archive.validate().map_err(D::Error::custom)?;
for key in archive.keys() {
if key.len() != 1 || key[0] >= data.centroids.len() {
return Err(D::Error::custom(
"archive key references a centroid index out of range",
));
}
}
Ok(Self {
archive,
centroids: data.centroids,
mutation_rate: data.mutation_rate,
batch_size: data.batch_size,
rng: data.rng,
})
}
}
fn validate_centroids(centroids: &[Vec<f32>]) -> Result<(), &'static str> {
if centroids.is_empty() {
return Err("centroids must be non-empty");
}
let dim = centroids[0].len();
if dim == 0 {
return Err("centroids must have at least one dimension");
}
for c in centroids {
if c.len() != dim {
return Err("all centroids must have the same dimension");
}
if c.iter().any(|v| !v.is_finite()) {
return Err("centroids must contain only finite values");
}
}
Ok(())
}
impl<G: Genotype> CvtMapElites<G> {
pub fn new(centroids: Vec<Vec<f32>>, mutation_rate: f32, batch_size: usize, seed: u64) -> Self {
validate_centroids(¢roids).expect("invalid centroids");
assert!(batch_size > 0, "batch_size must be greater than 0");
Self {
archive: Archive::new(),
centroids,
mutation_rate,
batch_size,
rng: Pcg64::seed_from_u64(seed),
}
}
pub fn with_lloyd(
num_centroids: usize,
descriptor_dim: usize,
num_samples: usize,
lloyd_iters: usize,
mutation_rate: f32,
batch_size: usize,
seed: u64,
) -> Self {
assert!(num_centroids > 0, "num_centroids must be greater than 0");
assert!(descriptor_dim > 0, "descriptor_dim must be greater than 0");
assert!(num_samples > 0, "num_samples must be greater than 0");
assert!(lloyd_iters > 0, "lloyd_iters must be greater than 0");
assert!(batch_size > 0, "batch_size must be greater than 0");
let mut rng = Pcg64::seed_from_u64(seed);
let centroids = lloyd_centroids(
num_centroids,
descriptor_dim,
num_samples,
lloyd_iters,
&mut rng,
);
Self {
archive: Archive::new(),
centroids,
mutation_rate,
batch_size,
rng,
}
}
pub fn centroids(&self) -> &[Vec<f32>] {
&self.centroids
}
pub fn num_centroids(&self) -> usize {
self.centroids.len()
}
pub fn mutation_rate(&self) -> f32 {
self.mutation_rate
}
pub fn set_mutation_rate(&mut self, rate: f32) {
self.mutation_rate = rate;
}
pub fn batch_size(&self) -> usize {
self.batch_size
}
pub fn set_batch_size(&mut self, size: usize) {
assert!(size > 0, "batch_size must be greater than 0");
self.batch_size = size;
}
pub fn archive_len(&self) -> usize {
self.archive.len()
}
pub fn archive_get(&self, centroid_idx: usize) -> Option<&Phenotype<G>> {
self.archive.get(&[centroid_idx])
}
pub fn archive_keys(&self) -> impl Iterator<Item = &Vec<usize>> {
self.archive.keys()
}
pub fn archive_iter(&self) -> impl Iterator<Item = (&Vec<usize>, &Phenotype<G>)> {
self.archive.iter()
}
pub fn best_by_fitness(&self) -> Option<&Phenotype<G>> {
self.archive.best_by_fitness()
}
pub fn coverage(&self) -> f64 {
self.archive.len() as f64 / self.centroids.len() as f64
}
pub fn qd_score(&self) -> f64 {
self.archive.qd_score()
}
#[cfg(feature = "export")]
pub fn export_csv<W: std::io::Write>(&self, writer: W) -> std::io::Result<()> {
self.archive.export_csv(writer)
}
pub fn assign_to_centroid(&self, descriptor: &[f32]) -> usize {
let dim = self.centroids[0].len();
assert_eq!(
descriptor.len(),
dim,
"descriptor dimension {} does not match centroid dimension {}",
descriptor.len(),
dim
);
nearest_centroid_index(descriptor, &self.centroids)
}
pub fn seed_population<E: Evaluator<G>>(&mut self, initial: Vec<G>, evaluator: &E) {
let dim = self.centroids[0].len();
for dna in initial {
let (f, obj, desc) = evaluator.evaluate(&dna);
if f.is_nan() || desc.iter().any(|v| v.is_nan()) {
continue;
}
if desc.len() != dim {
continue;
}
let key = vec![nearest_centroid_index(&desc, &self.centroids)];
self.archive.insert_if_better(
key,
Phenotype {
genotype: dna,
fitness: f,
objectives: obj,
descriptor: desc,
},
);
}
}
}
impl<G: Genotype> Evolver<G> for CvtMapElites<G> {
fn step<E: Evaluator<G>>(&mut self, evaluator: &E) {
if self.archive.is_empty() {
return;
}
let mutation_rate = self.mutation_rate;
let dim = self.centroids[0].len();
let selections: Vec<(Vec<usize>, u64)> = (0..self.batch_size)
.map(|_| {
let key = self
.archive
.sample_key(&mut self.rng)
.expect("archive non-empty checked above")
.clone();
let seed = self.rng.random::<u64>();
(key, seed)
})
.collect();
let parents: Vec<G> = selections
.iter()
.map(|(key, _)| {
self.archive
.get(key)
.expect("sampled key exists in archive")
.genotype
.clone()
})
.collect();
#[cfg(feature = "parallel")]
let results: Vec<(G, f32, Vec<f32>, Vec<f32>)> = parents
.into_par_iter()
.zip(selections.into_par_iter())
.map(|(mut dna, (_, seed))| {
let mut rng = Pcg64::seed_from_u64(seed);
dna.mutate(&mut rng, mutation_rate);
let (f, obj, desc) = evaluator.evaluate(&dna);
(dna, f, obj, desc)
})
.collect();
#[cfg(not(feature = "parallel"))]
let results: Vec<(G, f32, Vec<f32>, Vec<f32>)> = parents
.into_iter()
.zip(selections.into_iter())
.map(|(mut dna, (_, seed))| {
let mut rng = Pcg64::seed_from_u64(seed);
dna.mutate(&mut rng, mutation_rate);
let (f, obj, desc) = evaluator.evaluate(&dna);
(dna, f, obj, desc)
})
.collect();
for (dna, f, obj, desc) in results {
if f.is_nan() || desc.iter().any(|v| v.is_nan()) || desc.len() != dim {
continue;
}
let key = vec![nearest_centroid_index(&desc, &self.centroids)];
self.archive.insert_if_better(
key,
Phenotype {
genotype: dna,
fitness: f,
objectives: obj,
descriptor: desc,
},
);
}
}
fn population(&mut self) -> &[Phenotype<G>] {
self.archive.population()
}
}
#[inline]
fn squared_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
}
#[inline]
fn nearest_centroid_index(point: &[f32], centroids: &[Vec<f32>]) -> usize {
let mut best_idx = 0;
let mut best_dist = squared_distance(point, ¢roids[0]);
for (i, c) in centroids.iter().enumerate().skip(1) {
let d = squared_distance(point, c);
if d < best_dist {
best_dist = d;
best_idx = i;
}
}
best_idx
}
fn lloyd_centroids<R: Rng>(
num_centroids: usize,
dim: usize,
num_samples: usize,
iterations: usize,
rng: &mut R,
) -> Vec<Vec<f32>> {
let mut centroids: Vec<Vec<f32>> = (0..num_centroids)
.map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
.collect();
let mut sums: Vec<Vec<f32>> = vec![vec![0.0; dim]; num_centroids];
let mut counts: Vec<usize> = vec![0; num_centroids];
for _ in 0..iterations {
for s in &mut sums {
for v in s.iter_mut() {
*v = 0.0;
}
}
counts.iter_mut().for_each(|c| *c = 0);
for _ in 0..num_samples {
let sample: Vec<f32> = (0..dim).map(|_| rng.random::<f32>()).collect();
let nearest = nearest_centroid_index(&sample, ¢roids);
for (acc, v) in sums[nearest].iter_mut().zip(sample.iter()) {
*acc += v;
}
counts[nearest] += 1;
}
for i in 0..num_centroids {
if counts[i] > 0 {
let n = counts[i] as f32;
for d in 0..dim {
centroids[i][d] = sums[i][d] / n;
}
}
}
}
centroids
}