1use crate::random;
2use std::fmt::Debug;
3use std::ops::{Add, Div, Mul, Sub};
4
5pub struct Matrix<T> {
7 pub data: Vec<T>,
8 pub row_size: usize,
9 pub col_size: usize,
10}
11impl<T: Default + Clone> Matrix<T> {
12 pub fn new(row_size: usize, col_size: usize) -> Self {
14 Matrix {
15 data: vec![T::default(); row_size * col_size],
16 row_size,
17 col_size,
18 }
19 }
20
21 pub fn new_random(row_size: usize, col_size: usize) -> Self
23 where
24 T: random::Random,
25 {
26 let random_data: Vec<T> = random::gen_rand_vec(row_size * col_size);
27 Matrix {
28 data: random_data,
29 row_size,
30 col_size,
31 }
32 }
33
34 pub fn get(&self, row: usize, col: usize) -> Option<&T> {
36 if row < self.row_size && col < self.col_size {
37 Some(&self.data[col + row * self.col_size])
38 } else {
39 None
40 }
41 }
42
43 pub fn get_diagonal(&self) -> Vec<T> {
45 (0..self.col_size)
46 .filter_map(|col_idx| self.get(col_idx, col_idx).cloned())
47 .collect()
48 }
49
50 pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut T> {
52 if row < self.row_size && col < self.col_size {
53 Some(&mut self.data[col + row * self.col_size])
54 } else {
55 None
56 }
57 }
58
59 pub fn set(&mut self, row: usize, column: usize, value: T) -> bool {
61 if let Some(cell) = self.get_mut(row, column) {
62 *cell = value;
63 true
64 } else {
65 false
66 }
67 }
68
69 pub fn try_get_column(&self, column: usize) -> Option<Vec<T>> {
74 if column >= self.col_size {
76 return None;
77 }
78
79 let col_data: Vec<T> = (0..self.row_size)
81 .map(|row| self.data[row * self.col_size + column].clone())
82 .collect();
83
84 Some(col_data)
85 }
86
87 pub fn try_get_row(&self, row: usize) -> Option<Vec<T>> {
92 if row >= self.row_size {
94 return None;
95 }
96
97 let row_data: Vec<T> = (0..self.col_size)
99 .map(|col| self.data[row * self.col_size + col].clone())
100 .collect();
101
102 Some(row_data)
103 }
104
105 pub fn from_columns(cols: Vec<Vec<T>>) -> Matrix<T> {
107 if cols.is_empty() {
108 return Matrix {
109 data: Vec::new(),
110 row_size: 0,
111 col_size: 0,
112 };
113 }
114
115 let row_size = cols[0].len();
116 let col_size = cols.len();
117
118 let data = (0..row_size)
119 .flat_map(|row| cols.iter().filter_map(move |col| col.get(row).cloned()))
120 .collect();
121
122 Matrix {
123 data,
124 row_size,
125 col_size,
126 }
127 }
128
129 pub fn sub_matrix(&self, skip_row: usize, skip_col: usize) -> Matrix<T> {
131 let columns: Vec<Vec<T>> = (0..self.col_size)
132 .filter_map(|col| {
133 if col != skip_col {
134 Some(
135 self.try_get_column(col)?
136 .into_iter()
137 .enumerate()
138 .filter_map(|(row, val)| if row != skip_row { Some(val) } else { None })
139 .collect(),
140 )
141 } else {
142 None
143 }
144 })
145 .collect();
146
147 Matrix::from_columns(columns)
148 }
149
150 pub fn transpose(&self) -> Matrix<T> {
152 Matrix {
153 data: (0..self.col_size)
154 .flat_map(|col| {
155 (0..self.row_size).map(move |row| self.data[row * self.col_size + col].clone())
156 })
157 .collect(),
158
159 row_size: self.col_size,
160 col_size: self.row_size,
161 }
162 }
163}
164
165impl<T: Default + Clone> Default for Matrix<T> {
166 fn default() -> Self {
168 Self::new(2, 3)
169 }
170}
171impl<T: Default + Clone + Debug> Add for Matrix<T>
172where
173 T: Add<Output = T> + Clone,
174{
175 type Output = Matrix<T>;
176
177 fn add(self, rhs: Self) -> Matrix<T> {
180 let data: Vec<T> = (0..self.row_size)
181 .flat_map(|row| {
182 let row_a = self.try_get_row(row).expect("Invalid row in self");
183 let row_b = rhs.try_get_row(row).expect("Invalid row in rhs");
184 row_a.into_iter().zip(row_b).map(|(a, b)| a + b)
185 })
186 .collect();
187
188 Matrix {
189 data,
190 col_size: self.col_size,
191 row_size: self.row_size,
192 }
193 }
194}
195impl<T> Matrix<T>
196where
197 T: Default + Clone + Add<Output = T>,
198{
199 pub fn trace(&self) -> T {
205 self.get_diagonal()
206 .into_iter()
207 .fold(T::default(), |acc, diagonal| acc + diagonal)
208 }
209
210 pub fn sum(&self) -> T {
212 self.data
213 .iter()
214 .cloned()
215 .fold(T::default(), |acc, x| acc + x)
216 }
217}
218impl<T: Default + Clone + Debug> Sub for Matrix<T>
219where
220 T: Sub<Output = T> + Clone,
221{
222 type Output = Matrix<T>;
223
224 fn sub(self, rhs: Self) -> Matrix<T> {
227 let data: Vec<T> = (0..self.row_size)
228 .flat_map(|row| {
229 let row_a = self.try_get_row(row).expect("Invalid row in self");
230 let row_b = rhs.try_get_row(row).expect("Invalid row in rhs");
231 row_a.into_iter().zip(row_b).map(|(a, b)| a - b)
232 })
233 .collect();
234
235 Matrix {
236 data,
237 row_size: self.row_size,
238 col_size: self.col_size,
239 }
240 }
241}
242impl<T: Default> Matrix<T>
243where
244 T: Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Clone,
245{
246 pub fn scalar_multiply(&self, scalar: T) -> Matrix<T> {
249 let data = self
250 .data
251 .iter()
252 .map(|value| value.clone() * scalar.clone())
253 .collect();
254
255 Matrix {
256 data,
257 row_size: self.row_size,
258 col_size: self.col_size,
259 }
260 }
261
262 pub fn multiply(&self, multiplier: &Matrix<T>) -> Option<Matrix<T>> {
265 if self.col_size != multiplier.row_size {
267 return None;
268 }
269
270 let data: Vec<T> = (0..self.row_size)
271 .flat_map(|i| {
272 (0..multiplier.col_size).map(move |j| {
273 (0..self.col_size)
274 .map(|k| {
275 self.data[i * self.col_size + k].clone()
276 * multiplier.data[k * multiplier.col_size + j].clone()
277 })
278 .fold(T::default(), |acc, x| acc + x)
279 })
280 })
281 .collect();
282
283 Some(Matrix {
284 data,
285 col_size: multiplier.col_size,
286 row_size: self.row_size,
287 })
288 }
289
290 pub fn vector_multiply(&self, multiplier: &[T]) -> Option<Vec<T>> {
293 if self.col_size != multiplier.len() {
295 return None;
296 }
297
298 let data: Vec<T> = (0..self.row_size)
299 .map(|i| {
300 (0..multiplier.len())
301 .map(|j| self.data[i * self.col_size + j].clone() * multiplier[j].clone())
302 .fold(T::default(), |acc, x| acc + x)
303 })
304 .collect();
305
306 Some(data)
307 }
308
309 pub fn determinant(&self) -> Option<T> {
313 if self.col_size != self.row_size {
315 return None;
316 }
317
318 if self.row_size == 1 {
320 return Some(self.data[0].clone());
321 }
322
323 if let Some(first_row) = self.try_get_row(0) {
324 let determinant =
325 first_row
326 .iter()
327 .enumerate()
328 .fold(T::default(), |acc, (col_idx, item)| {
329 let sub_matrix = self.sub_matrix(0, col_idx);
330
331 if let Some(sub_determinant) = sub_matrix.determinant() {
332 if col_idx % 2 == 0 {
335 acc + item.clone() * sub_determinant
336 } else {
337 acc - item.clone() * sub_determinant
338 }
339 } else {
340 T::default()
341 }
342 });
343
344 Some(determinant)
345 } else {
346 None
347 }
348 }
349}
350impl<T> Matrix<T>
351where
352 T: Default + Clone + From<f64>,
353{
354 pub fn identity(size: usize) -> Matrix<T> {
356 let data = (0..size * size)
357 .map(|i| {
358 if i % (size + 1) == 0 {
359 T::from(1.0)
360 } else {
361 T::default()
362 }
363 })
364 .collect();
365
366 Matrix {
367 data,
368 row_size: size,
369 col_size: size,
370 }
371 }
372}
373impl<T> Matrix<T>
374where
375 T: Default + Clone + Mul<Output = T> + Add<Output = T> + Into<f64>,
376{
377 pub fn frobenius_norm(&self) -> f64 {
379 let sum_of_squares: f64 = self
380 .data
381 .iter()
382 .map(|val| {
383 let val_f64: f64 = val.clone().into();
384 val_f64 * val_f64
385 })
386 .fold(f64::default(), |acc, x| acc + x);
387
388 sum_of_squares.sqrt()
389 }
390}
391impl<T> Matrix<T>
392where
393 T: Copy
394 + PartialOrd
395 + Default
396 + From<f64>
397 + Into<f64>
398 + Sub<Output = T>
399 + Add<Output = T>
400 + Mul<Output = T>
401 + Div<Output = T>
402 + Abs,
403{
404 pub fn inverse(&self) -> Option<Matrix<T>> {
408 if self.row_size != self.col_size {
410 return None;
411 }
412
413 let n = self.row_size;
414 let epsilon = T::from(1e-10);
415
416 let mut a = self.data.clone();
418 let mut b = Matrix::<T>::identity(n).data;
420
421 for i in 0..n {
422 let pivot = (i..n).fold(i, |acc, j| {
423 match a[j * n + i].abs() > a[acc * n + i].abs() {
424 true => j,
425 false => acc,
426 }
427 });
428
429 if a[pivot * n + i].abs() < epsilon {
431 return None;
432 }
433
434 if pivot != i {
435 (0..n).for_each(|k| {
436 a.swap(i * n + k, pivot * n + k);
437 b.swap(i * n + k, pivot * n + k);
438 });
439 }
440
441 let pivot_val = a[i * n + i];
442 (0..n).for_each(|k| {
443 a[i * n + k] = a[i * n + k] / pivot_val;
444 b[i * n + k] = b[i * n + k] / pivot_val;
445 });
446
447 (0..n).filter(|&j| j != i).for_each(|j| {
448 let factor = a[j * n + i];
449 (0..n).for_each(|k| {
450 a[j * n + k] = a[j * n + k] - factor * a[i * n + k];
451 b[j * n + k] = b[j * n + k] - factor * b[i * n + k];
452 })
453 })
454 }
455
456 Some(Matrix {
457 data: b,
458 row_size: n,
459 col_size: n,
460 })
461 }
462}
463
464pub trait Abs {
465 fn abs(self) -> Self;
466}
467impl Abs for f64 {
468 fn abs(self) -> Self {
469 self.abs()
470 }
471}
472impl Abs for i32 {
473 fn abs(self) -> Self {
474 self.abs()
475 }
476}
477impl Abs for i64 {
478 fn abs(self) -> Self {
479 self.abs()
480 }
481}
482impl Abs for f32 {
483 fn abs(self) -> Self {
484 self.abs()
485 }
486}
487
488#[cfg(test)]
489mod tests {
490 use super::*;
491
492 fn approx_equal(a: &[f64], b: &[f64], epsilon: f64) -> bool {
494 a.iter()
495 .zip(b.iter())
496 .all(|(&a, &b)| (a - b).abs() < epsilon)
497 }
498
499 #[test]
500 fn test_matrix_vector_multiplication() {
501 let matrix = Matrix::<i32> {
502 data: vec![1, 2, 3, 4],
503 row_size: 2,
504 col_size: 2,
505 };
506 let vector = vec![5, 6];
507
508 let expected = vec![17, 39];
509 let result = matrix.vector_multiply(&vector).unwrap();
510 assert_eq!(result, expected);
511 }
512
513 #[test]
514 fn test_matrix_multiplication() {
515 let matrix_a = Matrix::<i32> {
516 data: vec![1, 2, 3, 4],
517 row_size: 2,
518 col_size: 2,
519 };
520
521 let matrix_b = Matrix::<i32> {
522 data: vec![2, 0, 1, 2],
523 row_size: 2,
524 col_size: 2,
525 };
526
527 let expected = Matrix::<i32> {
528 data: vec![4, 4, 10, 8],
529 row_size: 2,
530 col_size: 2,
531 };
532 let result = matrix_a.multiply(&matrix_b).unwrap();
533 assert_eq!(result.data, expected.data);
534 }
535
536 #[test]
537 fn test_martrix_trace() {
538 let matrix = Matrix::<i32> {
539 data: vec![1, 2, 3, 4],
540 col_size: 2,
541 row_size: 2,
542 };
543
544 let expected: i32 = 5;
545 let result = matrix.trace();
546 assert_eq!(result, expected);
547 }
548
549 #[test]
550 fn test_martrix_diagonal() {
551 let matrix = Matrix::<i32> {
552 data: vec![1, 2, 3, 4],
553 row_size: 2,
554 col_size: 2,
555 };
556
557 let expected: Vec<i32> = vec![1, 4];
558 let result = matrix.get_diagonal();
559 assert_eq!(result, expected);
560 }
561
562 #[test]
563 fn test_matrix_scalar_multiplication() {
564 let matrix = Matrix::<i32> {
565 data: vec![1, 2, 3, 4],
566 row_size: 2,
567 col_size: 2,
568 };
569
570 let expected = Matrix::<i32> {
571 data: vec![2, 4, 6, 8],
572 row_size: 2,
573 col_size: 2,
574 };
575 let result = matrix.scalar_multiply(2);
576
577 assert_eq!(result.data, expected.data);
578 }
579
580 #[test]
581 fn test_matrix_subtraction() {
582 let matrix_a = Matrix::<i32> {
583 data: vec![1, 2, 3, 4, 5, 6],
584 row_size: 2,
585 col_size: 3,
586 };
587 let matrix_b = Matrix::<i32> {
588 data: vec![6, 5, 4, 3, 2, 1],
589 row_size: 2,
590 col_size: 3,
591 };
592
593 let expected = Matrix::<i32> {
594 data: vec![-5, -3, -1, 1, 3, 5],
595 row_size: 2,
596 col_size: 3,
597 };
598 let result = matrix_a - matrix_b;
599 assert_eq!(result.data, expected.data);
600 }
601
602 #[test]
603 fn test_matrix_addition() {
604 let matrix_a = Matrix::<i32> {
605 data: vec![1, 2, 3, 4],
606 row_size: 2,
607 col_size: 2,
608 };
609 let matrix_b = Matrix::<i32> {
610 data: vec![4, 3, 2, 1],
611 row_size: 2,
612 col_size: 2,
613 };
614
615 let expected = Matrix::<i32> {
616 data: vec![5, 5, 5, 5],
617 row_size: 2,
618 col_size: 2,
619 };
620
621 let result = matrix_a + matrix_b;
622 assert_eq!(result.data, expected.data);
623 }
624
625 #[test]
626 fn new_matrix_has_correct_size() {
627 let matrix: Matrix<i32> = Matrix::new(2, 3);
628 assert_eq!(matrix.data.len(), 6);
629 }
630
631 #[test]
632 fn default_matrix_is_same_as_new() {
633 let default_matrix: Matrix<i32> = Matrix::default();
634 let new_matrix: Matrix<i32> = Matrix::new(2, 3);
635 assert_eq!(default_matrix.data, new_matrix.data);
636 }
637
638 #[test]
639 fn try_get_column_valid() {
640 let matrix: Matrix<i32> = Matrix::new(2, 3);
641 let column = matrix.try_get_column(1);
642 assert!(column.is_some());
643 assert_eq!(column.unwrap(), vec![0, 0]);
644 }
645
646 #[test]
647 fn try_get_column_invalid() {
648 let matrix: Matrix<i32> = Matrix::new(2, 3);
649 let column = matrix.try_get_column(3);
650 assert!(column.is_none());
651 }
652
653 #[test]
654 fn try_get_row_valid() {
655 let matrix: Matrix<i32> = Matrix::new(2, 3);
656 let row = matrix.try_get_row(0);
657 assert!(row.is_some());
658 assert_eq!(row.unwrap(), vec![0, 0, 0]);
659 }
660
661 #[test]
662 fn try_get_row_invalid() {
663 let matrix: Matrix<i32> = Matrix::new(2, 3);
664 let row = matrix.try_get_row(2);
665 assert!(row.is_none());
666 }
667
668 #[test]
669 fn transpose_works_correctly() {
670 let mut matrix: Matrix<i32> = Matrix::new(2, 3);
671 for i in 0..matrix.data.len() {
672 matrix.data[i] = i as i32;
673 }
674 let transposed = matrix.transpose();
675 assert_eq!(transposed.data, vec![0, 3, 1, 4, 2, 5]);
676 }
677
678 #[test]
679 fn determinant_of_1x1_matrix() {
680 let matrix = Matrix {
681 data: vec![7],
682 row_size: 1,
683 col_size: 1,
684 };
685 assert_eq!(matrix.determinant(), Some(7));
686 }
687
688 #[test]
689 fn determinant_of_2x2_matrix() {
690 let matrix = Matrix {
691 data: vec![1, 2, 3, 4],
692 row_size: 2,
693 col_size: 2,
694 };
695 assert_eq!(matrix.determinant(), Some(-2));
696 }
697
698 #[test]
699 fn determinant_of_3x3_matrix() {
700 let matrix = Matrix {
701 data: vec![3, 2, 1, 0, 1, 4, 5, 6, 0],
702 row_size: 3,
703 col_size: 3,
704 };
705 let expected = -37;
706 assert_eq!(matrix.determinant(), Some(expected));
707 }
708
709 #[test]
710 fn determinant_non_square_matrix() {
711 let matrix = Matrix {
712 data: vec![1, 2, 3, 4, 5, 6],
713 row_size: 2,
714 col_size: 3,
715 };
716 assert_eq!(matrix.determinant(), None);
717 }
718
719 #[test]
720 fn test_frobenius_norm_i32() {
721 let matrix = Matrix::<i32> {
722 data: vec![1, 2, 3, 4],
723 row_size: 2,
724 col_size: 2,
725 };
726
727 let expected: f64 = 5.477225575051661;
728 let result = matrix.frobenius_norm();
729 assert!((result - expected).abs() < 1e-10);
730 }
731
732 #[test]
733 fn test_frobenius_norm_f32() {
734 let matrix = Matrix::<f32> {
735 data: vec![1.0, 2.0, 3.0, 4.0],
736 row_size: 2,
737 col_size: 2,
738 };
739
740 let expected: f64 = 5.477225575051661;
741 let result = matrix.frobenius_norm();
742 assert!((result - expected).abs() < 1e-10);
743 }
744
745 #[test]
746 fn test_inverse_2x2() {
747 let matrix = Matrix {
748 data: vec![4.0, 7.0, 2.0, 6.0],
749 row_size: 2,
750 col_size: 2,
751 };
752
753 let expected_inverse = vec![0.6, -0.7, -0.2, 0.4];
754 let result = matrix.inverse().unwrap();
755
756 assert!(approx_equal(&result.data, &expected_inverse, 1e-6));
757 }
758
759 #[test]
760 fn test_inverse_3x3() {
761 let matrix = Matrix {
762 data: vec![1.0, 2.0, 3.0, 0.0, 1.0, 4.0, 5.0, 6.0, 0.0],
763 row_size: 3,
764 col_size: 3,
765 };
766
767 let expected_inverse = vec![-24.0, 18.0, 5.0, 20.0, -15.0, -4.0, -5.0, 4.0, 1.0];
768 let result = matrix.inverse().unwrap();
769
770 assert!(approx_equal(&result.data, &expected_inverse, 1e-6));
771 }
772
773 #[test]
774 fn test_inverse_identity() {
775 let matrix = Matrix::<f64>::identity(3);
776 let result = matrix.inverse().unwrap();
777
778 assert_eq!(result.data, matrix.data);
779 }
780
781 #[test]
782 fn test_inverse_singular() {
783 let matrix = Matrix {
784 data: vec![1.0, 2.0, 2.0, 4.0],
785 row_size: 2,
786 col_size: 2,
787 };
788
789 let result = matrix.inverse();
790
791 assert!(result.is_none());
792 }
793
794 #[test]
795 fn test_inverse_1x1() {
796 let matrix = Matrix {
797 data: vec![2.0],
798 row_size: 1,
799 col_size: 1,
800 };
801
802 let expected_inverse = vec![0.5];
803 let result = matrix.inverse().unwrap();
804
805 assert!(approx_equal(&result.data, &expected_inverse, 1e-6));
806 }
807
808 #[test]
809 fn test_random_1x1() {
811 let result: Matrix<f64> = Matrix::new_random(1, 1);
812 assert_eq!(result.data.len(), 1);
813 }
814
815 #[test]
816 fn test_random_2x3() {
818 let result: Matrix<i64> = Matrix::new_random(2, 3);
819 assert_eq!(result.data.len(), 6);
820 }
821
822 #[test]
823 fn test_random_20x20() {
825 let result: Matrix<i64> = Matrix::new_random(20, 20);
826 println!("{:?}", result.data);
827 assert_eq!(result.data.len(), 400);
828 }
829
830 #[test]
831 fn test_sum_empty_matrix() {
832 let mat = Matrix::<f64>::new(0, 0);
833 assert_eq!(mat.sum(), 0.0);
834 }
835
836 #[test]
837 fn test_sum_zero_matrix() {
838 let mat = Matrix::<f64>::new(3, 3);
839 assert_eq!(mat.sum(), 0.0);
840 }
841
842 #[test]
843 fn test_sum_positive_elements() {
844 let matrix = Matrix {
845 data: vec![1.0, 2.0, 3.0, 4.0],
846 row_size: 2,
847 col_size: 2,
848 };
849 assert_eq!(matrix.sum(), 10.0);
850 }
851
852 #[test]
853 fn test_sum_negative_elements() {
854 let matrix = Matrix {
855 data: vec![-1.0, -2.0, -3.0, -4.0],
856 row_size: 2,
857 col_size: 2,
858 };
859 assert_eq!(matrix.sum(), -10.0);
860 }
861
862 #[test]
863 fn test_sum_mixed_elements() {
864 let matrix = Matrix {
865 data: vec![-1.0, 2.0, -3.0, 4.0],
866 row_size: 2,
867 col_size: 2,
868 };
869 assert_eq!(matrix.sum(), 2.0);
870 }
871}