1use rand::Rng;
2use std::ops::{Add, AddAssign, Mul, MulAssign, Sub};
3
4#[derive(Clone, Debug)]
6pub struct Matrix {
7 rows: usize,
8 cols: usize,
9 data: Vec<Vec<f64>>,
10}
11
12impl Matrix {
13 pub fn new(rows: usize, cols: usize) -> Self {
16 let data = vec![vec![0.0; cols]; rows];
17 Self { rows, cols, data }
18 }
19
20 pub fn random(rows: usize, cols: usize) -> Self {
23 let mut rng = rand::r#rng();
24 let data = (0..rows)
25 .map(|_| (0..cols).map(|_| rng.random_range(-1.0..1.0)).collect())
26 .collect();
27 Self { rows, cols, data }
28 }
29
30 pub fn from_vec(data: Vec<Vec<f64>>) -> Self {
34 let rows = data.len();
35 let cols = if rows > 0 { data[0].len() } else { 0 };
36 for row in &data {
37 if row.len() != cols {
38 panic!("All rows must have the same number of columns");
39 }
40 }
41 Self { rows, cols, data }
42 }
43
44 pub fn from_col_vec(data: Vec<f64>) -> Self {
46 let rows = data.len();
47 let cols = 1;
48 let data = data.into_iter().map(|x| vec![x]).collect();
49 Self { rows, cols, data }
50 }
51
52 pub fn transpose(&self) -> Self {
54 let mut transposed_data = vec![vec![0.0; self.rows]; self.cols];
55 for i in 0..self.rows {
56 for j in 0..self.cols {
57 transposed_data[j][i] = self.data[i][j];
58 }
59 }
60 Self::from_vec(transposed_data)
61 }
62
63 pub fn rows(&self) -> usize {
65 self.rows
66 }
67
68 pub fn cols(&self) -> usize {
70 self.cols
71 }
72
73 pub fn col(&self, index: usize) -> Vec<f64> {
76 if index >= self.cols {
77 panic!("Index out of bounds");
78 }
79 (0..self.rows).map(|i| self.data[i][index]).collect()
80 }
81
82 pub fn data(&self) -> &Vec<Vec<f64>> {
84 &self.data
85 }
86
87 pub fn data_mut(&mut self) -> &mut Vec<Vec<f64>> {
89 &mut self.data
90 }
91
92 pub fn get(&self, row: usize, col: usize) -> f64 {
95 if row >= self.rows || col >= self.cols {
96 panic!("Index out of bounds");
97 }
98 self.data[row][col]
99 }
100
101 pub fn get_mut(&mut self, row: usize, col: usize) -> &mut f64 {
104 if row >= self.rows || col >= self.cols {
105 panic!("Index out of bounds");
106 }
107 &mut self.data[row][col]
108 }
109
110 pub fn set(&mut self, row: usize, col: usize, value: f64) {
113 if row >= self.rows || col >= self.cols {
114 panic!("Index out of bounds");
115 }
116 self.data[row][col] = value;
117 }
118
119 pub fn map<F>(&self, f: F) -> Matrix
122 where
123 F: Fn(f64) -> f64,
124 {
125 let mut result = Matrix::new(self.rows, self.cols);
126 for i in 0..self.rows {
127 for j in 0..self.cols {
128 result.set(i, j, f(self.get(i, j)));
129 }
130 }
131 result
132 }
133
134 pub fn map_mut<F>(&mut self, f: F)
137 where
138 F: Fn(f64) -> f64,
139 {
140 for i in 0..self.rows {
141 for j in 0..self.cols {
142 self.set(i, j, f(self.get(i, j)));
143 }
144 }
145 }
146
147 pub fn hadamar_product(&mut self, other: &Matrix) {
148 if self.rows != other.rows || self.cols != other.cols {
149 panic!("Matrices must have the same dimensions for Hadamard product");
150 }
151 for i in 0..self.rows {
152 for j in 0..self.cols {
153 self.set(i, j, self.get(i, j) * other.get(i, j));
154 }
155 }
156 }
157}
158
159impl Add<&Matrix> for Matrix {
160 type Output = Matrix;
161
162 fn add(self, other: &Matrix) -> Matrix {
165 if self.rows != other.rows || self.cols != other.cols {
166 panic!("Matrices must have the same dimensions to be added");
167 }
168 let mut result = Matrix::new(self.rows, self.cols);
169 for i in 0..self.rows {
170 for j in 0..self.cols {
171 result.set(i, j, self.get(i, j) + other.get(i, j));
172 }
173 }
174 result
175 }
176}
177
178impl AddAssign<&Matrix> for Matrix {
179 fn add_assign(&mut self, other: &Matrix) {
183 if self.rows != other.rows || self.cols != other.cols {
184 panic!("Matrices must have the same dimensions to be added");
185 }
186 for i in 0..self.rows {
187 for j in 0..self.cols {
188 self.set(i, j, self.get(i, j) + other.get(i, j));
189 }
190 }
191 }
192}
193
194impl Sub<&Matrix> for Matrix {
195 type Output = Matrix;
196
197 fn sub(self, rhs: &Matrix) -> Self::Output {
200 if self.rows != rhs.rows || self.cols != rhs.cols {
201 panic!("Matrices must have the same dimensions to be subtracted");
202 }
203 let mut result = Matrix::new(self.rows, self.cols);
204 for i in 0..self.rows {
205 for j in 0..self.cols {
206 result.set(i, j, self.get(i, j) - rhs.get(i, j));
207 }
208 }
209 result
210 }
211}
212
213impl Mul<f64> for Matrix {
214 type Output = Matrix;
215
216 fn mul(self, scalar: f64) -> Matrix {
218 let mut result = Matrix::new(self.rows, self.cols);
219 for i in 0..self.rows {
220 for j in 0..self.cols {
221 result.set(i, j, self.get(i, j) * scalar);
222 }
223 }
224 result
225 }
226}
227
228impl MulAssign<f64> for Matrix {
229 fn mul_assign(&mut self, scalar: f64) {
231 for i in 0..self.rows {
232 for j in 0..self.cols {
233 self.set(i, j, self.get(i, j) * scalar);
234 }
235 }
236 }
237}
238
239impl Mul<&Matrix> for &Matrix {
240 type Output = Matrix;
241
242 fn mul(self, other: &Matrix) -> Matrix {
245 if self.cols != other.rows {
246 panic!("Matrices have incompatible dimensions for multiplication");
247 }
248 let mut result = Matrix::new(self.rows, other.cols);
249 for i in 0..self.rows {
250 for j in 0..other.cols {
251 let mut sum = 0.0;
252 for k in 0..self.cols {
253 sum += self.get(i, k) * other.get(k, j);
254 }
255 result.set(i, j, sum);
256 }
257 }
258 result
259 }
260}
261
262#[cfg(test)]
263mod matrix_tests {
264 use super::*;
265
266 #[test]
267 fn it_works() {
268 let m = Matrix::new(2, 3);
269 assert_eq!(m.rows(), 2);
270 assert_eq!(m.cols(), 3);
271 assert_eq!(m.data().len(), 2);
272 assert_eq!(m.data[0].len(), 3);
273 assert_eq!(m.data[1].len(), 3);
274 assert_eq!(m.data[0][0], 0.0);
275 assert_eq!(m.data[0][1], 0.0);
276 assert_eq!(m.data[0][2], 0.0);
277 assert_eq!(m.data[1][0], 0.0);
278 assert_eq!(m.data[1][1], 0.0);
279 assert_eq!(m.data[1][2], 0.0);
280 }
281
282 #[test]
283 fn it_creates_random_matrix() {
284 let m = Matrix::random(2, 3);
285 assert_eq!(m.rows, 2);
286 assert_eq!(m.cols, 3);
287 assert_eq!(m.data.len(), 2);
288 assert_eq!(m.data[0].len(), 3);
289 assert_eq!(m.data[1].len(), 3);
290 for i in 0..2 {
291 for j in 0..3 {
292 assert!(m.data[i][j] >= -1.0 && m.data[i][j] <= 1.0);
293 }
294 }
295 }
296
297 #[test]
298 fn it_creates_a_matrix_from_a_vector() {
299 let v = vec![vec![1.0, 2.0, 5.0], vec![3.0, 4.0, 6.0]];
300 let m = Matrix::from_vec(v.clone());
301 assert_eq!(m.rows, 2);
302 assert_eq!(m.cols, 3);
303 assert_eq!(m.data, v);
304 }
305
306 #[test]
307 fn it_transposes_matrix() {
308 let m = Matrix::from_vec(vec![vec![1.0, 2.0, 5.0], vec![3.0, 4.0, 6.0]]);
309 let transposed = m.transpose();
310 assert_eq!(transposed.rows, 3);
311 assert_eq!(transposed.cols, 2);
312 assert_eq!(transposed.data[0][0], 1.0);
313 assert_eq!(transposed.data[0][1], 3.0);
314 assert_eq!(transposed.data[1][0], 2.0);
315 assert_eq!(transposed.data[1][1], 4.0);
316 assert_eq!(transposed.data[2][0], 5.0);
317 assert_eq!(transposed.data[2][1], 6.0);
318 }
319
320 #[test]
321 fn it_gets_and_sets_values() {
322 let mut m = Matrix::new(2, 3);
323 m.set(0, 0, 1.0);
324 m.set(1, 2, 2.0);
325 assert_eq!(m.get(0, 0), 1.0);
326 assert_eq!(m.get(1, 2), 2.0);
327 assert_eq!(m.get(0, 1), 0.0);
328 assert_eq!(m.get(1, 0), 0.0);
329 }
330
331 #[test]
332 #[should_panic(expected = "Index out of bounds")]
333 fn it_panics_on_out_of_bounds_get() {
334 let m = Matrix::new(2, 3);
335 m.get(2, 0);
336 }
337
338 #[test]
339 #[should_panic(expected = "Index out of bounds")]
340 fn it_panics_on_out_of_bounds_set() {
341 let mut m = Matrix::new(2, 3);
342 m.set(2, 0, 1.0);
343 }
344
345 #[test]
346 #[should_panic(expected = "Index out of bounds")]
347 fn it_panics_on_out_of_bounds_get_mut() {
348 let mut m = Matrix::new(2, 3);
349 m.get_mut(2, 0);
350 }
351
352 #[test]
353 #[should_panic(expected = "Index out of bounds")]
354 fn it_panics_on_out_of_bounds_set_mut() {
355 let mut m = Matrix::new(2, 3);
356 m.get_mut(2, 0);
357 }
358
359 #[test]
360 fn it_gets_and_sets_mutable_values() {
361 let mut m = Matrix::new(2, 3);
362 *m.get_mut(0, 0) = 1.0;
363 *m.get_mut(1, 2) = 2.0;
364 assert_eq!(m.get(0, 0), 1.0);
365 assert_eq!(m.get(1, 2), 2.0);
366 assert_eq!(m.get(0, 1), 0.0);
367 assert_eq!(m.get(1, 0), 0.0);
368 }
369
370 #[test]
371 fn it_returns_mutable_data() {
372 let mut m = Matrix::new(2, 3);
373 m.data_mut()[0][0] = 1.0;
374 m.data_mut()[1][2] = 2.0;
375 assert_eq!(m.get(0, 0), 1.0);
376 assert_eq!(m.get(1, 2), 2.0);
377 assert_eq!(m.get(0, 1), 0.0);
378 assert_eq!(m.get(1, 0), 0.0);
379 }
380
381 #[test]
382 fn it_adds_matrices() {
383 let m1 = Matrix::from_vec(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
384 let m2 = Matrix::from_vec(vec![vec![5.0, 6.0], vec![7.0, 8.0]]);
385 let result = m1 + &m2;
386 assert_eq!(result.get(0, 0), 6.0);
387 assert_eq!(result.get(0, 1), 8.0);
388 assert_eq!(result.get(1, 0), 10.0);
389 assert_eq!(result.get(1, 1), 12.0);
390 }
391
392 #[test]
393 fn it_adds_and_assigns() {
394 let mut m1 = Matrix::from_vec(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
395 let m2 = Matrix::from_vec(vec![vec![5.0, 6.0], vec![7.0, 8.0]]);
396 m1 += &m2;
397 assert_eq!(m1.get(0, 0), 6.0);
398 assert_eq!(m1.get(0, 1), 8.0);
399 assert_eq!(m1.get(1, 0), 10.0);
400 assert_eq!(m1.get(1, 1), 12.0);
401 }
402
403 #[test]
404 fn it_multiplies_by_scalar() {
405 let m = Matrix::from_vec(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
406 let result = m * 2.0;
407 assert_eq!(result.get(0, 0), 2.0);
408 assert_eq!(result.get(0, 1), 4.0);
409 assert_eq!(result.get(1, 0), 6.0);
410 assert_eq!(result.get(1, 1), 8.0);
411 }
412
413 #[test]
414 fn it_multiplies_by_scalar_in_place() {
415 let mut m = Matrix::from_vec(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
416 m *= 2.0;
417 assert_eq!(m.get(0, 0), 2.0);
418 assert_eq!(m.get(0, 1), 4.0);
419 assert_eq!(m.get(1, 0), 6.0);
420 assert_eq!(m.get(1, 1), 8.0);
421 }
422
423 #[test]
424 fn it_maps() {
425 let m = Matrix::from_vec(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
426 let result = m.map(|x| x * 2.0);
427 assert_eq!(result.get(0, 0), 2.0);
428 assert_eq!(result.get(0, 1), 4.0);
429 assert_eq!(result.get(1, 0), 6.0);
430 assert_eq!(result.get(1, 1), 8.0);
431 }
432
433 #[test]
434 fn it_maps_mut() {
435 let mut m = Matrix::from_vec(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
436 m.map_mut(|x| x * 2.0);
437 assert_eq!(m.get(0, 0), 2.0);
438 assert_eq!(m.get(0, 1), 4.0);
439 assert_eq!(m.get(1, 0), 6.0);
440 assert_eq!(m.get(1, 1), 8.0);
441 }
442}
443
444fn sigmoid(x: &Matrix) -> Matrix {
445 x.map(|x| 1.0 / (1.0 + (-x).exp()))
446}
447
448fn sigmoid_derivative(x: &Matrix) -> Matrix {
449 x.map(|x| x * (1.0 - x))
450}
451
452fn tanh(x: &Matrix) -> Matrix {
453 x.map(|x| x.tanh())
454}
455
456fn tanh_derivative(x: &Matrix) -> Matrix {
457 x.map(|x| 1.0 - x.tanh().powi(2))
458}
459
460fn linear(x: &Matrix) -> Matrix {
461 x.clone()
462}
463
464fn linear_derivative(x: &Matrix) -> Matrix {
465 x.map(|_| 1.0)
466}
467
468#[derive(Clone, Debug)]
470pub struct NeuralNetwork {
471 weights_input_hidden: Matrix,
472 weights_hidden_output: Matrix,
473 biases_hidden: Matrix,
474 biases_output: Matrix,
475 learning_rate: f64,
476 activation_function: fn(&Matrix) -> Matrix,
477 activation_function_derivative: fn(&Matrix) -> Matrix,
478}
479
480impl NeuralNetwork {
481 pub fn new(input_size: usize, hidden_size: usize, output_size: usize) -> Self {
484 NeuralNetwork {
485 weights_input_hidden: Matrix::random(hidden_size, input_size),
486 weights_hidden_output: Matrix::random(output_size, hidden_size),
487 biases_hidden: Matrix::random(hidden_size, 1),
488 biases_output: Matrix::random(output_size, 1),
489 learning_rate: 0.01,
490 activation_function: sigmoid,
491 activation_function_derivative: sigmoid_derivative,
492 }
493 }
494
495 pub fn set_learning_rate(&mut self, learning_rate: f64) {
497 self.learning_rate = learning_rate;
498 }
499
500 pub fn set_activation_function(
502 &mut self,
503 activation_function: fn(&Matrix) -> Matrix,
504 activation_function_derivative: fn(&Matrix) -> Matrix,
505 ) {
506 self.activation_function = activation_function;
507 self.activation_function_derivative = activation_function_derivative;
508 }
509
510 pub fn set_linear_activation(&mut self) {
511 self.activation_function = linear;
512 self.activation_function_derivative = linear_derivative;
513 }
514
515 pub fn set_sigmoid_activation(&mut self) {
516 self.activation_function = sigmoid;
517 self.activation_function_derivative = sigmoid_derivative;
518 }
519
520 pub fn set_tanh_activation(&mut self) {
521 self.activation_function = tanh;
522 self.activation_function_derivative = tanh_derivative;
523 }
524
525 pub fn predict(&self, input: Vec<f64>) -> Vec<f64> {
527 let input_matrix = Matrix::from_col_vec(input);
529 let hidden_layer_input = &self.weights_input_hidden * &input_matrix + &self.biases_hidden;
530 let hidden_layer_output = (self.activation_function)(&hidden_layer_input);
531 let output_layer_input = &self.weights_hidden_output * &hidden_layer_output + &self.biases_output;
533 let output_layer_output = (self.activation_function)(&output_layer_input);
534 output_layer_output.col(0)
536 }
537
538 pub fn train(&mut self, input: Vec<f64>, target: Vec<f64>) {
542 let input = Matrix::from_col_vec(input);
544 let hidden_layer_input = &self.weights_input_hidden * &input + &self.biases_hidden;
545 let hidden_layer_output = (self.activation_function)(&hidden_layer_input);
546
547 let output_layer_input = &self.weights_hidden_output * &hidden_layer_output + &self.biases_output;
549 let output_layer_output = (self.activation_function)(&output_layer_input);
550
551 let target = Matrix::from_col_vec(target);
553
554 let output_errors = target - &output_layer_output;
557
558 let mut gradients = (self.activation_function_derivative)(&output_layer_output);
560 gradients.hadamar_product(&output_errors);
561 gradients *= self.learning_rate;
562
563 let hidden_transposed = hidden_layer_output.transpose();
565 let weight_hidden_output_deltas = &gradients * &hidden_transposed;
566
567 self.weights_hidden_output += &weight_hidden_output_deltas;
569 self.biases_output += &gradients;
571
572 let weight_hidden_output_transposed = self.weights_hidden_output.transpose();
574 let hidden_errors = &weight_hidden_output_transposed * &output_errors;
575
576 let mut hidden_gradient = (self.activation_function_derivative)(&hidden_layer_output);
578 hidden_gradient.hadamar_product(&hidden_errors);
579 hidden_gradient *= self.learning_rate;
580
581 let inputs_transposed = input.transpose();
583 let weight_input_hidden_deltas = &hidden_gradient * &inputs_transposed;
584
585 self.weights_input_hidden += &weight_input_hidden_deltas;
586 self.biases_hidden += &hidden_gradient;
588 }
589}
590
591pub mod nn_tests {
592 #[test]
593 fn it_creates_a_neural_network() {
594 let m = super::NeuralNetwork::new(3, 5, 2);
595 assert_eq!(m.weights_input_hidden.rows(), 5);
596 assert_eq!(m.weights_input_hidden.cols(), 3);
597 assert_eq!(m.weights_hidden_output.rows(), 2);
598 assert_eq!(m.weights_hidden_output.cols(), 5);
599 assert_eq!(m.biases_hidden.rows(), 5);
600 assert_eq!(m.biases_hidden.cols(), 1);
601 assert_eq!(m.biases_output.rows(), 2);
602 assert_eq!(m.biases_output.cols(), 1);
603 }
604
605 #[test]
606 pub fn it_predicts() {
607 let m = super::NeuralNetwork::new(3, 5, 2);
608 let input = vec![0.5, 0.2, 0.1];
609 let output = m.predict(input.clone());
610 assert_eq!(output.len(), 2);
611 assert_ne!(output[0], output[1]);
612 }
613}