rusty_neat/
lib.rs

1#![allow(clippy::map_clone, clippy::needless_range_loop)]
2//#![feature(drain_filter)]
3
4use bincode::{serialize, deserialize};
5use serde::{Serialize, Deserialize};
6use std::fs::File;
7use std::io::prelude::*;
8
9#[cfg(test)]
10mod tests {
11    use super::*;
12
13    #[test]
14    fn it_works() {
15
16        let mut net = NN::new(3, 2);
17        for _ in 0..32 {
18            net.mutate();
19        }
20        let _ = net.get_chances();
21        net.set_chances(&[0,0,0]);
22        net.forward(&[0.5, 0.2, 0.8]);
23        println!("\nOrder: \n{:?}", net.layer_order);
24        println!("\nConnections: \n{:?}", net.connections);
25        println!("\nNodes: \n{:?}", net.nodes);
26
27        assert_eq!(1, 1);
28    }
29}
30
31
32#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
33pub enum ActFunc {
34    Sigmoid,
35    Tanh,
36    ReLU,
37    None
38}
39#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
40pub enum Genre {
41    Hidden,
42    Input,
43    Output,
44}
45
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct Node {
49    value: f64,
50    pub bias: f64,
51    genre: Genre, // 0 - hidden, 1 - input, 2 - output
52    pub act_func: ActFunc,
53    free_nodes: Vec<usize>, // vec containing nodes, to which there is free path
54}
55impl Node {
56    // initializing to random values 
57    pub fn new(genre: Genre) -> Self { 
58        let mut af = match fastrand::usize(0..3) {
59            0 => ActFunc::Sigmoid,
60            1 => ActFunc::Tanh,
61            2 => ActFunc::ReLU,
62            _ => unreachable!(),
63        };
64        // input nodes don't have activation functions
65        if genre == Genre::Input {af = ActFunc::None;}
66
67        // input nodes don't have biases
68        let b = match genre {
69            Genre::Input => 0.0,
70            _ => fastrand::f64() * 2.0 - 1.0,
71        };
72
73        Self { 
74            value: 0.0, 
75            bias: b,
76            genre,
77            act_func: af,
78            free_nodes: vec![],
79        } 
80    }
81}
82impl Default for Node {
83    fn default() -> Self {Self::new(Genre::Hidden)}
84}
85
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct Connection {
89    pub from: usize, // idx of start node
90    pub to: usize, // idx of end node
91    pub weight: f64,
92    pub active: bool // connections can be deactivated through mutations
93}
94
95impl Connection {
96    pub fn new(from: usize, to: usize) -> Self { 
97        Self { 
98            from,
99            to,
100            weight: fastrand::f64() * 2.0 - 1.0,
101            active: true,
102        } 
103    }
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct NN {    
108    pub nodes: Vec<Node>,
109    pub connections: Vec<Connection>,
110    pub layer_order: Vec<Vec<usize>>, // layers for calculating values (without input nodes), eg
111                                        // which node calculate first
112    pub generation: usize, // generation number, just out of curiosity
113    pub size: (usize, usize),
114
115    chances: [usize; 7], // chances for mutations to happen, sum does NOT need to be equal 100
116}
117
118impl NN {
119    pub fn new(input_count: usize, output_count: usize) -> Self { 
120        // create input and output nodes
121        let mut n = vec![];
122        for _ in 0..input_count { n.push(Node::new(Genre::Input)); }
123        for _ in 0..output_count { n.push(Node::new(Genre::Output)); }
124
125        Self { 
126            nodes: n, 
127            connections: vec![],
128            layer_order: vec![], 
129            generation: 0,
130            size: (input_count, output_count),
131            chances: [35, 35, 10, 10, 10, 0, 0],
132        } 
133    }
134
135    pub fn get_chances(&mut self) -> &[usize; 7] {
136        &self.chances
137    }
138
139    pub fn set_chances(&mut self, ch: &[usize]) {
140        let mut size = ch.len();
141        if size > 7 {size = 7;}
142        for i in 0..size {
143            self.chances[i] = ch[i];
144        }
145    }
146
147    pub fn forward(&mut self, input: &[f64]) -> Vec<f64>{
148        // read inputs
149        self.nodes.iter_mut().filter(|n| n.genre == Genre::Input).zip(input.iter()).for_each(|(n, v)|{
150            n.value = *v;
151        });
152
153        // process network: layer -> element
154        // faster method would be cool
155        self.layer_order.iter().for_each(|l|{
156            l.iter().for_each(|i|{
157                // collect active connections pointing to current node (from ordering)
158                self.connections.iter().filter(|f| f.to == *i && f.active).for_each(|c|{
159                    self.nodes[*i].value += self.nodes[c.from].value * c.weight;
160                });
161
162                // activate node
163                let v = self.nodes[*i].value;
164                match self.nodes[*i].act_func {
165                    ActFunc::Sigmoid => self.nodes[*i].value = 1.0 / (1.0 + (-v).exp()),
166                    ActFunc::Tanh => self.nodes[*i].value = v.tanh(),
167                    ActFunc::ReLU => self.nodes[*i].value = v.max(0.0),
168                    ActFunc::None => {},
169                };
170            });
171        });
172
173        // return outputs
174        self.nodes.iter()
175            .filter(|n| n.genre == Genre::Output)
176            .map(|n| n.value)
177            .collect::<Vec<_>>()
178    }
179
180    pub fn mutate(&mut self) {
181        // init network
182        if self.generation == 0 {
183            self.topological_sort();
184            self.free_nodes_calc();
185            self.m_connection_add();
186            self.topological_sort();
187            self.free_nodes_calc();
188        }
189        self.generation += 1; // increment generation
190
191        // choose mutation based on chances
192        let random_num = fastrand::usize(0..self.chances.iter().sum());
193        let mut cumulative_prob = 0;
194        let mut func_num = 99;
195        for (i, prob) in self.chances.iter().enumerate() {
196            cumulative_prob += prob;
197            if random_num <= cumulative_prob {
198                func_num = i;
199                break;
200            }
201        }
202
203        match func_num {
204            0 => self.m_weight_mangle(),
205            1 => self.m_bias_mangle(),
206            2 => self.m_act_mangle(),
207            3 => self.m_connection_add(),
208            4 => self.m_node_add(),
209            5 => self.m_connection_enable(),
210            6 => self.m_connection_disable(),
211            99 => {},
212            _ => unreachable!(),
213        }
214
215        self.topological_sort();
216        self.free_nodes_calc();
217    }
218
219// connections manipulations
220    fn m_node_add(&mut self) {
221        // get active connections, if none return
222        let mut con_filtered = self.connections.iter_mut().filter(|c| c.active).collect::<Vec<_>>();
223        if con_filtered.is_empty() {return;}
224
225        let idx = fastrand::usize(0..con_filtered.len());
226        let from = con_filtered[idx].from;
227        let to = con_filtered[idx].to;
228
229        // insert node in the middle of existing connection
230        self.nodes.push(Node::default());
231        con_filtered[idx].active = false;
232        self.connections.push(Connection::new(from, self.nodes.len()-1));
233        self.connections.push(Connection::new(self.nodes.len()-1, to));
234    }
235
236    fn m_connection_add(&mut self) {
237        // randomly select node, that isn't output and have free paths, if none return (full)
238        let mut from_v: Vec<usize> = vec![];
239        self.layer_order.iter().for_each(|l|{
240            l.iter().for_each(|p|{
241                if !self.nodes[*p].free_nodes.is_empty() {
242                    from_v.push(*p);
243                }
244            });
245        });
246        
247        
248        if from_v.is_empty(){println!("Err, no free nodes");return;}
249        let from = from_v[fastrand::usize(0..from_v.len())];
250
251        // randomly select to node
252        let ll = self.nodes[from].free_nodes.len();
253        let to = self.nodes[from].free_nodes[fastrand::usize(0..ll)];
254
255        self.connections.push(Connection::new(from, to));
256    }
257
258    fn m_connection_disable(&mut self) {
259        let mut con_filtered = self.connections.iter_mut().filter(|c| c.active && c.weight != 0.0).collect::<Vec<_>>();
260        if con_filtered.is_empty() {return;}
261        let idx = fastrand::usize(0..con_filtered.len());
262
263        con_filtered[idx].weight = 0.0;
264    }
265
266    fn m_connection_enable(&mut self) {
267        let mut con_filtered = self.connections.iter_mut().filter(|c| c.active && c.weight == 0.0).collect::<Vec<_>>();
268        if con_filtered.is_empty() {return;}
269        let idx = fastrand::usize(0..con_filtered.len());
270
271        con_filtered[idx].weight = fastrand::f64() * 2.0 - 1.0;
272    }
273
274// set of mutation funcs
275    fn m_weight_mangle(&mut self){
276        let mut con_filtered = self.connections.iter_mut().filter(|c| c.active && c.weight != 0.0).collect::<Vec<_>>();
277        if con_filtered.is_empty() {return;}
278        let idx = fastrand::usize(0..con_filtered.len());
279
280        con_filtered[idx].weight += fastrand::f64() / 2.5 - 0.2;
281        con_filtered[idx].weight = con_filtered[idx].weight.clamp(-1.0, 1.0);
282    }
283
284    fn m_bias_mangle(&mut self) {
285        let mut node_filtered = self.nodes.iter_mut().filter(|n| n.genre != Genre::Input).collect::<Vec<_>>();
286        let idx = fastrand::usize(0..node_filtered.len());
287
288        node_filtered[idx].bias += fastrand::f64() / 2.5 - 0.2;
289        node_filtered[idx].bias = node_filtered[idx].bias.clamp(-1.0, 1.0)
290    }
291
292    fn m_act_mangle(&mut self) {
293        let idx = fastrand::usize(self.size.0..self.nodes.len());
294        self.nodes[idx].act_func = match fastrand::usize(0..3) {
295            0 => ActFunc::Sigmoid,
296            1 => ActFunc::Tanh,
297            2 => ActFunc::ReLU,
298            _ => unreachable!(),
299        };
300    }
301
302// order funcs
303    fn free_nodes_calc(&mut self) {
304        // check, to which nodes is a free path from node
305        self.layer_order.iter().enumerate().for_each(|(i, l)|{
306            l.iter().for_each(|p|{
307
308                let mut free: Vec<usize> = vec![];
309                self.layer_order.iter().skip(i).for_each(|li|{
310                    li.iter().for_each(|pi|{
311                        if *p != *pi && self.nodes[*pi].genre != Genre::Input {free.push(*pi);}
312                    });
313                });
314
315                self.connections.iter().filter(|c| c.from == *p).for_each(|c|{
316                    //free.drain_filter(|e| c.to == *e);
317                    free.retain(|e| c.to != *e);
318                    // drain_filter is propably removed from rust, retain is simply inversion
319                });
320
321                if self.nodes[*p].genre != Genre::Output { 
322                    self.nodes[*p].free_nodes.clear(); 
323                    self.nodes[*p].free_nodes.append(&mut free); 
324                }
325                //else { self.nodes[*p].free_nodes = Vec::<usize>::new(); }
326            });
327        });
328    }
329
330    // belive or not, but that part was mostly written by chatgpt, 
331    // I have very little idea how it works, but it works, so...
332    fn topological_sort(&mut self) {
333        let mut indegrees = vec![0; self.nodes.len()];
334        let mut neighbors = vec![Vec::new(); self.nodes.len()];
335        
336        for connection in self.connections.as_slice() {
337            indegrees[connection.to] += 1;
338            neighbors[connection.from].push(connection.to);
339        }
340        
341        let mut queue = Vec::new();
342        let mut layers = Vec::new();
343        
344        for i in 0..self.nodes.len() {
345            if indegrees[i] == 0 {
346                queue.push(i);
347            }
348        }
349        
350        while !queue.is_empty() {
351            let mut layer = Vec::new();
352            let mut next_queue = Vec::new();
353            
354            for node in queue {
355                layer.push(node);
356                for neighbor in &neighbors[node] {
357                    indegrees[*neighbor] -= 1;
358                    if indegrees[*neighbor] == 0 {
359                        next_queue.push(*neighbor);
360                    }
361                }
362            }
363            
364            layers.push(layer);
365            queue = next_queue;
366        }
367        
368        self.layer_order = layers;
369    }
370
371    // save nn to file
372    pub fn save(&self, path: &str) {        
373        // convert simplified nn to Vec<u8>
374        let encoded: Vec<u8> = serialize(
375            &self
376        ).unwrap();
377    
378        // open file and write whole Vec<u8>
379        let mut file = File::create(path).unwrap();
380        file.write_all(&encoded).unwrap();
381    }
382
383    // load nn from file
384    pub fn load(&mut self, path: &str) {
385
386        // convert readed Vec<u8> to plain nn
387        let mut buffer = vec![];
388        let mut file = File::open(path).unwrap();
389        file.read_to_end(&mut buffer).unwrap();
390        let decoded: NN = deserialize(&buffer).unwrap();
391
392        *self = decoded;
393    }
394}