radiate_engines/steps/
filter.rs

1use radiate_core::{
2    Chromosome, Ecosystem, EngineStep, Genotype, MetricSet, Phenotype, ReplacementStrategy, Valid,
3    metric_names,
4};
5use std::sync::Arc;
6
7pub struct FilterStep<C>
8where
9    C: Chromosome,
10{
11    pub(crate) replacer: Arc<dyn ReplacementStrategy<C>>,
12    pub(crate) encoder: Arc<dyn Fn() -> Genotype<C> + Send + Sync>,
13    pub(crate) max_age: usize,
14    pub(crate) max_species_age: usize,
15}
16
17impl<C> EngineStep<C> for FilterStep<C>
18where
19    C: Chromosome,
20{
21    fn execute(
22        &mut self,
23        generation: usize,
24        metrics: &mut MetricSet,
25        ecosystem: &mut Ecosystem<C>,
26    ) {
27        let mut age_count = 0_f32;
28        let mut invalid_count = 0_f32;
29        for i in 0..ecosystem.population.len() {
30            let phenotype = &ecosystem.population[i];
31
32            let mut removed = false;
33            if phenotype.age(generation) > self.max_age {
34                removed = true;
35                age_count += 1_f32;
36            } else if !phenotype.genotype().is_valid() {
37                removed = true;
38                invalid_count += 1_f32;
39            }
40
41            if removed {
42                let new_genotype = self
43                    .replacer
44                    .replace(ecosystem.population(), Arc::clone(&self.encoder));
45                ecosystem.population[i] = Phenotype::from((new_genotype, generation));
46            }
47        }
48
49        if let Some(species) = ecosystem.species_mut() {
50            let before_species = species.len();
51            species.retain(|species| species.age(generation) < self.max_species_age);
52            let species_count = (before_species - species.len()) as f32;
53
54            if species_count > 0_f32 {
55                metrics.upsert_value(metric_names::SPECIES_AGE_FAIL, species_count);
56            }
57        }
58
59        if age_count > 0_f32 {
60            metrics.upsert_value(metric_names::REPLACE_AGE, age_count);
61        }
62
63        if invalid_count > 0_f32 {
64            metrics.upsert_value(metric_names::REPLACE_INVALID, invalid_count);
65        }
66    }
67}