Skip to main content

neural_network_study/
matrix.rs

1use rand::{Rng, rngs::StdRng};
2use rayon::prelude::*;
3use serde::{Deserialize, Serialize};
4use std::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign};
5
6/// A simple 2-dimensional matrix with basic operations
7#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
8pub struct Matrix {
9    rows: usize,
10    cols: usize,
11    data: Vec<f64>,
12}
13
14impl Matrix {
15    /// Creates a new matrix with the given number of rows and columns,
16    /// initialized to zero.
17    pub fn new(rows: usize, cols: usize) -> Self {
18        let data = vec![0.0; rows * cols];
19        Self { rows, cols, data }
20    }
21
22    /// Creates a new matrix with the given number of rows and columns,
23    /// initialized with random values between -1.0 and 1.0.
24    pub fn random(rng: &mut StdRng, rows: usize, cols: usize) -> Self {
25        Self::random_range(rng, rows, cols, -1.0, 1.0)
26    }
27
28    /// Creates a new matrix with the given number of rows and columns,
29    /// initialized with random values in the provided half-open range.
30    pub fn random_range(rng: &mut StdRng, rows: usize, cols: usize, min: f64, max: f64) -> Self {
31        let data = (0..(rows * cols))
32            .map(|_| rng.random_range(min..max))
33            .collect();
34        Self { rows, cols, data }
35    }
36
37    /// Creates a new matrix from a 2D vector.
38    /// The outer vector represents the rows, and the inner vectors represent the columns.
39    /// Panics if the inner vectors have different lengths.
40    pub fn from_vec(rows: usize, cols: usize, data: Vec<f64>) -> Self {
41        if data.len() != rows * cols {
42            panic!("data length does not match row and col count")
43        }
44        Self { rows, cols, data }
45    }
46
47    /// Creates a new matrix from a column vector.
48    pub fn from_col_vec(data: Vec<f64>) -> Self {
49        let rows = data.len();
50        let cols = 1;
51        Self::from_vec(rows, cols, data)
52    }
53
54    /// Transposes the matrix.
55    pub fn transpose(&self) -> Self {
56        let mut transposed_data = vec![0.0; self.rows * self.cols];
57        for i in 0..self.rows {
58            for j in 0..self.cols {
59                transposed_data[j * self.rows + i] = self.data[i * self.cols + j];
60            }
61        }
62        Self::from_vec(self.cols, self.rows, transposed_data)
63    }
64
65    /// Returns the number of rows in the matrix.
66    pub fn rows(&self) -> usize {
67        self.rows
68    }
69
70    /// Returns the number of columns in the matrix.
71    pub fn cols(&self) -> usize {
72        self.cols
73    }
74
75    /// Returns the column at the given index as a vector.
76    /// Panics if the index is out of bounds.
77    pub fn col(&self, col: usize) -> Vec<f64> {
78        if col >= self.cols {
79            panic!("Index out of bounds");
80        }
81        (0..self.rows)
82            .map(|i| self.data[i * self.cols + col])
83            .collect()
84    }
85
86    /// Returns a reference to the data in the matrix.
87    pub fn data(&self) -> &Vec<f64> {
88        &self.data
89    }
90
91    /// Returns a mutable reference to the data in the matrix.
92    pub fn data_mut(&mut self) -> &mut Vec<f64> {
93        &mut self.data
94    }
95
96    /// Returns the value at the given row and column.
97    /// Panics if the indices are out of bounds.
98    pub fn get(&self, row: usize, col: usize) -> f64 {
99        if row >= self.rows || col >= self.cols {
100            panic!("Index out of bounds");
101        }
102        self.data[row * self.cols + col]
103    }
104
105    /// Returns a mutable reference to the value at the given row and column.
106    /// Panics if the indices are out of bounds.
107    pub fn get_mut(&mut self, row: usize, col: usize) -> &mut f64 {
108        if row >= self.rows || col >= self.cols {
109            panic!("Index out of bounds");
110        }
111        &mut self.data[row * self.cols + col]
112    }
113
114    /// Sets the value at the given row and column.
115    /// Panics if the indices are out of bounds.
116    pub fn set(&mut self, row: usize, col: usize, value: f64) {
117        if row >= self.rows || col >= self.cols {
118            panic!("Index out of bounds");
119        }
120        self.data[row * self.cols + col] = value;
121    }
122
123    pub fn apply<F>(&mut self, f: F)
124    where
125        F: Fn(f64) -> f64,
126    {
127        for i in 0..self.rows {
128            for j in 0..self.cols {
129                let index = i * self.cols + j;
130                self.data[index] = f(self.data[index]);
131            }
132        }
133    }
134
135    pub fn hadamard_product(&mut self, other: &Matrix) {
136        if self.rows != other.rows || self.cols != other.cols {
137            panic!("Matrices must have the same dimensions for Hadamard product");
138        }
139        for i in 0..self.rows {
140            for j in 0..self.cols {
141                self.set(i, j, self.get(i, j) * other.get(i, j));
142            }
143        }
144    }
145
146    fn multiply_matrix_parallelized(&self, other: &Matrix) -> Matrix {
147        if self.cols != other.rows {
148            panic!("Matrices have incompatible dimensions for multiplication");
149        }
150
151        let other_t = Arc::new(other.transpose()); // wrap in Arc
152        let self_data = &self.data;
153        let other_data = &other_t.data;
154        let self_cols = self.cols;
155        let other_cols = other.cols;
156
157        let result_data: Vec<f64> = (0..self.rows)
158            .into_par_iter()
159            .flat_map_iter(|i| {
160                (0..other_t.rows).map(move |j| {
161                    let mut sum = 0.0;
162                    let row_start = i * self_cols;
163                    let col_start = j * self_cols;
164                    for k in 0..self_cols {
165                        sum += self_data[row_start + k] * other_data[col_start + k];
166                    }
167                    sum
168                })
169            })
170            .collect();
171
172        Matrix::from_vec(self.rows, other_cols, result_data)
173    }
174
175    fn multiply_matrix_naive(&self, other: &Matrix) -> Matrix {
176        if self.cols != other.rows {
177            panic!("Matrices have incompatible dimensions for multiplication");
178        }
179
180        let other_t = other.transpose(); // for better cache locality
181        let mut result = Matrix::new(self.rows, other.cols);
182
183        let self_data = &self.data;
184        let other_data = &other_t.data;
185        let result_data = &mut result.data;
186
187        let m = self.rows;
188        let n = self.cols;
189        let p = other.cols;
190
191        for i in 0..m {
192            for j in 0..p {
193                let mut sum = 0.0;
194                let a_row = i * n;
195                let b_row = j * n; // because other_t has shape [p x n]
196                for k in 0..n {
197                    sum += self_data[a_row + k] * other_data[b_row + k];
198                }
199                result_data[i * p + j] = sum;
200            }
201        }
202
203        result
204    }
205
206    pub fn multiply_matrix(&self, other: &Matrix) -> Matrix {
207        if self.rows * other.cols >= 128 * 128 {
208            self.multiply_matrix_parallelized(other)
209        } else {
210            self.multiply_matrix_naive(other)
211        }
212    }
213}
214
215impl Add<&Matrix> for Matrix {
216    type Output = Matrix;
217
218    /// Adds two matrices together, component-wise.
219    /// Panics if the matrices have different dimensions.
220    fn add(self, other: &Matrix) -> Matrix {
221        if self.rows != other.rows || self.cols != other.cols {
222            panic!("Matrices must have the same dimensions to be added");
223        }
224        let mut result = Matrix::new(self.rows, self.cols);
225        for i in 0..self.rows {
226            for j in 0..self.cols {
227                result.set(i, j, self.get(i, j) + other.get(i, j));
228            }
229        }
230        result
231    }
232}
233
234impl AddAssign<&Matrix> for Matrix {
235    /// Adds another matrix to this matrix, component-wise.
236    /// Panics if the matrices have different dimensions.
237    /// This is an in-place operation.
238    fn add_assign(&mut self, other: &Matrix) {
239        if self.rows != other.rows || self.cols != other.cols {
240            panic!("Matrices must have the same dimensions to be added");
241        }
242        for i in 0..self.rows {
243            for j in 0..self.cols {
244                self.set(i, j, self.get(i, j) + other.get(i, j));
245            }
246        }
247    }
248}
249
250impl Sub<&Matrix> for Matrix {
251    type Output = Matrix;
252
253    /// Subtracts another matrix from this matrix, component-wise.
254    /// Panics if the matrices have different dimensions.
255    fn sub(self, rhs: &Matrix) -> Self::Output {
256        if self.rows != rhs.rows || self.cols != rhs.cols {
257            panic!("Matrices must have the same dimensions to be subtracted");
258        }
259        let mut result = Matrix::new(self.rows, self.cols);
260        for i in 0..self.rows {
261            for j in 0..self.cols {
262                result.set(i, j, self.get(i, j) - rhs.get(i, j));
263            }
264        }
265        result
266    }
267}
268
269impl SubAssign<&Matrix> for Matrix {
270    fn sub_assign(&mut self, other: &Matrix) {
271        if self.rows != other.rows || self.cols != other.cols {
272            panic!("Matrices must have the same dimensions to be added");
273        }
274        for i in 0..self.rows {
275            for j in 0..self.cols {
276                self.set(i, j, self.get(i, j) - other.get(i, j));
277            }
278        }
279    }
280}
281
282impl Mul<f64> for Matrix {
283    type Output = Matrix;
284
285    /// Multiplies the matrix by a scalar.
286    fn mul(self, scalar: f64) -> Matrix {
287        let mut result = Matrix::new(self.rows, self.cols);
288        for i in 0..self.rows {
289            for j in 0..self.cols {
290                result.data[i * self.cols + j] = self.data[i * self.cols + j] * scalar;
291            }
292        }
293        result
294    }
295}
296
297impl MulAssign<f64> for Matrix {
298    /// Multiplies the matrix by a scalar in-place.
299    fn mul_assign(&mut self, scalar: f64) {
300        for i in 0..self.rows {
301            for j in 0..self.cols {
302                self.data[i * self.cols + j] *= scalar;
303            }
304        }
305    }
306}
307
308use std::sync::Arc;
309
310impl Mul<&Matrix> for &Matrix {
311    type Output = Matrix;
312
313    fn mul(self, other: &Matrix) -> Matrix {
314        self.multiply_matrix(other)
315    }
316}
317
318#[cfg(test)]
319mod matrix_tests {
320    use rand::SeedableRng;
321
322    use super::*;
323
324    #[test]
325    fn it_works() {
326        let m = Matrix::new(2, 3);
327        assert_eq!(m.rows(), 2);
328        assert_eq!(m.cols(), 3);
329        assert_eq!(m.data().len(), 2 * 3);
330    }
331
332    #[test]
333    fn it_creates_random_matrix() {
334        let mut rng = StdRng::from_os_rng();
335        let m = Matrix::random(&mut rng, 2, 3);
336        assert_eq!(m.rows, 2);
337        assert_eq!(m.cols, 3);
338        assert_eq!(m.data.len(), 2 * 3);
339        for i in 0..2 {
340            for j in 0..3 {
341                assert!(m.get(i, j) >= -1.0 && m.get(i, j) <= 1.0);
342            }
343        }
344    }
345
346    #[test]
347    fn it_creates_a_matrix_from_a_vector() {
348        let v = vec![1.0, 2.0, 5.0, 3.0, 4.0, 6.0];
349        let m = Matrix::from_vec(2, 3, v.clone());
350        assert_eq!(m.rows, 2);
351        assert_eq!(m.cols, 3);
352        assert_eq!(m.data, v);
353    }
354
355    #[test]
356    fn it_transposes_matrix() {
357        let m = Matrix::from_vec(
358            3,
359            2,
360            vec![
361                /* row 0 */ 1.0, 2.0, /* row 1 */ 5.0, 3.0, /* row 2 */ 4.0, 6.0,
362            ],
363        );
364        let transposed = m.transpose();
365        assert_eq!(transposed.rows, 2);
366        assert_eq!(transposed.cols, 3);
367        assert_eq!(transposed.get(0, 0), 1.0);
368        assert_eq!(transposed.get(0, 1), 5.0);
369        assert_eq!(transposed.get(0, 2), 4.0);
370        assert_eq!(transposed.get(1, 0), 2.0);
371        assert_eq!(transposed.get(1, 1), 3.0);
372        assert_eq!(transposed.get(1, 2), 6.0);
373    }
374
375    #[test]
376    fn it_gets_and_sets_values() {
377        let mut m = Matrix::new(2, 3);
378        m.set(0, 0, 1.0);
379        m.set(1, 2, 2.0);
380        assert_eq!(m.get(0, 0), 1.0);
381        assert_eq!(m.get(1, 2), 2.0);
382        assert_eq!(m.get(0, 1), 0.0);
383        assert_eq!(m.get(1, 0), 0.0);
384    }
385
386    #[test]
387    #[should_panic(expected = "Index out of bounds")]
388    fn it_panics_on_out_of_bounds_get() {
389        let m = Matrix::new(2, 3);
390        m.get(2, 0);
391    }
392
393    #[test]
394    #[should_panic(expected = "Index out of bounds")]
395    fn it_panics_on_out_of_bounds_set() {
396        let mut m = Matrix::new(2, 3);
397        m.set(2, 0, 1.0);
398    }
399
400    #[test]
401    #[should_panic(expected = "Index out of bounds")]
402    fn it_panics_on_out_of_bounds_get_mut() {
403        let mut m = Matrix::new(2, 3);
404        m.get_mut(2, 0);
405    }
406
407    #[test]
408    #[should_panic(expected = "Index out of bounds")]
409    fn it_panics_on_out_of_bounds_set_mut() {
410        let mut m = Matrix::new(2, 3);
411        m.get_mut(2, 0);
412    }
413
414    #[test]
415    fn it_gets_and_sets_mutable_values() {
416        let mut m = Matrix::new(2, 3);
417        *m.get_mut(0, 0) = 1.0;
418        *m.get_mut(1, 2) = 2.0;
419        assert_eq!(m.get(0, 0), 1.0);
420        assert_eq!(m.get(1, 2), 2.0);
421        assert_eq!(m.get(0, 1), 0.0);
422        assert_eq!(m.get(1, 0), 0.0);
423    }
424
425    #[test]
426    fn it_returns_mutable_data() {
427        let mut m = Matrix::new(2, 3);
428        m.data_mut()[0] = 1.0;
429        m.data_mut()[1 * 3 + 2] = 2.0;
430        assert_eq!(m.get(0, 0), 1.0);
431        assert_eq!(m.get(1, 2), 2.0);
432        assert_eq!(m.get(0, 1), 0.0);
433        assert_eq!(m.get(1, 0), 0.0);
434    }
435
436    #[test]
437    fn it_adds_matrices() {
438        let m1 = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
439        let m2 = Matrix::from_vec(2, 2, vec![5.0, 6.0, 7.0, 8.0]);
440        let result = m1 + &m2;
441        assert_eq!(result.get(0, 0), 6.0);
442        assert_eq!(result.get(0, 1), 8.0);
443        assert_eq!(result.get(1, 0), 10.0);
444        assert_eq!(result.get(1, 1), 12.0);
445    }
446
447    #[test]
448    fn it_adds_and_assigns() {
449        let mut m1 = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
450        let m2 = Matrix::from_vec(2, 2, vec![5.0, 6.0, 7.0, 8.0]);
451        m1 += &m2;
452        assert_eq!(m1.get(0, 0), 6.0);
453        assert_eq!(m1.get(0, 1), 8.0);
454        assert_eq!(m1.get(1, 0), 10.0);
455        assert_eq!(m1.get(1, 1), 12.0);
456    }
457
458    #[test]
459    fn it_multiplies_by_scalar() {
460        let m = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
461        let result = m * 2.0;
462        assert_eq!(result.get(0, 0), 2.0);
463        assert_eq!(result.get(0, 1), 4.0);
464        assert_eq!(result.get(1, 0), 6.0);
465        assert_eq!(result.get(1, 1), 8.0);
466    }
467
468    #[test]
469    fn it_multiplies_by_scalar_in_place() {
470        let mut m = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
471        m *= 2.0;
472        assert_eq!(m.get(0, 0), 2.0);
473        assert_eq!(m.get(0, 1), 4.0);
474        assert_eq!(m.get(1, 0), 6.0);
475        assert_eq!(m.get(1, 1), 8.0);
476    }
477
478    #[test]
479    fn it_multiplies_matrices() {
480        let m = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
481        let n = Matrix::from_vec(3, 2, vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]);
482        let e = Matrix::from_vec(2, 2, vec![58.0, 64.0, 139.0, 154.0]);
483        let r = &m * &n;
484        assert_eq!(r, e);
485    }
486
487    #[test]
488    fn it_maps() {
489        let mut m = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
490        m.apply(|x| x * 2.0);
491        assert_eq!(m.get(0, 0), 2.0);
492        assert_eq!(m.get(0, 1), 4.0);
493        assert_eq!(m.get(1, 0), 6.0);
494        assert_eq!(m.get(1, 1), 8.0);
495    }
496}