neural_network_study/
lib.rs

1use rand::Rng;
2use std::ops::{Add, AddAssign, Mul, MulAssign, Sub};
3
4/// A simple 2-dimensional matrix with basic operations
5#[derive(Clone, Debug)]
6pub struct Matrix {
7    rows: usize,
8    cols: usize,
9    data: Vec<Vec<f64>>,
10}
11
12impl Matrix {
13    /// Creates a new matrix with the given number of rows and columns,
14    /// initialized to zero.
15    pub fn new(rows: usize, cols: usize) -> Self {
16        let data = vec![vec![0.0; cols]; rows];
17        Self { rows, cols, data }
18    }
19
20    /// Creates a new matrix with the given number of rows and columns,
21    /// initialized with random values between -1.0 and 1.0.
22    pub fn random(rows: usize, cols: usize) -> Self {
23        let mut rng = rand::r#rng();
24        let data = (0..rows)
25            .map(|_| (0..cols).map(|_| rng.random_range(-1.0..1.0)).collect())
26            .collect();
27        Self { rows, cols, data }
28    }
29
30    /// Creates a new matrix from a 2D vector.
31    /// The outer vector represents the rows, and the inner vectors represent the columns.
32    /// Panics if the inner vectors have different lengths.
33    pub fn from_vec(data: Vec<Vec<f64>>) -> Self {
34        let rows = data.len();
35        let cols = if rows > 0 { data[0].len() } else { 0 };
36        for row in &data {
37            if row.len() != cols {
38                panic!("All rows must have the same number of columns");
39            }
40        }
41        Self { rows, cols, data }
42    }
43
44    /// Creates a new matrix from a column vector.
45    pub fn from_col_vec(data: Vec<f64>) -> Self {
46        let rows = data.len();
47        let cols = 1;
48        let data = data.into_iter().map(|x| vec![x]).collect();
49        Self { rows, cols, data }
50    }
51
52    /// Transposes the matrix.
53    pub fn transpose(&self) -> Self {
54        let mut transposed_data = vec![vec![0.0; self.rows]; self.cols];
55        for i in 0..self.rows {
56            for j in 0..self.cols {
57                transposed_data[j][i] = self.data[i][j];
58            }
59        }   
60        Self::from_vec(transposed_data)
61    }
62
63    /// Returns the number of rows in the matrix.
64    pub fn rows(&self) -> usize {
65        self.rows
66    }
67
68    /// Returns the number of columns in the matrix.
69    pub fn cols(&self) -> usize {
70        self.cols
71    }
72
73    /// Returns the column at the given index as a vector.
74    /// Panics if the index is out of bounds.
75    pub fn col(&self, index: usize) -> Vec<f64> {
76        if index >= self.cols {
77            panic!("Index out of bounds");
78        }
79        (0..self.rows).map(|i| self.data[i][index]).collect()
80    }
81
82    /// Returns a reference to the data in the matrix.
83    pub fn data(&self) -> &Vec<Vec<f64>> {
84        &self.data
85    }
86
87    /// Returns a mutable reference to the data in the matrix.
88    pub fn data_mut(&mut self) -> &mut Vec<Vec<f64>> {
89        &mut self.data
90    }
91
92    /// Returns the value at the given row and column.
93    /// Panics if the indices are out of bounds.
94    pub fn get(&self, row: usize, col: usize) -> f64 {
95        if row >= self.rows || col >= self.cols {
96            panic!("Index out of bounds");
97        }
98        self.data[row][col]
99    }
100
101    /// Returns a mutable reference to the value at the given row and column.
102    /// Panics if the indices are out of bounds.
103    pub fn get_mut(&mut self, row: usize, col: usize) -> &mut f64 {
104        if row >= self.rows || col >= self.cols {
105            panic!("Index out of bounds");
106        }
107        &mut self.data[row][col]
108    }
109
110    /// Sets the value at the given row and column.
111    /// Panics if the indices are out of bounds.
112    pub fn set(&mut self, row: usize, col: usize, value: f64) {
113        if row >= self.rows || col >= self.cols {
114            panic!("Index out of bounds");
115        }
116        self.data[row][col] = value;
117    }
118
119    /// Returns the matrix resulting from 
120    /// applying the function `f` to each element of the matrix.
121    pub fn map<F>(&self, f: F) -> Matrix
122    where
123        F: Fn(f64) -> f64,
124    {
125        let mut result = Matrix::new(self.rows, self.cols);
126        for i in 0..self.rows {
127            for j in 0..self.cols {
128                result.set(i, j, f(self.get(i, j)));
129            }
130        }
131        result
132    }
133
134    /// Applies the function `f` to each element of the matrix in place.
135    /// This is an in-place operation.
136    pub fn map_mut<F>(&mut self, f: F)
137    where
138        F: Fn(f64) -> f64,
139    {
140        for i in 0..self.rows {
141            for j in 0..self.cols {
142                self.set(i, j, f(self.get(i, j)));
143            }
144        }
145    }
146
147    pub fn hadamar_product(&mut self, other: &Matrix) {
148        if self.rows != other.rows || self.cols != other.cols {
149            panic!("Matrices must have the same dimensions for Hadamard product");
150        }
151        for i in 0..self.rows {
152            for j in 0..self.cols {
153                self.set(i, j, self.get(i, j) * other.get(i, j));
154            }
155        }
156    }
157}
158
159impl Add<&Matrix> for Matrix {
160    type Output = Matrix;
161
162    /// Adds two matrices together, component-wise.
163    /// Panics if the matrices have different dimensions.
164    fn add(self, other: &Matrix) -> Matrix {
165        if self.rows != other.rows || self.cols != other.cols {
166            panic!("Matrices must have the same dimensions to be added");
167        }
168        let mut result = Matrix::new(self.rows, self.cols);
169        for i in 0..self.rows {
170            for j in 0..self.cols {
171                result.set(i, j, self.get(i, j) + other.get(i, j));
172            }
173        }
174        result
175    }
176}
177
178impl AddAssign<&Matrix> for Matrix {
179    /// Adds another matrix to this matrix, component-wise.
180    /// Panics if the matrices have different dimensions.
181    /// This is an in-place operation.
182    fn add_assign(&mut self, other: &Matrix) {
183        if self.rows != other.rows || self.cols != other.cols {
184            panic!("Matrices must have the same dimensions to be added");
185        }
186        for i in 0..self.rows {
187            for j in 0..self.cols {
188                self.set(i, j, self.get(i, j) + other.get(i, j));
189            }
190        }
191    }
192}
193
194impl Sub<&Matrix> for Matrix  {
195    type Output = Matrix;
196
197    /// Subtracts another matrix from this matrix, component-wise.
198    /// Panics if the matrices have different dimensions.
199    fn sub(self, rhs: &Matrix) -> Self::Output {
200        if self.rows != rhs.rows || self.cols != rhs.cols {
201            panic!("Matrices must have the same dimensions to be subtracted");
202        }
203        let mut result = Matrix::new(self.rows, self.cols);
204        for i in 0..self.rows {
205            for j in 0..self.cols {
206                result.set(i, j, self.get(i, j) - rhs.get(i, j));
207            }
208        }
209        result
210    }
211}
212
213impl Mul<f64> for Matrix {
214    type Output = Matrix;
215
216    /// Multiplies the matrix by a scalar.
217    fn mul(self, scalar: f64) -> Matrix {
218        let mut result = Matrix::new(self.rows, self.cols);
219        for i in 0..self.rows {
220            for j in 0..self.cols {
221                result.set(i, j, self.get(i, j) * scalar);
222            }
223        }
224        result
225    }
226}
227
228impl MulAssign<f64> for Matrix {
229    /// Multiplies the matrix by a scalar in-place.
230    fn mul_assign(&mut self, scalar: f64) {
231        for i in 0..self.rows {
232            for j in 0..self.cols {
233                self.set(i, j, self.get(i, j) * scalar);
234            }
235        }
236    }
237}
238
239impl Mul<&Matrix> for &Matrix {
240    type Output = Matrix;
241
242    /// Multiplies two matrices together.
243    /// Panics if the matrices have incompatible dimensions.
244    fn mul(self, other: &Matrix) -> Matrix {
245        if self.cols != other.rows {
246            panic!("Matrices have incompatible dimensions for multiplication");
247        }
248        let mut result = Matrix::new(self.rows, other.cols);
249        for i in 0..self.rows {
250            for j in 0..other.cols {
251                let mut sum = 0.0;
252                for k in 0..self.cols {
253                    sum += self.get(i, k) * other.get(k, j);
254                }
255                result.set(i, j, sum);
256            }
257        }
258        result
259    }
260}
261
262#[cfg(test)]
263mod matrix_tests {
264    use super::*;
265
266    #[test]
267    fn it_works() {
268        let m = Matrix::new(2, 3);
269        assert_eq!(m.rows(), 2);
270        assert_eq!(m.cols(), 3);
271        assert_eq!(m.data().len(), 2);
272        assert_eq!(m.data[0].len(), 3);
273        assert_eq!(m.data[1].len(), 3);
274        assert_eq!(m.data[0][0], 0.0);
275        assert_eq!(m.data[0][1], 0.0);
276        assert_eq!(m.data[0][2], 0.0);
277        assert_eq!(m.data[1][0], 0.0);
278        assert_eq!(m.data[1][1], 0.0);
279        assert_eq!(m.data[1][2], 0.0);
280    }
281
282    #[test]
283    fn it_creates_random_matrix() {
284        let m = Matrix::random(2, 3);
285        assert_eq!(m.rows, 2);
286        assert_eq!(m.cols, 3);
287        assert_eq!(m.data.len(), 2);
288        assert_eq!(m.data[0].len(), 3);
289        assert_eq!(m.data[1].len(), 3);
290        for i in 0..2 {
291            for j in 0..3 {
292                assert!(m.data[i][j] >= -1.0 && m.data[i][j] <= 1.0);
293            }
294        }
295    }
296
297    #[test]
298    fn it_creates_a_matrix_from_a_vector() {
299        let v = vec![vec![1.0, 2.0, 5.0], vec![3.0, 4.0, 6.0]];
300        let m = Matrix::from_vec(v.clone());
301        assert_eq!(m.rows, 2);
302        assert_eq!(m.cols, 3);
303        assert_eq!(m.data, v);
304    }
305
306    #[test]
307    fn it_transposes_matrix() {
308        let m = Matrix::from_vec(vec![vec![1.0, 2.0, 5.0], vec![3.0, 4.0, 6.0]]);
309        let transposed = m.transpose();
310        assert_eq!(transposed.rows, 3);
311        assert_eq!(transposed.cols, 2);
312        assert_eq!(transposed.data[0][0], 1.0);
313        assert_eq!(transposed.data[0][1], 3.0);
314        assert_eq!(transposed.data[1][0], 2.0);
315        assert_eq!(transposed.data[1][1], 4.0);
316        assert_eq!(transposed.data[2][0], 5.0);
317        assert_eq!(transposed.data[2][1], 6.0);
318    }
319
320    #[test]
321    fn it_gets_and_sets_values() {
322        let mut m = Matrix::new(2, 3);
323        m.set(0, 0, 1.0);
324        m.set(1, 2, 2.0);
325        assert_eq!(m.get(0, 0), 1.0);
326        assert_eq!(m.get(1, 2), 2.0);
327        assert_eq!(m.get(0, 1), 0.0);
328        assert_eq!(m.get(1, 0), 0.0);
329    }
330
331    #[test]
332    #[should_panic(expected = "Index out of bounds")]
333    fn it_panics_on_out_of_bounds_get() {
334        let m = Matrix::new(2, 3);
335        m.get(2, 0);
336    }
337
338    #[test]
339    #[should_panic(expected = "Index out of bounds")]
340    fn it_panics_on_out_of_bounds_set() {
341        let mut m = Matrix::new(2, 3);
342        m.set(2, 0, 1.0);
343    }
344
345    #[test]
346    #[should_panic(expected = "Index out of bounds")]
347    fn it_panics_on_out_of_bounds_get_mut() {
348        let mut m = Matrix::new(2, 3);
349        m.get_mut(2, 0);
350    }
351
352    #[test]
353    #[should_panic(expected = "Index out of bounds")]
354    fn it_panics_on_out_of_bounds_set_mut() {
355        let mut m = Matrix::new(2, 3);
356        m.get_mut(2, 0);
357    }
358
359    #[test]
360    fn it_gets_and_sets_mutable_values() {
361        let mut m = Matrix::new(2, 3);
362        *m.get_mut(0, 0) = 1.0;
363        *m.get_mut(1, 2) = 2.0;
364        assert_eq!(m.get(0, 0), 1.0);
365        assert_eq!(m.get(1, 2), 2.0);
366        assert_eq!(m.get(0, 1), 0.0);
367        assert_eq!(m.get(1, 0), 0.0);
368    }
369
370    #[test]
371    fn it_returns_mutable_data() {
372        let mut m = Matrix::new(2, 3);
373        m.data_mut()[0][0] = 1.0;
374        m.data_mut()[1][2] = 2.0;
375        assert_eq!(m.get(0, 0), 1.0);
376        assert_eq!(m.get(1, 2), 2.0);
377        assert_eq!(m.get(0, 1), 0.0);
378        assert_eq!(m.get(1, 0), 0.0);
379    }
380
381    #[test]
382    fn it_adds_matrices() {
383        let m1 = Matrix::from_vec(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
384        let m2 = Matrix::from_vec(vec![vec![5.0, 6.0], vec![7.0, 8.0]]);
385        let result = m1 + &m2;
386        assert_eq!(result.get(0, 0), 6.0);
387        assert_eq!(result.get(0, 1), 8.0);
388        assert_eq!(result.get(1, 0), 10.0);
389        assert_eq!(result.get(1, 1), 12.0);
390    }
391
392    #[test]
393    fn it_adds_and_assigns() {
394        let mut m1 = Matrix::from_vec(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
395        let m2 = Matrix::from_vec(vec![vec![5.0, 6.0], vec![7.0, 8.0]]);
396        m1 += &m2;
397        assert_eq!(m1.get(0, 0), 6.0);
398        assert_eq!(m1.get(0, 1), 8.0);
399        assert_eq!(m1.get(1, 0), 10.0);
400        assert_eq!(m1.get(1, 1), 12.0);
401    }
402
403    #[test]
404    fn it_multiplies_by_scalar() {
405        let m = Matrix::from_vec(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
406        let result = m * 2.0;
407        assert_eq!(result.get(0, 0), 2.0);
408        assert_eq!(result.get(0, 1), 4.0);
409        assert_eq!(result.get(1, 0), 6.0);
410        assert_eq!(result.get(1, 1), 8.0);
411    }
412
413    #[test]
414    fn it_multiplies_by_scalar_in_place() {
415        let mut m = Matrix::from_vec(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
416        m *= 2.0;
417        assert_eq!(m.get(0, 0), 2.0);
418        assert_eq!(m.get(0, 1), 4.0);
419        assert_eq!(m.get(1, 0), 6.0);
420        assert_eq!(m.get(1, 1), 8.0);
421    }
422
423    #[test]
424    fn it_maps() {
425        let m = Matrix::from_vec(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
426        let result = m.map(|x| x * 2.0);
427        assert_eq!(result.get(0, 0), 2.0);
428        assert_eq!(result.get(0, 1), 4.0);
429        assert_eq!(result.get(1, 0), 6.0);
430        assert_eq!(result.get(1, 1), 8.0);
431    }
432
433    #[test]
434    fn it_maps_mut() {
435        let mut m = Matrix::from_vec(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
436        m.map_mut(|x| x * 2.0);
437        assert_eq!(m.get(0, 0), 2.0);
438        assert_eq!(m.get(0, 1), 4.0);
439        assert_eq!(m.get(1, 0), 6.0);
440        assert_eq!(m.get(1, 1), 8.0);
441    }
442}
443
444fn sigmoid(x: &Matrix) -> Matrix {
445    x.map(|x| 1.0 / (1.0 + (-x).exp()))
446}
447
448fn sigmoid_derivative(x: &Matrix) -> Matrix {
449    x.map(|x| x * (1.0 - x))
450}
451
452fn tanh(x: &Matrix) -> Matrix {
453    x.map(|x| x.tanh())
454}
455
456fn tanh_derivative(x: &Matrix) -> Matrix {
457    x.map(|x| 1.0 - x.tanh().powi(2))
458}
459
460fn linear(x: &Matrix) -> Matrix {
461    x.clone()
462}
463
464fn linear_derivative(x: &Matrix) -> Matrix {
465    x.map(|_| 1.0)
466}
467
468/// A simple feedforward neural network with one hidden layer.
469#[derive(Clone, Debug)]
470pub struct NeuralNetwork {
471    weights_input_hidden: Matrix,
472    weights_hidden_output: Matrix,
473    biases_hidden: Matrix,
474    biases_output: Matrix,
475    learning_rate: f64,
476    activation_function: fn(&Matrix) -> Matrix,
477    activation_function_derivative: fn(&Matrix) -> Matrix,
478}
479
480impl NeuralNetwork {
481    /// Creates a new neural network with the given sizes for input, hidden, and output layers.
482    /// The weights and biases are initialized randomly.
483    pub fn new(input_size: usize, hidden_size: usize, output_size: usize) -> Self {
484        NeuralNetwork {
485            weights_input_hidden: Matrix::random(hidden_size, input_size),
486            weights_hidden_output: Matrix::random(output_size, hidden_size),
487            biases_hidden: Matrix::random(hidden_size, 1),
488            biases_output: Matrix::random(output_size, 1),
489            learning_rate: 0.01,
490            activation_function: sigmoid,
491            activation_function_derivative: sigmoid_derivative,
492        }
493    }
494
495    /// Sets the learning rate for the neural network.
496    pub fn set_learning_rate(&mut self, learning_rate: f64) {
497        self.learning_rate = learning_rate;
498    }
499
500    /// Sets the activation function for the neural network.
501    pub fn set_activation_function(
502        &mut self,
503        activation_function: fn(&Matrix) -> Matrix,
504        activation_function_derivative: fn(&Matrix) -> Matrix,
505    ) {
506        self.activation_function = activation_function;
507        self.activation_function_derivative = activation_function_derivative;
508    }
509
510    pub fn set_linear_activation(&mut self) {
511        self.activation_function = linear;
512        self.activation_function_derivative = linear_derivative;
513    }
514
515    pub fn set_sigmoid_activation(&mut self) {
516        self.activation_function = sigmoid;
517        self.activation_function_derivative = sigmoid_derivative;
518    }
519
520    pub fn set_tanh_activation(&mut self) {
521        self.activation_function = tanh;
522        self.activation_function_derivative = tanh_derivative;
523    }
524
525    /// Predicts the output for the given input using the neural network.
526    pub fn predict(&self, input: Vec<f64>) -> Vec<f64> {
527        // Generate the hidden outputs
528        let input_matrix = Matrix::from_col_vec(input);
529        let hidden_layer_input = &self.weights_input_hidden * &input_matrix + &self.biases_hidden;
530        let hidden_layer_output = (self.activation_function)(&hidden_layer_input);
531        // Generate the output's output
532        let output_layer_input = &self.weights_hidden_output * &hidden_layer_output + &self.biases_output;
533        let output_layer_output = (self.activation_function)(&output_layer_input);
534        // Return the output as a vector
535        output_layer_output.col(0)
536    }
537
538    /// Trains the neural network using the given input and target output.
539    /// The input and target should be vectors of the same length as the input and output sizes of the network.
540    /// The training process involves forward propagation and backpropagation to adjust the weights and biases.
541    pub fn train(&mut self, input: Vec<f64>, target: Vec<f64>) {
542        // Generate the hidden outputs
543        let input = Matrix::from_col_vec(input);
544        let hidden_layer_input = &self.weights_input_hidden * &input + &self.biases_hidden;
545        let hidden_layer_output = (self.activation_function)(&hidden_layer_input);
546
547        // Generate the output's outputs
548        let output_layer_input = &self.weights_hidden_output * &hidden_layer_output + &self.biases_output;
549        let output_layer_output = (self.activation_function)(&output_layer_input);
550        
551        // Create target matrix
552        let target = Matrix::from_col_vec(target);
553
554        // Calculate the error
555        // ERROR = TARGET - OUTPUT
556        let output_errors = target - &output_layer_output;
557
558        // Calculate gradients
559        let mut gradients = (self.activation_function_derivative)(&output_layer_output);
560        gradients.hadamar_product(&output_errors);
561        gradients *= self.learning_rate;
562
563        // Calculcate deltas
564        let hidden_transposed = hidden_layer_output.transpose();
565        let weight_hidden_output_deltas = &gradients * &hidden_transposed;
566
567        // Adjust the weights by deltas
568        self.weights_hidden_output += &weight_hidden_output_deltas;
569        // Adjust the bias by its deltas (which is just the gradients)
570        self.biases_output += &gradients;
571
572        // Calculate the hidden layer errors
573        let weight_hidden_output_transposed = self.weights_hidden_output.transpose();
574        let hidden_errors = &weight_hidden_output_transposed * &output_errors;
575
576        // Calculate hidden gradients
577        let mut hidden_gradient = (self.activation_function_derivative)(&hidden_layer_output);
578        hidden_gradient.hadamar_product(&hidden_errors);
579        hidden_gradient *= self.learning_rate;
580
581        // Calculate input -> hidden deltas
582        let inputs_transposed = input.transpose();
583        let weight_input_hidden_deltas = &hidden_gradient * &inputs_transposed;
584
585        self.weights_input_hidden += &weight_input_hidden_deltas;
586        // Adjust the bias by its deltas (which is just the gradient)
587        self.biases_hidden += &hidden_gradient;
588    }
589}
590
591pub mod nn_tests {
592    #[test]
593    fn it_creates_a_neural_network() {
594        let m = super::NeuralNetwork::new(3, 5, 2);
595        assert_eq!(m.weights_input_hidden.rows(), 5);
596        assert_eq!(m.weights_input_hidden.cols(), 3);
597        assert_eq!(m.weights_hidden_output.rows(), 2);
598        assert_eq!(m.weights_hidden_output.cols(), 5);
599        assert_eq!(m.biases_hidden.rows(), 5);
600        assert_eq!(m.biases_hidden.cols(), 1);
601        assert_eq!(m.biases_output.rows(), 2);
602        assert_eq!(m.biases_output.cols(), 1);
603    }
604
605    #[test]
606    pub fn it_predicts() {
607        let m = super::NeuralNetwork::new(3, 5, 2);
608        let input = vec![0.5, 0.2, 0.1];
609        let output = m.predict(input.clone());
610        assert_eq!(output.len(), 2);
611        assert_ne!(output[0], output[1]);
612    }
613}