gene_evo/continuous.rs
1//! Continuous Training.
2//! See the documentation for [`ContinuousTrainer`] for more details.
3
4use std::{
5 fmt,
6 ops::AddAssign,
7 random::RandomSource,
8 sync::{Arc, RwLock, mpmc, mpsc},
9 thread::{self, ScopedJoinHandle},
10};
11
12use crate::{
13 Gate, Genome, PopulationStats, TrainingReportStrategy, num_cpus, random_choice_weighted_mapped,
14 random_f32,
15};
16
17/// This is one of two basic genetic algorithms in this crate, the "continuous" strategy.
18///
19/// This is an alternative genetic algorithm implementation compared to the standard stochastic,
20/// generation-based strategy. This trainer runs its training
21/// continuously, nonstop, with no definable "break" in between generations. New genes are constantly
22/// being reproduced, crossbred, mutated, evaluated, and ranked while the trainer runs, using a multithreaded
23/// pool of workers.
24///
25/// As it's a more nonstandard training strategy, the allowed criteria for
26/// determining when to print population reports and when to end training are more flexible than in the
27/// typical evolutionary trainer, allowing the user to define exactly what criteria they want to
28/// pay attention to during the training process. You can access these more detailed and complex
29/// controls by calling the [`ContinuousTrainer::train_custom`] function. If you would like a simpler
30/// interface with simple, default training criteria and reporting, just call [`ContinuousTrainer::train`].
31pub struct ContinuousTrainer<'scope, G> {
32 /// A collection of all the genes in the population and
33 /// their fitness score, sorted descending by fitness.
34 ///
35 /// Because of the continuous nature of the trainer,
36 /// the collection is behind an `Arc<RwLock<_>>` combo.
37 pub gene_pool: Arc<RwLock<Vec<(G, f32)>>>,
38
39 /// Count of the total number of children reproduced.
40 pub children_created: usize,
41
42 /// The mutation rate of newly reproduced children.
43 pub mutation_rate: f32,
44
45 /// The proportion of newly reproduced children that are created as a result of
46 /// crossbreeding vs mutations.
47 ///
48 /// Higher = more crossbreeding, lower = more mutations.
49 /// Set to 1 to only create new children via crossbreeding, and 0 to only create new children
50 /// via mutation.
51 pub reproduction_type_proportion: f32,
52 work_submission: mpmc::Sender<G>,
53 #[allow(unused)]
54 worker_pool: Vec<ScopedJoinHandle<'scope, ()>>,
55 #[allow(unused)]
56 receiver_thread: ScopedJoinHandle<'scope, ()>,
57 population_size: usize,
58 in_flight: Gate<usize>,
59}
60
61impl<'scope, G> ContinuousTrainer<'scope, G> {
62 /// Construct a new trainer with a given population size, mutation rate, and reproduction type proportion.
63 ///
64 /// A reference to a [`thread::Scope`] must be passed in order
65 /// to spawn the child worker threads for the lifetime of the trainer.
66 pub fn new(
67 population_size: usize,
68 mutation_rate: f32,
69 reproduction_type_proportion: f32,
70 scope: &'scope thread::Scope<'scope, '_>,
71 ) -> Self
72 where
73 G: Genome + 'scope + Send + Sync,
74 {
75 let in_flight = Gate::new(0);
76 let (work_submission, inbox) = mpmc::sync_channel(0);
77 let (outbox, work_reception) = mpsc::channel();
78 let gene_pool = Arc::new(RwLock::new(Vec::new()));
79 let worker_pool = (0..num_cpus())
80 .map(|_| {
81 let inbox = inbox.clone();
82 let outbox = outbox.clone();
83 scope.spawn(move || Self::worker_thread(inbox, outbox))
84 })
85 .collect();
86 let receiver_thread = {
87 let gene_pool = gene_pool.clone();
88 let in_flight = in_flight.clone();
89 scope.spawn(move || {
90 Self::work_receiver_thread(
91 work_reception,
92 gene_pool,
93 population_size,
94 in_flight,
95 )
96 })
97 };
98 Self {
99 gene_pool,
100 work_submission,
101 worker_pool,
102 receiver_thread,
103 mutation_rate,
104 population_size,
105 in_flight,
106 children_created: 0,
107 reproduction_type_proportion,
108 }
109 }
110
111 fn worker_thread(inbox: mpmc::Receiver<G>, outbox: mpsc::Sender<(G, f32)>)
112 where
113 G: Genome,
114 {
115 for gene in inbox {
116 let fitness = gene.fitness();
117 outbox.send((gene, fitness)).unwrap();
118 }
119 }
120
121 fn work_receiver_thread(
122 work_reception: mpsc::Receiver<(G, f32)>,
123 gene_pool: Arc<RwLock<Vec<(G, f32)>>>,
124 max_population_size: usize,
125 in_flight: Gate<usize>,
126 ) {
127 for (gene, score) in work_reception {
128 let mut gene_pool = gene_pool.write().unwrap();
129 let insert_index = gene_pool.binary_search_by(|x| score.total_cmp(&x.1));
130 let insert_index = match insert_index {
131 Ok(i) => i,
132 Err(i) => i,
133 };
134 gene_pool.insert(insert_index, (gene, score));
135 if gene_pool.len() > max_population_size {
136 gene_pool.drain(max_population_size..);
137 }
138 in_flight.update(|x| *x = x.saturating_sub(1));
139 }
140 }
141
142 /// Submit a new genome to the worker pool to be evaluated for its fitness and
143 /// ranked among the population.
144 ///
145 /// Used internally by the training process, should
146 /// typically not be called directly unless the user knows what they're doing.
147 pub fn submit_job(&mut self, gene: G) {
148 self.children_created += 1;
149 self.in_flight.update(|x| x.add_assign(1));
150 self.work_submission.send(gene).unwrap();
151 }
152
153 /// Seed the population with new genes up to the current population cap.
154 ///
155 /// This is called automatically at the start of training, so should typically
156 /// not need to be called directly.
157 ///
158 /// A [`RandomSource`] must be passed as a source of randomness
159 /// for generating the initial population.
160 pub fn seed<R>(&mut self, rng: &mut R)
161 where
162 R: RandomSource,
163 G: Genome,
164 {
165 let current_gene_pool_size = self.gene_pool.read().unwrap().len();
166 for _ in current_gene_pool_size..self.population_size {
167 self.submit_job(G::generate(rng));
168 }
169 }
170
171 /// Begin training, finishing once `num_children` children have been
172 /// reproduced, ranked for fitness, and introduced into the population.
173 ///
174 /// A [`RandomSource`] must be passed as a source of randomness
175 /// for mutating genes to produce new offspring.
176 pub fn train<R>(&mut self, num_children: usize, rng: &mut R) -> G
177 where
178 R: RandomSource,
179 G: Clone + Genome + Send + Sync + 'scope,
180 {
181 self.train_custom(
182 |x| x.child_count <= num_children,
183 Some(default_reporting_strategy(self.population_size)),
184 rng,
185 )
186 }
187
188 /// Begin training with detailed custom parameters.
189 ///
190 /// Instead of a specific child
191 /// count cutoff point, a function `train_criteria` is passed in, which takes in an
192 /// instance of [`TrainingCriteriaMetrics`] and outputs a `bool`. This allows greater
193 /// control over exactly what criteria to finish training under.
194 ///
195 /// Additionally, the user may pass a `reporting_strategy`, which determines the conditions
196 /// and method under which periodic statistical reporting of the population is performed.
197 /// Pass `None` to disable reporting entirely, otherwise pass `Some` with an instance of a
198 /// [`TrainingReportStrategy`] to define the two methods necessary to manage reporting.
199 /// To mimic the default reporting strategy, pass the result of [`default_reporting_strategy()`] wrapped
200 /// in `Some()`.
201 ///
202 /// A [`RandomSource`] must be passed as a source of randomness
203 /// for mutating genes to produce new offspring.
204 pub fn train_custom<R>(
205 &mut self,
206 mut train_criteria: impl FnMut(TrainingCriteriaMetrics) -> bool,
207 mut reporting_strategy: Option<
208 TrainingReportStrategy<
209 impl FnMut(TrainingCriteriaMetrics) -> bool,
210 impl FnMut(TrainingStats),
211 >,
212 >,
213 rng: &mut R,
214 ) -> G
215 where
216 R: RandomSource,
217 G: Clone + Genome + Send + Sync + 'scope,
218 {
219 self.seed(rng);
220 self.in_flight.wait_while(|x| *x > 0);
221 loop {
222 let new_child = {
223 let gene_pool = self.gene_pool.read().unwrap();
224 let min_fitness = gene_pool
225 .iter()
226 .map(|x| x.1)
227 .min_by(|a, b| a.total_cmp(b))
228 .unwrap();
229 let should_crossbreed = random_f32(rng) < self.reproduction_type_proportion;
230 let mut choose_parent =
231 || random_choice_weighted_mapped(&gene_pool, rng, |x| x - min_fitness);
232 if should_crossbreed {
233 // same mother & father could potentially be chosen; this is acceptable,
234 // as it is unlikely with a sufficiently large populations, and the only
235 // outcome is a gene that is identical to the singular parent, equivalent
236 // to a gene chosen for mutation with mutation rate 0.
237 let mother = (choose_parent)();
238 let father = (choose_parent)();
239 mother.crossbreed(father, rng)
240 } else {
241 let mut new_child = (choose_parent)().clone();
242 new_child.mutate(self.mutation_rate, rng);
243 new_child
244 }
245 };
246 self.submit_job(new_child);
247
248 let metrics = self.metrics();
249 if let Some(reporting_strategy) = &mut reporting_strategy {
250 if (reporting_strategy.should_report)(metrics) {
251 (reporting_strategy.report_callback)(self.stats())
252 }
253 }
254 if !(train_criteria)(metrics) {
255 break;
256 }
257 }
258 self.in_flight.wait_while(|x| *x > 0);
259 self.gene_pool.read().unwrap().first().unwrap().0.clone()
260 }
261
262 /// Generate training criteria metrics for the current state of this trainer.
263 ///
264 /// This is a strict subset of the data available in an instance of [`TrainingStats`]
265 /// returned from calling [`ContinuousTrainer::stats`]. However, these
266 /// metrics were chosen specifically for their computation efficiency, and thus can be
267 /// re-evaluated frequently with minimal cost. These metrics are used both to determine
268 /// whether or not to continue training, and whether or not to display a report about
269 /// training progress.
270 pub fn metrics(&self) -> TrainingCriteriaMetrics {
271 let gene_pool = self.gene_pool.read().unwrap();
272 TrainingCriteriaMetrics {
273 max_fitness: gene_pool.first().unwrap().1,
274 min_fitness: gene_pool.last().unwrap().1,
275 median_fitness: gene_pool[gene_pool.len() / 2].1,
276 child_count: self.children_created,
277 }
278 }
279
280 /// Generate population stats for the current state of this trainer.
281 ///
282 /// This function is called whenever the reporting strategy is asked
283 /// to produce a report about the current population, but it may also be called
284 /// manually here.
285 pub fn stats(&self) -> TrainingStats {
286 TrainingStats {
287 population_stats: self.gene_pool.read().unwrap().iter().map(|x| x.1).collect(),
288 child_count: self.children_created,
289 }
290 }
291}
292
293/// A collection of relevant & quick to compute metrics that
294/// can be used to inform whether or not to continue training.
295#[derive(Clone, Copy, Debug)]
296pub struct TrainingCriteriaMetrics {
297 /// Maximum fitness of the population.
298 pub max_fitness: f32,
299
300 /// Minimum fitness of the population.
301 pub min_fitness: f32,
302
303 /// Median fitness of the population.
304 pub median_fitness: f32,
305
306 /// Total number of children that have been
307 /// reproduced and introduced into the population,
308 /// including the initial seed population count.
309 pub child_count: usize,
310}
311
312/// A collection of statistics about the population as a whole.
313///
314/// Relatively more expensive to compute than training metrics, so
315/// should be computed infrequently.
316#[derive(Clone, Copy, Debug)]
317pub struct TrainingStats {
318 /// A collection of standard population stats: see [`PopulationStats`]
319 /// for more information
320 pub population_stats: PopulationStats,
321
322 /// Total number of children that have been
323 /// reproduced and introduced into the population,
324 /// including the initial seed population count.
325 ///
326 /// Same as [`TrainingCriteriaMetrics::child_count`].
327 pub child_count: usize,
328}
329
330impl fmt::Display for TrainingStats {
331 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
332 write!(f, "child #{} {}", self.child_count, self.population_stats)
333 }
334}
335
336/// Returns a default reporting strategy which logs population
337/// statistics to the console every `n` children reproduced.
338///
339/// Used by [`ContinuousTrainer::train`].
340pub fn default_reporting_strategy(
341 n: usize,
342) -> TrainingReportStrategy<impl FnMut(TrainingCriteriaMetrics) -> bool, impl FnMut(TrainingStats)>
343{
344 TrainingReportStrategy {
345 should_report: move |m: TrainingCriteriaMetrics| m.child_count % n == 0,
346 report_callback: |s| println!("{s}"),
347 }
348}