radiate_engines/steps/
speciate.rs

1use radiate_core::{
2    Chromosome, Diversity, Ecosystem, EngineStep, Genotype, MetricSet, Objective, Species,
3    metric_names,
4    thread_pool::{ThreadPool, WaitGroup},
5};
6use std::sync::{Arc, Mutex, RwLock};
7
8pub struct SpeciateStep<C>
9where
10    C: Chromosome,
11{
12    pub(crate) threashold: f32,
13    pub(crate) objective: Objective,
14    pub(crate) diversity: Arc<dyn Diversity<C>>,
15    pub(crate) thread_pool: Arc<ThreadPool>,
16}
17
18impl<C> SpeciateStep<C>
19where
20    C: Chromosome + 'static,
21{
22    fn process_chunk(
23        chunk_range: std::ops::Range<usize>,
24        population: Arc<RwLock<Vec<(usize, Genotype<C>)>>>,
25        species_snapshot: Arc<Vec<Genotype<C>>>,
26        threshold: f32,
27        diversity: Arc<dyn Diversity<C>>,
28        assignments: Arc<Mutex<Vec<Option<usize>>>>,
29        distances: Arc<Mutex<Vec<f32>>>,
30    ) {
31        let mut inner_distances = Vec::new();
32        for (i, individual) in population.read().unwrap().iter().enumerate() {
33            let mut assigned = None;
34            for (idx, sp) in species_snapshot.iter().enumerate() {
35                let dist = diversity.measure(&individual.1, &sp);
36                inner_distances.push(dist);
37
38                if dist < threshold {
39                    assigned = Some(idx);
40                    break;
41                }
42            }
43
44            assignments.lock().unwrap()[chunk_range.start + i] = assigned;
45        }
46
47        distances.lock().unwrap().extend(inner_distances);
48    }
49}
50
51impl<C> EngineStep<C> for SpeciateStep<C>
52where
53    C: Chromosome + 'static,
54{
55    fn execute(
56        &mut self,
57        generation: usize,
58        metrics: &mut MetricSet,
59        ecosystem: &mut Ecosystem<C>,
60    ) {
61        ecosystem.generate_mascots();
62
63        let wg = WaitGroup::new();
64        let num_threads = self.thread_pool.num_workers();
65        let pop_len = ecosystem.population().len();
66        let chunk_size = (pop_len as f32 / num_threads as f32).ceil() as usize;
67        let mut chunked_members = Vec::new();
68
69        let species_snapshot = Arc::new(
70            ecosystem
71                .species_mascots()
72                .into_iter()
73                .map(|spec| spec.genotype().clone())
74                .collect::<Vec<Genotype<C>>>(),
75        );
76        let distances = Arc::new(Mutex::new(Vec::with_capacity(
77            pop_len * species_snapshot.len(),
78        )));
79        let assignments = Arc::new(Mutex::new(vec![None; pop_len]));
80
81        for chunk_start in (0..pop_len).step_by(chunk_size) {
82            let chunk_end = (chunk_start + chunk_size).min(pop_len);
83            let chunk_range = chunk_start..chunk_end;
84
85            let chunk_population = Arc::new(RwLock::new(
86                ecosystem
87                    .population_mut()
88                    .iter_mut()
89                    .enumerate()
90                    .skip(chunk_start)
91                    .take(chunk_size)
92                    .map(|(idx, pheno)| (idx, pheno.take_genotype()))
93                    .collect::<Vec<_>>(),
94            ));
95
96            let threshold = self.threashold;
97            let diversity = Arc::clone(&self.diversity);
98            let assignments = Arc::clone(&assignments);
99            let distances = Arc::clone(&distances);
100            let population = Arc::clone(&chunk_population);
101            let species_snapshot = Arc::clone(&species_snapshot);
102
103            self.thread_pool.group_submit(&wg, move || {
104                Self::process_chunk(
105                    chunk_range,
106                    population,
107                    species_snapshot,
108                    threshold,
109                    diversity,
110                    assignments,
111                    distances,
112                );
113            });
114
115            chunked_members.push(chunk_population);
116        }
117
118        wg.wait();
119
120        for chunks in chunked_members {
121            let mut chunks = chunks.write().unwrap();
122            let mut taken_genotypes = Vec::with_capacity(chunks.len());
123            std::mem::swap(&mut *chunks, &mut taken_genotypes);
124
125            for (idx, geno) in taken_genotypes {
126                ecosystem.get_phenotype_mut(idx).unwrap().set_genotype(geno);
127            }
128        }
129
130        let assignments = assignments.lock().unwrap();
131        let mut distances = distances.lock().unwrap();
132        for i in 0..ecosystem.population().len() {
133            if let Some(species_id) = assignments[i] {
134                ecosystem.add_species_member(species_id, i);
135            } else {
136                let genotype = ecosystem.get_genotype(i).unwrap();
137                let maybe_idx = ecosystem
138                    .species()
139                    .map(|specs| {
140                        for (species_idx, species) in specs.iter().enumerate() {
141                            let dist = self
142                                .diversity
143                                .measure(genotype, &species.mascot().genotype());
144
145                            distances.push(dist);
146
147                            if dist < self.threashold {
148                                return Some(species_idx);
149                            }
150                        }
151                        None
152                    })
153                    .flatten();
154
155                match maybe_idx {
156                    Some(idx) => ecosystem.add_species_member(idx, i),
157                    None => {
158                        if let Some(pheno) = ecosystem.get_phenotype(i) {
159                            let new_species = Species::new(generation, pheno);
160                            ecosystem.push_species(new_species);
161                        }
162                    }
163                }
164            }
165        }
166
167        let before_species = ecosystem.species().as_ref().map_or(0, |s| s.len());
168        ecosystem.species_mut().unwrap().retain(|s| s.len() > 0);
169        let after_species = ecosystem.species().unwrap().len();
170
171        metrics.upsert_distribution(metric_names::SPECIES_DISTANCE_DIST, &distances);
172        metrics.upsert_value(
173            metric_names::SPECIES_DIED,
174            (before_species - after_species) as f32,
175        );
176
177        ecosystem.fitness_share(&self.objective);
178    }
179}