neural_network_study/
nn.rs

1use crate::matrix::Matrix;
2use rand::{Rng, SeedableRng, rngs::StdRng};
3use serde::{Deserialize, Serialize};
4
5fn sigmoid(x: &mut Matrix) {
6    x.apply(|x| 1.0 / (1.0 + (-x).exp()))
7}
8
9fn sigmoid_derivative(x: &mut Matrix) {
10    x.apply(|x| x * (1.0 - x))
11}
12
13fn tanh(x: &mut Matrix) {
14    x.apply(|x| x.tanh())
15}
16
17fn tanh_derivative(x: &mut Matrix) {
18    x.apply(|x| 1.0 - x.tanh().powi(2))
19}
20
21fn linear(_: &mut Matrix) {}
22
23fn linear_derivative(x: &mut Matrix) {
24    x.apply(|_| 1.0)
25}
26
27#[derive(Clone, Debug, Serialize, Deserialize)]
28pub enum ActivationFunction {
29    Sigmoid,
30    Tanh,
31    Linear,
32}
33
34impl Default for ActivationFunction {
35    fn default() -> Self {
36        ActivationFunction::Sigmoid
37    }
38}
39
40impl ActivationFunction {
41    fn apply(&self, x: &mut Matrix) {
42        match self {
43            ActivationFunction::Sigmoid => sigmoid(x),
44            ActivationFunction::Tanh => tanh(x),
45            ActivationFunction::Linear => linear(x),
46        }
47    }
48
49    fn derivative(&self, x: &mut Matrix) {
50        match self {
51            ActivationFunction::Sigmoid => sigmoid_derivative(x),
52            ActivationFunction::Tanh => tanh_derivative(x),
53            ActivationFunction::Linear => linear_derivative(x),
54        }
55    }
56}
57
58/// A simple feedforward neural network with one hidden layer.
59#[derive(Clone, Debug, Default, Serialize, Deserialize)]
60pub struct NeuralNetwork {
61    weights_input_hidden: Matrix,
62    weights_hidden_output: Matrix,
63    biases_hidden: Matrix,
64    biases_output: Matrix,
65    learning_rate: f64,
66    activation_function: ActivationFunction,
67}
68
69impl NeuralNetwork {
70    /// Creates a new neural network with the given sizes for input, hidden, and output layers.
71    /// The weights and biases are initialized randomly.
72    pub fn new(
73        input_size: usize,
74        hidden_size: usize,
75        output_size: usize,
76        rng: Option<&mut StdRng>,
77    ) -> Self {
78        let rng = match rng {
79            Some(rng) => rng,
80            None => &mut StdRng::from_os_rng(),
81        };
82        NeuralNetwork {
83            weights_input_hidden: Matrix::random(rng, hidden_size, input_size),
84            weights_hidden_output: Matrix::random(rng, output_size, hidden_size),
85            biases_hidden: Matrix::random(rng, hidden_size, 1),
86            biases_output: Matrix::random(rng, output_size, 1),
87            learning_rate: 0.01,
88            activation_function: ActivationFunction::default(),
89        }
90    }
91
92    /// Returns the learning rate of the neural network.
93    pub fn learning_rate(&self) -> f64 {
94        self.learning_rate
95    }
96
97    /// Sets the learning rate for the neural network.
98    pub fn set_learning_rate(&mut self, learning_rate: f64) {
99        self.learning_rate = learning_rate;
100    }
101
102    /// Returns the activation function of the neural network.
103    pub fn activation_function(&self) -> &ActivationFunction {
104        &self.activation_function
105    }
106
107    /// Sets the activation function for the neural network.
108    pub fn set_activation_function(&mut self, activation_function: ActivationFunction) {
109        self.activation_function = activation_function;
110    }
111
112    /// Predicts the output for the given input using the neural network.
113    pub fn predict(&self, input: Vec<f64>) -> Vec<f64> {
114        // Generate the hidden outputs
115        let input_matrix = Matrix::from_col_vec(input);
116        let mut hidden_layer_input = &self.weights_input_hidden * &input_matrix;
117        hidden_layer_input += &self.biases_hidden;
118        let mut hidden_layer_output = hidden_layer_input;
119        self.activation_function.apply(&mut hidden_layer_output);
120        // Generate the output's output
121        let output_layer_input =
122            &self.weights_hidden_output * &hidden_layer_output + &self.biases_output;
123        let mut output_layer_output = output_layer_input;
124        self.activation_function.apply(&mut output_layer_output);
125        // Return the output as a vector
126        output_layer_output.col(0)
127    }
128
129    /// Trains the neural network using the given input and target output.
130    /// The input and target should be vectors of the same length as the input and output sizes of the network.
131    /// The training process involves forward propagation and backpropagation to adjust the weights and biases.
132    pub fn train(&mut self, input: Vec<f64>, target: Vec<f64>) {
133        // Generate the hidden outputs
134        let input_matrix = Matrix::from_col_vec(input);
135        let mut hidden_layer_input = &self.weights_input_hidden * &input_matrix;
136        hidden_layer_input += &self.biases_hidden;
137        let mut hidden_layer_output = hidden_layer_input;
138        self.activation_function.apply(&mut hidden_layer_output);
139        // Generate the output's outputs
140        let output_layer_input =
141            &self.weights_hidden_output * &hidden_layer_output + &self.biases_output;
142        let mut output_layer_output = output_layer_input;
143        self.activation_function.apply(&mut output_layer_output);
144
145        // Create target matrix
146        let target = Matrix::from_col_vec(target);
147
148        // Calculate the error
149        // ERROR = TARGET - OUTPUT
150        let mut output_errors = target;
151        output_errors -= &output_layer_output;
152
153        // Calculate gradients
154        let mut gradients = output_layer_output;
155        self.activation_function.derivative(&mut gradients);
156        gradients.hadamard_product(&output_errors);
157        gradients *= self.learning_rate;
158
159        // Calculcate deltas
160        let hidden_transposed = hidden_layer_output.transpose();
161        let weight_hidden_output_deltas = &gradients * &hidden_transposed;
162
163        // Adjust the weights by deltas
164        self.weights_hidden_output += &weight_hidden_output_deltas;
165        // Adjust the bias by its deltas (which is just the gradients)
166        self.biases_output += &gradients;
167
168        // Calculate the hidden layer errors
169        let weight_hidden_output_transposed = self.weights_hidden_output.transpose();
170        let hidden_errors = &weight_hidden_output_transposed * &output_errors;
171
172        // Calculate hidden gradients
173        let mut hidden_gradient = hidden_layer_output;
174        self.activation_function.derivative(&mut hidden_gradient);
175        hidden_gradient.hadamard_product(&hidden_errors);
176        hidden_gradient *= self.learning_rate;
177
178        // Calculate input -> hidden deltas
179        let inputs_transposed = input_matrix.transpose();
180        let weight_input_hidden_deltas = &hidden_gradient * &inputs_transposed;
181        self.weights_input_hidden += &weight_input_hidden_deltas;
182        // Adjust the bias by its deltas (which is just the gradient)
183        self.biases_hidden += &hidden_gradient;
184    }
185
186    pub fn mutate(&mut self, rng: &mut StdRng, mutation_rate: f64) {
187        for i in 0..self.weights_input_hidden.rows() {
188            for j in 0..self.weights_input_hidden.cols() {
189                if rng.random::<f64>() < mutation_rate {
190                    self.weights_input_hidden
191                        .set(i, j, rng.random_range(-1.0..1.0));
192                }
193            }
194        }
195        for i in 0..self.weights_hidden_output.rows() {
196            for j in 0..self.weights_hidden_output.cols() {
197                if rng.random::<f64>() < mutation_rate {
198                    self.weights_hidden_output
199                        .set(i, j, rng.random_range(-1.0..1.0));
200                }
201            }
202        }
203        for i in 0..self.biases_hidden.rows() {
204            if rng.random::<f64>() < mutation_rate {
205                self.biases_hidden.set(i, 0, rng.random_range(-1.0..1.0));
206            }
207        }
208        for i in 0..self.biases_output.rows() {
209            if rng.random::<f64>() < mutation_rate {
210                self.biases_output.set(i, 0, rng.random_range(-1.0..1.0));
211            }
212        }
213    }
214}
215
216pub mod nn_tests {
217    #[test]
218    fn it_creates_a_neural_network() {
219        let m = super::NeuralNetwork::new(3, 5, 2, None);
220        assert_eq!(m.weights_input_hidden.rows(), 5);
221        assert_eq!(m.weights_input_hidden.cols(), 3);
222        assert_eq!(m.weights_hidden_output.rows(), 2);
223        assert_eq!(m.weights_hidden_output.cols(), 5);
224        assert_eq!(m.biases_hidden.rows(), 5);
225        assert_eq!(m.biases_hidden.cols(), 1);
226        assert_eq!(m.biases_output.rows(), 2);
227        assert_eq!(m.biases_output.cols(), 1);
228    }
229
230    #[test]
231    pub fn it_predicts() {
232        let m = super::NeuralNetwork::new(3, 5, 2, None);
233        let input = vec![0.5, 0.2, 0.1];
234        let output = m.predict(input.clone());
235        assert_eq!(output.len(), 2);
236        assert_ne!(output[0], output[1]);
237    }
238}