neural_network_rs/neural_network/
mod.rs1pub mod activation_function;
2pub mod cost_function;
3pub mod layer;
4pub mod optimizer;
5
6use crate::dataset::Dataset;
7
8use ndarray::Array2;
9
10use self::{
11 activation_function::ActivationFunction, cost_function::CostFunction, layer::Layer,
12 optimizer::Optimizer,
13};
14
15pub struct Network<'a> {
16 pub layers: Vec<Layer<'a>>,
17 optimizer: &'a mut dyn Optimizer,
18 pub shape: &'a [(&'a ActivationFunction, usize)],
19 pub cost_function: &'a CostFunction,
20}
21
22#[allow(non_snake_case)]
23impl Network<'_> {
24 pub fn new<'a>(
25 shape: &'a [(&ActivationFunction, usize)],
26 optimizer: &'a mut dyn Optimizer,
27 cost_function: &'a CostFunction,
28 ) -> Network<'a> {
29 let mut layers = Vec::new();
30 for i in 0..shape.len() - 1 {
31 let (activation_function, input_size) = shape[i];
32 let (_, output_size) = shape[i + 1];
33 layers.push(Layer::new(input_size, output_size, activation_function));
34 }
35
36 optimizer.initialize(&layers);
37
38 Network {
39 layers,
40 optimizer,
41 shape,
42 cost_function,
43 }
44 }
45
46 pub fn predict(&self, input: &Array2<f64>) -> Array2<f64> {
48 let mut output = input.clone();
49 for layer in &self.layers {
50 output = layer.predict(&output);
51 }
52 output
53 }
54
55 pub fn backprop(
57 &self,
58 X: &Array2<f64>,
59 y: &Array2<f64>,
60 ) -> (Vec<Array2<f64>>, Vec<Array2<f64>>) {
61 let mut nabla_bs = Vec::new();
62 let mut nabla_ws = Vec::new();
63
64 let mut activation = X.clone();
66 let mut activations = vec![activation.clone()];
67 let mut zs = Vec::new();
68 for layer in &self.layers {
69 let z = layer.forward(&activation);
70 zs.push(z.clone());
71 activation = layer.activation.function(&z);
72 activations.push(activation.clone());
73 }
74
75 let nabla_c = self.cost_function.cost_derivative(&activation, &y);
77
78 let sig_prime = self.layers[self.layers.len() - 1]
80 .activation
81 .derivative(&zs[zs.len() - 1]);
82
83 let mut delta = nabla_c * sig_prime;
85
86 nabla_bs.push(delta.clone());
88 nabla_ws.push((&activations[activations.len() - 2]).t().dot(&delta));
89
90 for i in 2..self.shape.len() {
92 let sig_prime = self.layers[self.layers.len() - i]
93 .activation
94 .derivative(&zs[zs.len() - i]);
95
96 let nabla_c = &delta.dot(&self.layers[self.layers.len() - i + 1].weights.t());
97
98 delta = nabla_c * sig_prime;
99
100 nabla_bs.push(delta.clone());
101 nabla_ws.push((&activations[activations.len() - i - 1].t()).dot(&delta));
102 }
103
104 nabla_bs.reverse();
106 nabla_ws.reverse();
107
108 let batch_size = X.nrows() as f64;
110 for (nabla_b, nabla_w) in nabla_bs.iter_mut().zip(nabla_ws.iter_mut()) {
111 *nabla_b = nabla_b
112 .sum_axis(ndarray::Axis(0))
113 .into_shape((1, nabla_b.ncols()))
114 .unwrap();
115
116 *nabla_b /= batch_size;
117 *nabla_w /= batch_size;
118 }
119
120 (nabla_bs, nabla_ws)
121 }
122
123 pub fn train_minibatch(&mut self, (X, y): &(Array2<f64>, Array2<f64>)) {
125 let (nabla_bs, nabla_ws) = self.backprop(X, y);
126
127 self.optimizer.pre_update();
128
129 self.optimizer
130 .update_params(&mut self.layers, &nabla_bs, &nabla_ws);
131
132 self.optimizer.post_update();
133 }
134
135 pub fn train_and_log(
137 &mut self,
138 data: &Dataset,
139 batch_size: usize,
140 verification_samples: usize,
141 epochs: i32,
142 ) -> Vec<(i32, f64)> {
143 let mut cost_history = Vec::new();
144
145 for epoch in 0..epochs {
146 self.train_minibatch(&data.get_batch(batch_size));
147
148 if epoch % (epochs / 100 + 1) == 0 {
149 let cost = self.eval(data, verification_samples);
150 cost_history.push((epoch, cost));
151
152 println!("Epoch: {}, Cost: {:.8}", epoch, cost);
153 }
154 }
155
156 cost_history
157 }
158
159 pub fn eval(&self, data: &Dataset, sample_size: usize) -> f64 {
161 let (x, y) = data.get_batch(sample_size);
162
163 let prediction = self.predict(&x);
164 let cost = self.cost_function.cost(&prediction, &y);
165 cost
166 }
167
168 pub fn predict_unit_square(&self, resolution: usize) -> ((usize, usize), Vec<Vec<f64>>) {
171 let unit_square = Dataset::get_2d_unit_square(resolution);
172 let pred = self.predict(&unit_square);
173
174 let res = pred
175 .lanes(ndarray::Axis(1))
176 .into_iter()
177 .map(|x| x.to_vec())
178 .collect();
179
180 ((resolution, resolution), res)
181 }
182}
183
184pub trait Summary {
185 fn summerize(&self) -> String;
186}
187
188impl Summary for Network<'_> {
189 fn summerize(&self) -> String {
190 let shape = self.shape.iter().map(|x| x.1).collect::<Vec<_>>();
191
192 format!("{}_{:?}", self.optimizer.summerize(), shape).replace(" ", "")
193 }
194}