1#![doc = include_str!("../README.md")]
2#![feature(random)]
3#![feature(mpmc_channel)]
4
5use std::{
6 fmt,
7 random::{Random, RandomSource},
8 sync::{Arc, Condvar, Mutex},
9 thread::available_parallelism,
10};
11
12#[allow(unused_imports)]
14use crate::{
15 continuous::ContinuousTrainer,
16 inertial::{InertialGenome, InertialTrainer},
17 stochastic::StochasticTrainer,
18};
19
20pub mod continuous;
21pub mod inertial;
22pub mod stochastic;
23
24pub trait Genome {
28 fn generate<R>(rng: &mut R) -> Self
30 where
31 R: RandomSource;
32
33 fn mutate<R>(&mut self, mutation_rate: f32, rng: &mut R)
38 where
39 R: RandomSource;
40
41 fn crossbreed<R>(&self, other: &Self, rng: &mut R) -> Self
43 where
44 R: RandomSource;
45
46 fn fitness(&self) -> f32;
52}
53
54#[derive(Clone, Copy, Debug)]
57pub struct PopulationStats {
58 pub max_fitness: f32,
60
61 pub min_fitness: f32,
63
64 pub mean_fitness: f32,
66
67 pub median_fitness: f32,
69}
70
71impl FromIterator<f32> for PopulationStats {
72 fn from_iter<T: IntoIterator<Item = f32>>(iter: T) -> Self {
73 let mut scores: Vec<_> = iter.into_iter().collect();
74 scores.sort_by(|a, b| a.total_cmp(b));
75 let &min_fitness = scores.first().unwrap();
76 let &max_fitness = scores.last().unwrap();
77 let mean_fitness = scores.iter().sum::<f32>() / scores.len() as f32;
78 let median_fitness = scores[scores.len() / 2];
79 PopulationStats {
80 min_fitness,
81 max_fitness,
82 mean_fitness,
83 median_fitness,
84 }
85 }
86}
87
88impl fmt::Display for PopulationStats {
89 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90 write!(
91 f,
92 "min={:.4} max={:.4} mean={:.4} median={:.4}",
93 self.min_fitness, self.max_fitness, self.mean_fitness, self.median_fitness
94 )
95 }
96}
97
98pub struct TrainingReportStrategy<F, C> {
107 pub should_report: F,
110
111 pub report_callback: C,
117}
118
119fn random_f32<R>(rng: &mut R) -> f32
120where
121 R: RandomSource,
122{
123 u32::random(rng) as f32 / u32::MAX as f32
124}
125
126fn num_cpus() -> usize {
127 available_parallelism().map(usize::from).unwrap_or(1)
128}
129
130fn random_choice_weighted<'a, T, R>(weights: &'a [(T, f32)], rng: &mut R) -> &'a T
131where
132 R: RandomSource,
133{
134 random_choice_weighted_mapped(weights, rng, |x| x)
135}
136
137fn random_choice_weighted_mapped<'a, T, R>(
138 weights: &'a [(T, f32)],
139 rng: &mut R,
140 weight_map: impl Fn(f32) -> f32,
141) -> &'a T
142where
143 R: RandomSource,
144{
145 let total: f32 = weights.iter().map(|x| (weight_map)(x.1)).sum();
146 let mut n = random_f32(rng) * total;
147 for (value, weight) in weights {
148 let weight = (weight_map)(*weight);
149 if n <= weight {
150 return value;
151 }
152 n -= weight;
153 }
154 &weights.last().unwrap().0
155}
156
157fn random_choice_weighted_mapped_by_key<'a, T, R>(
158 elems: &'a [T],
159 rng: &mut R,
160 weight_map: impl Fn(f32) -> f32,
161 key: impl Fn(&T) -> f32,
162) -> &'a T
163where
164 R: RandomSource,
165{
166 let total: f32 = elems.iter().map(|x| (weight_map)(key(x))).sum();
167 let mut n = random_f32(rng) * total;
168 for value in elems {
169 let weight = (weight_map)(key(value));
170 if n <= weight {
171 return value;
172 }
173 n -= weight;
174 }
175 &elems.last().unwrap()
176}
177
178#[derive(Clone)]
179struct Gate<S>(Arc<(Mutex<S>, Condvar)>);
180
181impl<S> Gate<S> {
182 fn new(initial_state: S) -> Self {
183 Self(Arc::new((Mutex::new(initial_state), Condvar::new())))
184 }
185
186 fn update(&self, updater: impl Fn(&mut S)) {
187 let mut state = self.0.0.lock().unwrap();
188 (updater)(&mut state);
189 self.0.1.notify_all();
190 }
191
192 fn wait_while(&self, condition: impl Fn(&S) -> bool) {
193 let mut state = self.0.0.lock().unwrap();
194 while (condition)(&state) {
195 state = self.0.1.wait(state).unwrap();
196 }
197 }
198}
199
200fn bounds(iter: impl Iterator<Item = f32>) -> Option<(f32, f32)> {
201 let mut minmax = None;
202 for score in iter {
203 minmax = match (minmax, score) {
204 (None, score) => Some((score, score)),
205 (Some((min, max)), score) if score < min => Some((score, max)),
206 (Some((min, max)), score) if score > max => Some((min, score)),
207 (minmax, _) => minmax,
208 };
209 }
210 minmax
211}