neat_lib/neat/
population.rs

1use std::vec;
2use serde::{Deserialize, Serialize};
3
4use crate::loader::save_load::{FileSaverLoader, Saver};
5
6use super::species::Species;
7use super::config::Config;
8use super::genome::Genome;
9use super::innovation_number::reset;
10use crate::neat::errors::*;
11
12#[derive(Serialize, Deserialize)]
13pub struct Population {
14    pub name: String,
15    pub generation: u32,
16    pub top_fitness: f64,
17    pub top_genome: Option<Genome>,
18    pub species: Vec<Species>,
19    pub staleness: u32,
20    config: Config,
21}
22
23pub struct EvaluteOption  {
24    pub epochs: u64,
25    pub fitness_target: Option<f64>
26}
27
28impl EvaluteOption {
29    pub fn new(epochs: u64) -> Self {
30        EvaluteOption { epochs, fitness_target: None }
31    }
32}
33
34pub type EvaluteFunc = fn(config: &Config, genome: &Genome) -> f64;
35
36impl Population {
37    pub fn new(config: &Config, name: &str) -> Self {
38        reset();
39        let mut pool = Population { 
40            name: name.to_owned(),
41            config: config.clone(),
42            generation: 1,
43            species: vec![],
44            top_fitness: 0.0,
45            top_genome: None,
46            staleness: 0,
47        };
48        pool.init_species();
49        pool
50    }
51
52    fn init_species(&mut self) {
53        let mut seed = Genome::new(&self.config);
54        seed.minimal_network();
55        for _ in 0..self.config.population {
56            let copy_gene = Genome::new_from(&seed, false);
57            self.add_to_species(copy_gene);
58        }
59    }
60
61    fn add_to_species(&mut self, genome: Genome) {
62        for i in 0..self.species.len() {
63            let species = &mut self.species[i];
64            if let Some(existing_genome) = species.get_genome_by_index(0) {
65                if Genome::is_same_species(&genome, existing_genome) {
66                    species.add_genome(genome);
67                    return
68                };
69            }
70        }
71        let mut new_species = Species::new(&self.config);
72        new_species.add_genome(genome);
73        self.species.push(new_species);
74    }
75
76    pub fn run(&mut self, evalute_func: EvaluteFunc, option: EvaluteOption) -> Result<Genome, Errors> {
77        let mut k = 0;
78        let config = self.config.clone();
79        while k < option.epochs {
80            k += 1;
81            let genomes = self.all_genomes();
82            for genome in genomes {
83                let fitness = evalute_func(&config, genome);
84                genome.fitness = fitness;
85                if let Some(fitness_target) = option.fitness_target {
86                    if fitness >= fitness_target {
87                        return Ok(genome.clone());
88                    }
89                }
90            }
91            self.breed_new_generation()?;
92            println!("round {:?} top fitness:{:?} species: {:?}", k, self.top_fitness, self.species.len());
93        }
94        Err(Errors::CanNotFindSolution())
95    }
96
97    pub fn set_fitness(&mut self, fitnesses: Vec<f64>) -> Result<(), Errors> {
98        let mut genomes = self.all_genomes();
99        if fitnesses.len() != genomes.len() {
100            return Err(Errors::InputSizeNotMatch("Set fitnesses size not match genomes size".to_owned()))
101        }
102        for (index, genome) in genomes.iter_mut().enumerate() {
103            genome.fitness = fitnesses[index];
104        }
105        Ok(())
106    }
107
108    pub fn all_genomes(&mut self) -> Vec<&mut Genome> {
109        self.species.iter_mut().fold(vec![], |mut acc, species| {
110            for genomo in &mut species.genomes {
111                acc.push(genomo);
112            }
113            acc
114        })
115    }
116
117    fn calculate_species_adjusted_fitness(&mut self) {
118        self.species.iter_mut()
119            .for_each(| species | {
120                species.calculate_adjusted_fitness();
121            });
122    }
123    
124    fn remove_stale_species(&mut self) {
125        self.species.iter_mut().for_each(| species | species.check_progress());
126        self.check_progress();
127        self.species.retain(| species | species.staleness < self.config.stale_species_threshold || species.top_fitness >= self.top_fitness );
128        if self.staleness >= self.config.stale_population_threshold {
129            self.species = vec![];
130            self.staleness = 0;
131            //self.species.retain(| species | species.top_fitness >= self.top_fitness )
132        }
133    }
134
135    fn check_progress(&mut self) {
136        let (top_fitness, genome) = self.get_top_genome();
137        if top_fitness > self.top_fitness {
138            self.top_fitness = top_fitness;
139            self.staleness = 0;
140            self.top_genome = genome;
141        } else {
142            self.staleness += 1;
143        }
144    }
145
146    pub fn breed_new_generation(&mut self) -> Result<(), Errors> {
147        let mut children = vec![];
148        self.remove_stale_species();
149        if self.species.is_empty() {
150            if let Some(genome) = self.top_genome.clone() {
151                for _ in 0..self.config.population {
152                    children.push(genome.clone());
153                }
154            } else {
155                return Err(Errors::PopulationExtinction());
156            }
157        } else {
158            self.calculate_species_adjusted_fitness();
159            let species_sizes = Self::calculate_pop_size_for_each_species(&self.species, self.config.population as usize, self.config.min_species_size);
160            for (target_size, species) in species_sizes.iter().zip(&mut self.species) {
161                children.extend(species.reproduce_to_size(*target_size));
162            }
163        }
164        for child in children {
165            self.add_to_species(child);
166        }
167        self.generation += 1;
168        Ok(())
169    }
170
171    fn calculate_pop_size_for_each_species(species: &Vec<Species>, pop_size: usize, min_species_size: usize) -> Vec<usize> {
172        let adjusted_fitness_sum = species.iter().fold(0.0, |acc, species| {
173             acc + species.adjusted_fitness
174        });
175        let mut new_species_pos_size = vec![];
176        for s in species {
177            let adjusted_fitness = s.adjusted_fitness;
178            let mut species_size: i64 = min_species_size as i64;
179            if adjusted_fitness_sum > 0.0 {
180                species_size = min_species_size.max((adjusted_fitness / adjusted_fitness_sum * pop_size as f64) as usize) as i64;
181            };
182            new_species_pos_size.push(species_size as usize);
183        }
184        let mut diff = pop_size as i64 - new_species_pos_size.iter().fold(0, | acc, size | acc + size ) as i64;
185        let mut counter = 0;
186        while diff != 0 {
187            let index = counter as usize % new_species_pos_size.len();
188            if diff > 0 {
189                new_species_pos_size[index] += 1;
190                diff -= 1;
191            } else {
192                if new_species_pos_size[index] > min_species_size {
193                    new_species_pos_size[index] -= 1;
194                    diff += 1;
195                } else {
196                    // every species is at min species size then give up.
197                    if !new_species_pos_size.iter().any(| size | *size > min_species_size) {
198                        break;
199                    }
200                }
201            }
202            counter += 1;
203        }
204        new_species_pos_size
205    }
206
207    pub fn get_top_genome(&self) -> (f64, Option<Genome>) {
208        let mut top_fitness = 0.0;
209        let mut top_genomo = None;
210        for species in &self.species {
211            if let Some(genome) = species.get_top_genome() {
212                if genome.fitness >= top_fitness {
213                    top_fitness = genome.fitness;
214                    top_genomo = Some(genome.clone());
215                }
216            }
217        }
218        return (top_fitness, top_genomo);
219    }
220
221    pub fn save_to(&self, file: &str) -> std::io::Result<()> {
222        let saver = FileSaverLoader::new(file);
223        saver.save(self)
224    }
225}
226
227impl std::fmt::Debug for Population {
228    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
229        let (_, genome) = self.get_top_genome();
230        write!(f, "Generation {:?} \nTop fitness {:?}\nNum of species {:?}\nTop Genome {:?}\n", self.generation, self.top_fitness, self.species.len(), genome.unwrap())
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use super::super::config::Config;
238    #[test]
239    fn test_init_species() {
240        let config = Config::default();
241        let population = Population::new(&config, "test");
242        assert_eq!(population.species.len(), 1);
243    }
244
245    #[test]
246    fn test_set_fitnesses() {
247        let mut config = Config::default();
248        config.population = 5;
249        let mut population = Population::new(&config, "test");
250        let fitnesses = vec![0.1, 0.2, 0.3, 0.4, 0.5];
251        population.set_fitness(vec![0.1, 0.2, 0.3, 0.4, 0.5]);
252        let genomes = population.all_genomes();
253        for (i, genome) in genomes.iter().enumerate() {
254            assert_eq!(genome.fitness, fitnesses[i]);
255        }
256    }
257
258    #[test]
259    fn test_check_progress_making_progress() {
260        let config = Config::default();
261        // higher fitness
262        let mut species1 = Species::new(&config);
263        let mut genome1 = Genome::new(&config);
264        genome1.fitness = 4.0;
265        let mut genome2 = Genome::new(&config);
266        genome2.fitness = 1.0;
267        let mut genome3 = Genome::new(&config);
268        genome3.fitness = 2.0;
269        species1.genomes = vec![genome1, genome2, genome3];
270
271        let mut species2 = Species::new(&config);
272        let mut genome1 = Genome::new(&config);
273        genome1.fitness = 5.0;
274        let mut genome2 = Genome::new(&config);
275        genome2.fitness = 1.0;
276        let mut genome3 = Genome::new(&config);
277        genome3.fitness = 2.0;
278        species2.genomes = vec![genome1, genome2, genome3];
279
280        let mut population = Population{
281            top_genome: None,
282            name: "test".to_owned(),
283            config,
284            generation: 0,
285            top_fitness: 0.0,
286            staleness: 0,
287            species: vec![species1, species2]
288        };
289
290        population.check_progress();
291        assert_eq!(population.staleness, 0);
292        assert_eq!(population.top_fitness, 5.0);
293    }
294
295
296    #[test]
297    fn test_check_progress_not_making_progress() {
298        let config = Config::default();
299        // higher fitness
300        let mut species1 = Species::new(&config);
301        let mut genome1 = Genome::new(&config);
302        genome1.fitness = 4.0;
303        let mut genome2 = Genome::new(&config);
304        genome2.fitness = 1.0;
305        let mut genome3 = Genome::new(&config);
306        genome3.fitness = 2.0;
307        species1.genomes = vec![genome1, genome2, genome3];
308
309        let mut species2 = Species::new(&config);
310        let mut genome1 = Genome::new(&config);
311        genome1.fitness = 5.0;
312        let mut genome2 = Genome::new(&config);
313        genome2.fitness = 1.0;
314        let mut genome3 = Genome::new(&config);
315        genome3.fitness = 2.0;
316        species2.genomes = vec![genome1, genome2, genome3];
317
318        let mut population = Population{
319            top_genome: None,
320            config,
321            name: "test".to_owned(),
322            generation: 0,
323            top_fitness: 6.0,
324            staleness: 0,
325            species: vec![species1, species2]
326        };
327
328        population.check_progress();
329        assert_eq!(population.staleness, 1);
330        assert_eq!(population.top_fitness, 6.0);
331    }
332
333    #[test]
334    fn test_remove_stale_species() {
335        let mut config = Config::default();
336        config.stale_species_threshold = 1;
337        config.stale_population_threshold = 1;
338
339        // should become stale
340        let mut species1 = Species::new(&config);
341        let mut genome1 = Genome::new(&config);
342        genome1.fitness = 4.0;
343        let mut genome2 = Genome::new(&config);
344        genome2.fitness = 1.0;
345        let mut genome3 = Genome::new(&config);
346        genome3.fitness = 2.0;
347        species1.genomes = vec![genome1, genome2, genome3];
348        species1.top_fitness = 5.0;
349
350        // should not become stale
351        let mut species2 = Species::new(&config);
352        let mut genome1 = Genome::new(&config);
353        genome1.fitness = 5.0;
354        let mut genome2 = Genome::new(&config);
355        genome2.fitness = 1.0;
356        let mut genome3 = Genome::new(&config);
357        genome3.fitness = 2.0;
358        species2.genomes = vec![genome1, genome2, genome3];
359        species2.top_fitness = 4.0;
360
361        // should not become stale
362        let mut species3 = Species::new(&config);
363        let mut genome1 = Genome::new(&config);
364        genome1.fitness = 6.0;
365        let mut genome2 = Genome::new(&config);
366        genome2.fitness = 1.0;
367        let mut genome3 = Genome::new(&config);
368        genome3.fitness = 2.0;
369        species3.genomes = vec![genome1, genome2, genome3];
370        species3.top_fitness = 4.0;
371
372        let mut population = Population{
373            config,
374            name: "test".to_owned(),
375            generation: 0,
376            top_fitness: 5.9,
377            staleness: 0,
378            species: vec![species1, species2, species3],
379            top_genome: None
380        };
381
382        population.remove_stale_species();
383        assert_eq!(population.species.len(), 2);
384        assert_eq!(population.top_fitness, 6.0);
385        assert_eq!(population.species[0].top_fitness, 5.0);
386        assert_eq!(population.species[1].top_fitness, 6.0);
387    }
388
389    #[test]
390    fn test_remove_stale_species_when_population_is_stale() {
391        let mut config = Config::default();
392        config.stale_species_threshold = 1;
393        config.stale_population_threshold = 1;
394
395        // should become stale
396        let mut species1 = Species::new(&config);
397        let mut genome1 = Genome::new(&config);
398        genome1.fitness = 4.0;
399        let mut genome2 = Genome::new(&config);
400        genome2.fitness = 1.0;
401        let mut genome3 = Genome::new(&config);
402        genome3.fitness = 2.0;
403        species1.genomes = vec![genome1, genome2, genome3];
404        species1.top_fitness = 5.0;
405
406        // should not become stale
407        let mut species2 = Species::new(&config);
408        let mut genome1 = Genome::new(&config);
409        genome1.fitness = 5.0;
410        let mut genome2 = Genome::new(&config);
411        genome2.fitness = 1.0;
412        let mut genome3 = Genome::new(&config);
413        genome3.fitness = 2.0;
414        species2.genomes = vec![genome1, genome2, genome3];
415        species2.top_fitness = 4.0;
416
417        // should not become stale
418        let mut species3 = Species::new(&config);
419        let mut genome1 = Genome::new(&config);
420        genome1.fitness = 6.0;
421        let mut genome2 = Genome::new(&config);
422        genome2.fitness = 1.0;
423        let mut genome3 = Genome::new(&config);
424        genome3.fitness = 2.0;
425        species3.genomes = vec![genome1, genome2, genome3];
426        species3.top_fitness = 4.0;
427
428        let mut population = Population::new(&config, "test");
429        population.species = vec![species1, species2, species3];
430        population.remove_stale_species();
431        assert_eq!(population.species.len(), 2);
432        assert_eq!(population.species[0].top_fitness, 6.0);
433    }
434
435    #[test]
436    fn test_breed_new_generation() {
437        let config = Config::default();
438        let mut population = Population::new(&config, "test");
439        population.breed_new_generation();
440        assert_eq!(population.generation, 2);
441    }
442
443    #[test]
444    fn test_calculate_pop_size_for_each_species() {
445        let mut config = Config::default();
446        config.population = 100;
447
448        let mut population = Population::new(&config, "test");
449
450        // only on species
451        population.species = vec![Species::new(&config)];
452        population.species[0].adjusted_fitness = 0.0;
453        population.species[0].genomes = vec![Genome::new(&config), Genome::new(&config)];
454        let species_sizes = Population::calculate_pop_size_for_each_species(&population.species, 100, 2);
455        assert_eq!(species_sizes[0], 100);
456
457        // two species
458        population.species = vec![Species::new(&config), Species::new(&config)];
459        population.species[0].adjusted_fitness = 0.1;
460        population.species[0].genomes = vec![Genome::new(&config), Genome::new(&config)];
461
462        population.species[1].adjusted_fitness = 0.1;
463        population.species[1].genomes = vec![Genome::new(&config), Genome::new(&config)];
464        let species_sizes = Population::calculate_pop_size_for_each_species(&population.species, 100, 2);
465        assert_eq!(species_sizes[0], 50);
466        assert_eq!(species_sizes[1], 50);
467
468        // two species
469        population.species = vec![Species::new(&config), Species::new(&config)];
470        population.species[0].adjusted_fitness = 0.0;
471        population.species[0].genomes = vec![Genome::new(&config), Genome::new(&config)];
472
473        population.species[1].adjusted_fitness = 0.1;
474        population.species[1].genomes = vec![Genome::new(&config), Genome::new(&config)];
475        let species_sizes = Population::calculate_pop_size_for_each_species(&population.species, 100, 2);
476        assert_eq!(species_sizes[0], 2);
477        assert_eq!(species_sizes[1], 98);
478
479        // two species
480        population.species = vec![Species::new(&config), Species::new(&config)];
481        population.species[0].adjusted_fitness = 0.0;
482        population.species[0].genomes = vec![Genome::new(&config), Genome::new(&config)];
483
484        population.species[1].adjusted_fitness = 0.1;
485        population.species[1].genomes = vec![Genome::new(&config), Genome::new(&config)];
486        let species_sizes = Population::calculate_pop_size_for_each_species(&population.species, 4, 2);
487        assert_eq!(species_sizes[0], 2);
488        assert_eq!(species_sizes[1], 2);
489
490        // two species
491        population.species = vec![Species::new(&config), Species::new(&config)];
492        population.species[0].adjusted_fitness = 0.0;
493        population.species[0].genomes = vec![Genome::new(&config), Genome::new(&config)];
494
495        population.species[1].adjusted_fitness = 0.1;
496        population.species[1].genomes = vec![Genome::new(&config), Genome::new(&config)];
497        let species_sizes = Population::calculate_pop_size_for_each_species(&population.species, 3, 2);
498        assert_eq!(species_sizes[0], 2);
499        assert_eq!(species_sizes[1], 2);
500
501        population.species = vec![Species::new(&config), Species::new(&config)];
502        population.species[0].adjusted_fitness = 0.0;
503        population.species[0].genomes = vec![Genome::new(&config), Genome::new(&config)];
504
505        population.species[1].adjusted_fitness = 0.1;
506        population.species[1].genomes = vec![Genome::new(&config), Genome::new(&config)];
507        let species_sizes = Population::calculate_pop_size_for_each_species(&population.species, 3, 1);
508        assert_eq!(species_sizes[0], 1);
509        assert_eq!(species_sizes[1], 2);
510
511        population.species = vec![Species::new(&config), Species::new(&config)];
512        population.species[0].adjusted_fitness = 0.0;
513        population.species[0].genomes = vec![Genome::new(&config), Genome::new(&config)];
514
515        population.species[1].adjusted_fitness = 0.1;
516        population.species[1].genomes = vec![Genome::new(&config), Genome::new(&config)];
517        let species_sizes = Population::calculate_pop_size_for_each_species(&population.species, 99, 2);
518        assert_eq!(species_sizes[0], 2);
519        assert_eq!(species_sizes[1], 97);
520    }
521}