1use crate::random;
2use std::fmt::Debug;
3use std::ops::{Add, Div, Mul, Sub};
4
5pub struct Matrix<T> {
7 pub data: Vec<T>,
8 pub row_size: usize,
9 pub col_size: usize,
10}
11impl<T: Default + Clone> Matrix<T> {
12 pub fn new(row_size: usize, col_size: usize) -> Self {
14 Matrix {
15 data: vec![T::default(); row_size * col_size],
16 row_size,
17 col_size,
18 }
19 }
20
21 pub fn new_random(row_size: usize, col_size: usize) -> Self
23 where
24 T: random::Random,
25 {
26 let random_data: Vec<T> = random::gen_rand_vec(row_size * col_size);
27 Matrix {
28 data: random_data,
29 row_size,
30 col_size,
31 }
32 }
33
34 pub fn get(&self, row: usize, col: usize) -> Option<&T> {
36 if row < self.row_size && col < self.col_size {
37 Some(&self.data[col + row * self.col_size])
38 } else {
39 None
40 }
41 }
42
43 pub fn get_diagonal(&self) -> Vec<T> {
45 (0..self.col_size)
46 .filter_map(|col_idx| self.get(col_idx, col_idx).cloned())
47 .collect()
48 }
49
50 pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut T> {
52 if row < self.row_size && col < self.col_size {
53 Some(&mut self.data[col + row * self.col_size])
54 } else {
55 None
56 }
57 }
58
59 pub fn set(&mut self, row: usize, column: usize, value: T) -> bool {
61 if let Some(cell) = self.get_mut(row, column) {
62 *cell = value;
63 true
64 } else {
65 false
66 }
67 }
68
69 pub fn try_get_column(&self, column: usize) -> Option<Vec<T>> {
74 if column >= self.col_size {
76 return None;
77 }
78
79 let col_data: Vec<T> = (0..self.row_size)
81 .map(|row| self.data[row * self.col_size + column].clone())
82 .collect();
83
84 Some(col_data)
85 }
86
87 pub fn try_get_row(&self, row: usize) -> Option<Vec<T>> {
92 if row >= self.row_size {
94 return None;
95 }
96
97 let row_data: Vec<T> = (0..self.col_size)
99 .map(|col| self.data[row * self.col_size + col].clone())
100 .collect();
101
102 Some(row_data)
103 }
104
105 pub fn from_columns(cols: Vec<Vec<T>>) -> Matrix<T> {
107 if cols.is_empty() {
108 return Matrix {
109 data: Vec::new(),
110 row_size: 0,
111 col_size: 0,
112 };
113 }
114
115 let row_size = cols[0].len();
116 let col_size = cols.len();
117
118 let data = (0..row_size)
119 .flat_map(|row| cols.iter().filter_map(move |col| col.get(row).cloned()))
120 .collect();
121
122 Matrix {
123 data,
124 row_size,
125 col_size,
126 }
127 }
128
129 pub fn sub_matrix(&self, skip_row: usize, skip_col: usize) -> Matrix<T> {
131 let columns: Vec<Vec<T>> = (0..self.col_size)
132 .filter_map(|col| {
133 if col != skip_col {
134 Some(
135 self.try_get_column(col)?
136 .into_iter()
137 .enumerate()
138 .filter_map(|(row, val)| if row != skip_row { Some(val) } else { None })
139 .collect(),
140 )
141 } else {
142 None
143 }
144 })
145 .collect();
146
147 Matrix::from_columns(columns)
148 }
149
150 pub fn transpose(&self) -> Matrix<T> {
152 Matrix {
153 data: (0..self.col_size)
154 .flat_map(|col| {
155 (0..self.row_size).map(move |row| self.data[row * self.col_size + col].clone())
156 })
157 .collect(),
158
159 row_size: self.col_size,
160 col_size: self.row_size,
161 }
162 }
163}
164
165impl<T: Default + Clone> Default for Matrix<T> {
166 fn default() -> Self {
168 Self::new(2, 3)
169 }
170}
171impl<T: Default + Clone + Debug> Add for Matrix<T>
172where
173 T: Add<Output = T> + Clone,
174{
175 type Output = Matrix<T>;
176
177 fn add(self, rhs: Self) -> Matrix<T> {
180 let data: Vec<T> = (0..self.row_size)
181 .flat_map(|row| {
182 let row_a = self.try_get_row(row).expect("Invalid row in self");
183 let row_b = rhs.try_get_row(row).expect("Invalid row in rhs");
184 row_a.into_iter().zip(row_b).map(|(a, b)| a + b)
185 })
186 .collect();
187
188 Matrix {
189 data,
190 col_size: self.col_size,
191 row_size: self.row_size,
192 }
193 }
194}
195impl<T> Matrix<T>
196where
197 T: Default + Clone + Add<Output = T>,
198{
199 pub fn trace(&self) -> T {
205 self.get_diagonal()
206 .into_iter()
207 .fold(T::default(), |acc, diagonal| acc + diagonal)
208 }
209}
210impl<T: Default + Clone + Debug> Sub for Matrix<T>
211where
212 T: Sub<Output = T> + Clone,
213{
214 type Output = Matrix<T>;
215
216 fn sub(self, rhs: Self) -> Matrix<T> {
219 let data: Vec<T> = (0..self.row_size)
220 .flat_map(|row| {
221 let row_a = self.try_get_row(row).expect("Invalid row in self");
222 let row_b = rhs.try_get_row(row).expect("Invalid row in rhs");
223 row_a.into_iter().zip(row_b).map(|(a, b)| a - b)
224 })
225 .collect();
226
227 Matrix {
228 data,
229 row_size: self.row_size,
230 col_size: self.col_size,
231 }
232 }
233}
234impl<T: Default> Matrix<T>
235where
236 T: Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Clone,
237{
238 pub fn scalar_multiply(&self, scalar: T) -> Matrix<T> {
241 let data = self
242 .data
243 .iter()
244 .map(|value| value.clone() * scalar.clone())
245 .collect();
246
247 Matrix {
248 data,
249 row_size: self.row_size,
250 col_size: self.col_size,
251 }
252 }
253
254 pub fn multiply(&self, multiplier: &Matrix<T>) -> Option<Matrix<T>> {
257 if self.col_size != multiplier.row_size {
259 return None;
260 }
261
262 let data: Vec<T> = (0..self.row_size)
263 .flat_map(|i| {
264 (0..multiplier.col_size).map(move |j| {
265 (0..self.col_size)
266 .map(|k| {
267 self.data[i * self.col_size + k].clone()
268 * multiplier.data[k * multiplier.col_size + j].clone()
269 })
270 .fold(T::default(), |acc, x| acc + x)
271 })
272 })
273 .collect();
274
275 Some(Matrix {
276 data,
277 col_size: multiplier.col_size,
278 row_size: self.row_size,
279 })
280 }
281
282 pub fn vector_multiply(&self, multiplier: &[T]) -> Option<Vec<T>> {
285 if self.col_size != multiplier.len() {
287 return None;
288 }
289
290 let data: Vec<T> = (0..self.row_size)
291 .map(|i| {
292 (0..multiplier.len())
293 .map(|j| self.data[i * self.col_size + j].clone() * multiplier[j].clone())
294 .fold(T::default(), |acc, x| acc + x)
295 })
296 .collect();
297
298 Some(data)
299 }
300
301 pub fn determinant(&self) -> Option<T> {
305 if self.col_size != self.row_size {
307 return None;
308 }
309
310 if self.row_size == 1 {
312 return Some(self.data[0].clone());
313 }
314
315 if let Some(first_row) = self.try_get_row(0) {
316 let determinant =
317 first_row
318 .iter()
319 .enumerate()
320 .fold(T::default(), |acc, (col_idx, item)| {
321 let sub_matrix = self.sub_matrix(0, col_idx);
322
323 if let Some(sub_determinant) = sub_matrix.determinant() {
324 if col_idx % 2 == 0 {
327 acc + item.clone() * sub_determinant
328 } else {
329 acc - item.clone() * sub_determinant
330 }
331 } else {
332 T::default()
333 }
334 });
335
336 Some(determinant)
337 } else {
338 None
339 }
340 }
341}
342impl<T> Matrix<T>
343where
344 T: Default + Clone + From<f64>,
345{
346 pub fn identity(size: usize) -> Matrix<T> {
348 let data = (0..size * size)
349 .map(|i| {
350 if i % (size + 1) == 0 {
351 T::from(1.0)
352 } else {
353 T::default()
354 }
355 })
356 .collect();
357
358 Matrix {
359 data,
360 row_size: size,
361 col_size: size,
362 }
363 }
364}
365impl<T> Matrix<T>
366where
367 T: Default + Clone + Mul<Output = T> + Add<Output = T> + Into<f64>,
368{
369 pub fn frobenius_norm(&self) -> f64 {
371 let sum_of_squares: f64 = self
372 .data
373 .iter()
374 .map(|val| {
375 let val_f64: f64 = val.clone().into();
376 val_f64 * val_f64
377 })
378 .fold(f64::default(), |acc, x| acc + x);
379
380 sum_of_squares.sqrt()
381 }
382}
383impl<T> Matrix<T>
384where
385 T: Copy
386 + PartialOrd
387 + Default
388 + From<f64>
389 + Into<f64>
390 + Sub<Output = T>
391 + Add<Output = T>
392 + Mul<Output = T>
393 + Div<Output = T>
394 + Abs,
395{
396 pub fn inverse(&self) -> Option<Matrix<T>> {
400 if self.row_size != self.col_size {
402 return None;
403 }
404
405 let n = self.row_size;
406 let epsilon = T::from(1e-10);
407
408 let mut a = self.data.clone();
410 let mut b = Matrix::<T>::identity(n).data;
412
413 for i in 0..n {
414 let pivot = (i..n).fold(i, |acc, j| {
415 match a[j * n + i].abs() > a[acc * n + i].abs() {
416 true => j,
417 false => acc,
418 }
419 });
420
421 if a[pivot * n + i].abs() < epsilon {
423 return None;
424 }
425
426 if pivot != i {
427 (0..n).for_each(|k| {
428 a.swap(i * n + k, pivot * n + k);
429 b.swap(i * n + k, pivot * n + k);
430 });
431 }
432
433 let pivot_val = a[i * n + i];
434 (0..n).for_each(|k| {
435 a[i * n + k] = a[i * n + k] / pivot_val;
436 b[i * n + k] = b[i * n + k] / pivot_val;
437 });
438
439 (0..n).filter(|&j| j != i).for_each(|j| {
440 let factor = a[j * n + i];
441 (0..n).for_each(|k| {
442 a[j * n + k] = a[j * n + k] - factor * a[i * n + k];
443 b[j * n + k] = b[j * n + k] - factor * b[i * n + k];
444 })
445 })
446 }
447
448 Some(Matrix {
449 data: b,
450 row_size: n,
451 col_size: n,
452 })
453 }
454}
455
456pub trait Abs {
457 fn abs(self) -> Self;
458}
459impl Abs for f64 {
460 fn abs(self) -> Self {
461 self.abs()
462 }
463}
464impl Abs for i32 {
465 fn abs(self) -> Self {
466 self.abs()
467 }
468}
469impl Abs for i64 {
470 fn abs(self) -> Self {
471 self.abs()
472 }
473}
474impl Abs for f32 {
475 fn abs(self) -> Self {
476 self.abs()
477 }
478}
479
480#[cfg(test)]
481mod tests {
482 use super::*;
483
484 fn approx_equal(a: &[f64], b: &[f64], epsilon: f64) -> bool {
486 a.iter()
487 .zip(b.iter())
488 .all(|(&a, &b)| (a - b).abs() < epsilon)
489 }
490
491 #[test]
492 fn test_matrix_vector_multiplication() {
493 let matrix = Matrix::<i32> {
494 data: vec![1, 2, 3, 4],
495 row_size: 2,
496 col_size: 2,
497 };
498 let vector = vec![5, 6];
499
500 let expected = vec![17, 39];
501 let result = matrix.vector_multiply(&vector).unwrap();
502 assert_eq!(result, expected);
503 }
504
505 #[test]
506 fn test_matrix_multiplication() {
507 let matrix_a = Matrix::<i32> {
508 data: vec![1, 2, 3, 4],
509 row_size: 2,
510 col_size: 2,
511 };
512
513 let matrix_b = Matrix::<i32> {
514 data: vec![2, 0, 1, 2],
515 row_size: 2,
516 col_size: 2,
517 };
518
519 let expected = Matrix::<i32> {
520 data: vec![4, 4, 10, 8],
521 row_size: 2,
522 col_size: 2,
523 };
524 let result = matrix_a.multiply(&matrix_b).unwrap();
525 assert_eq!(result.data, expected.data);
526 }
527
528 #[test]
529 fn test_martrix_trace() {
530 let matrix = Matrix::<i32> {
531 data: vec![1, 2, 3, 4],
532 col_size: 2,
533 row_size: 2,
534 };
535
536 let expected: i32 = 5;
537 let result = matrix.trace();
538 assert_eq!(result, expected);
539 }
540
541 #[test]
542 fn test_martrix_diagonal() {
543 let matrix = Matrix::<i32> {
544 data: vec![1, 2, 3, 4],
545 row_size: 2,
546 col_size: 2,
547 };
548
549 let expected: Vec<i32> = vec![1, 4];
550 let result = matrix.get_diagonal();
551 assert_eq!(result, expected);
552 }
553
554 #[test]
555 fn test_matrix_scalar_multiplication() {
556 let matrix = Matrix::<i32> {
557 data: vec![1, 2, 3, 4],
558 row_size: 2,
559 col_size: 2,
560 };
561
562 let expected = Matrix::<i32> {
563 data: vec![2, 4, 6, 8],
564 row_size: 2,
565 col_size: 2,
566 };
567 let result = matrix.scalar_multiply(2);
568
569 assert_eq!(result.data, expected.data);
570 }
571
572 #[test]
573 fn test_matrix_subtraction() {
574 let matrix_a = Matrix::<i32> {
575 data: vec![1, 2, 3, 4, 5, 6],
576 row_size: 2,
577 col_size: 3,
578 };
579 let matrix_b = Matrix::<i32> {
580 data: vec![6, 5, 4, 3, 2, 1],
581 row_size: 2,
582 col_size: 3,
583 };
584
585 let expected = Matrix::<i32> {
586 data: vec![-5, -3, -1, 1, 3, 5],
587 row_size: 2,
588 col_size: 3,
589 };
590 let result = matrix_a - matrix_b;
591 assert_eq!(result.data, expected.data);
592 }
593
594 #[test]
595 fn test_matrix_addition() {
596 let matrix_a = Matrix::<i32> {
597 data: vec![1, 2, 3, 4],
598 row_size: 2,
599 col_size: 2,
600 };
601 let matrix_b = Matrix::<i32> {
602 data: vec![4, 3, 2, 1],
603 row_size: 2,
604 col_size: 2,
605 };
606
607 let expected = Matrix::<i32> {
608 data: vec![5, 5, 5, 5],
609 row_size: 2,
610 col_size: 2,
611 };
612
613 let result = matrix_a + matrix_b;
614 assert_eq!(result.data, expected.data);
615 }
616
617 #[test]
618 fn new_matrix_has_correct_size() {
619 let matrix: Matrix<i32> = Matrix::new(2, 3);
620 assert_eq!(matrix.data.len(), 6);
621 }
622
623 #[test]
624 fn default_matrix_is_same_as_new() {
625 let default_matrix: Matrix<i32> = Matrix::default();
626 let new_matrix: Matrix<i32> = Matrix::new(2, 3);
627 assert_eq!(default_matrix.data, new_matrix.data);
628 }
629
630 #[test]
631 fn try_get_column_valid() {
632 let matrix: Matrix<i32> = Matrix::new(2, 3);
633 let column = matrix.try_get_column(1);
634 assert!(column.is_some());
635 assert_eq!(column.unwrap(), vec![0, 0]);
636 }
637
638 #[test]
639 fn try_get_column_invalid() {
640 let matrix: Matrix<i32> = Matrix::new(2, 3);
641 let column = matrix.try_get_column(3);
642 assert!(column.is_none());
643 }
644
645 #[test]
646 fn try_get_row_valid() {
647 let matrix: Matrix<i32> = Matrix::new(2, 3);
648 let row = matrix.try_get_row(0);
649 assert!(row.is_some());
650 assert_eq!(row.unwrap(), vec![0, 0, 0]);
651 }
652
653 #[test]
654 fn try_get_row_invalid() {
655 let matrix: Matrix<i32> = Matrix::new(2, 3);
656 let row = matrix.try_get_row(2);
657 assert!(row.is_none());
658 }
659
660 #[test]
661 fn transpose_works_correctly() {
662 let mut matrix: Matrix<i32> = Matrix::new(2, 3);
663 for i in 0..matrix.data.len() {
664 matrix.data[i] = i as i32;
665 }
666 let transposed = matrix.transpose();
667 assert_eq!(transposed.data, vec![0, 3, 1, 4, 2, 5]);
668 }
669
670 #[test]
671 fn determinant_of_1x1_matrix() {
672 let matrix = Matrix {
673 data: vec![7],
674 row_size: 1,
675 col_size: 1,
676 };
677 assert_eq!(matrix.determinant(), Some(7));
678 }
679
680 #[test]
681 fn determinant_of_2x2_matrix() {
682 let matrix = Matrix {
683 data: vec![1, 2, 3, 4],
684 row_size: 2,
685 col_size: 2,
686 };
687 assert_eq!(matrix.determinant(), Some(-2));
688 }
689
690 #[test]
691 fn determinant_of_3x3_matrix() {
692 let matrix = Matrix {
693 data: vec![3, 2, 1, 0, 1, 4, 5, 6, 0],
694 row_size: 3,
695 col_size: 3,
696 };
697 let expected = -37;
698 assert_eq!(matrix.determinant(), Some(expected));
699 }
700
701 #[test]
702 fn determinant_non_square_matrix() {
703 let matrix = Matrix {
704 data: vec![1, 2, 3, 4, 5, 6],
705 row_size: 2,
706 col_size: 3,
707 };
708 assert_eq!(matrix.determinant(), None);
709 }
710
711 #[test]
712 fn test_frobenius_norm_i32() {
713 let matrix = Matrix::<i32> {
714 data: vec![1, 2, 3, 4],
715 row_size: 2,
716 col_size: 2,
717 };
718
719 let expected: f64 = 5.477225575051661;
720 let result = matrix.frobenius_norm();
721 assert!((result - expected).abs() < 1e-10);
722 }
723
724 #[test]
725 fn test_frobenius_norm_f32() {
726 let matrix = Matrix::<f32> {
727 data: vec![1.0, 2.0, 3.0, 4.0],
728 row_size: 2,
729 col_size: 2,
730 };
731
732 let expected: f64 = 5.477225575051661;
733 let result = matrix.frobenius_norm();
734 assert!((result - expected).abs() < 1e-10);
735 }
736
737 #[test]
738 fn test_inverse_2x2() {
739 let matrix = Matrix {
740 data: vec![4.0, 7.0, 2.0, 6.0],
741 row_size: 2,
742 col_size: 2,
743 };
744
745 let expected_inverse = vec![0.6, -0.7, -0.2, 0.4];
746 let result = matrix.inverse().unwrap();
747
748 assert!(approx_equal(&result.data, &expected_inverse, 1e-6));
749 }
750
751 #[test]
752 fn test_inverse_3x3() {
753 let matrix = Matrix {
754 data: vec![1.0, 2.0, 3.0, 0.0, 1.0, 4.0, 5.0, 6.0, 0.0],
755 row_size: 3,
756 col_size: 3,
757 };
758
759 let expected_inverse = vec![-24.0, 18.0, 5.0, 20.0, -15.0, -4.0, -5.0, 4.0, 1.0];
760 let result = matrix.inverse().unwrap();
761
762 assert!(approx_equal(&result.data, &expected_inverse, 1e-6));
763 }
764
765 #[test]
766 fn test_inverse_identity() {
767 let matrix = Matrix::<f64>::identity(3);
768 let result = matrix.inverse().unwrap();
769
770 assert_eq!(result.data, matrix.data);
771 }
772
773 #[test]
774 fn test_inverse_singular() {
775 let matrix = Matrix {
776 data: vec![1.0, 2.0, 2.0, 4.0],
777 row_size: 2,
778 col_size: 2,
779 };
780
781 let result = matrix.inverse();
782
783 assert!(result.is_none());
784 }
785
786 #[test]
787 fn test_inverse_1x1() {
788 let matrix = Matrix {
789 data: vec![2.0],
790 row_size: 1,
791 col_size: 1,
792 };
793
794 let expected_inverse = vec![0.5];
795 let result = matrix.inverse().unwrap();
796
797 assert!(approx_equal(&result.data, &expected_inverse, 1e-6));
798 }
799
800 #[test]
801 fn test_random_1x1() {
803 let result: Matrix<f64> = Matrix::new_random(1, 1);
804 assert_eq!(result.data.len(), 1);
805 }
806
807 #[test]
808 fn test_random_2x3() {
810 let result: Matrix<i64> = Matrix::new_random(2, 3);
811 assert_eq!(result.data.len(), 6);
812 }
813
814 #[test]
815 fn test_random_20x20() {
817 let result: Matrix<i64> = Matrix::new_random(20, 20);
818 println!("{:?}", result.data);
819 assert_eq!(result.data.len(), 400);
820 }
821}