1
2use super::traits::*;
3use std::collections::HashSet;
4use rand::random;
5
6fn randint(n: usize) -> usize {
7 ((n as f64)*random::<f64>()) as usize
8}
9
10fn pick_one(list: &[f64]) -> usize {
11 let total: f64 = list.iter().sum();
12 let mut r = random::<f64>();
13
14 for (index, &item) in list.iter().enumerate() {
15 let prob = item / total;
16 r -= prob;
17 if (r < 0.) {
18 return index;
19 }
20 }
21
22 list.len() - 1
23}
24
25#[derive(Debug)]
27pub struct Neat<T: Gene> {
28 nodes: usize,
29 connections: HashSet<(usize, usize)>,
30 genomes: Vec<T>,
31 mutation_rate: f64
32}
33
34impl<T: Gene> Neat<T> {
35 pub fn new(inputs: usize, outputs: usize, size: usize, mutation_rate: f64) -> Self {
37 Self {
38 nodes: 1 + inputs + outputs,
39 connections: HashSet::new(),
40 genomes: (0..size).map(|_| Gene::empty(inputs, outputs)).collect(),
41 mutation_rate
42 }
43 }
44
45 fn speciate(&self) -> Vec<Vec<usize>> {
46 let mut species_plural = Vec::new();
47 species_plural.push(vec![0]);
48
49 for i in 1..self.genomes.len() {
50 let mut added = false;
51 for species in &mut species_plural {
52 let index = species[randint(species.len())];
53 if self.genomes[i].is_same_species_as(&self.genomes[index]) {
54 species.push(i);
55 added = true;
56 break;
57 }
58 }
59 if !added {
60 species_plural.push(vec![i]);
61 }
62 }
63 species_plural
64 }
65
66 pub fn calculate_fitness(&self, calculate: impl Fn(&T, bool) -> f64) -> (Vec<f64>, f64) {
69 let mut total_score = 0.;
70 let mut best_score = 0.;
71 let mut fittest = &self.genomes[0];
72 let mut scores = Vec::new();
73
74 for genome in &self.genomes {
75 let score = calculate(genome, false);
76 if score > best_score {
77 best_score = score;
78 fittest = genome;
79 }
80 scores.push(score);
81 total_score += score;
82 }
83 calculate(fittest, true);
84 (scores, total_score)
85 }
86
87 pub fn next_generation(&mut self, scores: &[f64], total_score: f64) {
90 let total_mean = total_score / (scores.len() as f64);
91
92 let mut species_plural = self.speciate();
93
94 println!("Number of species = {}", species_plural.len());
95 println!("Number of genomes = {}", self.genomes.len());
96
97 let mut next_generation = Vec::new();
98
99 for species in &mut species_plural {
100 species.sort_by(|&a, &b| scores[b].partial_cmp(&scores[a]).unwrap());
102
103 let top = (0.6 * species.len() as f64).round() as usize;
104
105 let num: f64 = species
106 .iter()
107 .map(|&i| scores[i])
108 .sum();
109
110 let num = (num / total_mean).round() as usize;
111
112 let species_scores: Vec<_> = species
113 .iter()
114 .take(top)
115 .map(|&i| scores[i])
116 .collect();
117
118 for _ in 0..num {
119 let index1 = species[pick_one(&species_scores)];
120 let index2 = species[pick_one(&species_scores)];
121
122 let one = &self.genomes[index1];
123 let two = &self.genomes[index2];
124
125 let mut child = if scores[index1] > scores[index2] {
126 one.cross(two)
127 } else {
128 two.cross(one)
129 };
130
131 if random::<f64>() < self.mutation_rate {
132 child.mutate(self);
133 }
134
135 next_generation.push(child);
136 }
137 }
138
139 self.genomes = next_generation;
140 }
141}
142
143impl<T: Gene> GlobalNeatCounter for Neat<T> {
144 fn try_adding_connection(&mut self, from: usize, to: usize) -> Option<usize> {
145 let innov_num = self.connections.len();
146
147 if self.connections.insert((from, to)) {
148 Some(innov_num)
149 } else {
150 None
151 }
152 }
153
154 fn get_new_node(&mut self) -> usize {
155 let new_node = self.nodes;
156 self.nodes += 1;
157 new_node
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164 #[test]
165 fn test_po() {
166 let scores = [100., 1., 2., 4., 5., 8., 92.];
167 for _ in 0..100 {
168 for _ in 0..50 {
169 print!("{} ", pick_one(&scores));
170 }
171 }
172 println!();
173 }
174}