neural_network_study/
nn.rs1use crate::matrix::Matrix;
2use rand::{Rng, SeedableRng, rngs::StdRng};
3use serde::{Deserialize, Serialize};
4
5fn sigmoid(x: &mut Matrix) {
6 x.apply(|x| 1.0 / (1.0 + (-x).exp()))
7}
8
9fn sigmoid_derivative(x: &mut Matrix) {
10 x.apply(|x| x * (1.0 - x))
11}
12
13fn tanh(x: &mut Matrix) {
14 x.apply(|x| x.tanh())
15}
16
17fn tanh_derivative(x: &mut Matrix) {
18 x.apply(|x| 1.0 - x.tanh().powi(2))
19}
20
21fn linear(_: &mut Matrix) {}
22
23fn linear_derivative(x: &mut Matrix) {
24 x.apply(|_| 1.0)
25}
26
27#[derive(Clone, Debug, Serialize, Deserialize)]
28pub enum ActivationFunction {
29 Sigmoid,
30 Tanh,
31 Linear,
32}
33
34impl Default for ActivationFunction {
35 fn default() -> Self {
36 ActivationFunction::Sigmoid
37 }
38}
39
40impl ActivationFunction {
41 fn apply(&self, x: &mut Matrix) {
42 match self {
43 ActivationFunction::Sigmoid => sigmoid(x),
44 ActivationFunction::Tanh => tanh(x),
45 ActivationFunction::Linear => linear(x),
46 }
47 }
48
49 fn derivative(&self, x: &mut Matrix) {
50 match self {
51 ActivationFunction::Sigmoid => sigmoid_derivative(x),
52 ActivationFunction::Tanh => tanh_derivative(x),
53 ActivationFunction::Linear => linear_derivative(x),
54 }
55 }
56}
57
58#[derive(Clone, Debug, Default, Serialize, Deserialize)]
60pub struct NeuralNetwork {
61 weights_input_hidden: Matrix,
62 weights_hidden_output: Matrix,
63 biases_hidden: Matrix,
64 biases_output: Matrix,
65 learning_rate: f64,
66 activation_function: ActivationFunction,
67}
68
69impl NeuralNetwork {
70 pub fn new(
73 input_size: usize,
74 hidden_size: usize,
75 output_size: usize,
76 rng: Option<&mut StdRng>,
77 ) -> Self {
78 let rng = match rng {
79 Some(rng) => rng,
80 None => &mut StdRng::from_os_rng(),
81 };
82 NeuralNetwork {
83 weights_input_hidden: Matrix::random(rng, hidden_size, input_size),
84 weights_hidden_output: Matrix::random(rng, output_size, hidden_size),
85 biases_hidden: Matrix::random(rng, hidden_size, 1),
86 biases_output: Matrix::random(rng, output_size, 1),
87 learning_rate: 0.01,
88 activation_function: ActivationFunction::default(),
89 }
90 }
91
92 pub fn learning_rate(&self) -> f64 {
94 self.learning_rate
95 }
96
97 pub fn set_learning_rate(&mut self, learning_rate: f64) {
99 self.learning_rate = learning_rate;
100 }
101
102 pub fn activation_function(&self) -> &ActivationFunction {
104 &self.activation_function
105 }
106
107 pub fn set_activation_function(&mut self, activation_function: ActivationFunction) {
109 self.activation_function = activation_function;
110 }
111
112 pub fn predict(&self, input: Vec<f64>) -> Vec<f64> {
114 let input_matrix = Matrix::from_col_vec(input);
116 let mut hidden_layer_input = &self.weights_input_hidden * &input_matrix;
117 hidden_layer_input += &self.biases_hidden;
118 let mut hidden_layer_output = hidden_layer_input;
119 self.activation_function.apply(&mut hidden_layer_output);
120 let output_layer_input =
122 &self.weights_hidden_output * &hidden_layer_output + &self.biases_output;
123 let mut output_layer_output = output_layer_input;
124 self.activation_function.apply(&mut output_layer_output);
125 output_layer_output.col(0)
127 }
128
129 pub fn train(&mut self, input: Vec<f64>, target: Vec<f64>) {
133 let input_matrix = Matrix::from_col_vec(input);
135 let mut hidden_layer_input = &self.weights_input_hidden * &input_matrix;
136 hidden_layer_input += &self.biases_hidden;
137 let mut hidden_layer_output = hidden_layer_input;
138 self.activation_function.apply(&mut hidden_layer_output);
139 let output_layer_input =
141 &self.weights_hidden_output * &hidden_layer_output + &self.biases_output;
142 let mut output_layer_output = output_layer_input;
143 self.activation_function.apply(&mut output_layer_output);
144
145 let target = Matrix::from_col_vec(target);
147
148 let mut output_errors = target;
151 output_errors -= &output_layer_output;
152
153 let mut gradients = output_layer_output;
155 self.activation_function.derivative(&mut gradients);
156 gradients.hadamard_product(&output_errors);
157 gradients *= self.learning_rate;
158
159 let hidden_transposed = hidden_layer_output.transpose();
161 let weight_hidden_output_deltas = &gradients * &hidden_transposed;
162
163 self.weights_hidden_output += &weight_hidden_output_deltas;
165 self.biases_output += &gradients;
167
168 let weight_hidden_output_transposed = self.weights_hidden_output.transpose();
170 let hidden_errors = &weight_hidden_output_transposed * &output_errors;
171
172 let mut hidden_gradient = hidden_layer_output;
174 self.activation_function.derivative(&mut hidden_gradient);
175 hidden_gradient.hadamard_product(&hidden_errors);
176 hidden_gradient *= self.learning_rate;
177
178 let inputs_transposed = input_matrix.transpose();
180 let weight_input_hidden_deltas = &hidden_gradient * &inputs_transposed;
181 self.weights_input_hidden += &weight_input_hidden_deltas;
182 self.biases_hidden += &hidden_gradient;
184 }
185
186 pub fn mutate(&mut self, rng: &mut StdRng, mutation_rate: f64) {
187 for i in 0..self.weights_input_hidden.rows() {
188 for j in 0..self.weights_input_hidden.cols() {
189 if rng.random::<f64>() < mutation_rate {
190 self.weights_input_hidden
191 .set(i, j, rng.random_range(-1.0..1.0));
192 }
193 }
194 }
195 for i in 0..self.weights_hidden_output.rows() {
196 for j in 0..self.weights_hidden_output.cols() {
197 if rng.random::<f64>() < mutation_rate {
198 self.weights_hidden_output
199 .set(i, j, rng.random_range(-1.0..1.0));
200 }
201 }
202 }
203 for i in 0..self.biases_hidden.rows() {
204 if rng.random::<f64>() < mutation_rate {
205 self.biases_hidden.set(i, 0, rng.random_range(-1.0..1.0));
206 }
207 }
208 for i in 0..self.biases_output.rows() {
209 if rng.random::<f64>() < mutation_rate {
210 self.biases_output.set(i, 0, rng.random_range(-1.0..1.0));
211 }
212 }
213 }
214}
215
216pub mod nn_tests {
217 #[test]
218 fn it_creates_a_neural_network() {
219 let m = super::NeuralNetwork::new(3, 5, 2, None);
220 assert_eq!(m.weights_input_hidden.rows(), 5);
221 assert_eq!(m.weights_input_hidden.cols(), 3);
222 assert_eq!(m.weights_hidden_output.rows(), 2);
223 assert_eq!(m.weights_hidden_output.cols(), 5);
224 assert_eq!(m.biases_hidden.rows(), 5);
225 assert_eq!(m.biases_hidden.cols(), 1);
226 assert_eq!(m.biases_output.rows(), 2);
227 assert_eq!(m.biases_output.cols(), 1);
228 }
229
230 #[test]
231 pub fn it_predicts() {
232 let m = super::NeuralNetwork::new(3, 5, 2, None);
233 let input = vec![0.5, 0.2, 0.1];
234 let output = m.predict(input.clone());
235 assert_eq!(output.len(), 2);
236 assert_ne!(output[0], output[1]);
237 }
238}