gene_evo/
lib.rs

1#![doc = include_str!("../README.md")]
2#![feature(random)]
3#![feature(mpmc_channel)]
4
5use std::{
6    fmt,
7    random::{Random, RandomSource},
8    sync::{Arc, Condvar, Mutex},
9    thread::available_parallelism,
10};
11
12// imports for documentation
13#[allow(unused_imports)]
14use crate::{
15    continuous::ContinuousTrainer,
16    inertial::{InertialGenome, InertialTrainer},
17    stochastic::StochasticTrainer,
18};
19
20pub mod continuous;
21pub mod inertial;
22pub mod stochastic;
23
24/// Represents a single Genome in a genetic algorithm.
25///
26/// Implement this trait for your type to evolve it using natural selection.
27pub trait Genome {
28    /// Generate a new instance of this genome from the given random source.
29    fn generate<R>(rng: &mut R) -> Self
30    where
31        R: RandomSource;
32
33    /// Mutate this genome by some amount using the given random source.
34    /// 
35    /// The `mutation_rate` parameter should influence how much mutation to
36    /// perform on the genome, with 0 meaning no mutation, ie. the genome is unchanged.
37    fn mutate<R>(&mut self, mutation_rate: f32, rng: &mut R)
38    where
39        R: RandomSource;
40
41    /// Yield a crossbred child gene of two separate parent genes.
42    fn crossbreed<R>(&self, other: &Self, rng: &mut R) -> Self
43    where
44        R: RandomSource;
45
46    /// Evaluate this genome for its 'fitness' score. 
47    /// 
48    /// Higher fitness scores will lead to a higher survival rate.
49    ///
50    /// Negative fitness scores are valid.
51    fn fitness(&self) -> f32;
52}
53
54/// A collection of standard population statistics that can be used
55/// for progress reporting.
56#[derive(Clone, Copy, Debug)]
57pub struct PopulationStats {
58    /// Maximum fitness of the population.
59    pub max_fitness: f32,
60
61    /// Minimum fitness of the population.
62    pub min_fitness: f32,
63
64    /// Mean fitness of the population.
65    pub mean_fitness: f32,
66
67    /// Median fitness of the population.
68    pub median_fitness: f32,
69}
70
71impl FromIterator<f32> for PopulationStats {
72    fn from_iter<T: IntoIterator<Item = f32>>(iter: T) -> Self {
73        let mut scores: Vec<_> = iter.into_iter().collect();
74        scores.sort_by(|a, b| a.total_cmp(b));
75        let &min_fitness = scores.first().unwrap();
76        let &max_fitness = scores.last().unwrap();
77        let mean_fitness = scores.iter().sum::<f32>() / scores.len() as f32;
78        let median_fitness = scores[scores.len() / 2];
79        PopulationStats {
80            min_fitness,
81            max_fitness,
82            mean_fitness,
83            median_fitness,
84        }
85    }
86}
87
88impl fmt::Display for PopulationStats {
89    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90        write!(
91            f,
92            "min={:.4} max={:.4} mean={:.4} median={:.4}",
93            self.min_fitness, self.max_fitness, self.mean_fitness, self.median_fitness
94        )
95    }
96}
97
98/// A struct that encompasses a two-part reporting
99/// strategy to use when performing periodic progress
100/// updates. 
101/// 
102/// [`TrainingReportStrategy::should_report`]
103/// represents whether or not a report should be generated
104/// at this moment, and [`TrainingReportStrategy::report_callback`]
105/// performs the actual reporting.
106pub struct TrainingReportStrategy<F, C> {
107    /// A callback used to determine if a report
108    /// should be generated at this moment.
109    pub should_report: F,
110
111    /// A callback used to when a report should be displayed
112    /// or logged in some form. 
113    /// 
114    /// This callback is only called if
115    /// [`TrainingReportStrategy::should_report`] returns `true`.
116    pub report_callback: C,
117}
118
119fn random_f32<R>(rng: &mut R) -> f32
120where
121    R: RandomSource,
122{
123    u32::random(rng) as f32 / u32::MAX as f32
124}
125
126fn num_cpus() -> usize {
127    available_parallelism().map(usize::from).unwrap_or(1)
128}
129
130fn random_choice_weighted<'a, T, R>(weights: &'a [(T, f32)], rng: &mut R) -> &'a T
131where
132    R: RandomSource,
133{
134    random_choice_weighted_mapped(weights, rng, |x| x)
135}
136
137fn random_choice_weighted_mapped<'a, T, R>(
138    weights: &'a [(T, f32)],
139    rng: &mut R,
140    weight_map: impl Fn(f32) -> f32,
141) -> &'a T
142where
143    R: RandomSource,
144{
145    let total: f32 = weights.iter().map(|x| (weight_map)(x.1)).sum();
146    let mut n = random_f32(rng) * total;
147    for (value, weight) in weights {
148        let weight = (weight_map)(*weight);
149        if n <= weight {
150            return value;
151        }
152        n -= weight;
153    }
154    &weights.last().unwrap().0
155}
156
157fn random_choice_weighted_mapped_by_key<'a, T, R>(
158    elems: &'a [T],
159    rng: &mut R,
160    weight_map: impl Fn(f32) -> f32,
161    key: impl Fn(&T) -> f32,
162) -> &'a T
163where
164    R: RandomSource,
165{
166    let total: f32 = elems.iter().map(|x| (weight_map)(key(x))).sum();
167    let mut n = random_f32(rng) * total;
168    for value in elems {
169        let weight = (weight_map)(key(value));
170        if n <= weight {
171            return value;
172        }
173        n -= weight;
174    }
175    &elems.last().unwrap()
176}
177
178#[derive(Clone)]
179struct Gate<S>(Arc<(Mutex<S>, Condvar)>);
180
181impl<S> Gate<S> {
182    fn new(initial_state: S) -> Self {
183        Self(Arc::new((Mutex::new(initial_state), Condvar::new())))
184    }
185
186    fn update(&self, updater: impl Fn(&mut S)) {
187        let mut state = self.0.0.lock().unwrap();
188        (updater)(&mut state);
189        self.0.1.notify_all();
190    }
191
192    fn wait_while(&self, condition: impl Fn(&S) -> bool) {
193        let mut state = self.0.0.lock().unwrap();
194        while (condition)(&state) {
195            state = self.0.1.wait(state).unwrap();
196        }
197    }
198}
199
200fn bounds(iter: impl Iterator<Item = f32>) -> Option<(f32, f32)> {
201    let mut minmax = None;
202    for score in iter {
203        minmax = match (minmax, score) {
204            (None, score) => Some((score, score)),
205            (Some((min, max)), score) if score < min => Some((score, max)),
206            (Some((min, max)), score) if score > max => Some((min, score)),
207            (minmax, _) => minmax,
208        };
209    }
210    minmax
211}