gene_evo/
continuous.rs

1//! Continuous Training.
2//! See the documentation for [`ContinuousTrainer`] for more details.
3
4use std::{
5    fmt,
6    ops::AddAssign,
7    random::RandomSource,
8    sync::{Arc, RwLock, mpmc, mpsc},
9    thread::{self, ScopedJoinHandle},
10};
11
12use crate::{
13    Gate, Genome, PopulationStats, TrainingReportStrategy, num_cpus, random_choice_weighted_mapped,
14    random_f32,
15};
16
17/// This is one of two basic genetic algorithms in this crate, the "continuous" strategy.
18///
19/// This is an alternative genetic algorithm implementation compared to the standard stochastic,
20/// generation-based strategy. This trainer runs its training
21/// continuously, nonstop, with no definable "break" in between generations. New genes are constantly
22/// being reproduced, crossbred, mutated, evaluated, and ranked while the trainer runs, using a multithreaded
23/// pool of workers.
24///
25/// As it's a more nonstandard training strategy, the allowed criteria for
26/// determining when to print population reports and when to end training are more flexible than in the
27/// typical evolutionary trainer, allowing the user to define exactly what criteria they want to
28/// pay attention to during the training process. You can access these more detailed and complex
29/// controls by calling the [`ContinuousTrainer::train_custom`] function. If you would like a simpler
30/// interface with simple, default training criteria and reporting, just call [`ContinuousTrainer::train`].
31pub struct ContinuousTrainer<'scope, G> {
32    /// A collection of all the genes in the population and
33    /// their fitness score, sorted descending by fitness.
34    ///
35    /// Because of the continuous nature of the trainer,
36    /// the collection is behind an `Arc<RwLock<_>>` combo.
37    pub gene_pool: Arc<RwLock<Vec<(G, f32)>>>,
38
39    /// Count of the total number of children reproduced.
40    pub children_created: usize,
41
42    /// The mutation rate of newly reproduced children.
43    pub mutation_rate: f32,
44
45    /// The proportion of newly reproduced children that are created as a result of
46    /// crossbreeding vs mutations.
47    ///
48    /// Higher = more crossbreeding, lower = more mutations.
49    /// Set to 1 to only create new children via crossbreeding, and 0 to only create new children
50    /// via mutation.
51    pub reproduction_type_proportion: f32,
52    work_submission: mpmc::Sender<G>,
53    #[allow(unused)]
54    worker_pool: Vec<ScopedJoinHandle<'scope, ()>>,
55    #[allow(unused)]
56    receiver_thread: ScopedJoinHandle<'scope, ()>,
57    population_size: usize,
58    in_flight: Gate<usize>,
59}
60
61impl<'scope, G> ContinuousTrainer<'scope, G> {
62    /// Construct a new trainer with a given population size, mutation rate, and reproduction type proportion.
63    ///
64    /// A reference to a [`thread::Scope`] must be passed in order
65    /// to spawn the child worker threads for the lifetime of the trainer.
66    pub fn new(
67        population_size: usize,
68        mutation_rate: f32,
69        reproduction_type_proportion: f32,
70        scope: &'scope thread::Scope<'scope, '_>,
71    ) -> Self
72    where
73        G: Genome + 'scope + Send + Sync,
74    {
75        let in_flight = Gate::new(0);
76        let (work_submission, inbox) = mpmc::sync_channel(0);
77        let (outbox, work_reception) = mpsc::channel();
78        let gene_pool = Arc::new(RwLock::new(Vec::new()));
79        let worker_pool = (0..num_cpus())
80            .map(|_| {
81                let inbox = inbox.clone();
82                let outbox = outbox.clone();
83                scope.spawn(move || Self::worker_thread(inbox, outbox))
84            })
85            .collect();
86        let receiver_thread = {
87            let gene_pool = gene_pool.clone();
88            let in_flight = in_flight.clone();
89            scope.spawn(move || {
90                Self::work_receiver_thread(
91                    work_reception,
92                    gene_pool,
93                    population_size,
94                    in_flight,
95                )
96            })
97        };
98        Self {
99            gene_pool,
100            work_submission,
101            worker_pool,
102            receiver_thread,
103            mutation_rate,
104            population_size,
105            in_flight,
106            children_created: 0,
107            reproduction_type_proportion,
108        }
109    }
110
111    fn worker_thread(inbox: mpmc::Receiver<G>, outbox: mpsc::Sender<(G, f32)>)
112    where
113        G: Genome,
114    {
115        for gene in inbox {
116            let fitness = gene.fitness();
117            outbox.send((gene, fitness)).unwrap();
118        }
119    }
120
121    fn work_receiver_thread(
122        work_reception: mpsc::Receiver<(G, f32)>,
123        gene_pool: Arc<RwLock<Vec<(G, f32)>>>,
124        max_population_size: usize,
125        in_flight: Gate<usize>,
126    ) {
127        for (gene, score) in work_reception {
128            let mut gene_pool = gene_pool.write().unwrap();
129            let insert_index = gene_pool.binary_search_by(|x| score.total_cmp(&x.1));
130            let insert_index = match insert_index {
131                Ok(i) => i,
132                Err(i) => i,
133            };
134            gene_pool.insert(insert_index, (gene, score));
135            if gene_pool.len() > max_population_size {
136                gene_pool.drain(max_population_size..);
137            }
138            in_flight.update(|x| *x = x.saturating_sub(1));
139        }
140    }
141
142    /// Submit a new genome to the worker pool to be evaluated for its fitness and
143    /// ranked among the population. 
144    /// 
145    /// Used internally by the training process, should
146    /// typically not be called directly unless the user knows what they're doing.
147    pub fn submit_job(&mut self, gene: G) {
148        self.children_created += 1;
149        self.in_flight.update(|x| x.add_assign(1));
150        self.work_submission.send(gene).unwrap();
151    }
152
153    /// Seed the population with new genes up to the current population cap.
154    /// 
155    /// This is called automatically at the start of training, so should typically
156    /// not need to be called directly.
157    /// 
158    /// A [`RandomSource`] must be passed as a source of randomness
159    /// for generating the initial population.
160    pub fn seed<R>(&mut self, rng: &mut R)
161    where
162        R: RandomSource,
163        G: Genome,
164    {
165        let current_gene_pool_size = self.gene_pool.read().unwrap().len();
166        for _ in current_gene_pool_size..self.population_size {
167            self.submit_job(G::generate(rng));
168        }
169    }
170
171    /// Begin training, finishing once `num_children` children have been
172    /// reproduced, ranked for fitness, and introduced into the population.
173    /// 
174    /// A [`RandomSource`] must be passed as a source of randomness
175    /// for mutating genes to produce new offspring.
176    pub fn train<R>(&mut self, num_children: usize, rng: &mut R) -> G
177    where
178        R: RandomSource,
179        G: Clone + Genome + Send + Sync + 'scope,
180    {
181        self.train_custom(
182            |x| x.child_count <= num_children,
183            Some(default_reporting_strategy(self.population_size)),
184            rng,
185        )
186    }
187
188    /// Begin training with detailed custom parameters. 
189    /// 
190    /// Instead of a specific child
191    /// count cutoff point, a function `train_criteria` is passed in, which takes in an
192    /// instance of [`TrainingCriteriaMetrics`] and outputs a `bool`. This allows greater
193    /// control over exactly what criteria to finish training under.
194    ///
195    /// Additionally, the user may pass a `reporting_strategy`, which determines the conditions
196    /// and method under which periodic statistical reporting of the population is performed.
197    /// Pass `None` to disable reporting entirely, otherwise pass `Some` with an instance of a
198    /// [`TrainingReportStrategy`] to define the two methods necessary to manage reporting.
199    /// To mimic the default reporting strategy, pass the result of [`default_reporting_strategy()`] wrapped
200    /// in `Some()`.
201    /// 
202    /// A [`RandomSource`] must be passed as a source of randomness
203    /// for mutating genes to produce new offspring.
204    pub fn train_custom<R>(
205        &mut self,
206        mut train_criteria: impl FnMut(TrainingCriteriaMetrics) -> bool,
207        mut reporting_strategy: Option<
208            TrainingReportStrategy<
209                impl FnMut(TrainingCriteriaMetrics) -> bool,
210                impl FnMut(TrainingStats),
211            >,
212        >,
213        rng: &mut R,
214    ) -> G
215    where
216        R: RandomSource,
217        G: Clone + Genome + Send + Sync + 'scope,
218    {
219        self.seed(rng);
220        self.in_flight.wait_while(|x| *x > 0);
221        loop {
222            let new_child = {
223                let gene_pool = self.gene_pool.read().unwrap();
224                let min_fitness = gene_pool
225                    .iter()
226                    .map(|x| x.1)
227                    .min_by(|a, b| a.total_cmp(b))
228                    .unwrap();
229                let should_crossbreed = random_f32(rng) < self.reproduction_type_proportion;
230                let mut choose_parent =
231                    || random_choice_weighted_mapped(&gene_pool, rng, |x| x - min_fitness);
232                if should_crossbreed {
233                    // same mother & father could potentially be chosen; this is acceptable,
234                    // as it is unlikely with a sufficiently large populations, and the only
235                    // outcome is a gene that is identical to the singular parent, equivalent
236                    // to a gene chosen for mutation with mutation rate 0.
237                    let mother = (choose_parent)();
238                    let father = (choose_parent)();
239                    mother.crossbreed(father, rng)
240                } else {
241                    let mut new_child = (choose_parent)().clone();
242                    new_child.mutate(self.mutation_rate, rng);
243                    new_child
244                }
245            };
246            self.submit_job(new_child);
247
248            let metrics = self.metrics();
249            if let Some(reporting_strategy) = &mut reporting_strategy {
250                if (reporting_strategy.should_report)(metrics) {
251                    (reporting_strategy.report_callback)(self.stats())
252                }
253            }
254            if !(train_criteria)(metrics) {
255                break;
256            }
257        }
258        self.in_flight.wait_while(|x| *x > 0);
259        self.gene_pool.read().unwrap().first().unwrap().0.clone()
260    }
261
262    /// Generate training criteria metrics for the current state of this trainer.
263    /// 
264    /// This is a strict subset of the data available in an instance of [`TrainingStats`]
265    /// returned from calling [`ContinuousTrainer::stats`]. However, these
266    /// metrics were chosen specifically for their computation efficiency, and thus can be
267    /// re-evaluated frequently with minimal cost. These metrics are used both to determine
268    /// whether or not to continue training, and whether or not to display a report about
269    /// training progress.
270    pub fn metrics(&self) -> TrainingCriteriaMetrics {
271        let gene_pool = self.gene_pool.read().unwrap();
272        TrainingCriteriaMetrics {
273            max_fitness: gene_pool.first().unwrap().1,
274            min_fitness: gene_pool.last().unwrap().1,
275            median_fitness: gene_pool[gene_pool.len() / 2].1,
276            child_count: self.children_created,
277        }
278    }
279
280    /// Generate population stats for the current state of this trainer.
281    /// 
282    /// This function is called whenever the reporting strategy is asked
283    /// to produce a report about the current population, but it may also be called
284    /// manually here.
285    pub fn stats(&self) -> TrainingStats {
286        TrainingStats {
287            population_stats: self.gene_pool.read().unwrap().iter().map(|x| x.1).collect(),
288            child_count: self.children_created,
289        }
290    }
291}
292
293/// A collection of relevant & quick to compute metrics that
294/// can be used to inform whether or not to continue training.
295#[derive(Clone, Copy, Debug)]
296pub struct TrainingCriteriaMetrics {
297    /// Maximum fitness of the population.
298    pub max_fitness: f32,
299
300    /// Minimum fitness of the population.
301    pub min_fitness: f32,
302
303    /// Median fitness of the population.
304    pub median_fitness: f32,
305
306    /// Total number of children that have been
307    /// reproduced and introduced into the population,
308    /// including the initial seed population count.
309    pub child_count: usize,
310}
311
312/// A collection of statistics about the population as a whole.
313/// 
314/// Relatively more expensive to compute than training metrics, so
315/// should be computed infrequently.
316#[derive(Clone, Copy, Debug)]
317pub struct TrainingStats {
318    /// A collection of standard population stats: see [`PopulationStats`]
319    /// for more information
320    pub population_stats: PopulationStats,
321
322    /// Total number of children that have been
323    /// reproduced and introduced into the population,
324    /// including the initial seed population count.
325    /// 
326    /// Same as [`TrainingCriteriaMetrics::child_count`].
327    pub child_count: usize,
328}
329
330impl fmt::Display for TrainingStats {
331    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
332        write!(f, "child #{} {}", self.child_count, self.population_stats)
333    }
334}
335
336/// Returns a default reporting strategy which logs population
337/// statistics to the console every `n` children reproduced.
338/// 
339/// Used by [`ContinuousTrainer::train`].
340pub fn default_reporting_strategy(
341    n: usize,
342) -> TrainingReportStrategy<impl FnMut(TrainingCriteriaMetrics) -> bool, impl FnMut(TrainingStats)>
343{
344    TrainingReportStrategy {
345        should_report: move |m: TrainingCriteriaMetrics| m.child_count % n == 0,
346        report_callback: |s| println!("{s}"),
347    }
348}