neural_network_study/
matrix.rs

1use rand::{Rng, rngs::StdRng};
2use rayon::prelude::*;
3use serde::{Deserialize, Serialize};
4use std::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign};
5
6/// A simple 2-dimensional matrix with basic operations
7#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
8pub struct Matrix {
9    rows: usize,
10    cols: usize,
11    data: Vec<f64>,
12}
13
14impl Matrix {
15    /// Creates a new matrix with the given number of rows and columns,
16    /// initialized to zero.
17    pub fn new(rows: usize, cols: usize) -> Self {
18        let data = vec![0.0; rows * cols];
19        Self { rows, cols, data }
20    }
21
22    /// Creates a new matrix with the given number of rows and columns,
23    /// initialized with random values between -1.0 and 1.0.
24    pub fn random(rng: &mut StdRng, rows: usize, cols: usize) -> Self {
25        let data = (0..(rows * cols))
26            .map(|_| rng.random_range(-1.0..1.0))
27            .collect();
28        Self { rows, cols, data }
29    }
30
31    /// Creates a new matrix from a 2D vector.
32    /// The outer vector represents the rows, and the inner vectors represent the columns.
33    /// Panics if the inner vectors have different lengths.
34    pub fn from_vec(rows: usize, cols: usize, data: Vec<f64>) -> Self {
35        if data.len() != rows * cols {
36            panic!("data length does not match row and col count")
37        }
38        Self { rows, cols, data }
39    }
40
41    /// Creates a new matrix from a column vector.
42    pub fn from_col_vec(data: Vec<f64>) -> Self {
43        let rows = data.len();
44        let cols = 1;
45        Self::from_vec(rows, cols, data)
46    }
47
48    /// Transposes the matrix.
49    pub fn transpose(&self) -> Self {
50        let mut transposed_data = vec![0.0; self.rows * self.cols];
51        for i in 0..self.rows {
52            for j in 0..self.cols {
53                transposed_data[j * self.rows + i] = self.data[i * self.cols + j];
54            }
55        }
56        Self::from_vec(self.cols, self.rows, transposed_data)
57    }
58
59    /// Returns the number of rows in the matrix.
60    pub fn rows(&self) -> usize {
61        self.rows
62    }
63
64    /// Returns the number of columns in the matrix.
65    pub fn cols(&self) -> usize {
66        self.cols
67    }
68
69    /// Returns the column at the given index as a vector.
70    /// Panics if the index is out of bounds.
71    pub fn col(&self, col: usize) -> Vec<f64> {
72        if col >= self.cols {
73            panic!("Index out of bounds");
74        }
75        (0..self.rows)
76            .map(|i| self.data[i * self.cols + col])
77            .collect()
78    }
79
80    /// Returns a reference to the data in the matrix.
81    pub fn data(&self) -> &Vec<f64> {
82        &self.data
83    }
84
85    /// Returns a mutable reference to the data in the matrix.
86    pub fn data_mut(&mut self) -> &mut Vec<f64> {
87        &mut self.data
88    }
89
90    /// Returns the value at the given row and column.
91    /// Panics if the indices are out of bounds.
92    pub fn get(&self, row: usize, col: usize) -> f64 {
93        if row >= self.rows || col >= self.cols {
94            panic!("Index out of bounds");
95        }
96        self.data[row * self.cols + col]
97    }
98
99    /// Returns a mutable reference to the value at the given row and column.
100    /// Panics if the indices are out of bounds.
101    pub fn get_mut(&mut self, row: usize, col: usize) -> &mut f64 {
102        if row >= self.rows || col >= self.cols {
103            panic!("Index out of bounds");
104        }
105        &mut self.data[row * self.cols + col]
106    }
107
108    /// Sets the value at the given row and column.
109    /// Panics if the indices are out of bounds.
110    pub fn set(&mut self, row: usize, col: usize, value: f64) {
111        if row >= self.rows || col >= self.cols {
112            panic!("Index out of bounds");
113        }
114        self.data[row * self.cols + col] = value;
115    }
116
117    pub fn apply<F>(&mut self, f: F)
118    where
119        F: Fn(f64) -> f64,
120    {
121        for i in 0..self.rows {
122            for j in 0..self.cols {
123                let index = i * self.cols + j;
124                self.data[index] = f(self.data[index]);
125            }
126        }
127    }
128
129    pub fn hadamard_product(&mut self, other: &Matrix) {
130        if self.rows != other.rows || self.cols != other.cols {
131            panic!("Matrices must have the same dimensions for Hadamard product");
132        }
133        for i in 0..self.rows {
134            for j in 0..self.cols {
135                self.set(i, j, self.get(i, j) * other.get(i, j));
136            }
137        }
138    }
139
140    fn multiply_matrix_parallelized(&self, other: &Matrix) -> Matrix {
141        if self.cols != other.rows {
142            panic!("Matrices have incompatible dimensions for multiplication");
143        }
144
145        let other_t = Arc::new(other.transpose()); // wrap in Arc
146        let self_data = &self.data;
147        let other_data = &other_t.data;
148        let self_cols = self.cols;
149        let other_cols = other.cols;
150
151        let result_data: Vec<f64> = (0..self.rows)
152            .into_par_iter()
153            .flat_map_iter(|i| {
154                (0..other_t.rows).map(move |j| {
155                    let mut sum = 0.0;
156                    let row_start = i * self_cols;
157                    let col_start = j * self_cols;
158                    for k in 0..self_cols {
159                        sum += self_data[row_start + k] * other_data[col_start + k];
160                    }
161                    sum
162                })
163            })
164            .collect();
165
166        Matrix::from_vec(self.rows, other_cols, result_data)
167    }
168
169    fn multiply_matrix_naive(&self, other: &Matrix) -> Matrix {
170        if self.cols != other.rows {
171            panic!("Matrices have incompatible dimensions for multiplication");
172        }
173
174        let other_t = other.transpose(); // for better cache locality
175        let mut result = Matrix::new(self.rows, other.cols);
176
177        let self_data = &self.data;
178        let other_data = &other_t.data;
179        let result_data = &mut result.data;
180
181        let m = self.rows;
182        let n = self.cols;
183        let p = other.cols;
184
185        for i in 0..m {
186            for j in 0..p {
187                let mut sum = 0.0;
188                let a_row = i * n;
189                let b_row = j * n; // because other_t has shape [p x n]
190                for k in 0..n {
191                    sum += self_data[a_row + k] * other_data[b_row + k];
192                }
193                result_data[i * p + j] = sum;
194            }
195        }
196
197        result
198    }
199
200    pub fn multiply_matrix(&self, other: &Matrix) -> Matrix {
201        if self.rows * other.cols >= 128 * 128 {
202            self.multiply_matrix_parallelized(other)
203        } else {
204            self.multiply_matrix_naive(other)
205        }
206    }
207}
208
209impl Add<&Matrix> for Matrix {
210    type Output = Matrix;
211
212    /// Adds two matrices together, component-wise.
213    /// Panics if the matrices have different dimensions.
214    fn add(self, other: &Matrix) -> Matrix {
215        if self.rows != other.rows || self.cols != other.cols {
216            panic!("Matrices must have the same dimensions to be added");
217        }
218        let mut result = Matrix::new(self.rows, self.cols);
219        for i in 0..self.rows {
220            for j in 0..self.cols {
221                result.set(i, j, self.get(i, j) + other.get(i, j));
222            }
223        }
224        result
225    }
226}
227
228impl AddAssign<&Matrix> for Matrix {
229    /// Adds another matrix to this matrix, component-wise.
230    /// Panics if the matrices have different dimensions.
231    /// This is an in-place operation.
232    fn add_assign(&mut self, other: &Matrix) {
233        if self.rows != other.rows || self.cols != other.cols {
234            panic!("Matrices must have the same dimensions to be added");
235        }
236        for i in 0..self.rows {
237            for j in 0..self.cols {
238                self.set(i, j, self.get(i, j) + other.get(i, j));
239            }
240        }
241    }
242}
243
244impl Sub<&Matrix> for Matrix {
245    type Output = Matrix;
246
247    /// Subtracts another matrix from this matrix, component-wise.
248    /// Panics if the matrices have different dimensions.
249    fn sub(self, rhs: &Matrix) -> Self::Output {
250        if self.rows != rhs.rows || self.cols != rhs.cols {
251            panic!("Matrices must have the same dimensions to be subtracted");
252        }
253        let mut result = Matrix::new(self.rows, self.cols);
254        for i in 0..self.rows {
255            for j in 0..self.cols {
256                result.set(i, j, self.get(i, j) - rhs.get(i, j));
257            }
258        }
259        result
260    }
261}
262
263impl SubAssign<&Matrix> for Matrix {
264    fn sub_assign(&mut self, other: &Matrix) {
265        if self.rows != other.rows || self.cols != other.cols {
266            panic!("Matrices must have the same dimensions to be added");
267        }
268        for i in 0..self.rows {
269            for j in 0..self.cols {
270                self.set(i, j, self.get(i, j) - other.get(i, j));
271            }
272        }
273    }
274}
275
276impl Mul<f64> for Matrix {
277    type Output = Matrix;
278
279    /// Multiplies the matrix by a scalar.
280    fn mul(self, scalar: f64) -> Matrix {
281        let mut result = Matrix::new(self.rows, self.cols);
282        for i in 0..self.rows {
283            for j in 0..self.cols {
284                result.data[i * self.cols + j] = self.data[i * self.cols + j] * scalar;
285            }
286        }
287        result
288    }
289}
290
291impl MulAssign<f64> for Matrix {
292    /// Multiplies the matrix by a scalar in-place.
293    fn mul_assign(&mut self, scalar: f64) {
294        for i in 0..self.rows {
295            for j in 0..self.cols {
296                self.data[i * self.cols + j] *= scalar;
297            }
298        }
299    }
300}
301
302use std::sync::Arc;
303
304impl Mul<&Matrix> for &Matrix {
305    type Output = Matrix;
306
307    fn mul(self, other: &Matrix) -> Matrix {
308        self.multiply_matrix(other)
309    }
310}
311
312#[cfg(test)]
313mod matrix_tests {
314    use rand::SeedableRng;
315
316    use super::*;
317
318    #[test]
319    fn it_works() {
320        let m = Matrix::new(2, 3);
321        assert_eq!(m.rows(), 2);
322        assert_eq!(m.cols(), 3);
323        assert_eq!(m.data().len(), 2 * 3);
324    }
325
326    #[test]
327    fn it_creates_random_matrix() {
328        let mut rng = StdRng::from_os_rng();
329        let m = Matrix::random(&mut rng, 2, 3);
330        assert_eq!(m.rows, 2);
331        assert_eq!(m.cols, 3);
332        assert_eq!(m.data.len(), 2 * 3);
333        for i in 0..2 {
334            for j in 0..3 {
335                assert!(m.get(i, j) >= -1.0 && m.get(i, j) <= 1.0);
336            }
337        }
338    }
339
340    #[test]
341    fn it_creates_a_matrix_from_a_vector() {
342        let v = vec![1.0, 2.0, 5.0, 3.0, 4.0, 6.0];
343        let m = Matrix::from_vec(2, 3, v.clone());
344        assert_eq!(m.rows, 2);
345        assert_eq!(m.cols, 3);
346        assert_eq!(m.data, v);
347    }
348
349    #[test]
350    fn it_transposes_matrix() {
351        let m = Matrix::from_vec(
352            3,
353            2,
354            vec![
355                /* row 0 */ 1.0, 2.0, /* row 1 */ 5.0, 3.0, /* row 2 */ 4.0, 6.0,
356            ],
357        );
358        let transposed = m.transpose();
359        assert_eq!(transposed.rows, 2);
360        assert_eq!(transposed.cols, 3);
361        assert_eq!(transposed.get(0, 0), 1.0);
362        assert_eq!(transposed.get(0, 1), 5.0);
363        assert_eq!(transposed.get(0, 2), 4.0);
364        assert_eq!(transposed.get(1, 0), 2.0);
365        assert_eq!(transposed.get(1, 1), 3.0);
366        assert_eq!(transposed.get(1, 2), 6.0);
367    }
368
369    #[test]
370    fn it_gets_and_sets_values() {
371        let mut m = Matrix::new(2, 3);
372        m.set(0, 0, 1.0);
373        m.set(1, 2, 2.0);
374        assert_eq!(m.get(0, 0), 1.0);
375        assert_eq!(m.get(1, 2), 2.0);
376        assert_eq!(m.get(0, 1), 0.0);
377        assert_eq!(m.get(1, 0), 0.0);
378    }
379
380    #[test]
381    #[should_panic(expected = "Index out of bounds")]
382    fn it_panics_on_out_of_bounds_get() {
383        let m = Matrix::new(2, 3);
384        m.get(2, 0);
385    }
386
387    #[test]
388    #[should_panic(expected = "Index out of bounds")]
389    fn it_panics_on_out_of_bounds_set() {
390        let mut m = Matrix::new(2, 3);
391        m.set(2, 0, 1.0);
392    }
393
394    #[test]
395    #[should_panic(expected = "Index out of bounds")]
396    fn it_panics_on_out_of_bounds_get_mut() {
397        let mut m = Matrix::new(2, 3);
398        m.get_mut(2, 0);
399    }
400
401    #[test]
402    #[should_panic(expected = "Index out of bounds")]
403    fn it_panics_on_out_of_bounds_set_mut() {
404        let mut m = Matrix::new(2, 3);
405        m.get_mut(2, 0);
406    }
407
408    #[test]
409    fn it_gets_and_sets_mutable_values() {
410        let mut m = Matrix::new(2, 3);
411        *m.get_mut(0, 0) = 1.0;
412        *m.get_mut(1, 2) = 2.0;
413        assert_eq!(m.get(0, 0), 1.0);
414        assert_eq!(m.get(1, 2), 2.0);
415        assert_eq!(m.get(0, 1), 0.0);
416        assert_eq!(m.get(1, 0), 0.0);
417    }
418
419    #[test]
420    fn it_returns_mutable_data() {
421        let mut m = Matrix::new(2, 3);
422        m.data_mut()[0] = 1.0;
423        m.data_mut()[1 * 3 + 2] = 2.0;
424        assert_eq!(m.get(0, 0), 1.0);
425        assert_eq!(m.get(1, 2), 2.0);
426        assert_eq!(m.get(0, 1), 0.0);
427        assert_eq!(m.get(1, 0), 0.0);
428    }
429
430    #[test]
431    fn it_adds_matrices() {
432        let m1 = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
433        let m2 = Matrix::from_vec(2, 2, vec![5.0, 6.0, 7.0, 8.0]);
434        let result = m1 + &m2;
435        assert_eq!(result.get(0, 0), 6.0);
436        assert_eq!(result.get(0, 1), 8.0);
437        assert_eq!(result.get(1, 0), 10.0);
438        assert_eq!(result.get(1, 1), 12.0);
439    }
440
441    #[test]
442    fn it_adds_and_assigns() {
443        let mut m1 = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
444        let m2 = Matrix::from_vec(2, 2, vec![5.0, 6.0, 7.0, 8.0]);
445        m1 += &m2;
446        assert_eq!(m1.get(0, 0), 6.0);
447        assert_eq!(m1.get(0, 1), 8.0);
448        assert_eq!(m1.get(1, 0), 10.0);
449        assert_eq!(m1.get(1, 1), 12.0);
450    }
451
452    #[test]
453    fn it_multiplies_by_scalar() {
454        let m = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
455        let result = m * 2.0;
456        assert_eq!(result.get(0, 0), 2.0);
457        assert_eq!(result.get(0, 1), 4.0);
458        assert_eq!(result.get(1, 0), 6.0);
459        assert_eq!(result.get(1, 1), 8.0);
460    }
461
462    #[test]
463    fn it_multiplies_by_scalar_in_place() {
464        let mut m = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
465        m *= 2.0;
466        assert_eq!(m.get(0, 0), 2.0);
467        assert_eq!(m.get(0, 1), 4.0);
468        assert_eq!(m.get(1, 0), 6.0);
469        assert_eq!(m.get(1, 1), 8.0);
470    }
471
472    #[test]
473    fn it_multiplies_matrices() {
474        let m = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
475        let n = Matrix::from_vec(3, 2, vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]);
476        let e = Matrix::from_vec(2, 2, vec![58.0, 64.0, 139.0, 154.0]);
477        let r = &m * &n;
478        assert_eq!(r, e);
479    }
480
481    #[test]
482    fn it_maps() {
483        let mut m = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
484        m.apply(|x| x * 2.0);
485        assert_eq!(m.get(0, 0), 2.0);
486        assert_eq!(m.get(0, 1), 4.0);
487        assert_eq!(m.get(1, 0), 6.0);
488        assert_eq!(m.get(1, 1), 8.0);
489    }
490}