neat_rs/
neat.rs

1
2use super::traits::*;
3use std::collections::HashSet;
4use rand::random;
5
6fn randint(n: usize) -> usize {
7    ((n as f64)*random::<f64>()) as usize
8}
9
10fn pick_one(list: &[f64]) -> usize {
11    let total: f64 = list.iter().sum();
12    let mut r = random::<f64>();
13
14    for (index, &item) in list.iter().enumerate() {
15        let prob = item / total;
16        r -= prob;
17        if (r < 0.) {
18            return index;
19        }
20    }
21
22    list.len() - 1
23}
24
25/// Neat struct that takes care of evolving the population based on the fitness scores
26#[derive(Debug)]
27pub struct Neat<T: Gene> {
28    nodes: usize,
29    connections: HashSet<(usize, usize)>,
30    genomes: Vec<T>,
31    mutation_rate: f64
32}
33
34impl<T: Gene> Neat<T> {
35    /// Creates a new population
36    pub fn new(inputs: usize, outputs: usize, size: usize, mutation_rate: f64) -> Self {
37        Self {
38            nodes: 1 + inputs + outputs,
39            connections: HashSet::new(),
40            genomes: (0..size).map(|_| Gene::empty(inputs, outputs)).collect(),
41            mutation_rate
42        }
43    }
44
45    fn speciate(&self) -> Vec<Vec<usize>> {
46        let mut species_plural = Vec::new();
47        species_plural.push(vec![0]);
48
49        for i in 1..self.genomes.len() {
50            let mut added = false;
51            for species in &mut species_plural {
52                let index = species[randint(species.len())];
53                if self.genomes[i].is_same_species_as(&self.genomes[index]) {
54                    species.push(i);
55                    added = true;
56                    break;
57                }
58            }
59            if !added {
60                species_plural.push(vec![i]);
61            }
62        }
63        species_plural
64    }
65
66    /// Evaluates all the genomes based on the fitness function and returns all the scores
67    /// and the total score
68    pub fn calculate_fitness(&self, calculate: impl Fn(&T, bool) -> f64) -> (Vec<f64>, f64) {
69        let mut total_score = 0.;
70        let mut best_score = 0.;
71        let mut fittest = &self.genomes[0];
72        let mut scores = Vec::new();
73
74        for genome in &self.genomes {
75            let score = calculate(genome, false);
76            if score > best_score {
77                best_score = score;
78                fittest = genome;
79            }
80            scores.push(score);
81            total_score += score;
82        }
83        calculate(fittest, true);
84        (scores, total_score)
85    }
86
87    /// Takes care of speciation, selection and mutation and creates the
88    /// new population
89    pub fn next_generation(&mut self, scores: &[f64], total_score: f64) {
90        let total_mean = total_score / (scores.len() as f64);
91
92        let mut species_plural = self.speciate();
93
94        println!("Number of species = {}", species_plural.len());
95        println!("Number of genomes = {}", self.genomes.len());
96
97        let mut next_generation = Vec::new();
98
99        for species in &mut species_plural {
100            // Sort species based on fitness
101            species.sort_by(|&a, &b| scores[b].partial_cmp(&scores[a]).unwrap());
102
103            let top = (0.6 * species.len() as f64).round() as usize;
104
105            let num: f64 = species
106                .iter()
107                .map(|&i| scores[i])
108                .sum();
109            
110            let num = (num / total_mean).round() as usize;
111
112            let species_scores: Vec<_> = species
113            .iter()
114            .take(top)
115            .map(|&i| scores[i])
116            .collect();
117
118            for _ in 0..num {
119                let index1 = species[pick_one(&species_scores)];
120                let index2 = species[pick_one(&species_scores)];
121
122                let one = &self.genomes[index1];
123                let two = &self.genomes[index2];
124
125                let mut child = if scores[index1] > scores[index2] {
126                    one.cross(two)
127                } else {
128                    two.cross(one)
129                };
130
131                if random::<f64>() < self.mutation_rate {
132                    child.mutate(self);
133                }
134
135                next_generation.push(child);
136            }
137        }
138
139        self.genomes = next_generation;
140    }
141}
142
143impl<T: Gene> GlobalNeatCounter for Neat<T> {
144    fn try_adding_connection(&mut self, from: usize, to: usize) -> Option<usize> {
145        let innov_num = self.connections.len();
146
147        if self.connections.insert((from, to)) {
148            Some(innov_num)
149        } else {
150            None
151        }
152    }
153
154    fn get_new_node(&mut self) -> usize {
155        let new_node = self.nodes;
156        self.nodes += 1;
157        new_node
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    #[test]
165    fn test_po() {
166        let scores = [100., 1., 2., 4., 5., 8., 92.];
167        for _ in 0..100 {
168            for _ in 0..50 {
169                print!("{} ", pick_one(&scores));
170            }
171        }
172        println!();
173    }
174}