1use std::vec;
2use serde::{Deserialize, Serialize};
3
4use crate::loader::save_load::{FileSaverLoader, Saver};
5
6use super::species::Species;
7use super::config::Config;
8use super::genome::Genome;
9use super::innovation_number::reset;
10use crate::neat::errors::*;
11
12#[derive(Serialize, Deserialize)]
13pub struct Population {
14 pub name: String,
15 pub generation: u32,
16 pub top_fitness: f64,
17 pub top_genome: Option<Genome>,
18 pub species: Vec<Species>,
19 pub staleness: u32,
20 config: Config,
21}
22
23pub struct EvaluteOption {
24 pub epochs: u64,
25 pub fitness_target: Option<f64>
26}
27
28impl EvaluteOption {
29 pub fn new(epochs: u64) -> Self {
30 EvaluteOption { epochs, fitness_target: None }
31 }
32}
33
34pub type EvaluteFunc = fn(config: &Config, genome: &Genome) -> f64;
35
36impl Population {
37 pub fn new(config: &Config, name: &str) -> Self {
38 reset();
39 let mut pool = Population {
40 name: name.to_owned(),
41 config: config.clone(),
42 generation: 1,
43 species: vec![],
44 top_fitness: 0.0,
45 top_genome: None,
46 staleness: 0,
47 };
48 pool.init_species();
49 pool
50 }
51
52 fn init_species(&mut self) {
53 let mut seed = Genome::new(&self.config);
54 seed.minimal_network();
55 for _ in 0..self.config.population {
56 let copy_gene = Genome::new_from(&seed, false);
57 self.add_to_species(copy_gene);
58 }
59 }
60
61 fn add_to_species(&mut self, genome: Genome) {
62 for i in 0..self.species.len() {
63 let species = &mut self.species[i];
64 if let Some(existing_genome) = species.get_genome_by_index(0) {
65 if Genome::is_same_species(&genome, existing_genome) {
66 species.add_genome(genome);
67 return
68 };
69 }
70 }
71 let mut new_species = Species::new(&self.config);
72 new_species.add_genome(genome);
73 self.species.push(new_species);
74 }
75
76 pub fn run(&mut self, evalute_func: EvaluteFunc, option: EvaluteOption) -> Result<Genome, Errors> {
77 let mut k = 0;
78 let config = self.config.clone();
79 while k < option.epochs {
80 k += 1;
81 let genomes = self.all_genomes();
82 for genome in genomes {
83 let fitness = evalute_func(&config, genome);
84 genome.fitness = fitness;
85 if let Some(fitness_target) = option.fitness_target {
86 if fitness >= fitness_target {
87 return Ok(genome.clone());
88 }
89 }
90 }
91 self.breed_new_generation()?;
92 println!("round {:?} top fitness:{:?} species: {:?}", k, self.top_fitness, self.species.len());
93 }
94 Err(Errors::CanNotFindSolution())
95 }
96
97 pub fn set_fitness(&mut self, fitnesses: Vec<f64>) -> Result<(), Errors> {
98 let mut genomes = self.all_genomes();
99 if fitnesses.len() != genomes.len() {
100 return Err(Errors::InputSizeNotMatch("Set fitnesses size not match genomes size".to_owned()))
101 }
102 for (index, genome) in genomes.iter_mut().enumerate() {
103 genome.fitness = fitnesses[index];
104 }
105 Ok(())
106 }
107
108 pub fn all_genomes(&mut self) -> Vec<&mut Genome> {
109 self.species.iter_mut().fold(vec![], |mut acc, species| {
110 for genomo in &mut species.genomes {
111 acc.push(genomo);
112 }
113 acc
114 })
115 }
116
117 fn calculate_species_adjusted_fitness(&mut self) {
118 self.species.iter_mut()
119 .for_each(| species | {
120 species.calculate_adjusted_fitness();
121 });
122 }
123
124 fn remove_stale_species(&mut self) {
125 self.species.iter_mut().for_each(| species | species.check_progress());
126 self.check_progress();
127 self.species.retain(| species | species.staleness < self.config.stale_species_threshold || species.top_fitness >= self.top_fitness );
128 if self.staleness >= self.config.stale_population_threshold {
129 self.species = vec![];
130 self.staleness = 0;
131 }
133 }
134
135 fn check_progress(&mut self) {
136 let (top_fitness, genome) = self.get_top_genome();
137 if top_fitness > self.top_fitness {
138 self.top_fitness = top_fitness;
139 self.staleness = 0;
140 self.top_genome = genome;
141 } else {
142 self.staleness += 1;
143 }
144 }
145
146 pub fn breed_new_generation(&mut self) -> Result<(), Errors> {
147 let mut children = vec![];
148 self.remove_stale_species();
149 if self.species.is_empty() {
150 if let Some(genome) = self.top_genome.clone() {
151 for _ in 0..self.config.population {
152 children.push(genome.clone());
153 }
154 } else {
155 return Err(Errors::PopulationExtinction());
156 }
157 } else {
158 self.calculate_species_adjusted_fitness();
159 let species_sizes = Self::calculate_pop_size_for_each_species(&self.species, self.config.population as usize, self.config.min_species_size);
160 for (target_size, species) in species_sizes.iter().zip(&mut self.species) {
161 children.extend(species.reproduce_to_size(*target_size));
162 }
163 }
164 for child in children {
165 self.add_to_species(child);
166 }
167 self.generation += 1;
168 Ok(())
169 }
170
171 fn calculate_pop_size_for_each_species(species: &Vec<Species>, pop_size: usize, min_species_size: usize) -> Vec<usize> {
172 let adjusted_fitness_sum = species.iter().fold(0.0, |acc, species| {
173 acc + species.adjusted_fitness
174 });
175 let mut new_species_pos_size = vec![];
176 for s in species {
177 let adjusted_fitness = s.adjusted_fitness;
178 let mut species_size: i64 = min_species_size as i64;
179 if adjusted_fitness_sum > 0.0 {
180 species_size = min_species_size.max((adjusted_fitness / adjusted_fitness_sum * pop_size as f64) as usize) as i64;
181 };
182 new_species_pos_size.push(species_size as usize);
183 }
184 let mut diff = pop_size as i64 - new_species_pos_size.iter().fold(0, | acc, size | acc + size ) as i64;
185 let mut counter = 0;
186 while diff != 0 {
187 let index = counter as usize % new_species_pos_size.len();
188 if diff > 0 {
189 new_species_pos_size[index] += 1;
190 diff -= 1;
191 } else {
192 if new_species_pos_size[index] > min_species_size {
193 new_species_pos_size[index] -= 1;
194 diff += 1;
195 } else {
196 if !new_species_pos_size.iter().any(| size | *size > min_species_size) {
198 break;
199 }
200 }
201 }
202 counter += 1;
203 }
204 new_species_pos_size
205 }
206
207 pub fn get_top_genome(&self) -> (f64, Option<Genome>) {
208 let mut top_fitness = 0.0;
209 let mut top_genomo = None;
210 for species in &self.species {
211 if let Some(genome) = species.get_top_genome() {
212 if genome.fitness >= top_fitness {
213 top_fitness = genome.fitness;
214 top_genomo = Some(genome.clone());
215 }
216 }
217 }
218 return (top_fitness, top_genomo);
219 }
220
221 pub fn save_to(&self, file: &str) -> std::io::Result<()> {
222 let saver = FileSaverLoader::new(file);
223 saver.save(self)
224 }
225}
226
227impl std::fmt::Debug for Population {
228 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
229 let (_, genome) = self.get_top_genome();
230 write!(f, "Generation {:?} \nTop fitness {:?}\nNum of species {:?}\nTop Genome {:?}\n", self.generation, self.top_fitness, self.species.len(), genome.unwrap())
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use super::super::config::Config;
238 #[test]
239 fn test_init_species() {
240 let config = Config::default();
241 let population = Population::new(&config, "test");
242 assert_eq!(population.species.len(), 1);
243 }
244
245 #[test]
246 fn test_set_fitnesses() {
247 let mut config = Config::default();
248 config.population = 5;
249 let mut population = Population::new(&config, "test");
250 let fitnesses = vec![0.1, 0.2, 0.3, 0.4, 0.5];
251 population.set_fitness(vec![0.1, 0.2, 0.3, 0.4, 0.5]);
252 let genomes = population.all_genomes();
253 for (i, genome) in genomes.iter().enumerate() {
254 assert_eq!(genome.fitness, fitnesses[i]);
255 }
256 }
257
258 #[test]
259 fn test_check_progress_making_progress() {
260 let config = Config::default();
261 let mut species1 = Species::new(&config);
263 let mut genome1 = Genome::new(&config);
264 genome1.fitness = 4.0;
265 let mut genome2 = Genome::new(&config);
266 genome2.fitness = 1.0;
267 let mut genome3 = Genome::new(&config);
268 genome3.fitness = 2.0;
269 species1.genomes = vec![genome1, genome2, genome3];
270
271 let mut species2 = Species::new(&config);
272 let mut genome1 = Genome::new(&config);
273 genome1.fitness = 5.0;
274 let mut genome2 = Genome::new(&config);
275 genome2.fitness = 1.0;
276 let mut genome3 = Genome::new(&config);
277 genome3.fitness = 2.0;
278 species2.genomes = vec![genome1, genome2, genome3];
279
280 let mut population = Population{
281 top_genome: None,
282 name: "test".to_owned(),
283 config,
284 generation: 0,
285 top_fitness: 0.0,
286 staleness: 0,
287 species: vec![species1, species2]
288 };
289
290 population.check_progress();
291 assert_eq!(population.staleness, 0);
292 assert_eq!(population.top_fitness, 5.0);
293 }
294
295
296 #[test]
297 fn test_check_progress_not_making_progress() {
298 let config = Config::default();
299 let mut species1 = Species::new(&config);
301 let mut genome1 = Genome::new(&config);
302 genome1.fitness = 4.0;
303 let mut genome2 = Genome::new(&config);
304 genome2.fitness = 1.0;
305 let mut genome3 = Genome::new(&config);
306 genome3.fitness = 2.0;
307 species1.genomes = vec![genome1, genome2, genome3];
308
309 let mut species2 = Species::new(&config);
310 let mut genome1 = Genome::new(&config);
311 genome1.fitness = 5.0;
312 let mut genome2 = Genome::new(&config);
313 genome2.fitness = 1.0;
314 let mut genome3 = Genome::new(&config);
315 genome3.fitness = 2.0;
316 species2.genomes = vec![genome1, genome2, genome3];
317
318 let mut population = Population{
319 top_genome: None,
320 config,
321 name: "test".to_owned(),
322 generation: 0,
323 top_fitness: 6.0,
324 staleness: 0,
325 species: vec![species1, species2]
326 };
327
328 population.check_progress();
329 assert_eq!(population.staleness, 1);
330 assert_eq!(population.top_fitness, 6.0);
331 }
332
333 #[test]
334 fn test_remove_stale_species() {
335 let mut config = Config::default();
336 config.stale_species_threshold = 1;
337 config.stale_population_threshold = 1;
338
339 let mut species1 = Species::new(&config);
341 let mut genome1 = Genome::new(&config);
342 genome1.fitness = 4.0;
343 let mut genome2 = Genome::new(&config);
344 genome2.fitness = 1.0;
345 let mut genome3 = Genome::new(&config);
346 genome3.fitness = 2.0;
347 species1.genomes = vec![genome1, genome2, genome3];
348 species1.top_fitness = 5.0;
349
350 let mut species2 = Species::new(&config);
352 let mut genome1 = Genome::new(&config);
353 genome1.fitness = 5.0;
354 let mut genome2 = Genome::new(&config);
355 genome2.fitness = 1.0;
356 let mut genome3 = Genome::new(&config);
357 genome3.fitness = 2.0;
358 species2.genomes = vec![genome1, genome2, genome3];
359 species2.top_fitness = 4.0;
360
361 let mut species3 = Species::new(&config);
363 let mut genome1 = Genome::new(&config);
364 genome1.fitness = 6.0;
365 let mut genome2 = Genome::new(&config);
366 genome2.fitness = 1.0;
367 let mut genome3 = Genome::new(&config);
368 genome3.fitness = 2.0;
369 species3.genomes = vec![genome1, genome2, genome3];
370 species3.top_fitness = 4.0;
371
372 let mut population = Population{
373 config,
374 name: "test".to_owned(),
375 generation: 0,
376 top_fitness: 5.9,
377 staleness: 0,
378 species: vec![species1, species2, species3],
379 top_genome: None
380 };
381
382 population.remove_stale_species();
383 assert_eq!(population.species.len(), 2);
384 assert_eq!(population.top_fitness, 6.0);
385 assert_eq!(population.species[0].top_fitness, 5.0);
386 assert_eq!(population.species[1].top_fitness, 6.0);
387 }
388
389 #[test]
390 fn test_remove_stale_species_when_population_is_stale() {
391 let mut config = Config::default();
392 config.stale_species_threshold = 1;
393 config.stale_population_threshold = 1;
394
395 let mut species1 = Species::new(&config);
397 let mut genome1 = Genome::new(&config);
398 genome1.fitness = 4.0;
399 let mut genome2 = Genome::new(&config);
400 genome2.fitness = 1.0;
401 let mut genome3 = Genome::new(&config);
402 genome3.fitness = 2.0;
403 species1.genomes = vec![genome1, genome2, genome3];
404 species1.top_fitness = 5.0;
405
406 let mut species2 = Species::new(&config);
408 let mut genome1 = Genome::new(&config);
409 genome1.fitness = 5.0;
410 let mut genome2 = Genome::new(&config);
411 genome2.fitness = 1.0;
412 let mut genome3 = Genome::new(&config);
413 genome3.fitness = 2.0;
414 species2.genomes = vec![genome1, genome2, genome3];
415 species2.top_fitness = 4.0;
416
417 let mut species3 = Species::new(&config);
419 let mut genome1 = Genome::new(&config);
420 genome1.fitness = 6.0;
421 let mut genome2 = Genome::new(&config);
422 genome2.fitness = 1.0;
423 let mut genome3 = Genome::new(&config);
424 genome3.fitness = 2.0;
425 species3.genomes = vec![genome1, genome2, genome3];
426 species3.top_fitness = 4.0;
427
428 let mut population = Population::new(&config, "test");
429 population.species = vec![species1, species2, species3];
430 population.remove_stale_species();
431 assert_eq!(population.species.len(), 2);
432 assert_eq!(population.species[0].top_fitness, 6.0);
433 }
434
435 #[test]
436 fn test_breed_new_generation() {
437 let config = Config::default();
438 let mut population = Population::new(&config, "test");
439 population.breed_new_generation();
440 assert_eq!(population.generation, 2);
441 }
442
443 #[test]
444 fn test_calculate_pop_size_for_each_species() {
445 let mut config = Config::default();
446 config.population = 100;
447
448 let mut population = Population::new(&config, "test");
449
450 population.species = vec![Species::new(&config)];
452 population.species[0].adjusted_fitness = 0.0;
453 population.species[0].genomes = vec![Genome::new(&config), Genome::new(&config)];
454 let species_sizes = Population::calculate_pop_size_for_each_species(&population.species, 100, 2);
455 assert_eq!(species_sizes[0], 100);
456
457 population.species = vec![Species::new(&config), Species::new(&config)];
459 population.species[0].adjusted_fitness = 0.1;
460 population.species[0].genomes = vec![Genome::new(&config), Genome::new(&config)];
461
462 population.species[1].adjusted_fitness = 0.1;
463 population.species[1].genomes = vec![Genome::new(&config), Genome::new(&config)];
464 let species_sizes = Population::calculate_pop_size_for_each_species(&population.species, 100, 2);
465 assert_eq!(species_sizes[0], 50);
466 assert_eq!(species_sizes[1], 50);
467
468 population.species = vec![Species::new(&config), Species::new(&config)];
470 population.species[0].adjusted_fitness = 0.0;
471 population.species[0].genomes = vec![Genome::new(&config), Genome::new(&config)];
472
473 population.species[1].adjusted_fitness = 0.1;
474 population.species[1].genomes = vec![Genome::new(&config), Genome::new(&config)];
475 let species_sizes = Population::calculate_pop_size_for_each_species(&population.species, 100, 2);
476 assert_eq!(species_sizes[0], 2);
477 assert_eq!(species_sizes[1], 98);
478
479 population.species = vec![Species::new(&config), Species::new(&config)];
481 population.species[0].adjusted_fitness = 0.0;
482 population.species[0].genomes = vec![Genome::new(&config), Genome::new(&config)];
483
484 population.species[1].adjusted_fitness = 0.1;
485 population.species[1].genomes = vec![Genome::new(&config), Genome::new(&config)];
486 let species_sizes = Population::calculate_pop_size_for_each_species(&population.species, 4, 2);
487 assert_eq!(species_sizes[0], 2);
488 assert_eq!(species_sizes[1], 2);
489
490 population.species = vec![Species::new(&config), Species::new(&config)];
492 population.species[0].adjusted_fitness = 0.0;
493 population.species[0].genomes = vec![Genome::new(&config), Genome::new(&config)];
494
495 population.species[1].adjusted_fitness = 0.1;
496 population.species[1].genomes = vec![Genome::new(&config), Genome::new(&config)];
497 let species_sizes = Population::calculate_pop_size_for_each_species(&population.species, 3, 2);
498 assert_eq!(species_sizes[0], 2);
499 assert_eq!(species_sizes[1], 2);
500
501 population.species = vec![Species::new(&config), Species::new(&config)];
502 population.species[0].adjusted_fitness = 0.0;
503 population.species[0].genomes = vec![Genome::new(&config), Genome::new(&config)];
504
505 population.species[1].adjusted_fitness = 0.1;
506 population.species[1].genomes = vec![Genome::new(&config), Genome::new(&config)];
507 let species_sizes = Population::calculate_pop_size_for_each_species(&population.species, 3, 1);
508 assert_eq!(species_sizes[0], 1);
509 assert_eq!(species_sizes[1], 2);
510
511 population.species = vec![Species::new(&config), Species::new(&config)];
512 population.species[0].adjusted_fitness = 0.0;
513 population.species[0].genomes = vec![Genome::new(&config), Genome::new(&config)];
514
515 population.species[1].adjusted_fitness = 0.1;
516 population.species[1].genomes = vec![Genome::new(&config), Genome::new(&config)];
517 let species_sizes = Population::calculate_pop_size_for_each_species(&population.species, 99, 2);
518 assert_eq!(species_sizes[0], 2);
519 assert_eq!(species_sizes[1], 97);
520 }
521}