1use crate::matrix::Matrix;
2use rand::{Rng, SeedableRng, rngs::StdRng};
3use serde::{Deserialize, Serialize};
4use std::{error::Error, fmt};
5
6fn sigmoid(x: &mut Matrix) {
7 x.apply(|x| 1.0 / (1.0 + (-x).exp()))
8}
9
10fn sigmoid_derivative(x: &mut Matrix) {
11 x.apply(|x| x * (1.0 - x))
12}
13
14fn tanh(x: &mut Matrix) {
15 x.apply(|x| x.tanh())
16}
17
18fn tanh_derivative(x: &mut Matrix) {
19 x.apply(|x| 1.0 - x.powi(2))
20}
21
22fn linear(_: &mut Matrix) {}
23
24fn linear_derivative(x: &mut Matrix) {
25 x.apply(|_| 1.0)
26}
27
28#[derive(Clone, Debug, Serialize, Deserialize)]
29pub enum ActivationFunction {
30 Sigmoid,
31 Tanh,
32 Linear,
33}
34
35impl Default for ActivationFunction {
36 fn default() -> Self {
37 ActivationFunction::Sigmoid
38 }
39}
40
41impl ActivationFunction {
42 fn apply(&self, x: &mut Matrix) {
43 match self {
44 ActivationFunction::Sigmoid => sigmoid(x),
45 ActivationFunction::Tanh => tanh(x),
46 ActivationFunction::Linear => linear(x),
47 }
48 }
49
50 fn derivative(&self, x: &mut Matrix) {
51 match self {
52 ActivationFunction::Sigmoid => sigmoid_derivative(x),
53 ActivationFunction::Tanh => tanh_derivative(x),
54 ActivationFunction::Linear => linear_derivative(x),
55 }
56 }
57}
58
59#[derive(Clone, Debug, PartialEq, Eq)]
60pub enum NeuralNetworkError {
61 InvalidLayerSize {
62 layer: &'static str,
63 size: usize,
64 },
65 InputLengthMismatch {
66 expected: usize,
67 got: usize,
68 },
69 TargetLengthMismatch {
70 expected: usize,
71 got: usize,
72 },
73}
74
75impl fmt::Display for NeuralNetworkError {
76 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77 match self {
78 NeuralNetworkError::InvalidLayerSize { layer, size } => {
79 write!(
80 f,
81 "invalid {layer} layer size: expected a positive size, got {size}"
82 )
83 }
84 NeuralNetworkError::InputLengthMismatch { expected, got } => {
85 write!(f, "input length mismatch: expected {expected}, got {got}")
86 }
87 NeuralNetworkError::TargetLengthMismatch { expected, got } => {
88 write!(f, "target length mismatch: expected {expected}, got {got}")
89 }
90 }
91 }
92}
93
94impl Error for NeuralNetworkError {}
95
96#[derive(Clone, Debug, Default, Serialize, Deserialize)]
98pub struct NeuralNetwork {
99 weights_input_hidden: Matrix,
100 weights_hidden_output: Matrix,
101 biases_hidden: Matrix,
102 biases_output: Matrix,
103 learning_rate: f64,
104 activation_function: ActivationFunction,
105}
106
107impl NeuralNetwork {
108 fn input_size(&self) -> usize {
109 self.weights_input_hidden.cols()
110 }
111
112 fn output_size(&self) -> usize {
113 self.weights_hidden_output.rows()
114 }
115
116 fn validate_input_len(&self, actual: usize) -> Result<(), NeuralNetworkError> {
117 if actual == self.input_size() {
118 Ok(())
119 } else {
120 Err(NeuralNetworkError::InputLengthMismatch {
121 expected: self.input_size(),
122 got: actual,
123 })
124 }
125 }
126
127 fn validate_target_len(&self, actual: usize) -> Result<(), NeuralNetworkError> {
128 if actual == self.output_size() {
129 Ok(())
130 } else {
131 Err(NeuralNetworkError::TargetLengthMismatch {
132 expected: self.output_size(),
133 got: actual,
134 })
135 }
136 }
137
138 pub fn new(
141 input_size: usize,
142 hidden_size: usize,
143 output_size: usize,
144 rng: Option<&mut StdRng>,
145 ) -> Result<Self, NeuralNetworkError> {
146 if input_size == 0 {
147 return Err(NeuralNetworkError::InvalidLayerSize {
148 layer: "input",
149 size: input_size,
150 });
151 }
152 if hidden_size == 0 {
153 return Err(NeuralNetworkError::InvalidLayerSize {
154 layer: "hidden",
155 size: hidden_size,
156 });
157 }
158 if output_size == 0 {
159 return Err(NeuralNetworkError::InvalidLayerSize {
160 layer: "output",
161 size: output_size,
162 });
163 }
164
165 let rng = match rng {
166 Some(rng) => rng,
167 None => &mut StdRng::from_os_rng(),
168 };
169
170 let limit_input_hidden = (6.0 / (input_size + hidden_size) as f64).sqrt();
171 let limit_hidden_output = (6.0 / (hidden_size + output_size) as f64).sqrt();
172
173 Ok(NeuralNetwork {
174 weights_input_hidden: Matrix::random_range(
175 rng,
176 hidden_size,
177 input_size,
178 -limit_input_hidden,
179 limit_input_hidden,
180 ),
181 weights_hidden_output: Matrix::random_range(
182 rng,
183 output_size,
184 hidden_size,
185 -limit_hidden_output,
186 limit_hidden_output,
187 ),
188 biases_hidden: Matrix::new(hidden_size, 1),
189 biases_output: Matrix::new(output_size, 1),
190 learning_rate: 0.01,
191 activation_function: ActivationFunction::default(),
192 })
193 }
194
195 pub fn learning_rate(&self) -> f64 {
197 self.learning_rate
198 }
199
200 pub fn set_learning_rate(&mut self, learning_rate: f64) {
202 self.learning_rate = learning_rate;
203 }
204
205 pub fn activation_function(&self) -> &ActivationFunction {
207 &self.activation_function
208 }
209
210 pub fn set_activation_function(&mut self, activation_function: ActivationFunction) {
212 self.activation_function = activation_function;
213 }
214
215 pub fn predict(&self, input: Vec<f64>) -> Result<Vec<f64>, NeuralNetworkError> {
217 self.validate_input_len(input.len())?;
218
219 let input_matrix = Matrix::from_col_vec(input);
220 let mut hidden_layer_input = &self.weights_input_hidden * &input_matrix;
221 hidden_layer_input += &self.biases_hidden;
222 let mut hidden_layer_output = hidden_layer_input;
223 self.activation_function.apply(&mut hidden_layer_output);
224
225 let output_layer_input =
226 &self.weights_hidden_output * &hidden_layer_output + &self.biases_output;
227 let mut output_layer_output = output_layer_input;
228 self.activation_function.apply(&mut output_layer_output);
229
230 Ok(output_layer_output.col(0))
231 }
232
233 pub fn train(
235 &mut self,
236 input: Vec<f64>,
237 target: Vec<f64>,
238 ) -> Result<(), NeuralNetworkError> {
239 self.validate_input_len(input.len())?;
240 self.validate_target_len(target.len())?;
241
242 let input_matrix = Matrix::from_col_vec(input);
243 let mut hidden_layer_input = &self.weights_input_hidden * &input_matrix;
244 hidden_layer_input += &self.biases_hidden;
245 let mut hidden_layer_output = hidden_layer_input;
246 self.activation_function.apply(&mut hidden_layer_output);
247
248 let output_layer_input =
249 &self.weights_hidden_output * &hidden_layer_output + &self.biases_output;
250 let mut output_layer_output = output_layer_input;
251 self.activation_function.apply(&mut output_layer_output);
252
253 let target = Matrix::from_col_vec(target);
254
255 let mut output_errors = target;
256 output_errors -= &output_layer_output;
257
258 let mut gradients = output_layer_output;
259 self.activation_function.derivative(&mut gradients);
260 gradients.hadamard_product(&output_errors);
261 gradients *= self.learning_rate;
262
263 let hidden_transposed = hidden_layer_output.transpose();
264 let weight_hidden_output_deltas = &gradients * &hidden_transposed;
265
266 let weight_hidden_output_transposed = self.weights_hidden_output.transpose();
267 let hidden_errors = &weight_hidden_output_transposed * &output_errors;
268
269 self.weights_hidden_output += &weight_hidden_output_deltas;
270 self.biases_output += &gradients;
271
272 let mut hidden_gradient = hidden_layer_output;
273 self.activation_function.derivative(&mut hidden_gradient);
274 hidden_gradient.hadamard_product(&hidden_errors);
275 hidden_gradient *= self.learning_rate;
276
277 let inputs_transposed = input_matrix.transpose();
278 let weight_input_hidden_deltas = &hidden_gradient * &inputs_transposed;
279 self.weights_input_hidden += &weight_input_hidden_deltas;
280 self.biases_hidden += &hidden_gradient;
281
282 Ok(())
283 }
284
285 pub fn mutate(&mut self, rng: &mut StdRng, mutation_rate: f64) {
286 for i in 0..self.weights_input_hidden.rows() {
287 for j in 0..self.weights_input_hidden.cols() {
288 if rng.random::<f64>() < mutation_rate {
289 self.weights_input_hidden
290 .set(i, j, rng.random_range(-1.0..1.0));
291 }
292 }
293 }
294 for i in 0..self.weights_hidden_output.rows() {
295 for j in 0..self.weights_hidden_output.cols() {
296 if rng.random::<f64>() < mutation_rate {
297 self.weights_hidden_output
298 .set(i, j, rng.random_range(-1.0..1.0));
299 }
300 }
301 }
302 for i in 0..self.biases_hidden.rows() {
303 if rng.random::<f64>() < mutation_rate {
304 self.biases_hidden.set(i, 0, rng.random_range(-1.0..1.0));
305 }
306 }
307 for i in 0..self.biases_output.rows() {
308 if rng.random::<f64>() < mutation_rate {
309 self.biases_output.set(i, 0, rng.random_range(-1.0..1.0));
310 }
311 }
312 }
313}
314
315#[cfg(test)]
316pub mod nn_tests {
317 use rand::{SeedableRng, rngs::StdRng};
318 use serde_json;
319
320 #[test]
321 fn it_creates_a_neural_network() {
322 let m = super::NeuralNetwork::new(3, 5, 2, None).unwrap();
323 assert_eq!(m.weights_input_hidden.rows(), 5);
324 assert_eq!(m.input_size(), 3);
325 assert_eq!(m.output_size(), 2);
326 assert_eq!(m.weights_hidden_output.cols(), 5);
327 assert_eq!(m.biases_hidden.rows(), 5);
328 assert_eq!(m.biases_hidden.cols(), 1);
329 assert_eq!(m.biases_output.rows(), 2);
330 assert_eq!(m.biases_output.cols(), 1);
331 }
332
333 #[test]
334 pub fn it_predicts() {
335 let m = super::NeuralNetwork::new(3, 5, 2, None).unwrap();
336 let input = vec![0.5, 0.2, 0.1];
337 let output = m.predict(input).unwrap();
338 assert_eq!(output.len(), 2);
339 assert_ne!(output[0], output[1]);
340 }
341
342 #[test]
343 fn it_learns_the_or_function() {
344 let mut rng = StdRng::seed_from_u64(42);
345 let mut nn = super::NeuralNetwork::new(2, 4, 1, Some(&mut rng)).unwrap();
346 nn.set_learning_rate(0.5);
347
348 let training_data = [
349 (vec![0.0, 0.0], vec![0.0]),
350 (vec![0.0, 1.0], vec![1.0]),
351 (vec![1.0, 0.0], vec![1.0]),
352 (vec![1.0, 1.0], vec![1.0]),
353 ];
354
355 for _ in 0..10_000 {
356 for (input, target) in &training_data {
357 nn.train(input.clone(), target.clone()).unwrap();
358 }
359 }
360
361 assert!(nn.predict(vec![0.0, 0.0]).unwrap()[0] < 0.2);
362 assert!(nn.predict(vec![0.0, 1.0]).unwrap()[0] > 0.8);
363 assert!(nn.predict(vec![1.0, 0.0]).unwrap()[0] > 0.8);
364 assert!(nn.predict(vec![1.0, 1.0]).unwrap()[0] > 0.8);
365 }
366
367 #[test]
368 fn tanh_derivative_uses_activated_output() {
369 let mut x = crate::Matrix::from_col_vec(vec![0.5, -0.25]);
370 super::tanh_derivative(&mut x);
371
372 assert!((x.get(0, 0) - 0.75).abs() < 1e-12);
373 assert!((x.get(1, 0) - 0.9375).abs() < 1e-12);
374 }
375
376 #[test]
377 fn predict_returns_clear_error_for_wrong_input_size() {
378 let nn = super::NeuralNetwork::new(3, 5, 2, None).unwrap();
379
380 assert_eq!(
381 nn.predict(vec![0.1, 0.2]),
382 Err(super::NeuralNetworkError::InputLengthMismatch {
383 expected: 3,
384 got: 2,
385 })
386 );
387 }
388
389 #[test]
390 fn train_returns_clear_error_for_wrong_target_size() {
391 let mut nn = super::NeuralNetwork::new(3, 5, 2, None).unwrap();
392
393 assert_eq!(
394 nn.train(vec![0.1, 0.2, 0.3], vec![1.0]),
395 Err(super::NeuralNetworkError::TargetLengthMismatch {
396 expected: 2,
397 got: 1,
398 })
399 );
400 }
401
402 #[test]
403 fn new_rejects_zero_sized_layers() {
404 assert_eq!(
405 super::NeuralNetwork::new(0, 5, 2, None).unwrap_err(),
406 super::NeuralNetworkError::InvalidLayerSize {
407 layer: "input",
408 size: 0,
409 }
410 );
411 assert_eq!(
412 super::NeuralNetwork::new(3, 0, 2, None).unwrap_err(),
413 super::NeuralNetworkError::InvalidLayerSize {
414 layer: "hidden",
415 size: 0,
416 }
417 );
418 assert_eq!(
419 super::NeuralNetwork::new(3, 5, 0, None).unwrap_err(),
420 super::NeuralNetworkError::InvalidLayerSize {
421 layer: "output",
422 size: 0,
423 }
424 );
425 }
426
427 #[test]
428 fn new_uses_zero_biases() {
429 let nn = super::NeuralNetwork::new(3, 5, 2, None).unwrap();
430
431 assert!(nn.biases_hidden.data().iter().all(|value| *value == 0.0));
432 assert!(nn.biases_output.data().iter().all(|value| *value == 0.0));
433 }
434
435 #[test]
436 fn new_uses_xavier_weight_ranges() {
437 let mut rng = StdRng::seed_from_u64(7);
438 let nn = super::NeuralNetwork::new(3, 5, 2, Some(&mut rng)).unwrap();
439 let limit_input_hidden = (6.0_f64 / 8.0_f64).sqrt();
440 let limit_hidden_output = (6.0_f64 / 7.0_f64).sqrt();
441
442 assert!(nn
443 .weights_input_hidden
444 .data()
445 .iter()
446 .all(|value| *value >= -limit_input_hidden && *value < limit_input_hidden));
447 assert!(nn
448 .weights_hidden_output
449 .data()
450 .iter()
451 .all(|value| *value >= -limit_hidden_output && *value < limit_hidden_output));
452 }
453
454 #[test]
455 fn it_learns_the_xor_function() {
456 let mut rng = StdRng::seed_from_u64(99);
457 let mut nn = super::NeuralNetwork::new(2, 4, 1, Some(&mut rng)).unwrap();
458 nn.set_learning_rate(0.5);
459
460 let training_data = [
461 (vec![0.0, 0.0], vec![0.0]),
462 (vec![0.0, 1.0], vec![1.0]),
463 (vec![1.0, 0.0], vec![1.0]),
464 (vec![1.0, 1.0], vec![0.0]),
465 ];
466
467 for _ in 0..20_000 {
468 for (input, target) in &training_data {
469 nn.train(input.clone(), target.clone()).unwrap();
470 }
471 }
472
473 assert!(nn.predict(vec![0.0, 0.0]).unwrap()[0] < 0.2);
474 assert!(nn.predict(vec![0.0, 1.0]).unwrap()[0] > 0.8);
475 assert!(nn.predict(vec![1.0, 0.0]).unwrap()[0] > 0.8);
476 assert!(nn.predict(vec![1.0, 1.0]).unwrap()[0] < 0.2);
477 }
478
479 #[test]
480 fn serde_round_trip_preserves_predictions() {
481 let mut rng = StdRng::seed_from_u64(123);
482 let mut nn = super::NeuralNetwork::new(2, 4, 1, Some(&mut rng)).unwrap();
483 nn.set_learning_rate(0.5);
484
485 let training_data = [
486 (vec![0.0, 0.0], vec![0.0]),
487 (vec![0.0, 1.0], vec![1.0]),
488 (vec![1.0, 0.0], vec![1.0]),
489 (vec![1.0, 1.0], vec![0.0]),
490 ];
491
492 for _ in 0..5_000 {
493 for (input, target) in &training_data {
494 nn.train(input.clone(), target.clone()).unwrap();
495 }
496 }
497
498 let probe_input = vec![0.25, 0.75];
499 let before = nn.predict(probe_input.clone()).unwrap();
500
501 let json = serde_json::to_string(&nn).unwrap();
502 let restored: super::NeuralNetwork = serde_json::from_str(&json).unwrap();
503 let after = restored.predict(probe_input).unwrap();
504
505 assert_eq!(before, after);
506 }
507}