1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#![allow(unused)]
#![deny(missing_docs)]
mod neat;
mod genotype;
mod traits;
mod random;
pub use neat::*;
pub use genotype::*;
pub use traits::*;
#[cfg(test)]
mod tests {
use super::*;
use gym::{GymClient, SpaceData::DISCRETE};
fn calculate_cart(genome: &impl Gene, display: bool) -> f64 {
let mut fitness = 0.0;
let client = GymClient::default();
let env = client.make("CartPole-v0");
let mut input = [0f64; 4];
let init = env.reset().unwrap().get_box().unwrap();
for i in 0..4 {
input[i] = init[i];
}
loop {
let pred = genome.predict(&input)[0];
let action = DISCRETE(if pred < 0. { 0 } else { 1 });
let state = env.step(&action).unwrap();
let input_box = state.observation.get_box().unwrap();
if display {
env.render();
}
for i in 0..4 {
input[i] = input_box[i];
}
if state.is_done {
break;
}
fitness += state.reward;
}
env.close();
if display {
println!("Fitness = {}", fitness);
}
fitness
}
#[test]
fn test_neat() {
let mut neat = Neat::<Genotype>::new(4, 1, 100, 0.01);
for i in 1..=100 {
println!("---------Gen #{}--------", i);
neat.next_generation(calculate_cart);
}
}
}