memega/
eval.rs

1use std::fmt;
2use std::hash::Hash;
3
4use concurrent_lru::sharded::LruCache;
5use derive_more::Display;
6use float_pretty_print::PrettyPrintFloat;
7
8use crate::cfg::Cfg;
9use crate::gen::species::{SpeciesId, NO_SPECIES};
10use crate::gen::Params;
11
12pub trait Genome = Clone + Send + Sync + PartialOrd + PartialEq + fmt::Debug;
13pub trait FitnessFn<G: Genome> = Fn(&G, usize) -> f64 + Sync + Send + Clone;
14
15#[derive(Clone, PartialOrd, PartialEq, Debug, Display)]
16#[display(fmt = "fitness {} species {}", "PrettyPrintFloat(*base_fitness)", species)]
17pub struct Mem<G: Genome> {
18    pub genome: G,              // Actual genome.
19    pub params: Params,         // Adaptively evolved parameters
20    pub species: SpeciesId,     // Species index
21    pub base_fitness: f64,      // Original fitness, generated by Evaluator fitness function.
22    pub selection_fitness: f64, // Potentially adjusted fitness, for selection.
23}
24
25impl<G: Genome> Mem<G> {
26    pub fn new<E: Evaluator>(genome: G, cfg: &Cfg) -> Self {
27        Self {
28            genome,
29            params: Params::new::<E>(cfg),
30            species: NO_SPECIES,
31            base_fitness: 0.0,
32            selection_fitness: 0.0,
33        }
34    }
35}
36
37pub trait Evaluator: Send + Sync {
38    type Genome: Genome;
39    const NUM_CROSSOVER: usize = 2; // Specify the number of crossover operators.
40    const NUM_MUTATION: usize = 1; // Specify the number of mutation operators.
41
42    // |idx| specifies which crossover function to use. 0 is conventionally do nothing,
43    // with actual crossover starting from index 1.
44    fn crossover(&self, s1: &mut Self::Genome, s2: &mut Self::Genome, idx: usize);
45
46    // Unlike crossover, mutation is called for every mutation operator. No need for a nop operator.
47    fn mutate(&self, s: &mut Self::Genome, rate: f64, idx: usize);
48    fn fitness(&self, s: &Self::Genome, gen: usize) -> f64;
49    fn distance(&self, s1: &Self::Genome, s2: &Self::Genome) -> f64;
50}
51
52// Evaluator which uses an LRU cache to cache fitness and distance values.
53pub struct CachedEvaluator<E: Evaluator>
54where
55    E::Genome: Hash + Eq,
56{
57    eval: E,
58    fitness_cache: LruCache<E::Genome, f64>,
59}
60
61impl<E: Evaluator> CachedEvaluator<E>
62where
63    E::Genome: Hash + Eq,
64{
65    pub fn new(eval: E, cap: usize) -> Self {
66        Self { eval, fitness_cache: LruCache::new(cap as u64) }
67    }
68}
69
70impl<E: Evaluator> Evaluator for CachedEvaluator<E>
71where
72    E::Genome: Hash + Eq,
73{
74    type Genome = E::Genome;
75    const NUM_CROSSOVER: usize = E::NUM_CROSSOVER;
76    const NUM_MUTATION: usize = E::NUM_MUTATION;
77
78    fn crossover(&self, s1: &mut Self::Genome, s2: &mut Self::Genome, idx: usize) {
79        self.eval.crossover(s1, s2, idx);
80    }
81
82    fn mutate(&self, s: &mut Self::Genome, rate: f64, idx: usize) {
83        self.eval.mutate(s, rate, idx);
84    }
85
86    fn fitness(&self, s: &Self::Genome, gen: usize) -> f64 {
87        *self.fitness_cache.get_or_init(s.clone(), 1, |s| self.eval.fitness(s, gen)).value()
88    }
89
90    fn distance(&self, s1: &Self::Genome, s2: &Self::Genome) -> f64 {
91        self.eval.distance(s1, s2)
92    }
93}