use std::marker::PhantomData;
use crate::{Genotype, Phenotype};
pub trait CompatibilityDistance<G: Genotype>: Send + Sync {
fn distance(&self, a: &G, b: &G) -> f32;
}
#[derive(Clone, Debug)]
pub struct Species<G: Genotype> {
pub id: u64,
pub representative: G,
pub member_indices: Vec<usize>,
}
impl<G: Genotype> Species<G> {
#[must_use]
pub fn size(&self) -> usize {
self.member_indices.len()
}
}
pub struct Speciation<G: Genotype, D: CompatibilityDistance<G>> {
species: Vec<Species<G>>,
threshold: f32,
target_count: usize,
threshold_step: f32,
min_threshold: f32,
distance: D,
next_species_id: u64,
_marker: PhantomData<fn() -> G>,
}
impl<G: Genotype, D: CompatibilityDistance<G>> Speciation<G, D> {
pub const DEFAULT_THRESHOLD_STEP: f32 = 0.3;
pub const DEFAULT_MIN_THRESHOLD: f32 = 0.1;
#[must_use]
pub fn new(distance: D, initial_threshold: f32, target_count: usize) -> Self {
Self {
species: Vec::new(),
threshold: initial_threshold,
target_count,
threshold_step: Self::DEFAULT_THRESHOLD_STEP,
min_threshold: Self::DEFAULT_MIN_THRESHOLD,
distance,
next_species_id: 0,
_marker: PhantomData,
}
}
#[must_use]
pub fn with_threshold_step(mut self, step: f32) -> Self {
self.threshold_step = step;
self
}
#[must_use]
pub fn with_min_threshold(mut self, min: f32) -> Self {
self.min_threshold = min;
self
}
pub fn assign(&mut self, population: &[Phenotype<G>]) {
for s in &mut self.species {
s.member_indices.clear();
}
for (idx, phen) in population.iter().enumerate() {
let mut placed = false;
for s in &mut self.species {
if self.distance.distance(&phen.genotype, &s.representative) < self.threshold {
s.member_indices.push(idx);
placed = true;
break;
}
}
if !placed {
self.species.push(Species {
id: self.next_species_id,
representative: phen.genotype.clone(),
member_indices: vec![idx],
});
self.next_species_id += 1;
}
}
self.species.retain(|s| !s.member_indices.is_empty());
for s in &mut self.species {
let rep_idx = s.member_indices[0];
s.representative = population[rep_idx].genotype.clone();
}
}
pub fn share_fitness(&self, population: &mut [Phenotype<G>]) {
for s in &self.species {
let divisor = s.member_indices.len() as f32;
if divisor == 0.0 {
continue;
}
for &idx in &s.member_indices {
if let Some(phen) = population.get_mut(idx) {
phen.fitness /= divisor;
}
}
}
}
pub fn adjust_threshold(&mut self) {
let count = self.species.len();
if count > self.target_count {
self.threshold += self.threshold_step;
} else if count < self.target_count {
self.threshold = (self.threshold - self.threshold_step).max(self.min_threshold);
}
}
#[must_use]
pub fn species(&self) -> &[Species<G>] {
&self.species
}
#[must_use]
pub fn threshold(&self) -> f32 {
self.threshold
}
#[must_use]
pub fn target_count(&self) -> usize {
self.target_count
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::Rng;
use serde::{Deserialize, Serialize};
#[derive(Clone, Serialize, Deserialize, Debug)]
struct Scalar(f32);
impl Genotype for Scalar {
fn mutate<R: Rng>(&mut self, rng: &mut R, _rate: f32) {
self.0 += rng.random::<f32>() - 0.5;
}
fn crossover<R: Rng>(&self, other: &Self, _rng: &mut R) -> Self {
Scalar((self.0 + other.0) * 0.5)
}
}
struct Abs;
impl CompatibilityDistance<Scalar> for Abs {
fn distance(&self, a: &Scalar, b: &Scalar) -> f32 {
(a.0 - b.0).abs()
}
}
fn make_pop(values: &[f32]) -> Vec<Phenotype<Scalar>> {
values
.iter()
.map(|&v| Phenotype {
genotype: Scalar(v),
fitness: 1.0,
objectives: vec![],
descriptor: vec![],
})
.collect()
}
#[test]
fn assign_clusters_close_genotypes() {
let pop = make_pop(&[0.0, 0.05, 0.1, 5.0, 5.1]);
let mut spec = Speciation::new(Abs, 0.5, 2);
spec.assign(&pop);
assert_eq!(spec.species().len(), 2);
}
#[test]
fn fitness_sharing_divides_by_species_size() {
let mut pop = make_pop(&[0.0, 0.1, 0.2, 5.0]);
for p in &mut pop {
p.fitness = 4.0;
}
let mut spec = Speciation::new(Abs, 0.5, 2);
spec.assign(&pop);
spec.share_fitness(&mut pop);
assert!((pop[0].fitness - 4.0 / 3.0).abs() < 1e-5);
assert!((pop[1].fitness - 4.0 / 3.0).abs() < 1e-5);
assert!((pop[2].fitness - 4.0 / 3.0).abs() < 1e-5);
assert!((pop[3].fitness - 4.0).abs() < 1e-5);
}
#[test]
fn threshold_increases_when_too_many_species() {
let pop = make_pop(&[0.0, 1.0, 2.0, 3.0, 4.0]);
let mut spec = Speciation::new(Abs, 0.5, 1).with_threshold_step(0.5);
spec.assign(&pop);
let t0 = spec.threshold();
spec.adjust_threshold();
assert!(spec.threshold() > t0);
}
#[test]
fn threshold_decreases_when_too_few_species() {
let pop = make_pop(&[0.0, 0.05, 0.1]);
let mut spec = Speciation::new(Abs, 5.0, 5).with_threshold_step(0.5);
spec.assign(&pop);
let t0 = spec.threshold();
spec.adjust_threshold();
assert!(spec.threshold() < t0);
}
#[test]
fn threshold_floors_at_min() {
let pop = make_pop(&[0.0]);
let mut spec = Speciation::new(Abs, 0.2, 99)
.with_threshold_step(1.0)
.with_min_threshold(0.1);
spec.assign(&pop);
spec.adjust_threshold();
assert!((spec.threshold() - 0.1).abs() < 1e-6);
}
#[test]
fn species_ids_are_stable_across_generations() {
let mut pop = make_pop(&[0.0, 0.05, 5.0, 5.05]);
let mut spec = Speciation::new(Abs, 0.5, 2);
spec.assign(&pop);
let ids_gen0: Vec<u64> = spec.species().iter().map(|s| s.id).collect();
for p in &mut pop {
p.genotype.0 += 0.01;
}
spec.assign(&pop);
let ids_gen1: Vec<u64> = spec.species().iter().map(|s| s.id).collect();
assert_eq!(ids_gen0, ids_gen1);
}
#[test]
fn empty_species_dropped() {
let pop = make_pop(&[0.0, 5.0]);
let mut spec = Speciation::new(Abs, 0.5, 2);
spec.assign(&pop);
assert_eq!(spec.species().len(), 2);
let pop2 = make_pop(&[0.0, 0.05]);
spec.assign(&pop2);
assert_eq!(spec.species().len(), 1);
}
#[test]
fn species_count_converges_toward_target() {
let pop = make_pop(&[
0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 7.5,
]);
let target = 4;
let mut spec = Speciation::new(Abs, 0.5, target).with_threshold_step(0.1);
for _ in 0..100 {
spec.assign(&pop);
spec.adjust_threshold();
}
let count = spec.species().len();
assert!(
count.abs_diff(target) <= 1,
"expected count near {target}, got {count}"
);
}
}