1#![allow(unused)]
2#![deny(missing_docs)]
3
4mod neat;
9mod genotype;
10mod traits;
11mod random;
12pub use neat::*;
13pub use genotype::*;
14pub use traits::*;
15
16#[cfg(test)]
17mod tests {
18 use super::*;
19
20 use gym::{GymClient, SpaceData::DISCRETE};
21
22 fn tanh(x: f64) -> f64 {
23 x.tanh()
24 }
25
26 fn calculate_cart(genome: &impl Gene, display: bool) -> f64 {
27 let mut fitness = 0.0;
28
29 let client = GymClient::default();
30 let env = client.make("CartPole-v1");
31 let mut input = [0f64; 4];
32
33 let init = env.reset().unwrap().get_box().unwrap();
34 for i in 0..4 {
35 input[i] = init[i];
36 }
37
38 loop {
39 let pred = genome.predict(&input, tanh)[0];
40
41 let action = DISCRETE(if pred < 0. { 0 } else { 1 });
42 let state = env.step(&action).unwrap();
43 let input_box = state.observation.get_box().unwrap();
44 if display {
45 env.render();
46 }
47 for i in 0..4 {
48 input[i] = input_box[i];
49 }
50 if state.is_done {
51 break;
52 }
53 fitness += state.reward;
55 }
56 env.close();
57
58 if display {
59 println!("Fitness = {}", fitness);
61 }
62
63 fitness
64 }
65
66 #[test]
67 fn test_cart() {
68 let mut neat = Neat::<Genotype>::new(4, 1, 1000, 0.1);
69
70 for i in 1..=100 {
71 println!("---------Gen #{}--------", i);
72 let (scores, total_score) = neat.calculate_fitness(calculate_cart);
73 neat.next_generation(&scores, total_score);
74 }
75 }
76
77 use std::{thread, time};
78
79 fn calculate_pacman(genome: &impl Gene, display: bool) -> f64 {
80 let mut time = 0.0;
81 let mut score = 0.0;
82
83 let client = GymClient::default();
84 let env = client.make("MsPacman-ram-v0");
85 let mut input = [0f64; 128];
86
87 let init = env.reset().unwrap().get_box().unwrap();
88 for i in 0..128 {
89 input[i] = init[i];
90 }
91
92 loop {
93 let preds = genome.predict(&input, tanh);
94 let mut argmax = 0;
95 let mut max = preds[argmax];
96
97 for i in 1..6 {
98 if preds[i] > max {
99 max = preds[i];
100 argmax = i;
101 }
102 }
103
104 let action = DISCRETE(argmax);
105 let state = env.step(&action).unwrap();
106 let input_box = state.observation.get_box().unwrap();
107 if display {
108 env.render();
109 thread::sleep(time::Duration::from_millis(200));
110 }
111 for i in 0..4 {
112 input[i] = input_box[i];
113 }
114 if state.is_done {
115 break;
116 }
117 time += 1.;
119 score += state.reward;
120 }
121 env.close();
122
123 if display {
124 println!("Time = {}, Score = {}", time, score);
126 }
127
128 time*(1. + score)
129 }
130
131 #[test]
132 fn test_pacman() {
133 let mut neat = Neat::<Genotype>::new(128, 6, 200, 1.);
134
135 for i in 1..=100 {
136 println!("---------Gen #{}--------", i);
137 let (scores, total_score) = neat.calculate_fitness(calculate_pacman);
138 neat.next_generation(&scores, total_score);
139 }
140 }
141}