Skip to main content

neural_network_study/
nn.rs

1use crate::matrix::Matrix;
2use rand::{Rng, SeedableRng, rngs::StdRng};
3use serde::{Deserialize, Serialize};
4use std::{error::Error, fmt};
5
6fn sigmoid(x: &mut Matrix) {
7    x.apply(|x| 1.0 / (1.0 + (-x).exp()))
8}
9
10fn sigmoid_derivative(x: &mut Matrix) {
11    x.apply(|x| x * (1.0 - x))
12}
13
14fn tanh(x: &mut Matrix) {
15    x.apply(|x| x.tanh())
16}
17
18fn tanh_derivative(x: &mut Matrix) {
19    x.apply(|x| 1.0 - x.powi(2))
20}
21
22fn linear(_: &mut Matrix) {}
23
24fn linear_derivative(x: &mut Matrix) {
25    x.apply(|_| 1.0)
26}
27
28#[derive(Clone, Debug, Serialize, Deserialize)]
29pub enum ActivationFunction {
30    Sigmoid,
31    Tanh,
32    Linear,
33}
34
35impl Default for ActivationFunction {
36    fn default() -> Self {
37        ActivationFunction::Sigmoid
38    }
39}
40
41impl ActivationFunction {
42    fn apply(&self, x: &mut Matrix) {
43        match self {
44            ActivationFunction::Sigmoid => sigmoid(x),
45            ActivationFunction::Tanh => tanh(x),
46            ActivationFunction::Linear => linear(x),
47        }
48    }
49
50    fn derivative(&self, x: &mut Matrix) {
51        match self {
52            ActivationFunction::Sigmoid => sigmoid_derivative(x),
53            ActivationFunction::Tanh => tanh_derivative(x),
54            ActivationFunction::Linear => linear_derivative(x),
55        }
56    }
57}
58
59#[derive(Clone, Debug, PartialEq, Eq)]
60pub enum NeuralNetworkError {
61    InvalidLayerSize {
62        layer: &'static str,
63        size: usize,
64    },
65    InputLengthMismatch {
66        expected: usize,
67        got: usize,
68    },
69    TargetLengthMismatch {
70        expected: usize,
71        got: usize,
72    },
73}
74
75impl fmt::Display for NeuralNetworkError {
76    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77        match self {
78            NeuralNetworkError::InvalidLayerSize { layer, size } => {
79                write!(
80                    f,
81                    "invalid {layer} layer size: expected a positive size, got {size}"
82                )
83            }
84            NeuralNetworkError::InputLengthMismatch { expected, got } => {
85                write!(f, "input length mismatch: expected {expected}, got {got}")
86            }
87            NeuralNetworkError::TargetLengthMismatch { expected, got } => {
88                write!(f, "target length mismatch: expected {expected}, got {got}")
89            }
90        }
91    }
92}
93
94impl Error for NeuralNetworkError {}
95
96/// A simple feedforward neural network with one hidden layer.
97#[derive(Clone, Debug, Default, Serialize, Deserialize)]
98pub struct NeuralNetwork {
99    weights_input_hidden: Matrix,
100    weights_hidden_output: Matrix,
101    biases_hidden: Matrix,
102    biases_output: Matrix,
103    learning_rate: f64,
104    activation_function: ActivationFunction,
105}
106
107impl NeuralNetwork {
108    fn input_size(&self) -> usize {
109        self.weights_input_hidden.cols()
110    }
111
112    fn output_size(&self) -> usize {
113        self.weights_hidden_output.rows()
114    }
115
116    fn validate_input_len(&self, actual: usize) -> Result<(), NeuralNetworkError> {
117        if actual == self.input_size() {
118            Ok(())
119        } else {
120            Err(NeuralNetworkError::InputLengthMismatch {
121                expected: self.input_size(),
122                got: actual,
123            })
124        }
125    }
126
127    fn validate_target_len(&self, actual: usize) -> Result<(), NeuralNetworkError> {
128        if actual == self.output_size() {
129            Ok(())
130        } else {
131            Err(NeuralNetworkError::TargetLengthMismatch {
132                expected: self.output_size(),
133                got: actual,
134            })
135        }
136    }
137
138    /// Creates a new neural network with the given sizes for input, hidden, and output layers.
139    /// The weights use Xavier-style initialization and biases start at zero.
140    pub fn new(
141        input_size: usize,
142        hidden_size: usize,
143        output_size: usize,
144        rng: Option<&mut StdRng>,
145    ) -> Result<Self, NeuralNetworkError> {
146        if input_size == 0 {
147            return Err(NeuralNetworkError::InvalidLayerSize {
148                layer: "input",
149                size: input_size,
150            });
151        }
152        if hidden_size == 0 {
153            return Err(NeuralNetworkError::InvalidLayerSize {
154                layer: "hidden",
155                size: hidden_size,
156            });
157        }
158        if output_size == 0 {
159            return Err(NeuralNetworkError::InvalidLayerSize {
160                layer: "output",
161                size: output_size,
162            });
163        }
164
165        let rng = match rng {
166            Some(rng) => rng,
167            None => &mut StdRng::from_os_rng(),
168        };
169
170        let limit_input_hidden = (6.0 / (input_size + hidden_size) as f64).sqrt();
171        let limit_hidden_output = (6.0 / (hidden_size + output_size) as f64).sqrt();
172
173        Ok(NeuralNetwork {
174            weights_input_hidden: Matrix::random_range(
175                rng,
176                hidden_size,
177                input_size,
178                -limit_input_hidden,
179                limit_input_hidden,
180            ),
181            weights_hidden_output: Matrix::random_range(
182                rng,
183                output_size,
184                hidden_size,
185                -limit_hidden_output,
186                limit_hidden_output,
187            ),
188            biases_hidden: Matrix::new(hidden_size, 1),
189            biases_output: Matrix::new(output_size, 1),
190            learning_rate: 0.01,
191            activation_function: ActivationFunction::default(),
192        })
193    }
194
195    /// Returns the learning rate of the neural network.
196    pub fn learning_rate(&self) -> f64 {
197        self.learning_rate
198    }
199
200    /// Sets the learning rate for the neural network.
201    pub fn set_learning_rate(&mut self, learning_rate: f64) {
202        self.learning_rate = learning_rate;
203    }
204
205    /// Returns the activation function of the neural network.
206    pub fn activation_function(&self) -> &ActivationFunction {
207        &self.activation_function
208    }
209
210    /// Sets the activation function for the neural network.
211    pub fn set_activation_function(&mut self, activation_function: ActivationFunction) {
212        self.activation_function = activation_function;
213    }
214
215    /// Predicts the output for the given input using the neural network.
216    pub fn predict(&self, input: Vec<f64>) -> Result<Vec<f64>, NeuralNetworkError> {
217        self.validate_input_len(input.len())?;
218
219        let input_matrix = Matrix::from_col_vec(input);
220        let mut hidden_layer_input = &self.weights_input_hidden * &input_matrix;
221        hidden_layer_input += &self.biases_hidden;
222        let mut hidden_layer_output = hidden_layer_input;
223        self.activation_function.apply(&mut hidden_layer_output);
224
225        let output_layer_input =
226            &self.weights_hidden_output * &hidden_layer_output + &self.biases_output;
227        let mut output_layer_output = output_layer_input;
228        self.activation_function.apply(&mut output_layer_output);
229
230        Ok(output_layer_output.col(0))
231    }
232
233    /// Trains the neural network using the given input and target output.
234    pub fn train(
235        &mut self,
236        input: Vec<f64>,
237        target: Vec<f64>,
238    ) -> Result<(), NeuralNetworkError> {
239        self.validate_input_len(input.len())?;
240        self.validate_target_len(target.len())?;
241
242        let input_matrix = Matrix::from_col_vec(input);
243        let mut hidden_layer_input = &self.weights_input_hidden * &input_matrix;
244        hidden_layer_input += &self.biases_hidden;
245        let mut hidden_layer_output = hidden_layer_input;
246        self.activation_function.apply(&mut hidden_layer_output);
247
248        let output_layer_input =
249            &self.weights_hidden_output * &hidden_layer_output + &self.biases_output;
250        let mut output_layer_output = output_layer_input;
251        self.activation_function.apply(&mut output_layer_output);
252
253        let target = Matrix::from_col_vec(target);
254
255        let mut output_errors = target;
256        output_errors -= &output_layer_output;
257
258        let mut gradients = output_layer_output;
259        self.activation_function.derivative(&mut gradients);
260        gradients.hadamard_product(&output_errors);
261        gradients *= self.learning_rate;
262
263        let hidden_transposed = hidden_layer_output.transpose();
264        let weight_hidden_output_deltas = &gradients * &hidden_transposed;
265
266        let weight_hidden_output_transposed = self.weights_hidden_output.transpose();
267        let hidden_errors = &weight_hidden_output_transposed * &output_errors;
268
269        self.weights_hidden_output += &weight_hidden_output_deltas;
270        self.biases_output += &gradients;
271
272        let mut hidden_gradient = hidden_layer_output;
273        self.activation_function.derivative(&mut hidden_gradient);
274        hidden_gradient.hadamard_product(&hidden_errors);
275        hidden_gradient *= self.learning_rate;
276
277        let inputs_transposed = input_matrix.transpose();
278        let weight_input_hidden_deltas = &hidden_gradient * &inputs_transposed;
279        self.weights_input_hidden += &weight_input_hidden_deltas;
280        self.biases_hidden += &hidden_gradient;
281
282        Ok(())
283    }
284
285    pub fn mutate(&mut self, rng: &mut StdRng, mutation_rate: f64) {
286        for i in 0..self.weights_input_hidden.rows() {
287            for j in 0..self.weights_input_hidden.cols() {
288                if rng.random::<f64>() < mutation_rate {
289                    self.weights_input_hidden
290                        .set(i, j, rng.random_range(-1.0..1.0));
291                }
292            }
293        }
294        for i in 0..self.weights_hidden_output.rows() {
295            for j in 0..self.weights_hidden_output.cols() {
296                if rng.random::<f64>() < mutation_rate {
297                    self.weights_hidden_output
298                        .set(i, j, rng.random_range(-1.0..1.0));
299                }
300            }
301        }
302        for i in 0..self.biases_hidden.rows() {
303            if rng.random::<f64>() < mutation_rate {
304                self.biases_hidden.set(i, 0, rng.random_range(-1.0..1.0));
305            }
306        }
307        for i in 0..self.biases_output.rows() {
308            if rng.random::<f64>() < mutation_rate {
309                self.biases_output.set(i, 0, rng.random_range(-1.0..1.0));
310            }
311        }
312    }
313}
314
315#[cfg(test)]
316pub mod nn_tests {
317    use rand::{SeedableRng, rngs::StdRng};
318    use serde_json;
319
320    #[test]
321    fn it_creates_a_neural_network() {
322        let m = super::NeuralNetwork::new(3, 5, 2, None).unwrap();
323        assert_eq!(m.weights_input_hidden.rows(), 5);
324        assert_eq!(m.input_size(), 3);
325        assert_eq!(m.output_size(), 2);
326        assert_eq!(m.weights_hidden_output.cols(), 5);
327        assert_eq!(m.biases_hidden.rows(), 5);
328        assert_eq!(m.biases_hidden.cols(), 1);
329        assert_eq!(m.biases_output.rows(), 2);
330        assert_eq!(m.biases_output.cols(), 1);
331    }
332
333    #[test]
334    pub fn it_predicts() {
335        let m = super::NeuralNetwork::new(3, 5, 2, None).unwrap();
336        let input = vec![0.5, 0.2, 0.1];
337        let output = m.predict(input).unwrap();
338        assert_eq!(output.len(), 2);
339        assert_ne!(output[0], output[1]);
340    }
341
342    #[test]
343    fn it_learns_the_or_function() {
344        let mut rng = StdRng::seed_from_u64(42);
345        let mut nn = super::NeuralNetwork::new(2, 4, 1, Some(&mut rng)).unwrap();
346        nn.set_learning_rate(0.5);
347
348        let training_data = [
349            (vec![0.0, 0.0], vec![0.0]),
350            (vec![0.0, 1.0], vec![1.0]),
351            (vec![1.0, 0.0], vec![1.0]),
352            (vec![1.0, 1.0], vec![1.0]),
353        ];
354
355        for _ in 0..10_000 {
356            for (input, target) in &training_data {
357                nn.train(input.clone(), target.clone()).unwrap();
358            }
359        }
360
361        assert!(nn.predict(vec![0.0, 0.0]).unwrap()[0] < 0.2);
362        assert!(nn.predict(vec![0.0, 1.0]).unwrap()[0] > 0.8);
363        assert!(nn.predict(vec![1.0, 0.0]).unwrap()[0] > 0.8);
364        assert!(nn.predict(vec![1.0, 1.0]).unwrap()[0] > 0.8);
365    }
366
367    #[test]
368    fn tanh_derivative_uses_activated_output() {
369        let mut x = crate::Matrix::from_col_vec(vec![0.5, -0.25]);
370        super::tanh_derivative(&mut x);
371
372        assert!((x.get(0, 0) - 0.75).abs() < 1e-12);
373        assert!((x.get(1, 0) - 0.9375).abs() < 1e-12);
374    }
375
376    #[test]
377    fn predict_returns_clear_error_for_wrong_input_size() {
378        let nn = super::NeuralNetwork::new(3, 5, 2, None).unwrap();
379
380        assert_eq!(
381            nn.predict(vec![0.1, 0.2]),
382            Err(super::NeuralNetworkError::InputLengthMismatch {
383                expected: 3,
384                got: 2,
385            })
386        );
387    }
388
389    #[test]
390    fn train_returns_clear_error_for_wrong_target_size() {
391        let mut nn = super::NeuralNetwork::new(3, 5, 2, None).unwrap();
392
393        assert_eq!(
394            nn.train(vec![0.1, 0.2, 0.3], vec![1.0]),
395            Err(super::NeuralNetworkError::TargetLengthMismatch {
396                expected: 2,
397                got: 1,
398            })
399        );
400    }
401
402    #[test]
403    fn new_rejects_zero_sized_layers() {
404        assert_eq!(
405            super::NeuralNetwork::new(0, 5, 2, None).unwrap_err(),
406            super::NeuralNetworkError::InvalidLayerSize {
407                layer: "input",
408                size: 0,
409            }
410        );
411        assert_eq!(
412            super::NeuralNetwork::new(3, 0, 2, None).unwrap_err(),
413            super::NeuralNetworkError::InvalidLayerSize {
414                layer: "hidden",
415                size: 0,
416            }
417        );
418        assert_eq!(
419            super::NeuralNetwork::new(3, 5, 0, None).unwrap_err(),
420            super::NeuralNetworkError::InvalidLayerSize {
421                layer: "output",
422                size: 0,
423            }
424        );
425    }
426
427    #[test]
428    fn new_uses_zero_biases() {
429        let nn = super::NeuralNetwork::new(3, 5, 2, None).unwrap();
430
431        assert!(nn.biases_hidden.data().iter().all(|value| *value == 0.0));
432        assert!(nn.biases_output.data().iter().all(|value| *value == 0.0));
433    }
434
435    #[test]
436    fn new_uses_xavier_weight_ranges() {
437        let mut rng = StdRng::seed_from_u64(7);
438        let nn = super::NeuralNetwork::new(3, 5, 2, Some(&mut rng)).unwrap();
439        let limit_input_hidden = (6.0_f64 / 8.0_f64).sqrt();
440        let limit_hidden_output = (6.0_f64 / 7.0_f64).sqrt();
441
442        assert!(nn
443            .weights_input_hidden
444            .data()
445            .iter()
446            .all(|value| *value >= -limit_input_hidden && *value < limit_input_hidden));
447        assert!(nn
448            .weights_hidden_output
449            .data()
450            .iter()
451            .all(|value| *value >= -limit_hidden_output && *value < limit_hidden_output));
452    }
453
454    #[test]
455    fn it_learns_the_xor_function() {
456        let mut rng = StdRng::seed_from_u64(99);
457        let mut nn = super::NeuralNetwork::new(2, 4, 1, Some(&mut rng)).unwrap();
458        nn.set_learning_rate(0.5);
459
460        let training_data = [
461            (vec![0.0, 0.0], vec![0.0]),
462            (vec![0.0, 1.0], vec![1.0]),
463            (vec![1.0, 0.0], vec![1.0]),
464            (vec![1.0, 1.0], vec![0.0]),
465        ];
466
467        for _ in 0..20_000 {
468            for (input, target) in &training_data {
469                nn.train(input.clone(), target.clone()).unwrap();
470            }
471        }
472
473        assert!(nn.predict(vec![0.0, 0.0]).unwrap()[0] < 0.2);
474        assert!(nn.predict(vec![0.0, 1.0]).unwrap()[0] > 0.8);
475        assert!(nn.predict(vec![1.0, 0.0]).unwrap()[0] > 0.8);
476        assert!(nn.predict(vec![1.0, 1.0]).unwrap()[0] < 0.2);
477    }
478
479    #[test]
480    fn serde_round_trip_preserves_predictions() {
481        let mut rng = StdRng::seed_from_u64(123);
482        let mut nn = super::NeuralNetwork::new(2, 4, 1, Some(&mut rng)).unwrap();
483        nn.set_learning_rate(0.5);
484
485        let training_data = [
486            (vec![0.0, 0.0], vec![0.0]),
487            (vec![0.0, 1.0], vec![1.0]),
488            (vec![1.0, 0.0], vec![1.0]),
489            (vec![1.0, 1.0], vec![0.0]),
490        ];
491
492        for _ in 0..5_000 {
493            for (input, target) in &training_data {
494                nn.train(input.clone(), target.clone()).unwrap();
495            }
496        }
497
498        let probe_input = vec![0.25, 0.75];
499        let before = nn.predict(probe_input.clone()).unwrap();
500
501        let json = serde_json::to_string(&nn).unwrap();
502        let restored: super::NeuralNetwork = serde_json::from_str(&json).unwrap();
503        let after = restored.predict(probe_input).unwrap();
504
505        assert_eq!(before, after);
506    }
507}