matrix_oxide/
matrix.rs

1use crate::random;
2use std::fmt::Debug;
3use std::ops::{Add, Div, Mul, Sub};
4
5/// MxN Matrix
6pub struct Matrix<T> {
7    pub data: Vec<T>,
8    pub row_size: usize,
9    pub col_size: usize,
10}
11impl<T: Default + Clone> Matrix<T> {
12    /// Construct a new *non-empty* and *sized* `Matrix`
13    pub fn new(row_size: usize, col_size: usize) -> Self {
14        Matrix {
15            data: vec![T::default(); row_size * col_size],
16            row_size,
17            col_size,
18        }
19    }
20
21    /// Construct a new *non-empty* and *sized* `Matrix` with random values of `T`
22    pub fn new_random(row_size: usize, col_size: usize) -> Self
23    where
24        T: random::Random,
25    {
26        let random_data: Vec<T> = random::gen_rand_vec(row_size * col_size);
27        Matrix {
28            data: random_data,
29            row_size,
30            col_size,
31        }
32    }
33
34    /// Try to get a reference to the value at a given row and column from the matrix
35    pub fn get(&self, row: usize, col: usize) -> Option<&T> {
36        if row < self.row_size && col < self.col_size {
37            Some(&self.data[col + row * self.col_size])
38        } else {
39            None
40        }
41    }
42
43    /// Get a vector of the diagonal elements of the matrix
44    pub fn get_diagonal(&self) -> Vec<T> {
45        (0..self.col_size)
46            .filter_map(|col_idx| self.get(col_idx, col_idx).cloned())
47            .collect()
48    }
49
50    /// Try to get a mutable reference to the value at a given row and column from the matrix
51    pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut T> {
52        if row < self.row_size && col < self.col_size {
53            Some(&mut self.data[col + row * self.col_size])
54        } else {
55            None
56        }
57    }
58
59    /// Try to set a value at a given row and column in the matrix
60    pub fn set(&mut self, row: usize, column: usize, value: T) -> bool {
61        if let Some(cell) = self.get_mut(row, column) {
62            *cell = value;
63            true
64        } else {
65            false
66        }
67    }
68
69    /// Try to get all the values for a given column
70    ///
71    /// NOTE: If you pass a column value larger than the number of columns
72    /// this function will return None.
73    pub fn try_get_column(&self, column: usize) -> Option<Vec<T>> {
74        // Bounds check
75        if column >= self.col_size {
76            return None;
77        }
78
79        // Iterate over all the rows grabbing a specific column each time
80        let col_data: Vec<T> = (0..self.row_size)
81            .map(|row| self.data[row * self.col_size + column].clone())
82            .collect();
83
84        Some(col_data)
85    }
86
87    /// Try to get all the values for a given row
88    ///
89    /// NOTE: If you pass a row value larger than the number of rows
90    /// this function will return None.
91    pub fn try_get_row(&self, row: usize) -> Option<Vec<T>> {
92        // Bounds check
93        if row >= self.row_size {
94            return None;
95        }
96
97        // Iterate over all the rows grabbing a specific column each time
98        let row_data: Vec<T> = (0..self.col_size)
99            .map(|col| self.data[row * self.col_size + col].clone())
100            .collect();
101
102        Some(row_data)
103    }
104
105    /// Create a `Matrix` from a columns (vec of vec)
106    pub fn from_columns(cols: Vec<Vec<T>>) -> Matrix<T> {
107        if cols.is_empty() {
108            return Matrix {
109                data: Vec::new(),
110                row_size: 0,
111                col_size: 0,
112            };
113        }
114
115        let row_size = cols[0].len();
116        let col_size = cols.len();
117
118        let data = (0..row_size)
119            .flat_map(|row| cols.iter().filter_map(move |col| col.get(row).cloned()))
120            .collect();
121
122        Matrix {
123            data,
124            row_size,
125            col_size,
126        }
127    }
128
129    /// Create a sub matrix with a specific row and column to exclude
130    pub fn sub_matrix(&self, skip_row: usize, skip_col: usize) -> Matrix<T> {
131        let columns: Vec<Vec<T>> = (0..self.col_size)
132            .filter_map(|col| {
133                if col != skip_col {
134                    Some(
135                        self.try_get_column(col)?
136                            .into_iter()
137                            .enumerate()
138                            .filter_map(|(row, val)| if row != skip_row { Some(val) } else { None })
139                            .collect(),
140                    )
141                } else {
142                    None
143                }
144            })
145            .collect();
146
147        Matrix::from_columns(columns)
148    }
149
150    /// Perform a transpose operation (swap rows for columns and vice versa)
151    pub fn transpose(&self) -> Matrix<T> {
152        Matrix {
153            data: (0..self.col_size)
154                .flat_map(|col| {
155                    (0..self.row_size).map(move |row| self.data[row * self.col_size + col].clone())
156                })
157                .collect(),
158
159            row_size: self.col_size,
160            col_size: self.row_size,
161        }
162    }
163}
164
165impl<T: Default + Clone> Default for Matrix<T> {
166    /// Create a default `Matrix` instance
167    fn default() -> Self {
168        Self::new(2, 3)
169    }
170}
171impl<T: Default + Clone + Debug> Add for Matrix<T>
172where
173    T: Add<Output = T> + Clone,
174{
175    type Output = Matrix<T>;
176
177    /// Matrix addition
178    /// NOTE: the matrices you add MUST have the same dimensionality
179    fn add(self, rhs: Self) -> Matrix<T> {
180        let data: Vec<T> = (0..self.row_size)
181            .flat_map(|row| {
182                let row_a = self.try_get_row(row).expect("Invalid row in self");
183                let row_b = rhs.try_get_row(row).expect("Invalid row in rhs");
184                row_a.into_iter().zip(row_b).map(|(a, b)| a + b)
185            })
186            .collect();
187
188        Matrix {
189            data,
190            col_size: self.col_size,
191            row_size: self.row_size,
192        }
193    }
194}
195impl<T> Matrix<T>
196where
197    T: Default + Clone + Add<Output = T>,
198{
199    /// Perform the trace operation that computes the sum of all diagonal
200    /// elements in the matrix.
201    ///
202    /// NOTE: off-diagnonal elements do NOT contribute to the trace of the
203    /// matrix, so 2 very different matrices can have the same trace.
204    pub fn trace(&self) -> T {
205        self.get_diagonal()
206            .into_iter()
207            .fold(T::default(), |acc, diagonal| acc + diagonal)
208    }
209}
210impl<T: Default + Clone + Debug> Sub for Matrix<T>
211where
212    T: Sub<Output = T> + Clone,
213{
214    type Output = Matrix<T>;
215
216    /// Subtract a matrix by another matrix
217    /// NOTE: the matrix you subtract by MUST have the same dimensionality
218    fn sub(self, rhs: Self) -> Matrix<T> {
219        let data: Vec<T> = (0..self.row_size)
220            .flat_map(|row| {
221                let row_a = self.try_get_row(row).expect("Invalid row in self");
222                let row_b = rhs.try_get_row(row).expect("Invalid row in rhs");
223                row_a.into_iter().zip(row_b).map(|(a, b)| a - b)
224            })
225            .collect();
226
227        Matrix {
228            data,
229            row_size: self.row_size,
230            col_size: self.col_size,
231        }
232    }
233}
234impl<T: Default> Matrix<T>
235where
236    T: Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Clone,
237{
238    /// Multiply a matrix by a single number (scalar)
239    /// NOTE: The scalar type MUST match the matrix type.
240    pub fn scalar_multiply(&self, scalar: T) -> Matrix<T> {
241        let data = self
242            .data
243            .iter()
244            .map(|value| value.clone() * scalar.clone())
245            .collect();
246
247        Matrix {
248            data,
249            row_size: self.row_size,
250            col_size: self.col_size,
251        }
252    }
253
254    /// Multiply `Matrix` with another `Matrix` using standard matrix multiplication
255    /// NOTE: The matrices inner dimensions MUST match else returns None
256    pub fn multiply(&self, multiplier: &Matrix<T>) -> Option<Matrix<T>> {
257        // Validity check for the matrices inner dimensions
258        if self.col_size != multiplier.row_size {
259            return None;
260        }
261
262        let data: Vec<T> = (0..self.row_size)
263            .flat_map(|i| {
264                (0..multiplier.col_size).map(move |j| {
265                    (0..self.col_size)
266                        .map(|k| {
267                            self.data[i * self.col_size + k].clone()
268                                * multiplier.data[k * multiplier.col_size + j].clone()
269                        })
270                        .fold(T::default(), |acc, x| acc + x)
271                })
272            })
273            .collect();
274
275        Some(Matrix {
276            data,
277            col_size: multiplier.col_size,
278            row_size: self.row_size,
279        })
280    }
281
282    /// Multiply the `Matrix` by a vector
283    /// NOTE: The vectors length MUST match the vector columns, else returns None
284    pub fn vector_multiply(&self, multiplier: &[T]) -> Option<Vec<T>> {
285        // Validity check that the `Matrix` column size matches the vector column size
286        if self.col_size != multiplier.len() {
287            return None;
288        }
289
290        let data: Vec<T> = (0..self.row_size)
291            .map(|i| {
292                (0..multiplier.len())
293                    .map(|j| self.data[i * self.col_size + j].clone() * multiplier[j].clone())
294                    .fold(T::default(), |acc, x| acc + x)
295            })
296            .collect();
297
298        Some(data)
299    }
300
301    /// Compute a unique determinant for a `Matrix`
302    /// NOTE: Only computable for square (M x M) matrices.
303    /// NOTE: The determinant is 0 for a `Matrix` with rank r < M (non-invertable).
304    pub fn determinant(&self) -> Option<T> {
305        // Validity check that it's a square matrix
306        if self.col_size != self.row_size {
307            return None;
308        }
309
310        // Base case for recursion: 1x1 matrix
311        if self.row_size == 1 {
312            return Some(self.data[0].clone());
313        }
314
315        if let Some(first_row) = self.try_get_row(0) {
316            let determinant =
317                first_row
318                    .iter()
319                    .enumerate()
320                    .fold(T::default(), |acc, (col_idx, item)| {
321                        let sub_matrix = self.sub_matrix(0, col_idx);
322
323                        if let Some(sub_determinant) = sub_matrix.determinant() {
324                            // Alternate between addition and subtraction
325                            // operations every iteration
326                            if col_idx % 2 == 0 {
327                                acc + item.clone() * sub_determinant
328                            } else {
329                                acc - item.clone() * sub_determinant
330                            }
331                        } else {
332                            T::default()
333                        }
334                    });
335
336            Some(determinant)
337        } else {
338            None
339        }
340    }
341}
342impl<T> Matrix<T>
343where
344    T: Default + Clone + From<f64>,
345{
346    /// Create an identity matrix of given size
347    pub fn identity(size: usize) -> Matrix<T> {
348        let data = (0..size * size)
349            .map(|i| {
350                if i % (size + 1) == 0 {
351                    T::from(1.0)
352                } else {
353                    T::default()
354                }
355            })
356            .collect();
357
358        Matrix {
359            data,
360            row_size: size,
361            col_size: size,
362        }
363    }
364}
365impl<T> Matrix<T>
366where
367    T: Default + Clone + Mul<Output = T> + Add<Output = T> + Into<f64>,
368{
369    /// Compute the frobenius norm of a `Matrix`
370    pub fn frobenius_norm(&self) -> f64 {
371        let sum_of_squares: f64 = self
372            .data
373            .iter()
374            .map(|val| {
375                let val_f64: f64 = val.clone().into();
376                val_f64 * val_f64
377            })
378            .fold(f64::default(), |acc, x| acc + x);
379
380        sum_of_squares.sqrt()
381    }
382}
383impl<T> Matrix<T>
384where
385    T: Copy
386        + PartialOrd
387        + Default
388        + From<f64>
389        + Into<f64>
390        + Sub<Output = T>
391        + Add<Output = T>
392        + Mul<Output = T>
393        + Div<Output = T>
394        + Abs,
395{
396    /// Compute the inverse of a `Matrix`
397    /// NOTE: Only computable for square (M x M) matrices.
398    /// NOTE: Only computable for a `Matrix` where r = M (full rank).
399    pub fn inverse(&self) -> Option<Matrix<T>> {
400        // Validity check that it's a square matrix
401        if self.row_size != self.col_size {
402            return None;
403        }
404
405        let n = self.row_size;
406        let epsilon = T::from(1e-10);
407
408        // Regular matrix
409        let mut a = self.data.clone();
410        // Inverted matrix
411        let mut b = Matrix::<T>::identity(n).data;
412
413        for i in 0..n {
414            let pivot = (i..n).fold(i, |acc, j| {
415                match a[j * n + i].abs() > a[acc * n + i].abs() {
416                    true => j,
417                    false => acc,
418                }
419            });
420
421            // Validity check that it's not a singular matrix
422            if a[pivot * n + i].abs() < epsilon {
423                return None;
424            }
425
426            if pivot != i {
427                (0..n).for_each(|k| {
428                    a.swap(i * n + k, pivot * n + k);
429                    b.swap(i * n + k, pivot * n + k);
430                });
431            }
432
433            let pivot_val = a[i * n + i];
434            (0..n).for_each(|k| {
435                a[i * n + k] = a[i * n + k] / pivot_val;
436                b[i * n + k] = b[i * n + k] / pivot_val;
437            });
438
439            (0..n).filter(|&j| j != i).for_each(|j| {
440                let factor = a[j * n + i];
441                (0..n).for_each(|k| {
442                    a[j * n + k] = a[j * n + k] - factor * a[i * n + k];
443                    b[j * n + k] = b[j * n + k] - factor * b[i * n + k];
444                })
445            })
446        }
447
448        Some(Matrix {
449            data: b,
450            row_size: n,
451            col_size: n,
452        })
453    }
454}
455
456pub trait Abs {
457    fn abs(self) -> Self;
458}
459impl Abs for f64 {
460    fn abs(self) -> Self {
461        self.abs()
462    }
463}
464impl Abs for i32 {
465    fn abs(self) -> Self {
466        self.abs()
467    }
468}
469impl Abs for i64 {
470    fn abs(self) -> Self {
471        self.abs()
472    }
473}
474impl Abs for f32 {
475    fn abs(self) -> Self {
476        self.abs()
477    }
478}
479
480#[cfg(test)]
481mod tests {
482    use super::*;
483
484    /// Check if 2 float value's are *ABOUT* equal
485    fn approx_equal(a: &[f64], b: &[f64], epsilon: f64) -> bool {
486        a.iter()
487            .zip(b.iter())
488            .all(|(&a, &b)| (a - b).abs() < epsilon)
489    }
490
491    #[test]
492    fn test_matrix_vector_multiplication() {
493        let matrix = Matrix::<i32> {
494            data: vec![1, 2, 3, 4],
495            row_size: 2,
496            col_size: 2,
497        };
498        let vector = vec![5, 6];
499
500        let expected = vec![17, 39];
501        let result = matrix.vector_multiply(&vector).unwrap();
502        assert_eq!(result, expected);
503    }
504
505    #[test]
506    fn test_matrix_multiplication() {
507        let matrix_a = Matrix::<i32> {
508            data: vec![1, 2, 3, 4],
509            row_size: 2,
510            col_size: 2,
511        };
512
513        let matrix_b = Matrix::<i32> {
514            data: vec![2, 0, 1, 2],
515            row_size: 2,
516            col_size: 2,
517        };
518
519        let expected = Matrix::<i32> {
520            data: vec![4, 4, 10, 8],
521            row_size: 2,
522            col_size: 2,
523        };
524        let result = matrix_a.multiply(&matrix_b).unwrap();
525        assert_eq!(result.data, expected.data);
526    }
527
528    #[test]
529    fn test_martrix_trace() {
530        let matrix = Matrix::<i32> {
531            data: vec![1, 2, 3, 4],
532            col_size: 2,
533            row_size: 2,
534        };
535
536        let expected: i32 = 5;
537        let result = matrix.trace();
538        assert_eq!(result, expected);
539    }
540
541    #[test]
542    fn test_martrix_diagonal() {
543        let matrix = Matrix::<i32> {
544            data: vec![1, 2, 3, 4],
545            row_size: 2,
546            col_size: 2,
547        };
548
549        let expected: Vec<i32> = vec![1, 4];
550        let result = matrix.get_diagonal();
551        assert_eq!(result, expected);
552    }
553
554    #[test]
555    fn test_matrix_scalar_multiplication() {
556        let matrix = Matrix::<i32> {
557            data: vec![1, 2, 3, 4],
558            row_size: 2,
559            col_size: 2,
560        };
561
562        let expected = Matrix::<i32> {
563            data: vec![2, 4, 6, 8],
564            row_size: 2,
565            col_size: 2,
566        };
567        let result = matrix.scalar_multiply(2);
568
569        assert_eq!(result.data, expected.data);
570    }
571
572    #[test]
573    fn test_matrix_subtraction() {
574        let matrix_a = Matrix::<i32> {
575            data: vec![1, 2, 3, 4, 5, 6],
576            row_size: 2,
577            col_size: 3,
578        };
579        let matrix_b = Matrix::<i32> {
580            data: vec![6, 5, 4, 3, 2, 1],
581            row_size: 2,
582            col_size: 3,
583        };
584
585        let expected = Matrix::<i32> {
586            data: vec![-5, -3, -1, 1, 3, 5],
587            row_size: 2,
588            col_size: 3,
589        };
590        let result = matrix_a - matrix_b;
591        assert_eq!(result.data, expected.data);
592    }
593
594    #[test]
595    fn test_matrix_addition() {
596        let matrix_a = Matrix::<i32> {
597            data: vec![1, 2, 3, 4],
598            row_size: 2,
599            col_size: 2,
600        };
601        let matrix_b = Matrix::<i32> {
602            data: vec![4, 3, 2, 1],
603            row_size: 2,
604            col_size: 2,
605        };
606
607        let expected = Matrix::<i32> {
608            data: vec![5, 5, 5, 5],
609            row_size: 2,
610            col_size: 2,
611        };
612
613        let result = matrix_a + matrix_b;
614        assert_eq!(result.data, expected.data);
615    }
616
617    #[test]
618    fn new_matrix_has_correct_size() {
619        let matrix: Matrix<i32> = Matrix::new(2, 3);
620        assert_eq!(matrix.data.len(), 6);
621    }
622
623    #[test]
624    fn default_matrix_is_same_as_new() {
625        let default_matrix: Matrix<i32> = Matrix::default();
626        let new_matrix: Matrix<i32> = Matrix::new(2, 3);
627        assert_eq!(default_matrix.data, new_matrix.data);
628    }
629
630    #[test]
631    fn try_get_column_valid() {
632        let matrix: Matrix<i32> = Matrix::new(2, 3);
633        let column = matrix.try_get_column(1);
634        assert!(column.is_some());
635        assert_eq!(column.unwrap(), vec![0, 0]);
636    }
637
638    #[test]
639    fn try_get_column_invalid() {
640        let matrix: Matrix<i32> = Matrix::new(2, 3);
641        let column = matrix.try_get_column(3);
642        assert!(column.is_none());
643    }
644
645    #[test]
646    fn try_get_row_valid() {
647        let matrix: Matrix<i32> = Matrix::new(2, 3);
648        let row = matrix.try_get_row(0);
649        assert!(row.is_some());
650        assert_eq!(row.unwrap(), vec![0, 0, 0]);
651    }
652
653    #[test]
654    fn try_get_row_invalid() {
655        let matrix: Matrix<i32> = Matrix::new(2, 3);
656        let row = matrix.try_get_row(2);
657        assert!(row.is_none());
658    }
659
660    #[test]
661    fn transpose_works_correctly() {
662        let mut matrix: Matrix<i32> = Matrix::new(2, 3);
663        for i in 0..matrix.data.len() {
664            matrix.data[i] = i as i32;
665        }
666        let transposed = matrix.transpose();
667        assert_eq!(transposed.data, vec![0, 3, 1, 4, 2, 5]);
668    }
669
670    #[test]
671    fn determinant_of_1x1_matrix() {
672        let matrix = Matrix {
673            data: vec![7],
674            row_size: 1,
675            col_size: 1,
676        };
677        assert_eq!(matrix.determinant(), Some(7));
678    }
679
680    #[test]
681    fn determinant_of_2x2_matrix() {
682        let matrix = Matrix {
683            data: vec![1, 2, 3, 4],
684            row_size: 2,
685            col_size: 2,
686        };
687        assert_eq!(matrix.determinant(), Some(-2));
688    }
689
690    #[test]
691    fn determinant_of_3x3_matrix() {
692        let matrix = Matrix {
693            data: vec![3, 2, 1, 0, 1, 4, 5, 6, 0],
694            row_size: 3,
695            col_size: 3,
696        };
697        let expected = -37;
698        assert_eq!(matrix.determinant(), Some(expected));
699    }
700
701    #[test]
702    fn determinant_non_square_matrix() {
703        let matrix = Matrix {
704            data: vec![1, 2, 3, 4, 5, 6],
705            row_size: 2,
706            col_size: 3,
707        };
708        assert_eq!(matrix.determinant(), None);
709    }
710
711    #[test]
712    fn test_frobenius_norm_i32() {
713        let matrix = Matrix::<i32> {
714            data: vec![1, 2, 3, 4],
715            row_size: 2,
716            col_size: 2,
717        };
718
719        let expected: f64 = 5.477225575051661;
720        let result = matrix.frobenius_norm();
721        assert!((result - expected).abs() < 1e-10);
722    }
723
724    #[test]
725    fn test_frobenius_norm_f32() {
726        let matrix = Matrix::<f32> {
727            data: vec![1.0, 2.0, 3.0, 4.0],
728            row_size: 2,
729            col_size: 2,
730        };
731
732        let expected: f64 = 5.477225575051661;
733        let result = matrix.frobenius_norm();
734        assert!((result - expected).abs() < 1e-10);
735    }
736
737    #[test]
738    fn test_inverse_2x2() {
739        let matrix = Matrix {
740            data: vec![4.0, 7.0, 2.0, 6.0],
741            row_size: 2,
742            col_size: 2,
743        };
744
745        let expected_inverse = vec![0.6, -0.7, -0.2, 0.4];
746        let result = matrix.inverse().unwrap();
747
748        assert!(approx_equal(&result.data, &expected_inverse, 1e-6));
749    }
750
751    #[test]
752    fn test_inverse_3x3() {
753        let matrix = Matrix {
754            data: vec![1.0, 2.0, 3.0, 0.0, 1.0, 4.0, 5.0, 6.0, 0.0],
755            row_size: 3,
756            col_size: 3,
757        };
758
759        let expected_inverse = vec![-24.0, 18.0, 5.0, 20.0, -15.0, -4.0, -5.0, 4.0, 1.0];
760        let result = matrix.inverse().unwrap();
761
762        assert!(approx_equal(&result.data, &expected_inverse, 1e-6));
763    }
764
765    #[test]
766    fn test_inverse_identity() {
767        let matrix = Matrix::<f64>::identity(3);
768        let result = matrix.inverse().unwrap();
769
770        assert_eq!(result.data, matrix.data);
771    }
772
773    #[test]
774    fn test_inverse_singular() {
775        let matrix = Matrix {
776            data: vec![1.0, 2.0, 2.0, 4.0],
777            row_size: 2,
778            col_size: 2,
779        };
780
781        let result = matrix.inverse();
782
783        assert!(result.is_none());
784    }
785
786    #[test]
787    fn test_inverse_1x1() {
788        let matrix = Matrix {
789            data: vec![2.0],
790            row_size: 1,
791            col_size: 1,
792        };
793
794        let expected_inverse = vec![0.5];
795        let result = matrix.inverse().unwrap();
796
797        assert!(approx_equal(&result.data, &expected_inverse, 1e-6));
798    }
799
800    #[test]
801    /// Test that a one by one matrix only has one value
802    fn test_random_1x1() {
803        let result: Matrix<f64> = Matrix::new_random(1, 1);
804        assert_eq!(result.data.len(), 1);
805    }
806
807    #[test]
808    /// Test that a two by three matrix has six values
809    fn test_random_2x3() {
810        let result: Matrix<i64> = Matrix::new_random(2, 3);
811        assert_eq!(result.data.len(), 6);
812    }
813
814    #[test]
815    /// Test that a twenty by twenty matrix has four hundred values
816    fn test_random_20x20() {
817        let result: Matrix<i64> = Matrix::new_random(20, 20);
818        println!("{:?}", result.data);
819        assert_eq!(result.data.len(), 400);
820    }
821}