Skip to main content

lgp/core/engines/
core_engine.rs

1use std::{iter::repeat_with, sync::Arc};
2
3use clap::{Args, Parser};
4use derivative::Derivative;
5use itertools::Itertools;
6use rand::{seq::IteratorRandom, Rng};
7use rayon::prelude::*;
8
9use crate::{
10    core::{
11        engines::{breed_engine::Breed, reset_engine::Reset},
12        environment::State,
13    },
14    utils::random::{generator, update_seed},
15};
16
17use super::{
18    fitness_engine::Fitness, freeze_engine::Freeze, generate_engine::Generate,
19    mutate_engine::Mutate, status_engine::Status,
20};
21use derive_builder::Builder;
22use serde::{de::DeserializeOwned, Deserialize, Serialize};
23use tracing::{debug, instrument, trace};
24
25#[derive(Debug, Deserialize, Serialize, Builder, Copy, Derivative, Parser)]
26#[command(author, version, about, long_about=None)]
27#[command(propagate_version = true)]
28#[derivative(Clone)]
29pub struct HyperParameters<C>
30where
31    C: Core,
32{
33    #[builder(default = "0.")]
34    #[arg(long, default_value = "0.")]
35    pub default_fitness: f64,
36    #[builder(default = "100")]
37    #[arg(long, default_value = "100")]
38    pub population_size: usize,
39    #[builder(default = "0.5")]
40    #[arg(long, default_value = "0.5")]
41    pub gap: f64,
42    #[builder(default = "0.5")]
43    #[arg(long, default_value = "0.5")]
44    pub mutation_percent: f64,
45    #[builder(default = "0.5")]
46    #[arg(long, default_value = "0.5")]
47    pub crossover_percent: f64,
48    #[builder(default = "100")]
49    #[arg(long, default_value = "100")]
50    pub n_generations: usize,
51    #[builder(default = "100")]
52    #[arg(long, default_value = "100")]
53    pub n_trials: usize,
54    #[builder(default = "None")]
55    #[arg(long)]
56    pub seed: Option<u64>,
57    #[builder(default = "None")]
58    #[arg(long)]
59    pub n_threads: Option<usize>,
60    #[command(flatten)]
61    pub program_parameters: C::ProgramParameters,
62}
63
64pub struct CoreIter<C>
65where
66    C: Core,
67{
68    generation: usize,
69    next_population: Vec<C::Individual>,
70    params: HyperParameters<C>,
71    trials: Vec<C::State>,
72}
73
74impl<C> CoreIter<C>
75where
76    C: Core,
77{
78    #[instrument(skip_all, fields(
79        population_size = hp.population_size,
80        n_generations = hp.n_generations,
81        n_trials = hp.n_trials,
82        gap = hp.gap,
83        mutation_percent = hp.mutation_percent,
84        crossover_percent = hp.crossover_percent,
85        seed = ?hp.seed,
86        n_threads = ?hp.n_threads
87    ))]
88    pub fn new(hp: HyperParameters<C>) -> Self {
89        debug!("Initializing evolution engine");
90
91        let current_population = C::init_population(hp.program_parameters, hp.population_size);
92        trace!(
93            individuals = current_population.len(),
94            "Initial population generated"
95        );
96
97        let trials: Vec<C::State> = repeat_with(|| C::Generate::generate(()))
98            .take(hp.n_trials)
99            .collect_vec();
100        trace!(trials = trials.len(), "Trial environments generated");
101
102        Self {
103            generation: 0,
104            next_population: current_population,
105            params: hp,
106            trials,
107        }
108    }
109}
110
111impl<C> Iterator for CoreIter<C>
112where
113    C: Core,
114{
115    type Item = Vec<C::Individual>;
116
117    fn next(&mut self) -> Option<Self::Item> {
118        if self.generation > self.params.n_generations {
119            return None;
120        }
121
122        let mut population = self.next_population.clone();
123
124        C::eval_fitness(&mut population, &self.trials, self.params.default_fitness);
125        C::rank(&mut population);
126
127        assert!(population.iter().all(C::Status::evaluated));
128
129        let best_fitness = population.first().map(C::Status::get_fitness);
130        let median_fitness = population
131            .get(population.len() / 2)
132            .map(C::Status::get_fitness);
133        let worst_fitness = population.last().map(C::Status::get_fitness);
134
135        debug!(
136            generation = self.generation,
137            best_fitness = ?best_fitness,
138            median_fitness = ?median_fitness,
139            worst_fitness = ?worst_fitness,
140            population_size = population.len(),
141            "Generation complete"
142        );
143
144        debug!(
145            best = serde_json::to_string(&population.first()).ok(),
146            median = serde_json::to_string(&population.get(population.len() / 2)).ok(),
147            worst = serde_json::to_string(&population.last()).ok(),
148            "Full individual details"
149        );
150
151        let mut new_population = population.clone();
152
153        trace!(
154            before_selection = new_population.len(),
155            "Starting selection"
156        );
157        C::survive(&mut new_population, self.params.gap);
158        trace!(after_selection = new_population.len(), "Selection complete");
159
160        trace!(
161            crossover_percent = self.params.crossover_percent,
162            mutation_percent = self.params.mutation_percent,
163            "Starting variation"
164        );
165        C::variation(
166            &mut new_population,
167            self.params.crossover_percent,
168            self.params.mutation_percent,
169            self.params.program_parameters,
170        );
171        trace!(after_variation = new_population.len(), "Variation complete");
172
173        self.next_population = new_population;
174        self.generation += 1;
175
176        Some(population)
177    }
178}
179
180impl<T> HyperParameters<T>
181where
182    T: Core,
183{
184    pub fn build_engine(&self) -> CoreIter<T> {
185        update_seed(self.seed);
186
187        if let Some(n_threads) = self.n_threads {
188            rayon::ThreadPoolBuilder::new()
189                .num_threads(n_threads)
190                .build_global()
191                .ok();
192        }
193
194        CoreIter::new(self.clone())
195    }
196}
197
198pub trait Core {
199    type Individual: Ord + Clone + Send + Sync + Serialize + DeserializeOwned;
200    type ProgramParameters: Copy + Send + Sync + Clone + Serialize + DeserializeOwned + Args;
201    type State: State + Clone + Send + Sync;
202    type FitnessMarker;
203    type Generate: Generate<Self::ProgramParameters, Self::Individual> + Generate<(), Self::State>;
204    type Fitness: Fitness<Self::Individual, Self::State, Self::FitnessMarker>;
205    type Reset: Reset<Self::Individual> + Reset<Self::State>;
206    type Breed: Breed<Self::Individual>;
207    type Mutate: Mutate<Self::ProgramParameters, Self::Individual>;
208    type Status: Status<Self::Individual>;
209    type Freeze: Freeze<Self::Individual>;
210
211    fn init_population(
212        program_parameters: Self::ProgramParameters,
213        population_size: usize,
214    ) -> Vec<Self::Individual> {
215        repeat_with(|| Self::Generate::generate(program_parameters))
216            .take(population_size)
217            .collect()
218    }
219
220    fn eval_fitness(
221        population: &mut Vec<Self::Individual>,
222        trials: &[Self::State],
223        default_fitness: f64,
224    ) {
225        let n_trials = trials.len();
226        population.par_iter_mut().for_each(|individual| {
227            let total: f64 = trials
228                .iter()
229                .cloned()
230                .map(|mut trial| {
231                    Self::Reset::reset(individual);
232                    Self::Reset::reset(&mut trial);
233                    let score = Self::Fitness::eval_fitness(individual, &mut trial);
234                    if score.is_finite() {
235                        score
236                    } else {
237                        default_fitness
238                    }
239                })
240                .sum();
241            Self::Status::set_fitness(individual, total / n_trials as f64);
242        });
243    }
244
245    fn rank(population: &mut Vec<Self::Individual>) {
246        population.sort_by(|a, b| b.cmp(a));
247        debug_assert!(population.windows(2).all(|w| {
248            let a = &w[0];
249            let b = &w[1];
250
251            debug_assert!(a >= b);
252            a >= b
253        }));
254    }
255
256    fn survive(population: &mut Vec<Self::Individual>, gap: f64) {
257        let n_individuals = population.len();
258
259        let mut n_of_individuals_to_drop =
260            (n_individuals as isize) - ((1.0 - gap) * (n_individuals as f64)).floor() as isize;
261
262        population.retain(Self::Status::valid);
263        let n_individuals_dropped = n_individuals - population.len();
264        n_of_individuals_to_drop -= n_individuals_dropped as isize;
265
266        while n_of_individuals_to_drop > 0 {
267            n_of_individuals_to_drop -= 1;
268            population.pop();
269        }
270    }
271
272    fn variation(
273        population: &mut Vec<Self::Individual>,
274        crossover_percent: f64,
275        mutation_percent: f64,
276        program_parameters: Self::ProgramParameters,
277    ) {
278        debug_assert!(!population.is_empty());
279
280        let pop_cap = population.capacity();
281        let pop_len = population.len();
282
283        let remaining_pool_spots = pop_cap - pop_len;
284
285        if remaining_pool_spots == 0 {
286            return;
287        }
288
289        let n_mutations = (remaining_pool_spots as f64 * mutation_percent).floor() as usize;
290        let n_crossovers = (remaining_pool_spots as f64 * crossover_percent).floor() as usize;
291        let n_clones = remaining_pool_spots - n_mutations - n_crossovers;
292
293        let mut clone_offspring: Vec<Self::Individual> = Vec::with_capacity(n_clones);
294        let mut mutation_offspring: Vec<Self::Individual> = Vec::with_capacity(n_mutations);
295        let mut crossover_offspring: Vec<Self::Individual> = Vec::with_capacity(n_crossovers);
296
297        debug_assert!(n_mutations + n_crossovers <= remaining_pool_spots);
298
299        let rc_population = Arc::new(population.clone());
300
301        rayon::scope(|s| {
302            s.spawn(|_| {
303                crossover_offspring.extend((0..n_crossovers).filter_map(|_| {
304                    let population_to_read = rc_population.clone();
305                    let parent_a = population_to_read.iter().choose(&mut generator());
306                    let parent_b = population_to_read.iter().choose(&mut generator());
307
308                    if let (Some(parent_a), Some(parent_b)) = (parent_a, parent_b) {
309                        let children = Self::Breed::two_point_crossover(parent_a, parent_b);
310                        match generator().gen_range(0..2) {
311                            0 => Some(children.0),
312                            1 => Some(children.1),
313                            _ => unreachable!(),
314                        }
315                    } else {
316                        None
317                    }
318                }));
319            });
320
321            s.spawn(|_| {
322                mutation_offspring.extend((0..n_mutations).filter_map(|_| {
323                    let population_to_read = rc_population.clone();
324                    let parent = population_to_read.iter().choose(&mut generator());
325
326                    if let Some(internal_parent) = parent {
327                        let mut clone = internal_parent.clone();
328                        Self::Mutate::mutate(&mut clone, program_parameters);
329                        Some(clone)
330                    } else {
331                        None
332                    }
333                }))
334            });
335
336            s.spawn(|_| {
337                clone_offspring.extend((0..n_clones).filter_map(|_| {
338                    let population_to_read = rc_population.clone();
339                    let parent = population_to_read.iter().choose(&mut generator());
340
341                    if let Some(internal_parent) = parent {
342                        let mut clone = internal_parent.clone();
343                        Self::Reset::reset(&mut clone);
344                        Some(clone)
345                    } else {
346                        None
347                    }
348                }))
349            });
350        });
351
352        // Step 3: Add Children to Population
353        population.append(&mut crossover_offspring);
354        population.append(&mut mutation_offspring);
355        population.append(&mut clone_offspring);
356    }
357}