1
2extern crate rand;
3extern crate slow_nn;
4
5use super::traits::*;
6use super::random::*;
7
8trait SparseInsert<T> {
9 fn sparse_insert(&mut self, index: usize, val: T);
10}
11
12impl<T> SparseInsert<T> for Vec<Option<T>> {
13 fn sparse_insert(&mut self, index: usize, val: T) {
14 while self.len() <= index {
15 self.push(None);
16 }
17 self[index] = Some(val);
18 }
19}
20
21#[derive(Debug, Clone)]
22enum Node {
23 Input,
24 Bias,
25 Hidden(u128),
26 Output
27}
28
29impl Node {
30 fn bias() -> Self {
31 Node::Bias
32 }
33
34 fn input() -> Self{
35 Node::Input
36 }
37
38 fn output() -> Self {
39 Node::Output
40 }
41
42 fn is_hidden(&self) -> bool {
43 match *self {
44 Node::Hidden(_) => true,
45 _ => false
46 }
47 }
48
49 fn value(&self) -> u128 {
50 match *self {
51 Node::Input | Node::Bias => 0,
52 Node::Hidden(val) => val,
53 Node::Output => std::u128::MAX
54 }
55 }
56
57 fn can_connect_to(&self, to: &Self) -> bool {
58 self.value() < to.value()
59 }
60
61 fn new_hidden(from: &Self, to: &Self) -> Option<Self> {
62 let val1 = from.value();
63 let val2 = to.value();
64
65 let mid = val1 + (val2 - val1) / 2;
66
67 if val1 < mid && mid < val2 {
68 Some(Node::Hidden(mid))
69 } else {
70 None
71 }
72 }
73}
74
75#[derive(Debug, Clone)]
76enum ConnectionState {
77 Disabled,
78 Enabled
79}
80
81use ConnectionState::*;
82
83impl ConnectionState {
84 fn toggle(&mut self) {
85 *self = match *self {
86 Enabled => Disabled,
87 Disabled => Enabled
88 };
89 }
90
91 fn enable(&mut self) {
92 *self = Enabled;
93 }
94
95 fn disable(&mut self) {
96 *self = Disabled;
97 }
98}
99
100#[derive(Debug, Clone)]
102pub struct Connection {
103 from: usize,
104 to: usize,
105 weight: f64,
106 state: ConnectionState
107}
108
109impl Connection {
110 fn new(from: usize, to: usize, weight: f64) -> Self {
111 Self {
112 from,
113 to,
114 weight,
115 state: Enabled
116 }
117 }
118
119 fn disable(&mut self) {
120 self.state.disable();
121 }
122
123 fn enable(&mut self) {
124 self.state.enable();
125 }
126
127 fn toggle(&mut self) {
128 self.state.toggle();
129 }
130
131 fn shift_weight(&mut self) {
132 self.weight *= 0.95;
133 }
134
135 fn change_weight(&mut self, weight: f64) {
136 self.weight = weight;
137 }
138 pub fn is_enabled(&self) -> bool {
140 match self.state {
141 Enabled => true,
142 Disabled => false
143 }
144 }
145}
146
147#[derive(Debug)]
149pub struct Genotype {
150 nodes: Vec<Option<Node>>,
151 conns: Vec<Option<Connection>>,
152 bias: f64,
153 inputs: usize,
154 outputs: usize
155}
156
157impl Genotype {
158 fn distance_from(&self, other: &Self) -> f64 {
159 let mut disjoint_genes = 0.;
160 let mut delta_w = 0.;
161 let mut excess_genes = 0.;
162
163 let mut n1: f64 = 0.;
164 let mut n2: f64 = 0.;
165
166 if self.conns.len() > other.conns.len() {
167 for conn in self.conns.iter().skip(other.conns.len()) {
168 if let Some(_) = conn.as_ref() {
169 excess_genes += 1.;
170 }
171 }
172 n1 += excess_genes;
173 } else if self.conns.len() < other.conns.len() {
174 for conn in other.conns.iter().skip(self.conns.len()) {
175 if let Some(_) = conn.as_ref() {
176 excess_genes += 1.;
177 }
178 }
179 n2 += excess_genes;
180 }
181
182 for (conn1, conn2) in self.conns.iter().zip(other.conns.iter()) {
183 match (&conn1, &conn2) {
184 (Some(connection1), Some(connection2)) => {
185 delta_w += (connection1.weight - connection2.weight).abs();
186 n1 += 1.;
187 n2 += 1.;
188 },
189 (Some(_), None) => {
190 disjoint_genes += 1.;
191 n1 += 1.;
192 },
193 (None, Some(_)) => {
194 disjoint_genes += 1.;
195 n2 += 1.;
196 },
197 _ => {}
198 }
199 }
200
201 let mut n = n1.max(n2);
202
203 if n < 20. {
204 n = 1.;
205 }
206
207 excess_genes/n + disjoint_genes/n + 3.*delta_w
208 }
209
210 fn change_bias(&mut self) {
211 self.bias *= 95.;
212 }
213
214 fn new_bias(&mut self) {
215 self.bias = random_bias();
216 }
217
218 fn add_connection<T: GlobalNeatCounter>(&mut self, neat: &mut T) {
219 for _ in 0..100 {
220 let from = randint(self.nodes.len());
221 let to = randint(self.nodes.len());
222
223 if let (Some(node1), Some(node2)) = (&self.nodes[from], &self.nodes[to]) {
224 if node1.can_connect_to(node2) {
225 if let Some(innov) = neat.try_adding_connection(from, to) {
226 let new_connection = Connection::new(from, to, random_weight());
227 self.conns.sparse_insert(innov, new_connection);
228 break;
229 }
230 } else if node2.can_connect_to(node1) {
231 if let Some(innov) = neat.try_adding_connection(to, from) {
232 let new_connection = Connection::new(to, from, random_weight());
233 self.conns.sparse_insert(innov, new_connection);
234 break;
235 }
236 break;
237 }
238 }
239 }
240 }
241
242 fn add_node<T: GlobalNeatCounter>(&mut self, neat: &mut T) {
243 if self.conns.len() == 0 {
244 return;
245 }
246 for _ in 0..100 {
247 let index = randint(self.conns.len());
248
249 if self.conns[index].is_none() {
251 continue;
252 }
253
254 if let Disabled = self.conns[index].as_ref().unwrap().state {
255 continue;
256 }
257
258 let connection = self.conns[index].as_ref().unwrap();
259 let from = connection.from;
260 let to = connection.to;
261 let node1 = self.nodes[from].as_ref()
264 .expect("How can the node not exist when connection to this node does?");
265 let node2 = self.nodes[to].as_ref()
266 .expect("How can the node not exist when connection to this node does?");
267
268 if let Some(new_node) = Node::new_hidden(node1, node2) {
269 let new_index = neat.get_new_node();
270 self.nodes.sparse_insert(new_index, new_node);
271
272 let innov = neat.try_adding_connection(from, new_index)
273 .expect("How can this new node already have a connection?");
274 let connection = Connection::new(from, new_index, random_weight());
275 self.conns.sparse_insert(innov, connection);
276
277 let innov = neat.try_adding_connection(new_index, to)
278 .expect("How can this new node already have a connection?");
279 let connection = Connection::new(new_index, to, random_weight());
280 self.conns.sparse_insert(innov, connection);
281
282 self.conns[index].as_mut().unwrap().disable();
283
284 break;
285 }
286 }
287 }
288
289 pub fn get_network(&self) -> slow_nn::Network {
291 let connections: Vec<_> = self
292 .conns
293 .iter()
294 .filter(|c| c.is_some())
295 .map(|c| match c.as_ref() {
296 Some(conns) => (conns.from, conns.to, conns.weight).into(),
297 _ => panic!("this line will never be reached"),
298 })
299 .collect();
300
301 let inputs = self.inputs;
302 let outputs = self.outputs;
303 let hidden = self.nodes.len() - 1 - inputs - outputs;
304
305 slow_nn::Network::from_conns(self.bias, inputs, outputs, hidden, &connections)
306 }
307}
308
309impl Gene for Genotype {
310 fn empty(inputs: usize, outputs: usize) -> Self {
311 let nodes = (0..1).map(|_| Some(Node::bias()))
312 .chain((0..inputs).map(|_| Some(Node::input())))
313 .chain((0..outputs).map(|_| Some(Node::output())))
314 .collect();
315 Self {
316 nodes,
317 conns: Vec::new(),
318 bias: random_bias(),
319 inputs: inputs,
320 outputs: outputs
321 }
322 }
323
324 fn is_same_species_as(&self, other: &Self) -> bool {
325 self.distance_from(other) < 4.
326 }
327
328 fn cross(&self, other: &Self) -> Self {
329 let mut nodes: Vec<_> = self
330 .nodes
331 .iter()
332 .take_while(|x| x.is_some() && !x.as_ref().unwrap().is_hidden())
333 .cloned()
334 .collect();
335
336 let mut add_nodes = |from, to| {
337 nodes.sparse_insert(from, self.nodes[from].clone().unwrap());
338 nodes.sparse_insert(to, self.nodes[to].clone().unwrap());
339 };
340
341 let mut conns = Vec::new();
342 let bias = self.bias;
343
344 let len = (self.conns.len() as i32).min(other.conns.len() as i32) as usize;
345
346 for i in 0..len {
347 let new_conn = match (&self.conns[i], &other.conns[i]) {
348 (Some(conn1), Some(conn2)) => {
349 if random::<f64>() < 0.8 {
350 add_nodes(conn1.from, conn1.to);
351 Some(conn1.clone())
352 } else {
353 add_nodes(conn2.from, conn2.to);
354 Some(conn2.clone())
355 }
356 },
357 (Some(conn), None) => {
358 add_nodes(conn.from, conn.to);
359 Some(conn.clone())
360 },
361 _ => {
362 None
363 }
364 };
365 conns.push(new_conn);
366 }
367
368 for maybe_conn in self.conns.iter().skip(len) {
369 if let Some(conn) = maybe_conn {
370 add_nodes(conn.from, conn.to);
371 conns.push(Some((*conn).clone()));
372 } else {
373 conns.push(None);
374 }
375 }
376
377 Self {
378 nodes,
379 conns,
380 bias,
381 inputs: self.inputs,
382 outputs: self.outputs
383 }
384 }
385
386 fn mutate<T: GlobalNeatCounter>(&mut self, neat: &mut T) {
387 match randint(100) {
388 0..=2 => self.add_node(neat),
389 3 => self.new_bias(),
390 4 => self.change_bias(),
391 5..=34 => self.add_connection(neat),
392 34..=40 if self.conns.len() >= 1 => {
393 let index = randint(self.conns.len());
394 if let Some(connection) = self.conns[index].as_mut() {
395 match randint(100) {
396 0..=1 => connection.shift_weight(),
397 2..=3 => connection.change_weight(random_weight()),
398 _ => {}
399 }
400 }
401 }
402 _ => {}
403 }
404 }
405
406 fn predict(&self, input: &[f64], activate: fn(f64) -> f64) -> Vec<f64> {
407 let connections: Vec<_> = self
408 .conns
409 .iter()
410 .filter(|c| c.is_some())
411 .map(|c| match c.as_ref() {
412 Some(conns) => (conns.from, conns.to, conns.weight).into(),
413 _ => panic!("this line will never be reached"),
414 })
415 .collect();
416
417 let inputs = self.inputs;
418 let outputs = self.outputs;
419 let hidden = self.nodes.len() - 1 - inputs - outputs;
420
421 let net = slow_nn::Network::from_conns(self.bias, inputs, outputs, hidden, &connections);
422 net.predict(input, activate)
423 }
424}
425
426#[cfg(test)]
427mod tests {
428 use super::*;
429 use std::collections::HashSet;
430
431 struct Neat {
432 connections: HashSet<(usize, usize)>,
433 nodes: usize
434 }
435
436 impl Neat {
437 fn new(inputs: usize, outputs: usize) -> Self {
438 Self {
439 connections: HashSet::new(),
440 nodes: 1 + inputs + outputs
441 }
442 }
443 }
444
445 impl GlobalNeatCounter for Neat {
446 fn try_adding_connection(&mut self, from: usize, to: usize) -> Option<usize> {
447 let innov_num = self.connections.len();
448 if self.connections.insert((from, to)) {
449 Some(innov_num)
450 } else {
451 None
452 }
453 }
454
455 fn get_new_node(&mut self) -> usize {
456 let new_node = self.nodes;
457 self.nodes += 1;
458 new_node
459 }
460 }
461
462 #[test]
463 fn test_node() {
464 let input = Node::input();
465 let output = Node::output();
466
467 let hidden = Node::new_hidden(&input, &output).unwrap();
468 let hidden1 = Node::new_hidden(&input, &hidden).unwrap();
469 let hidden2 = Node::new_hidden(&hidden1, &output).unwrap();
470 let hiddden3 = Node::new_hidden(&hidden1, &hidden).unwrap();
471 }
472
473 #[test]
474 fn test_genome() {
475 let mut genome1 = Genotype::empty(3, 2);
476 let mut genome2 = Genotype::empty(3, 2);
477 let mut neat = Neat::new(3, 2);
478
479 for _ in 0..1000 {
480 genome1.mutate(&mut neat);
481 }
482 }
483}