neat_rs/
genotype.rs

1
2extern crate rand;
3extern crate slow_nn;
4
5use super::traits::*;
6use super::random::*;
7
8trait SparseInsert<T> {
9    fn sparse_insert(&mut self, index: usize, val: T);
10}
11
12impl<T> SparseInsert<T> for Vec<Option<T>> {
13    fn sparse_insert(&mut self, index: usize, val: T) {
14        while self.len() <= index {
15            self.push(None);
16        }
17        self[index] = Some(val);
18    }
19}
20
21#[derive(Debug, Clone)]
22enum Node {
23    Input,
24    Bias,
25    Hidden(u128),
26    Output
27}
28
29impl Node {
30    fn bias() -> Self {
31        Node::Bias
32    }
33
34    fn input() -> Self{
35        Node::Input
36    }
37
38    fn output() -> Self {
39        Node::Output
40    }
41
42    fn is_hidden(&self) -> bool {
43        match *self {
44            Node::Hidden(_) => true,
45            _ => false
46        }
47    }
48
49    fn value(&self) -> u128 {
50        match *self {
51            Node::Input | Node::Bias => 0,
52            Node::Hidden(val) => val,
53            Node::Output => std::u128::MAX
54        }
55    }
56
57    fn can_connect_to(&self, to: &Self) -> bool {
58        self.value() < to.value()
59    }
60
61    fn new_hidden(from: &Self, to: &Self) -> Option<Self> {
62        let val1 = from.value();
63        let val2 = to.value();
64        
65        let mid = val1 + (val2 - val1) / 2;
66        
67        if val1 < mid && mid < val2 {
68            Some(Node::Hidden(mid))
69        } else {
70            None
71        }
72    }
73}
74
75#[derive(Debug, Clone)]
76enum ConnectionState {
77    Disabled,
78    Enabled
79}
80
81use ConnectionState::*;
82
83impl ConnectionState {
84    fn toggle(&mut self) {
85        *self = match *self {
86            Enabled => Disabled,
87            Disabled => Enabled
88        };
89    }
90
91    fn enable(&mut self) {
92        *self = Enabled;
93    }
94
95    fn disable(&mut self) {
96        *self = Disabled;
97    }
98}
99
100/// Connection gene provided by the Genome struct that can be used to build the neural network
101#[derive(Debug, Clone)]
102pub struct Connection {
103    from: usize,
104    to: usize,
105    weight: f64,
106    state: ConnectionState
107}
108
109impl Connection {
110    fn new(from: usize, to: usize, weight: f64) -> Self {
111        Self {
112            from,
113            to,
114            weight,
115            state: Enabled
116        }
117    }
118
119    fn disable(&mut self) {
120        self.state.disable();
121    }
122
123    fn enable(&mut self) {
124        self.state.enable();
125    }
126
127    fn toggle(&mut self) {
128        self.state.toggle();
129    }
130
131    fn shift_weight(&mut self) {
132        self.weight *= 0.95;
133    }
134
135    fn change_weight(&mut self, weight: f64) {
136        self.weight = weight;
137    }
138    /// Checks if the connection is enabled of not
139    pub fn is_enabled(&self) -> bool {
140        match self.state {
141            Enabled => true,
142            Disabled => false
143        }
144    }
145}
146
147/// Genotype representation of the network
148#[derive(Debug)]
149pub struct Genotype {
150    nodes: Vec<Option<Node>>,
151    conns: Vec<Option<Connection>>,
152    bias: f64,
153    inputs: usize,
154    outputs: usize
155}
156
157impl Genotype {
158    fn distance_from(&self, other: &Self) -> f64 {
159        let mut disjoint_genes = 0.;
160        let mut delta_w = 0.;
161        let mut excess_genes = 0.;
162
163        let mut n1: f64 = 0.;
164        let mut n2: f64 = 0.;
165
166        if self.conns.len() > other.conns.len() {
167            for conn in self.conns.iter().skip(other.conns.len()) {
168                if let Some(_) = conn.as_ref() {
169                    excess_genes += 1.;
170                }
171            }
172            n1 += excess_genes;
173        } else if self.conns.len() < other.conns.len() {
174            for conn in other.conns.iter().skip(self.conns.len()) {
175                if let Some(_) = conn.as_ref() {
176                    excess_genes += 1.;
177                }
178            }
179            n2 += excess_genes;
180        }
181
182        for (conn1, conn2) in self.conns.iter().zip(other.conns.iter()) {
183            match (&conn1, &conn2) {
184                (Some(connection1), Some(connection2)) => {
185                    delta_w += (connection1.weight - connection2.weight).abs();
186                    n1 += 1.;
187                    n2 += 1.;
188                },
189                (Some(_), None) => {
190                    disjoint_genes += 1.;
191                    n1 += 1.;
192                },
193                (None, Some(_)) => {
194                    disjoint_genes += 1.;
195                    n2 += 1.;
196                },
197                _ => {}
198            }
199        }
200
201        let mut n = n1.max(n2);
202
203        if n < 20. {
204            n = 1.;
205        }
206
207        excess_genes/n + disjoint_genes/n + 3.*delta_w
208    }
209
210    fn change_bias(&mut self) {
211        self.bias *= 95.;
212    }
213
214    fn new_bias(&mut self) {
215        self.bias = random_bias();
216    }
217
218    fn add_connection<T: GlobalNeatCounter>(&mut self, neat: &mut T) {
219        for _ in 0..100 {
220            let from = randint(self.nodes.len());
221            let to = randint(self.nodes.len());
222
223            if let (Some(node1), Some(node2)) = (&self.nodes[from], &self.nodes[to]) {
224                if node1.can_connect_to(node2) {
225                    if let Some(innov) = neat.try_adding_connection(from, to) {
226                        let new_connection = Connection::new(from, to, random_weight());
227                        self.conns.sparse_insert(innov, new_connection);
228                        break;
229                    }
230                } else if node2.can_connect_to(node1) {
231                    if let Some(innov) = neat.try_adding_connection(to, from) {
232                        let new_connection = Connection::new(to, from, random_weight());
233                        self.conns.sparse_insert(innov, new_connection);
234                        break;
235                    }
236                    break;
237                }
238            }
239        }
240    }
241
242    fn add_node<T: GlobalNeatCounter>(&mut self, neat: &mut T) {
243        if self.conns.len() == 0 {
244            return;
245        }
246        for _ in 0..100 {
247            let index = randint(self.conns.len());
248
249            // Check if this connection exists in this genome
250            if self.conns[index].is_none() {
251                continue;
252            }
253
254            if let Disabled = self.conns[index].as_ref().unwrap().state {
255                continue;
256            }
257
258            let connection = self.conns[index].as_ref().unwrap();
259            let from = connection.from;
260            let to = connection.to;
261            // let weight = connection.weight;
262
263            let node1 = self.nodes[from].as_ref()
264                .expect("How can the node not exist when connection to this node does?");
265            let node2 = self.nodes[to].as_ref()
266                .expect("How can the node not exist when connection to this node does?");
267            
268            if let Some(new_node) = Node::new_hidden(node1, node2) {
269                let new_index = neat.get_new_node();
270                self.nodes.sparse_insert(new_index, new_node);
271
272                let innov = neat.try_adding_connection(from, new_index)
273                    .expect("How can this new node already have a connection?");
274                let connection = Connection::new(from, new_index, random_weight());
275                self.conns.sparse_insert(innov, connection);
276
277                let innov = neat.try_adding_connection(new_index, to)
278                    .expect("How can this new node already have a connection?");
279                let connection = Connection::new(new_index, to, random_weight());
280                self.conns.sparse_insert(innov, connection);
281
282                self.conns[index].as_mut().unwrap().disable();
283
284                break;
285            }
286        }
287    }
288
289    /// Creates and returns the network corresponding to the genotype
290    pub fn get_network(&self) -> slow_nn::Network {
291        let connections: Vec<_> = self
292            .conns
293            .iter()
294            .filter(|c| c.is_some())
295            .map(|c| match c.as_ref() {
296                Some(conns) => (conns.from, conns.to, conns.weight).into(),
297                _ => panic!("this line will never be reached"),
298            })
299            .collect();
300    
301        let inputs = self.inputs;
302        let outputs = self.outputs;
303        let hidden = self.nodes.len() - 1 - inputs - outputs;
304        
305        slow_nn::Network::from_conns(self.bias, inputs, outputs, hidden, &connections)
306    }
307}
308
309impl Gene for Genotype {
310    fn empty(inputs: usize, outputs: usize) -> Self {
311        let nodes = (0..1).map(|_| Some(Node::bias()))
312            .chain((0..inputs).map(|_| Some(Node::input())))
313            .chain((0..outputs).map(|_| Some(Node::output())))
314            .collect();
315        Self {
316            nodes,
317            conns: Vec::new(),
318            bias: random_bias(),
319            inputs: inputs,
320            outputs: outputs
321        }
322    }
323
324    fn is_same_species_as(&self, other: &Self) -> bool {
325        self.distance_from(other) < 4.
326    }
327
328    fn cross(&self, other: &Self) -> Self {
329        let mut nodes: Vec<_> = self
330            .nodes
331            .iter()
332            .take_while(|x| x.is_some() && !x.as_ref().unwrap().is_hidden())
333            .cloned()
334            .collect();
335        
336        let mut add_nodes = |from, to| {
337            nodes.sparse_insert(from, self.nodes[from].clone().unwrap());
338            nodes.sparse_insert(to, self.nodes[to].clone().unwrap());
339        };
340        
341        let mut conns = Vec::new();
342        let bias = self.bias;
343        
344        let len = (self.conns.len() as i32).min(other.conns.len() as i32) as usize;
345
346        for i in 0..len {
347            let new_conn = match (&self.conns[i], &other.conns[i]) {
348                (Some(conn1), Some(conn2)) => {
349                    if random::<f64>() < 0.8 {
350                        add_nodes(conn1.from, conn1.to);
351                        Some(conn1.clone())
352                    } else {
353                        add_nodes(conn2.from, conn2.to);
354                        Some(conn2.clone())
355                    }
356                },
357                (Some(conn), None) => {
358                    add_nodes(conn.from, conn.to);
359                    Some(conn.clone())
360                },
361                _ => {
362                    None
363                }
364            };
365            conns.push(new_conn);
366        }
367
368        for maybe_conn in self.conns.iter().skip(len) {
369            if let Some(conn) = maybe_conn {
370                add_nodes(conn.from, conn.to);
371                conns.push(Some((*conn).clone()));
372            } else {
373                conns.push(None);
374            }
375        }
376
377        Self {
378            nodes,
379            conns,
380            bias,
381            inputs: self.inputs,
382            outputs: self.outputs
383        }
384    }
385
386    fn mutate<T: GlobalNeatCounter>(&mut self, neat: &mut T) {
387        match randint(100) {
388            0..=2 => self.add_node(neat),
389            3 => self.new_bias(),
390            4 => self.change_bias(),
391            5..=34 => self.add_connection(neat),
392            34..=40 if self.conns.len() >= 1 => {
393                let index = randint(self.conns.len());
394                if let Some(connection) = self.conns[index].as_mut() {
395                    match randint(100) {
396                        0..=1 => connection.shift_weight(),
397                        2..=3 => connection.change_weight(random_weight()),
398                        _ => {}
399                    }
400                }
401            }
402            _ => {}
403        }
404    }
405
406    fn predict(&self, input: &[f64], activate: fn(f64) -> f64) -> Vec<f64> {
407        let connections: Vec<_> = self
408            .conns
409            .iter()
410            .filter(|c| c.is_some())
411            .map(|c| match c.as_ref() {
412                Some(conns) => (conns.from, conns.to, conns.weight).into(),
413                _ => panic!("this line will never be reached"),
414            })
415            .collect();
416        
417        let inputs = self.inputs;
418        let outputs = self.outputs;
419        let hidden = self.nodes.len() - 1 - inputs - outputs;
420        
421        let net = slow_nn::Network::from_conns(self.bias, inputs, outputs, hidden, &connections);
422        net.predict(input, activate)
423    }
424}
425
426#[cfg(test)]
427mod tests {
428    use super::*;
429    use std::collections::HashSet;
430
431    struct Neat {
432        connections: HashSet<(usize, usize)>,
433        nodes: usize
434    }
435
436    impl Neat {
437        fn new(inputs: usize, outputs: usize) -> Self {
438            Self {
439                connections: HashSet::new(),
440                nodes: 1 + inputs + outputs
441            }
442        }
443    }
444
445    impl GlobalNeatCounter for Neat {
446        fn try_adding_connection(&mut self, from: usize, to: usize) -> Option<usize> {
447            let innov_num = self.connections.len();
448            if self.connections.insert((from, to)) {
449                Some(innov_num)
450            } else {
451                None
452            }
453        }
454
455        fn get_new_node(&mut self) -> usize {
456            let new_node = self.nodes;
457            self.nodes += 1;
458            new_node
459        }
460    }
461
462    #[test]
463    fn test_node() {
464        let input = Node::input();
465        let output = Node::output();
466
467        let hidden = Node::new_hidden(&input, &output).unwrap();
468        let hidden1 = Node::new_hidden(&input, &hidden).unwrap();
469        let hidden2 = Node::new_hidden(&hidden1, &output).unwrap();
470        let hiddden3 = Node::new_hidden(&hidden1, &hidden).unwrap();
471    }
472
473    #[test]
474    fn test_genome() {
475        let mut genome1 = Genotype::empty(3, 2);
476        let mut genome2 = Genotype::empty(3, 2);
477        let mut neat = Neat::new(3, 2);
478
479        for _ in 0..1000 {
480            genome1.mutate(&mut neat);
481        }
482    }
483}