1use std::{iter::repeat_with, sync::Arc};
2
3use clap::{Args, Parser};
4use derivative::Derivative;
5use itertools::Itertools;
6use rand::{seq::IteratorRandom, Rng};
7use rayon::prelude::*;
8
9use crate::{
10 core::{
11 engines::{breed_engine::Breed, reset_engine::Reset},
12 environment::State,
13 },
14 utils::random::{generator, update_seed},
15};
16
17use super::{
18 fitness_engine::Fitness, freeze_engine::Freeze, generate_engine::Generate,
19 mutate_engine::Mutate, status_engine::Status,
20};
21use derive_builder::Builder;
22use serde::{de::DeserializeOwned, Deserialize, Serialize};
23use tracing::{debug, instrument, trace};
24
25#[derive(Debug, Deserialize, Serialize, Builder, Copy, Derivative, Parser)]
26#[command(author, version, about, long_about=None)]
27#[command(propagate_version = true)]
28#[derivative(Clone)]
29pub struct HyperParameters<C>
30where
31 C: Core,
32{
33 #[builder(default = "0.")]
34 #[arg(long, default_value = "0.")]
35 pub default_fitness: f64,
36 #[builder(default = "100")]
37 #[arg(long, default_value = "100")]
38 pub population_size: usize,
39 #[builder(default = "0.5")]
40 #[arg(long, default_value = "0.5")]
41 pub gap: f64,
42 #[builder(default = "0.5")]
43 #[arg(long, default_value = "0.5")]
44 pub mutation_percent: f64,
45 #[builder(default = "0.5")]
46 #[arg(long, default_value = "0.5")]
47 pub crossover_percent: f64,
48 #[builder(default = "100")]
49 #[arg(long, default_value = "100")]
50 pub n_generations: usize,
51 #[builder(default = "100")]
52 #[arg(long, default_value = "100")]
53 pub n_trials: usize,
54 #[builder(default = "None")]
55 #[arg(long)]
56 pub seed: Option<u64>,
57 #[builder(default = "None")]
58 #[arg(long)]
59 pub n_threads: Option<usize>,
60 #[command(flatten)]
61 pub program_parameters: C::ProgramParameters,
62}
63
64pub struct CoreIter<C>
65where
66 C: Core,
67{
68 generation: usize,
69 next_population: Vec<C::Individual>,
70 params: HyperParameters<C>,
71 trials: Vec<C::State>,
72}
73
74impl<C> CoreIter<C>
75where
76 C: Core,
77{
78 #[instrument(skip_all, fields(
79 population_size = hp.population_size,
80 n_generations = hp.n_generations,
81 n_trials = hp.n_trials,
82 gap = hp.gap,
83 mutation_percent = hp.mutation_percent,
84 crossover_percent = hp.crossover_percent,
85 seed = ?hp.seed,
86 n_threads = ?hp.n_threads
87 ))]
88 pub fn new(hp: HyperParameters<C>) -> Self {
89 debug!("Initializing evolution engine");
90
91 let current_population = C::init_population(hp.program_parameters, hp.population_size);
92 trace!(
93 individuals = current_population.len(),
94 "Initial population generated"
95 );
96
97 let trials: Vec<C::State> = repeat_with(|| C::Generate::generate(()))
98 .take(hp.n_trials)
99 .collect_vec();
100 trace!(trials = trials.len(), "Trial environments generated");
101
102 Self {
103 generation: 0,
104 next_population: current_population,
105 params: hp,
106 trials,
107 }
108 }
109}
110
111impl<C> Iterator for CoreIter<C>
112where
113 C: Core,
114{
115 type Item = Vec<C::Individual>;
116
117 fn next(&mut self) -> Option<Self::Item> {
118 if self.generation > self.params.n_generations {
119 return None;
120 }
121
122 let mut population = self.next_population.clone();
123
124 C::eval_fitness(&mut population, &self.trials, self.params.default_fitness);
125 C::rank(&mut population);
126
127 assert!(population.iter().all(C::Status::evaluated));
128
129 let best_fitness = population.first().map(C::Status::get_fitness);
130 let median_fitness = population
131 .get(population.len() / 2)
132 .map(C::Status::get_fitness);
133 let worst_fitness = population.last().map(C::Status::get_fitness);
134
135 debug!(
136 generation = self.generation,
137 best_fitness = ?best_fitness,
138 median_fitness = ?median_fitness,
139 worst_fitness = ?worst_fitness,
140 population_size = population.len(),
141 "Generation complete"
142 );
143
144 debug!(
145 best = serde_json::to_string(&population.first()).ok(),
146 median = serde_json::to_string(&population.get(population.len() / 2)).ok(),
147 worst = serde_json::to_string(&population.last()).ok(),
148 "Full individual details"
149 );
150
151 let mut new_population = population.clone();
152
153 trace!(
154 before_selection = new_population.len(),
155 "Starting selection"
156 );
157 C::survive(&mut new_population, self.params.gap);
158 trace!(after_selection = new_population.len(), "Selection complete");
159
160 trace!(
161 crossover_percent = self.params.crossover_percent,
162 mutation_percent = self.params.mutation_percent,
163 "Starting variation"
164 );
165 C::variation(
166 &mut new_population,
167 self.params.crossover_percent,
168 self.params.mutation_percent,
169 self.params.program_parameters,
170 );
171 trace!(after_variation = new_population.len(), "Variation complete");
172
173 self.next_population = new_population;
174 self.generation += 1;
175
176 Some(population)
177 }
178}
179
180impl<T> HyperParameters<T>
181where
182 T: Core,
183{
184 pub fn build_engine(&self) -> CoreIter<T> {
185 update_seed(self.seed);
186
187 if let Some(n_threads) = self.n_threads {
188 rayon::ThreadPoolBuilder::new()
189 .num_threads(n_threads)
190 .build_global()
191 .ok();
192 }
193
194 CoreIter::new(self.clone())
195 }
196}
197
198pub trait Core {
199 type Individual: Ord + Clone + Send + Sync + Serialize + DeserializeOwned;
200 type ProgramParameters: Copy + Send + Sync + Clone + Serialize + DeserializeOwned + Args;
201 type State: State + Clone + Send + Sync;
202 type FitnessMarker;
203 type Generate: Generate<Self::ProgramParameters, Self::Individual> + Generate<(), Self::State>;
204 type Fitness: Fitness<Self::Individual, Self::State, Self::FitnessMarker>;
205 type Reset: Reset<Self::Individual> + Reset<Self::State>;
206 type Breed: Breed<Self::Individual>;
207 type Mutate: Mutate<Self::ProgramParameters, Self::Individual>;
208 type Status: Status<Self::Individual>;
209 type Freeze: Freeze<Self::Individual>;
210
211 fn init_population(
212 program_parameters: Self::ProgramParameters,
213 population_size: usize,
214 ) -> Vec<Self::Individual> {
215 repeat_with(|| Self::Generate::generate(program_parameters))
216 .take(population_size)
217 .collect()
218 }
219
220 fn eval_fitness(
221 population: &mut Vec<Self::Individual>,
222 trials: &[Self::State],
223 default_fitness: f64,
224 ) {
225 let n_trials = trials.len();
226 population.par_iter_mut().for_each(|individual| {
227 let total: f64 = trials
228 .iter()
229 .cloned()
230 .map(|mut trial| {
231 Self::Reset::reset(individual);
232 Self::Reset::reset(&mut trial);
233 let score = Self::Fitness::eval_fitness(individual, &mut trial);
234 if score.is_finite() {
235 score
236 } else {
237 default_fitness
238 }
239 })
240 .sum();
241 Self::Status::set_fitness(individual, total / n_trials as f64);
242 });
243 }
244
245 fn rank(population: &mut Vec<Self::Individual>) {
246 population.sort_by(|a, b| b.cmp(a));
247 debug_assert!(population.windows(2).all(|w| {
248 let a = &w[0];
249 let b = &w[1];
250
251 debug_assert!(a >= b);
252 a >= b
253 }));
254 }
255
256 fn survive(population: &mut Vec<Self::Individual>, gap: f64) {
257 let n_individuals = population.len();
258
259 let mut n_of_individuals_to_drop =
260 (n_individuals as isize) - ((1.0 - gap) * (n_individuals as f64)).floor() as isize;
261
262 population.retain(Self::Status::valid);
263 let n_individuals_dropped = n_individuals - population.len();
264 n_of_individuals_to_drop -= n_individuals_dropped as isize;
265
266 while n_of_individuals_to_drop > 0 {
267 n_of_individuals_to_drop -= 1;
268 population.pop();
269 }
270 }
271
272 fn variation(
273 population: &mut Vec<Self::Individual>,
274 crossover_percent: f64,
275 mutation_percent: f64,
276 program_parameters: Self::ProgramParameters,
277 ) {
278 debug_assert!(!population.is_empty());
279
280 let pop_cap = population.capacity();
281 let pop_len = population.len();
282
283 let remaining_pool_spots = pop_cap - pop_len;
284
285 if remaining_pool_spots == 0 {
286 return;
287 }
288
289 let n_mutations = (remaining_pool_spots as f64 * mutation_percent).floor() as usize;
290 let n_crossovers = (remaining_pool_spots as f64 * crossover_percent).floor() as usize;
291 let n_clones = remaining_pool_spots - n_mutations - n_crossovers;
292
293 let mut clone_offspring: Vec<Self::Individual> = Vec::with_capacity(n_clones);
294 let mut mutation_offspring: Vec<Self::Individual> = Vec::with_capacity(n_mutations);
295 let mut crossover_offspring: Vec<Self::Individual> = Vec::with_capacity(n_crossovers);
296
297 debug_assert!(n_mutations + n_crossovers <= remaining_pool_spots);
298
299 let rc_population = Arc::new(population.clone());
300
301 rayon::scope(|s| {
302 s.spawn(|_| {
303 crossover_offspring.extend((0..n_crossovers).filter_map(|_| {
304 let population_to_read = rc_population.clone();
305 let parent_a = population_to_read.iter().choose(&mut generator());
306 let parent_b = population_to_read.iter().choose(&mut generator());
307
308 if let (Some(parent_a), Some(parent_b)) = (parent_a, parent_b) {
309 let children = Self::Breed::two_point_crossover(parent_a, parent_b);
310 match generator().gen_range(0..2) {
311 0 => Some(children.0),
312 1 => Some(children.1),
313 _ => unreachable!(),
314 }
315 } else {
316 None
317 }
318 }));
319 });
320
321 s.spawn(|_| {
322 mutation_offspring.extend((0..n_mutations).filter_map(|_| {
323 let population_to_read = rc_population.clone();
324 let parent = population_to_read.iter().choose(&mut generator());
325
326 if let Some(internal_parent) = parent {
327 let mut clone = internal_parent.clone();
328 Self::Mutate::mutate(&mut clone, program_parameters);
329 Some(clone)
330 } else {
331 None
332 }
333 }))
334 });
335
336 s.spawn(|_| {
337 clone_offspring.extend((0..n_clones).filter_map(|_| {
338 let population_to_read = rc_population.clone();
339 let parent = population_to_read.iter().choose(&mut generator());
340
341 if let Some(internal_parent) = parent {
342 let mut clone = internal_parent.clone();
343 Self::Reset::reset(&mut clone);
344 Some(clone)
345 } else {
346 None
347 }
348 }))
349 });
350 });
351
352 population.append(&mut crossover_offspring);
354 population.append(&mut mutation_offspring);
355 population.append(&mut clone_offspring);
356 }
357}