neural_network_rs/neural_network/
mod.rs

1pub mod activation_function;
2pub mod cost_function;
3pub mod layer;
4pub mod optimizer;
5
6use crate::dataset::Dataset;
7
8use ndarray::Array2;
9
10use self::{
11    activation_function::ActivationFunction, cost_function::CostFunction, layer::Layer,
12    optimizer::Optimizer,
13};
14
15pub struct Network<'a> {
16    pub layers: Vec<Layer<'a>>,
17    optimizer: &'a mut dyn Optimizer,
18    pub shape: &'a [(&'a ActivationFunction, usize)],
19    pub cost_function: &'a CostFunction,
20}
21
22#[allow(non_snake_case)]
23impl Network<'_> {
24    pub fn new<'a>(
25        shape: &'a [(&ActivationFunction, usize)],
26        optimizer: &'a mut dyn Optimizer,
27        cost_function: &'a CostFunction,
28    ) -> Network<'a> {
29        let mut layers = Vec::new();
30        for i in 0..shape.len() - 1 {
31            let (activation_function, input_size) = shape[i];
32            let (_, output_size) = shape[i + 1];
33            layers.push(Layer::new(input_size, output_size, activation_function));
34        }
35
36        optimizer.initialize(&layers);
37
38        Network {
39            layers,
40            optimizer,
41            shape,
42            cost_function,
43        }
44    }
45
46    // Predicts the output of the network given an input
47    pub fn predict(&self, input: &Array2<f64>) -> Array2<f64> {
48        let mut output = input.clone();
49        for layer in &self.layers {
50            output = layer.predict(&output);
51        }
52        output
53    }
54
55    // Calculates the needed adjustments to the weights and biases for a given input and expected output
56    pub fn backprop(
57        &self,
58        X: &Array2<f64>,
59        y: &Array2<f64>,
60    ) -> (Vec<Array2<f64>>, Vec<Array2<f64>>) {
61        let mut nabla_bs = Vec::new();
62        let mut nabla_ws = Vec::new();
63
64        // Forward pass
65        let mut activation = X.clone();
66        let mut activations = vec![activation.clone()];
67        let mut zs = Vec::new();
68        for layer in &self.layers {
69            let z = layer.forward(&activation);
70            zs.push(z.clone());
71            activation = layer.activation.function(&z);
72            activations.push(activation.clone());
73        }
74
75        // Calculate the cost
76        let nabla_c = self.cost_function.cost_derivative(&activation, &y);
77
78        // Calculate sensitivity
79        let sig_prime = self.layers[self.layers.len() - 1]
80            .activation
81            .derivative(&zs[zs.len() - 1]);
82
83        // Calculate delta for last layer
84        let mut delta = nabla_c * sig_prime;
85
86        // Calculate nabla_b and nabla_w for last layer
87        nabla_bs.push(delta.clone());
88        nabla_ws.push((&activations[activations.len() - 2]).t().dot(&delta));
89
90        // Loop backwards through the layers, calculating delta, nabla_b and nabla_w
91        for i in 2..self.shape.len() {
92            let sig_prime = self.layers[self.layers.len() - i]
93                .activation
94                .derivative(&zs[zs.len() - i]);
95
96            let nabla_c = &delta.dot(&self.layers[self.layers.len() - i + 1].weights.t());
97
98            delta = nabla_c * sig_prime;
99
100            nabla_bs.push(delta.clone());
101            nabla_ws.push((&activations[activations.len() - i - 1].t()).dot(&delta));
102        }
103
104        // restore correct ordering
105        nabla_bs.reverse();
106        nabla_ws.reverse();
107
108        //Adjust for batch size
109        let batch_size = X.nrows() as f64;
110        for (nabla_b, nabla_w) in nabla_bs.iter_mut().zip(nabla_ws.iter_mut()) {
111            *nabla_b = nabla_b
112                .sum_axis(ndarray::Axis(0))
113                .into_shape((1, nabla_b.ncols()))
114                .unwrap();
115
116            *nabla_b /= batch_size;
117            *nabla_w /= batch_size;
118        }
119
120        (nabla_bs, nabla_ws)
121    }
122
123    // Trains the network using a minibatch
124    pub fn train_minibatch(&mut self, (X, y): &(Array2<f64>, Array2<f64>)) {
125        let (nabla_bs, nabla_ws) = self.backprop(X, y);
126
127        self.optimizer.pre_update();
128
129        self.optimizer
130            .update_params(&mut self.layers, &nabla_bs, &nabla_ws);
131
132        self.optimizer.post_update();
133    }
134
135    // Trains the network using a dataset, records the cost for each epoch
136    pub fn train_and_log(
137        &mut self,
138        data: &Dataset,
139        batch_size: usize,
140        verification_samples: usize,
141        epochs: i32,
142    ) -> Vec<(i32, f64)> {
143        let mut cost_history = Vec::new();
144
145        for epoch in 0..epochs {
146            self.train_minibatch(&data.get_batch(batch_size));
147
148            if epoch % (epochs / 100 + 1) == 0 {
149                let cost = self.eval(data, verification_samples);
150                cost_history.push((epoch, cost));
151
152                println!("Epoch: {}, Cost: {:.8}", epoch, cost);
153            }
154        }
155
156        cost_history
157    }
158
159    // Evaluates the network on a given dataset
160    pub fn eval(&self, data: &Dataset, sample_size: usize) -> f64 {
161        let (x, y) = data.get_batch(sample_size);
162
163        let prediction = self.predict(&x);
164        let cost = self.cost_function.cost(&prediction, &y);
165        cost
166    }
167
168    // evaluates the prediction-results for the unit-square, returns a list
169    // containing the result for each point in a row by row fashion
170    pub fn predict_unit_square(&self, resolution: usize) -> ((usize, usize), Vec<Vec<f64>>) {
171        let unit_square = Dataset::get_2d_unit_square(resolution);
172        let pred = self.predict(&unit_square);
173
174        let res = pred
175            .lanes(ndarray::Axis(1))
176            .into_iter()
177            .map(|x| x.to_vec())
178            .collect();
179
180        ((resolution, resolution), res)
181    }
182}
183
184pub trait Summary {
185    fn summerize(&self) -> String;
186}
187
188impl Summary for Network<'_> {
189    fn summerize(&self) -> String {
190        let shape = self.shape.iter().map(|x| x.1).collect::<Vec<_>>();
191
192        format!("{}_{:?}", self.optimizer.summerize(), shape).replace(" ", "")
193    }
194}