neat_rs/
lib.rs

1#![allow(unused)]
2#![deny(missing_docs)]
3
4//! # neat-rs
5//! 
6//! Implementation of neat algorithm in rust
7
8mod neat;
9mod genotype;
10mod traits;
11mod random;
12pub use neat::*;
13pub use genotype::*;
14pub use traits::*;
15
16#[cfg(test)]
17mod tests {
18    use super::*;
19
20    use gym::{GymClient, SpaceData::DISCRETE};
21    
22    fn tanh(x: f64) -> f64 {
23        x.tanh()
24    }
25
26    fn calculate_cart(genome: &impl Gene, display: bool) -> f64 {
27        let mut fitness = 0.0;
28        
29        let client = GymClient::default();
30        let env = client.make("CartPole-v1");
31        let mut input = [0f64; 4];
32
33        let init = env.reset().unwrap().get_box().unwrap();
34        for i in 0..4 {
35            input[i] = init[i];
36        }
37
38        loop {
39            let pred = genome.predict(&input, tanh)[0];
40
41            let action = DISCRETE(if pred < 0. { 0 } else { 1 });
42            let state = env.step(&action).unwrap();
43            let input_box = state.observation.get_box().unwrap();
44            if display {
45                env.render();
46            }
47            for i in 0..4 {
48                input[i] = input_box[i];
49            }
50            if state.is_done {
51                break;
52            }
53            // println!("{}", state.reward);
54            fitness += state.reward;
55        }
56        env.close();
57
58        if display {
59            // println!("Output = {:#?}", for_display);
60            println!("Fitness = {}", fitness);
61        }
62
63        fitness
64    }
65
66    #[test]
67    fn test_cart() {
68        let mut neat = Neat::<Genotype>::new(4, 1, 1000, 0.1);
69
70        for i in 1..=100 {
71            println!("---------Gen #{}--------", i);
72            let (scores, total_score) = neat.calculate_fitness(calculate_cart);
73            neat.next_generation(&scores, total_score);
74        }
75    }
76
77    use std::{thread, time};
78
79    fn calculate_pacman(genome: &impl Gene, display: bool) -> f64 {
80        let mut time = 0.0;
81        let mut score = 0.0;
82        
83        let client = GymClient::default();
84        let env = client.make("MsPacman-ram-v0");
85        let mut input = [0f64; 128];
86
87        let init = env.reset().unwrap().get_box().unwrap();
88        for i in 0..128 {
89            input[i] = init[i];
90        }
91
92        loop {
93            let preds = genome.predict(&input, tanh);
94            let mut argmax = 0;
95            let mut max = preds[argmax];
96
97            for i in 1..6 {
98                if preds[i] > max {
99                    max = preds[i];
100                    argmax = i;
101                }
102            }
103
104            let action = DISCRETE(argmax);
105            let state = env.step(&action).unwrap();
106            let input_box = state.observation.get_box().unwrap();
107            if display {
108                env.render();
109                thread::sleep(time::Duration::from_millis(200));
110            }
111            for i in 0..4 {
112                input[i] = input_box[i];
113            }
114            if state.is_done {
115                break;
116            }
117            // println!("{}", state.reward);
118            time += 1.;
119            score += state.reward;
120        }
121        env.close();
122
123        if display {
124            // println!("Output = {:#?}", for_display);
125            println!("Time = {}, Score = {}", time, score);
126        }
127
128        time*(1. + score)
129    }
130
131    #[test]
132    fn test_pacman() {
133        let mut neat = Neat::<Genotype>::new(128, 6, 200, 1.);
134
135        for i in 1..=100 {
136            println!("---------Gen #{}--------", i);
137            let (scores, total_score) = neat.calculate_fitness(calculate_pacman);
138            neat.next_generation(&scores, total_score);
139        }
140    }
141}