Skip to main content

slow_nn/
network.rs

1mod neuron;
2
3pub use neuron::*;
4use std::collections::VecDeque;
5use std::f64::NEG_INFINITY;
6use serde::{Serialize, Deserialize};
7use std::fs::File;
8use std::io::{Read, Write, Error};
9
10struct Outputs {
11    outputs: Vec<f64>,
12    sums: Vec<f64>,
13}
14
15impl Outputs {
16    fn new(outputs: Vec<f64>, sums: Vec<f64>) -> Self {
17        Self { outputs, sums }
18    }
19}
20
21struct Gradients {
22    grads: Vec<f64>,
23    del_ws: Vec<Vec<f64>>,
24}
25
26impl Gradients {
27    fn new(grads: Vec<f64>, del_ws: Vec<Vec<f64>>) -> Self {
28        Self { grads, del_ws }
29    }
30}
31
32/// Neural Network struct
33#[derive(Serialize, Deserialize)]
34pub struct Network {
35    bias: f64,
36    inputs: usize,
37    outputs: usize,
38    hidden: usize,
39    neurons: Vec<Neuron>,
40}
41
42impl Network {
43    /// Creates a network from a slice of Connections
44    pub fn from_conns(
45        bias: f64,
46        inputs: usize,
47        outputs: usize,
48        hidden: usize,
49        conns: &[Connection],
50    ) -> Self {
51        let cap = 1 + inputs + outputs + hidden;
52        let mut neurons: Vec<_> = (0..).take(cap).map(|_| Neuron::new()).collect();
53
54        for conn in conns {
55            neurons[conn.to].connected_from(conn.from, conn.weight);
56        }
57
58        Self {
59            bias,
60            inputs,
61            outputs,
62            hidden,
63            neurons,
64        }
65    }
66
67    fn add_conns(
68        conns: &mut Vec<Connection>,
69        start1: usize,
70        end1: usize,
71        start2: usize,
72        end2: usize,
73    ) {
74        for i in start1..end1 {
75            for j in start2..end2 {
76                conns.push(Connection::new(i, j));
77            }
78        }
79    }
80
81    fn rng2vec(start: usize, end: usize) -> Vec<usize> {
82        (start..end).collect()
83    }
84
85    /// Creates a fully connected network with random weights
86    pub fn dense(bias: f64, inputs: usize, outputs: usize, layers: &[usize]) -> Self {
87        // Generate indeces for connections
88        let mut offset = 1 + inputs + outputs;
89        let mut indeces = vec![Self::rng2vec(0, 1 + inputs)];
90        for layer in layers {
91            indeces.push(Self::rng2vec(offset, offset + layer));
92            offset += layer;
93        }
94        let mut conns = Vec::new();
95        indeces.push(Self::rng2vec(1 + inputs, 1 + inputs + outputs));
96
97        // Create connections
98        for i in 1..indeces.len() {
99            for &index1 in &indeces[i - 1] {
100                for &index2 in &indeces[i] {
101                    conns.push(Connection::new(index1, index2));
102                }
103            }
104        }
105
106        let hidden = offset - 1 - inputs - outputs;
107        println!("Hidden = {}", hidden);
108
109        Network::from_conns(bias, inputs, outputs, hidden, &conns)
110    }
111
112    /// Does a forward pass
113    pub fn predict(&self, input: &[f64], activation: fn(f64) -> f64) -> Vec<f64> {
114        let Outputs { outputs, .. } = self.compute_outputs(input, activation);
115        let offset = 1 + self.inputs;
116
117        outputs[offset..offset + self.outputs]
118            .iter()
119            .map(|&i| i)
120            .collect()
121    }
122
123    /// Does both forward and backward pass and updates the weights
124    pub fn train(
125        &mut self,
126        input: &[f64],
127        expected: &[f64],
128        activation: fn(f64) -> f64,
129        deactivation: fn(f64) -> f64,
130        loss: fn(f64, f64) -> f64,
131        dloss: fn(f64, f64) -> f64,
132        lr: f64,
133    ) -> f64 {
134        let outs = self.compute_outputs(input, activation);
135        let pred: Vec<_> = outs.outputs.iter().skip(1 + self.inputs).take(self.outputs).cloned().collect();
136        let error = pred.iter().zip(expected.iter()).fold(0., |acc, (&a, &b)| acc + loss(a, b));
137        let Gradients { grads, del_ws } = self.compute_grads(outs, expected, deactivation, dloss, lr);
138
139        self.bias += grads[0] * lr;
140        for i in 0..self.neurons.len() {
141            for j in 0..self.neurons[i].connections() {
142                self.neurons[i].weights[j] -= del_ws[i][j] * lr;
143            }
144        }
145
146        error
147    }
148
149    fn compute_outputs(&self, input: &[f64], activation: fn(f64) -> f64) -> Outputs {
150        let len = self.neurons.len();
151        let mut outputs: Vec<_> = (0..)
152            .take(len)
153            .map(|i| match i {
154                0 => self.bias,
155                x if x <= self.inputs => input[i - 1],
156                _ => NEG_INFINITY,
157            })
158            .collect();
159        let mut sums: Vec<_> = (0..)
160            .take(len)
161            .map(|_| 0.)
162            .collect();
163        let mut stack: Vec<_> = (0..).skip(1 + self.inputs).take(self.outputs).collect();
164
165        while let Some(&top) = stack.last() {
166            let len = stack.len();
167            let mut sum = 0.0;
168            for i in 0..self.neurons[top].connections() {
169                let index = self.neurons[top].in_comes[i];
170                if outputs[index] == NEG_INFINITY {
171                    stack.push(index);
172                } else {
173                    let weight = self.neurons[top].weights[i];
174                    sum += weight * outputs[index];
175                }
176            }
177
178            if len == stack.len() {
179                stack.pop();
180                sums[top] = sum;
181                outputs[top] = activation(sum);
182            }
183        }
184
185        Outputs::new(outputs, sums)
186    }
187
188    fn compute_grads(
189        &self,
190        outs: Outputs,
191        expected: &[f64],
192        deactivation: fn(f64) -> f64,
193        dloss: fn(f64, f64) -> f64,
194        lr: f64,
195    ) -> Gradients {
196        let Outputs { outputs, sums } = outs;
197        let offset = 1 + self.inputs;
198        let mut grad: Vec<_> = (0..)
199            .take(offset)
200            .map(|_| NEG_INFINITY)
201            .chain(
202                (0..)
203                    .take(self.outputs)
204                    .map(|i| dloss(outputs[i + offset], expected[i])),
205            )
206            .chain((0..).take(self.hidden).map(|_| NEG_INFINITY))
207            .collect();
208
209        let mut del_ws: Vec<Vec<_>> = self
210            .neurons
211            .iter()
212            .map(|n| (0..n.connections()).map(|_| 0.).collect())
213            .collect();
214
215        let mut queue: VecDeque<_> = (0..).skip(1 + self.inputs).take(self.outputs).collect();
216
217        while let Some(top) = queue.pop_front() {
218            let da = grad[top];
219            let dz = da * deactivation(outputs[top]);
220
221            for i in 0..self.neurons[top].connections() {
222                let index = self.neurons[top].in_comes[i];
223
224                let dw = dz * outputs[index];
225                let dx = dz * self.neurons[top].weights[i];
226
227                del_ws[top][i] += dw;
228
229                if grad[index] == NEG_INFINITY {
230                    grad[index] = dx;
231                    queue.push_back(index);
232                } else {
233                    grad[index] += dx;
234                }
235            }
236        }
237
238        Gradients::new(grad, del_ws)
239    }
240
241    /// loads the network from a file
242    pub fn load(path: &str) -> Result<Self, Error> {
243        let mut file = File::open(path)?;
244        let mut buffer = Vec::new();
245        file.read_to_end(&mut buffer)?;
246        Ok(bincode::deserialize(&buffer).expect("Error reading file"))
247    }
248
249    /// Saves the network to a file
250    pub fn save(&self, path: &str) -> Result<(), Error> {
251        let mut file = File::create(path)?;
252        let buffer: Vec<u8> = bincode::serialize(&self).expect("Error serializing file");
253        file.write(&buffer)?;
254        Ok(())
255    }
256
257    /// Returns the bytes of the object after serializing
258    pub fn to_bytes(&self) -> Option<Vec<u8>> {
259        bincode::serialize(&self).ok()
260    }
261}