1use std::iter::zip;
7
8use rand::Rng;
9
10const RELU_LEAK: f32 = 0.01;
11
12#[derive(Debug, Clone)]
13pub struct Neuron<const I: usize> {
17 pub weights: [f32; I],
18 pub bias: f32,
19}
20impl<const I: usize> Neuron<I> {
21 pub fn zero() -> Self {
23 Self {
24 weights: [0.; I],
25 bias: 0.,
26 }
27 }
28 pub fn random_with_0_bias() -> Self {
30 let mut rng = rand::thread_rng();
31 Self {
32 weights: [0.; I].map(|_| rng.gen_range(-1.0..=1.0)),
33 bias: 0.0,
34 }
35 }
36 pub fn random() -> Self {
38 let mut rng = rand::thread_rng();
39 Self {
40 weights: [0.; I].map(|_| rng.gen_range(-1.0..=1.0)),
41 bias: rng.gen_range(-1.0..=1.0),
42 }
43 }
44 pub fn random_with_range(range: f32) -> Self {
46 let mut rng = rand::thread_rng();
47 Self {
48 weights: [0.; I].map(|_| rng.gen_range(-range..=range)),
49 bias: rng.gen_range(-range..=range),
50 }
51 }
52 fn leaky_relu(x: f32) -> f32 {
53 if x > 0. {
54 x
55 } else {
56 RELU_LEAK * x
57 }
58 }
59 fn leaky_relu_derivative(x: f32) -> f32 {
60 if x > 0. {
61 1.
62 } else {
63 RELU_LEAK
64 }
65 }
66 fn weighted_sum(&self, x: &[f32; I]) -> f32 {
67 zip(x, &self.weights).map(|(x, w)| x * w).sum::<f32>() + self.bias
68 }
69 pub fn output(&self, x: &[f32; I]) -> f32 {
71 Self::leaky_relu(self.weighted_sum(x))
72 }
73 fn derivative(&self, x: &[f32; I]) -> f32 {
74 Self::leaky_relu_derivative(self.weighted_sum(x))
75 }
76 fn compute_update(&self, x: &[f32; I], value: f32, e: f32) -> ([f32; I], f32) {
77 let d_y_e_der = -self.derivative(x) * value * e;
78 (x.map(|x_i| d_y_e_der * x_i), d_y_e_der)
79 }
80 fn update_d_weights(
81 d_weights_j: &[f32; I],
82 d_bias_j: f32,
83 d_weights: &mut [f32; I],
84 d_bias: &mut f32,
85 ) {
86 for (d_weight, d_weight_j) in zip(d_weights.iter_mut(), d_weights_j) {
87 *d_weight += d_weight_j;
88 }
89 *d_bias += d_bias_j;
90 }
91 fn update_d_weights_output(
92 &self,
93 x: &[f32; I],
94 y_data: f32,
95 e: f32,
96 d_weights: &mut [f32; I],
97 d_bias: &mut f32,
98 ) {
99 let y_pred = self.output(x);
100 let (d_weights_j, d_bias_j) = self.compute_update(x, y_pred - y_data, e);
101 Self::update_d_weights(&d_weights_j, d_bias_j, d_weights, d_bias);
102 }
103 fn update_d_weights_hidden(
104 &self,
105 x: &[f32; I],
106 d_hidden: f32,
107 e: f32,
108 d_weights: &mut [f32; I],
109 d_bias: &mut f32,
110 ) {
111 let (d_weights_j, d_bias_j) = self.compute_update(x, d_hidden, e);
112 Self::update_d_weights(&d_weights_j, d_bias_j, d_weights, d_bias);
113 }
114 fn update_weights(&mut self, d_weights: &[f32; I], d_bias: f32) {
115 for (w, d_w) in self.weights.iter_mut().zip(d_weights) {
116 *w += d_w;
117 }
118 self.bias += d_bias;
119 }
120 pub fn train<'a>(&mut self, data: impl Iterator<Item = &'a ([f32; I], f32)>, epsilon: f32) {
122 assert!(epsilon < 0.5);
123 let e = 2. * epsilon;
124 let mut d_weights = [0f32; I];
125 let mut d_bias = 0f32;
126 for (x, y_data) in data {
127 self.update_d_weights_output(x, *y_data, e, &mut d_weights, &mut d_bias);
128 }
129 self.update_weights(&d_weights, d_bias);
130 }
131}
132
133#[derive(Debug, Clone)]
137pub struct NeuralNetwork<const I: usize, const H: usize> {
138 pub hidden_layer: [Neuron<I>; H],
139 pub output_layer: Neuron<H>,
140}
141impl<const I: usize, const H: usize> NeuralNetwork<I, H> {
142 fn x_mid(&self, x: &[f32; I]) -> [f32; H] {
143 self.hidden_layer
144 .iter()
145 .map(|n| n.output(x))
146 .collect::<Vec<_>>()
147 .try_into()
148 .unwrap()
149 }
150 pub fn output(&self, x: &[f32; I]) -> f32 {
152 self.output_layer.output(&self.x_mid(x))
153 }
154 pub fn train<'a>(&mut self, data: impl Iterator<Item = &'a ([f32; I], f32)>, epsilon: f32) {
156 assert!(epsilon < 0.5);
157 let e = 2. * epsilon;
158 let mut d_weights_hidden = [[0f32; I]; H];
160 let mut d_bias_hidden = [0f32; H];
161 let mut d_weights_output = [0f32; H];
162 let mut d_bias_output = 0f32;
163 for (x, y_data) in data {
165 let x_mid = self.x_mid(x);
167 self.output_layer.update_d_weights_output(
169 &x_mid,
170 *y_data,
171 e,
172 &mut d_weights_output,
173 &mut d_bias_output,
174 );
175 let y_pred = self.output_layer.output(&x_mid);
177 let d_output = self.output_layer.derivative(&x_mid) * (y_pred - y_data);
178 for ((neuron, w_output), (d_weights, d_bias)) in zip(
179 zip(self.hidden_layer.iter_mut(), self.output_layer.weights),
180 zip(d_weights_hidden.iter_mut(), d_bias_hidden.iter_mut()),
181 ) {
182 let d_hidden = d_output * w_output;
183 neuron.update_d_weights_hidden(x, d_hidden, e, d_weights, d_bias);
184 }
185 }
186 for (neuron, (d_weights, d_bias)) in zip(
188 self.hidden_layer.iter_mut(),
189 zip(d_weights_hidden.iter(), d_bias_hidden),
190 ) {
191 neuron.update_weights(d_weights, d_bias);
192 }
193 self.output_layer
194 .update_weights(&d_weights_output, d_bias_output);
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use rand::Rng;
201
202 use crate::NeuralNetwork;
203
204 use super::Neuron;
205
206 fn approx_equal(a: f32, b: f32) -> bool {
207 (a - b).abs() < 1e-4
208 }
209 fn assert_approx_equal(a: f32, b: f32) {
210 if !approx_equal(a, b) {
211 panic!("{a} is different than {b}");
212 }
213 }
214
215 const LINEAR_1D_DATA: [([f32; 1], f32); 3] = [([0.], 1.), ([1.], 2.5), ([2.], 4.)];
216
217 #[test]
218 fn linear_function_1d() {
219 let mut neuron = Neuron::zero();
220 for _i in 0..100 {
221 neuron.train(LINEAR_1D_DATA.iter(), 0.1);
222 }
223 assert_approx_equal(neuron.weights[0], 1.5);
224 assert_approx_equal(neuron.bias, 1.);
225 }
226
227 #[test]
228 fn linear_function_2d() {
229 let mut neuron = Neuron {
230 weights: [0.234, -1.43],
231 bias: -1.425,
232 };
233 let data = [
234 ([0., 0.], 1.),
235 ([1., 0.], 1.5),
236 ([0., 1.], 2.),
237 ([1., 1.], 2.5),
238 ];
239 for _i in 0..150 {
240 neuron.train(data.iter(), 0.1);
241 }
242 assert_approx_equal(neuron.weights[0], 0.5);
243 assert_approx_equal(neuron.weights[1], 1.0);
244 assert_approx_equal(neuron.bias, 1.);
245 }
246
247 #[test]
248 fn two_layers_optimal_must_be_stable() {
249 let mut network = NeuralNetwork {
250 hidden_layer: [Neuron {
251 weights: [1.0],
252 bias: 0.0,
253 }],
254 output_layer: Neuron {
255 weights: [1.5],
256 bias: 1.0,
257 },
258 };
259 for _ in 0..100 {
260 network.train(LINEAR_1D_DATA.iter(), 0.1);
261 }
262 assert_approx_equal(network.hidden_layer[0].weights[0], 1.0);
263 assert_approx_equal(network.hidden_layer[0].bias, 0.0);
264 assert_approx_equal(network.output_layer.weights[0], 1.5);
265 assert_approx_equal(network.output_layer.bias, 1.0);
266 network = NeuralNetwork {
267 hidden_layer: [Neuron {
268 weights: [1.5],
269 bias: 1.0,
270 }],
271 output_layer: Neuron {
272 weights: [1.0],
273 bias: 0.0,
274 },
275 };
276 for _ in 0..100 {
277 network.train(LINEAR_1D_DATA.iter(), 0.1);
278 }
279 assert_approx_equal(network.hidden_layer[0].weights[0], 1.5);
280 assert_approx_equal(network.hidden_layer[0].bias, 1.0);
281 assert_approx_equal(network.output_layer.weights[0], 1.0);
282 assert_approx_equal(network.output_layer.bias, 0.0);
283 }
284
285 #[test]
286 fn linear_function_1d_hidden() {
287 let mut min_sse = f32::INFINITY;
288 for _rerun in 0..20 {
289 let mut network = NeuralNetwork {
290 hidden_layer: [Neuron::<1>::random()],
291 output_layer: Neuron::random(),
292 };
293 for _i in 0..500 {
294 network.train(LINEAR_1D_DATA.iter(), 0.05);
295 }
296 let sse: f32 = LINEAR_1D_DATA
297 .iter()
298 .map(|([x], y)| {
299 let y_pred = network.output(&[*x]);
300 (y - y_pred) * (y - y_pred)
301 })
302 .sum();
303 min_sse = sse.min(min_sse);
304 }
310 assert!(min_sse < 0.01, "min SSE {min_sse} is >= 0.01 over 20 runs");
312 }
313
314 #[test]
315 fn non_linear_function_1d() {
316 fn f(x: f32) -> f32 {
318 x * x
319 }
320 let mut network = NeuralNetwork {
321 hidden_layer: [
322 Neuron::<1>::random(),
323 Neuron::random(),
324 Neuron::random(),
325 Neuron::random(),
326 Neuron::random(),
327 Neuron::random(),
328 Neuron::random(),
329 Neuron::random(),
330 ],
331 output_layer: Neuron::random(),
332 };
333 let mut rng = rand::thread_rng();
334 let test_set = (0..100)
335 .map(|_| {
336 let x = rng.gen_range(-2.0..2.0);
337 (x, f(x))
338 })
339 .collect::<Vec<_>>();
340 let count = 1000;
341 let mut min_avg_sse = f32::INFINITY;
342 for _rerun in 0..50 {
343 for batch in 0..count {
344 let data = (0..4)
346 .map(|_| {
347 let x = rng.gen_range(-2.0..2.0);
348 ([x], f(x))
349 })
350 .collect::<Vec<_>>();
351 let progress = batch as f32 / count as f32;
352 let epsilon = 0.02 * (1.0 - progress) + 0.01;
353 network.train(data.iter(), epsilon);
354 }
365 let sse: f32 = test_set
366 .iter()
367 .map(|(x, y)| {
368 let y_pred = network.output(&[*x]);
369 (y - y_pred) * (y - y_pred)
370 })
371 .sum();
372 min_avg_sse = (sse / test_set.len() as f32).min(min_avg_sse);
373 }
379 assert!(
381 min_avg_sse < 0.1,
382 "min average SSE {min_avg_sse} is >= 0.1 over 20 runs"
383 );
384 }
385
386 #[test]
387 fn xor() {
388 let xor_data = [
389 ([0., 0.], 0.),
390 ([1., 0.], 1.),
391 ([0., 1.], 1.),
392 ([1., 1.], 0.),
393 ];
394 let mut min_sse = f32::INFINITY;
395 for _rerun in 0..20 {
396 let mut network = NeuralNetwork {
397 hidden_layer: [
398 Neuron::<2>::random_with_0_bias(),
399 Neuron::random_with_0_bias(),
400 ],
401 output_layer: Neuron::random_with_0_bias(),
402 };
403 for _i in 0..1000 {
404 network.train(xor_data.iter(), 0.03);
405 }
406 let mut sse = 0.;
407 for (x, y) in xor_data {
408 let y_pred = network.output(&x);
409 sse += (y_pred - y) * (y_pred - y);
410 }
411 min_sse = sse.min(min_sse);
412 }
418 assert!(min_sse < 0.01, "min SSE {min_sse} is >= 0.01 over 20 runs");
420 }
421}