1use core::ops::{Add, Mul, Sub};
52use std::fmt;
53
54use ark_ff::Field;
55
56#[derive(Debug, Clone)]
57pub struct Matrix<F: Field> {
58 matrix: Vec<Vec<F>>,
59}
60
61impl<F: Field> fmt::Display for Matrix<F> {
62 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63 let mut s = String::new();
64 for row in self.matrix.iter() {
65 for entry in row {
66 s += &format!("{:20?}", entry);
67 }
68 s += "\n";
69 }
70 write!(f, "{}", s)
71 }
72}
73
74impl<F: Field> PartialEq for Matrix<F> {
75 fn eq(&self, other: &Matrix<F>) -> bool {
76 let num_rows = self.matrix.len();
77 let num_columns = self.matrix.first().unwrap().len();
78
79 assert_eq!(num_rows, num_columns, "Matrix is not square");
80
81 for i in 0..num_rows {
82 for j in 0..num_columns {
83 if self.matrix[i][j] != other.matrix[i][j] {
84 return false;
85 }
86 }
87 }
88
89 true
90 }
91}
92
93impl<F: Field> Add<Matrix<F>> for Matrix<F> {
94 type Output = Matrix<F>;
95
96 fn add(self, other: Matrix<F>) -> Matrix<F> {
97 assert_eq!(
98 self.matrix.len(),
99 other.matrix.len(),
100 "Matrices have different number of rows"
101 );
102 assert_eq!(
103 self.matrix.first().unwrap().len(),
104 other.matrix.first().unwrap().len(),
105 "Matrices have different number of columns"
106 );
107
108 let mut result = self.clone();
109
110 for i in 0..self.matrix.len() {
111 for j in 0..self.matrix.first().unwrap().len() {
112 result.matrix[i][j] = self.matrix[i][j] + other.matrix[i][j];
113 }
114 }
115
116 result
117 }
118}
119
120impl<F: Field> Sub<Matrix<F>> for Matrix<F> {
121 type Output = Matrix<F>;
122
123 fn sub(self, other: Matrix<F>) -> Matrix<F> {
124 assert_eq!(
125 self.matrix.len(),
126 other.matrix.len(),
127 "Matrices have different number of rows"
128 );
129 assert_eq!(
130 self.matrix.first().unwrap().len(),
131 other.matrix.first().unwrap().len(),
132 "Matrices have different number of columns"
133 );
134
135 let mut result = self.clone();
136
137 for i in 0..self.matrix.len() {
138 for j in 0..self.matrix.first().unwrap().len() {
139 result.matrix[i][j] = self.matrix[i][j] - other.matrix[i][j];
140 }
141 }
142
143 result
144 }
145}
146
147impl<F: Field> Mul<Matrix<F>> for Matrix<F> {
148 type Output = Matrix<F>;
149
150 fn mul(self, other: Matrix<F>) -> Matrix<F> {
151 assert_eq!(
152 self.matrix.first().unwrap().len(),
153 other.matrix.len(),
154 "Matrices cannot be multiplied"
155 );
156
157 let mut result = Matrix::new(vec![
158 vec![F::ZERO; other.matrix.first().unwrap().len()];
159 self.matrix.len()
160 ]);
161
162 for i in 0..self.matrix.len() {
163 for j in 0..other.matrix.first().unwrap().len() {
164 for k in 0..self.matrix.first().unwrap().len() {
165 result.matrix[i][j] += self.matrix[i][k] * other.matrix[k][j];
166 }
167 }
168 }
169
170 result
171 }
172}
173
174impl<F: Field> Matrix<F> {
175 pub fn new(matrix: Vec<Vec<F>>) -> Matrix<F> {
184 Matrix { matrix }
185 }
186
187 pub fn is_square(self) -> bool {
194 let num_rows = self.matrix.len();
195 let num_columns = self.matrix.first().unwrap().len();
196
197 if num_rows == 0 {
198 return false;
199 }
200
201 num_rows == num_columns
202 }
203
204 pub fn determinant(mut self) -> F {
210 assert_eq!(
211 self.matrix.len(),
212 self.matrix[0].len(),
213 "Matrix is not square"
214 );
215
216 let n = self.matrix.len();
217 let mut det = F::ONE;
218
219 for i in 0..n {
220 let mut pivot_row = i;
221 for j in (i + 1)..n {
222 if self.matrix[j][i] != F::ZERO {
223 pivot_row = j;
224 break;
225 }
226 }
227
228 if pivot_row != i {
229 self.matrix.swap(i, pivot_row);
230 det = -det;
231 }
232
233 let pivot = self.matrix[i][i];
234
235 if pivot == F::ZERO {
236 return F::ZERO;
237 }
238
239 det *= pivot;
240
241 for j in (i + 1)..n {
242 let factor = self.matrix[j][i] / pivot;
243 for k in (i + 1)..n {
244 self.matrix[j][k] = self.matrix[j][k] - factor * self.matrix[i][k];
245 }
246 }
247 }
248
249 det
250 }
251
252 pub fn is_diagonal(self) -> bool {
259 let num_rows = self.matrix.len();
260 let num_columns = self.matrix.first().unwrap().len();
261
262 if num_rows == 0 {
263 return false;
264 }
265
266 for i in 0..num_rows {
267 for j in 0..num_columns {
268 if i != j && self.matrix[i][j] != F::ZERO {
269 return false;
270 }
271 }
272 }
273
274 true
275 }
276
277 pub fn transpose(self) -> Matrix<F> {
283 let num_rows = self.matrix.len();
284 let num_columns = self.matrix.first().unwrap().len();
285
286 let mut new_rows = vec![vec![F::ZERO; num_rows]; num_columns];
287
288 for i in 0..num_rows {
289 for j in 0..num_columns {
290 new_rows[j][i] = self.matrix[i][j];
291 }
292 }
293
294 Matrix { matrix: new_rows }
295 }
296
297 pub fn adjoint(self) -> Matrix<F> {
303 let num_rows = self.matrix.len();
304 let num_columns = self.matrix.first().unwrap().len();
305
306 let mut new_rows = vec![vec![F::ZERO; num_rows]; num_columns];
307
308 for i in 0..num_rows {
309 for j in 0..num_columns {
310 new_rows[j][i] = self.matrix[i][j];
311 }
312 }
313
314 Matrix { matrix: new_rows }
315 }
316
317 pub fn inverse(self) -> Option<Matrix<F>> {
328 let (l_prima, u_prima) = self.lu_decomposition();
329 let (l, u) = (l_prima.matrix, u_prima.matrix);
330
331 let n = self.matrix.len();
332 let mut x = vec![vec![F::ZERO; n]; n];
333
334 for i in 0..n {
335 let mut b = vec![F::ZERO; n];
336 b[i] = F::ONE;
337
338 let mut y = vec![F::ZERO; n];
339
340 for j in 0..n {
342 let mut sum = F::ZERO;
343 for k in 0..j {
344 sum += l[j][k] * y[k];
345 }
346 y[j] = b[j] - sum;
347 }
348
349 for j in (0..n).rev() {
351 let mut sum = F::ZERO;
352 for k in j + 1..n {
353 sum += u[j][k] * x[k][i];
354 }
355 x[j][i] = (y[j] - sum) / u[j][j];
356 }
357 }
358
359 Some(Matrix { matrix: x })
360 }
361
362 pub fn is_identity(self) -> bool {
371 let num_rows = self.matrix.len();
372 let num_columns = self.matrix.first().unwrap().len();
373
374 if num_rows == 0 {
375 return false;
376 }
377
378 for i in 0..num_rows {
379 for j in 0..num_columns {
380 if i == j && self.matrix[i][j] != F::ONE {
381 return false;
382 }
383 if i != j && self.matrix[i][j] != F::ZERO {
384 return false;
385 }
386 }
387 }
388
389 true
390 }
391
392 pub fn ax_b_solve_for_x(self, b: Vec<F>) -> Vec<F> {
403 let (l, u) = self.lu_decomposition();
404 let n = l.matrix.len();
405
406 let mut y = vec![F::ZERO; n];
407 let mut x = vec![F::ZERO; n];
408
409 for i in 0..n {
411 let mut sum = F::ZERO;
412 for j in 0..i {
413 sum += l.matrix[i][j] * y[j];
414 }
415 y[i] = b[i] - sum;
416 }
417
418 for i in (0..n).rev() {
420 let mut sum = F::ZERO;
421 for j in (i + 1)..n {
422 sum += u.matrix[i][j] * x[i];
423 }
424 x[i] = (y[i] - sum) / u.matrix[i][i];
425 }
426
427 x
428 }
429
430 pub fn lu_decomposition(&self) -> (Matrix<F>, Matrix<F>) {
441 assert_ne!(self.clone().determinant(), F::ZERO, "Det(A) = 0");
442
443 let num_rows = self.matrix.len();
444 let num_columns = self.matrix.first().unwrap().len();
445
446 let mut l = vec![vec![F::ZERO; num_rows]; num_columns];
447 let mut u = vec![vec![F::ZERO; num_rows]; num_columns];
448
449 for i in 0..num_rows {
450 for j in 0..num_columns {
451 if i == j {
452 l[i][j] = F::ONE;
453 }
454 }
455 }
456
457 for i in 0..num_rows {
458 for j in 0..num_columns {
459 let mut sum = F::ZERO;
460 for k in 0..i {
461 sum += l[i][k] * u[k][j];
462 }
463 u[i][j] = self.matrix[i][j] - sum;
464 }
465
466 for j in 0..num_columns {
467 let mut sum = F::ZERO;
468 for k in 0..i {
469 sum += l[j][k] * u[k][i];
470 }
471 l[j][i] = (self.matrix[j][i] - sum) / u[i][i];
472 }
473 }
474
475 (Matrix { matrix: l }, Matrix { matrix: u })
476 }
477
478 pub fn scalar_mul(self, scalar: F) -> Matrix<F> {
486 let num_rows = self.matrix.len();
487 let num_columns = self.matrix.first().unwrap().len();
488
489 let mut new_rows = vec![vec![F::ZERO; num_rows]; num_columns];
490
491 for i in 0..num_rows {
492 for j in 0..num_columns {
493 new_rows[i][j] = self.matrix[i][j] * scalar;
494 }
495 }
496
497 Matrix { matrix: new_rows }
498 }
499
500 pub fn mul_vec(self, vec: Vec<F>) -> Vec<F> {
510 assert_eq!(
511 vec.len(),
512 self.matrix.len(),
513 "Vector and matrix can't be multiplied by the other"
514 );
515
516 let mut result = vec![F::ZERO; self.matrix.first().unwrap().len()];
517
518 for i in 0..self.matrix.first().unwrap().len() {
519 for j in 0..self.matrix.len() {
520 result[i] += vec[j] * self.matrix[j][i];
521 }
522 }
523
524 result
525 }
526
527 pub fn sum_of_matrix(self) -> F {
535 let num_rows = self.matrix.len();
536 let num_columns = self.matrix.first().unwrap().len();
537
538 let mut sum = F::ZERO;
539
540 for i in 0..num_rows {
541 for j in 0..num_columns {
542 sum += self.matrix[i][j];
543 }
544 }
545
546 sum
547 }
548
549 pub fn set_element(&mut self, row: usize, column: usize, new_value: F) {
566 let num_rows = self.matrix.len();
567 let num_columns = self.matrix.first().unwrap().len();
568
569 assert!(
570 row < num_rows && column < num_columns,
571 "Index out of bounds"
572 );
573
574 for i in 0..num_rows {
575 for j in 0..num_columns {
576 if i == row && j == column {
577 self.matrix[i][j] = new_value;
578 }
579 }
580 }
581 }
582
583 pub fn get_element(&self, row: usize, column: usize) -> F {
600 let num_rows = self.matrix.len();
601 let num_columns = self.matrix.first().unwrap().len();
602
603 assert!(
604 row < num_rows && column < num_columns,
605 "Index out of bounds"
606 );
607
608 self.matrix[row][column]
609 }
610}
611
612#[cfg(test)]
613mod tests {
614 use super::Matrix;
615
616 use ark_ff::{Fp64, MontBackend};
617
618 #[test]
619 fn test_matrix_utils_determinant() {
620 #[derive(ark_ff::MontConfig)]
621 #[modulus = "127"]
622 #[generator = "6"]
623 pub struct F127Config;
624 type F = Fp64<MontBackend<F127Config, 1>>;
625 let a: Matrix<F> = Matrix::new(vec![
626 vec![F::from(2), F::from(2)],
627 vec![F::from(3), F::from(4)],
628 ]);
629 let b: Matrix<F> = Matrix::new(vec![
630 vec![F::from(1), F::from(2), F::from(3)],
631 vec![F::from(4), F::from(5), F::from(6)],
632 vec![F::from(7), F::from(8), F::from(9)],
633 ]);
634
635 assert_eq!(a.determinant(), F::from(2));
636 assert_eq!(b.determinant(), F::from(0));
637 }
638
639 #[test]
640 fn test_matrix_utils_is_square() {
641 #[derive(ark_ff::MontConfig)]
642 #[modulus = "127"]
643 #[generator = "6"]
644 pub struct F127Config;
645 type F = Fp64<MontBackend<F127Config, 1>>;
646 let a: Matrix<F> = Matrix::new(vec![
647 vec![F::from(1), F::from(2)],
648 vec![F::from(3), F::from(4)],
649 ]);
650 let b = Matrix::new(vec![
651 vec![F::from(1), F::from(2)],
652 vec![F::from(3), F::from(4)],
653 vec![F::from(5), F::from(6)],
654 ]);
655
656 assert!(a.is_square());
657 assert!(!b.is_square());
658 }
659
660 #[test]
661 fn test_matrix_utils_is_diagonal() {
662 #[derive(ark_ff::MontConfig)]
663 #[modulus = "127"]
664 #[generator = "6"]
665 pub struct F127Config;
666 type F = Fp64<MontBackend<F127Config, 1>>;
667 let a: Matrix<F> = Matrix::new(vec![
668 vec![F::from(1), F::from(0)],
669 vec![F::from(0), F::from(4)],
670 ]);
671 let b = Matrix::new(vec![
672 vec![F::from(1), F::from(2)],
673 vec![F::from(3), F::from(4)],
674 ]);
675
676 assert!(a.is_diagonal());
677 assert!(!b.is_diagonal());
678 }
679
680 #[test]
681 fn test_matrix_utils_transpose() {
682 #[derive(ark_ff::MontConfig)]
683 #[modulus = "127"]
684 #[generator = "6"]
685 pub struct F127Config;
686 type F = Fp64<MontBackend<F127Config, 1>>;
687 let a: Matrix<F> = Matrix::new(vec![
688 vec![F::from(1), F::from(2)],
689 vec![F::from(3), F::from(4)],
690 ]);
691
692 let b = a.transpose();
693
694 assert_eq!(
695 b,
696 Matrix::new(vec![
697 vec![F::from(1), F::from(3)],
698 vec![F::from(2), F::from(4)]
699 ])
700 );
701 }
702
703 #[test]
704 fn test_matrix_utils_adjoint() {
705 #[derive(ark_ff::MontConfig)]
706 #[modulus = "127"]
707 #[generator = "6"]
708 pub struct F127Config;
709 type F = Fp64<MontBackend<F127Config, 1>>;
710 let a: Matrix<F> = Matrix::new(vec![
711 vec![F::from(2), F::from(3)],
712 vec![F::from(4), F::from(3)],
713 ]);
714
715 let b = a.adjoint();
716
717 assert_eq!(
718 b,
719 Matrix::new(vec![
720 vec![F::from(2), F::from(4)],
721 vec![F::from(3), F::from(3)]
722 ])
723 );
724 }
725
726 #[test]
727 fn test_matrix_utils_inverse() {
728 #[derive(ark_ff::MontConfig)]
729 #[modulus = "127"]
730 #[generator = "6"]
731 pub struct F127Config;
732 type F = Fp64<MontBackend<F127Config, 1>>;
733 let a: Matrix<F> = Matrix::new(vec![
734 vec![F::from(1), F::from(2)],
735 vec![F::from(3), F::from(7)],
736 ]);
737
738 let b = a.inverse().unwrap();
739
740 assert_eq!(
741 b,
742 Matrix::new(vec![
743 vec![F::from(7), -F::from(2)],
744 vec![-F::from(3), F::from(1)]
745 ])
746 );
747 }
748
749 #[test]
750 fn test_matrix_utils_is_identity() {
751 #[derive(ark_ff::MontConfig)]
752 #[modulus = "127"]
753 #[generator = "6"]
754 pub struct F127Config;
755 type F = Fp64<MontBackend<F127Config, 1>>;
756 let a: Matrix<F> = Matrix::new(vec![
757 vec![F::from(1), F::from(0)],
758 vec![F::from(0), F::from(1)],
759 ]);
760 let b = Matrix::new(vec![
761 vec![F::from(1), F::from(2)],
762 vec![F::from(3), F::from(4)],
763 ]);
764
765 assert!(a.is_identity());
766 assert!(!b.is_identity());
767 }
768
769 #[test]
771 fn test_matrix_utils_ax_b_solve_for_x() {
772 #[derive(ark_ff::MontConfig)]
773 #[modulus = "127"]
774 #[generator = "6"]
775 pub struct F127Config;
776 type F = Fp64<MontBackend<F127Config, 1>>;
777 let a: Matrix<F> = Matrix::new(vec![
778 vec![F::from(1), F::from(2)],
779 vec![F::from(3), F::from(4)],
780 ]);
781 let b = vec![F::from(1), F::from(2)];
782
783 let x = a.ax_b_solve_for_x(b);
784
785 assert_eq!(x, vec![F::from(1), F::from(1) / F::from(2)]);
786 }
787
788 #[test]
789 fn test_matrix_utils_lu_decomposition() {
790 #[derive(ark_ff::MontConfig)]
791 #[modulus = "127"]
792 #[generator = "6"]
793 pub struct F127Config;
794 type F = Fp64<MontBackend<F127Config, 1>>;
795 let a: Matrix<F> = Matrix::new(vec![
796 vec![F::from(1), F::from(2)],
797 vec![F::from(3), F::from(4)],
798 ]);
799
800 let (l, u) = a.lu_decomposition();
801
802 assert_eq!(
803 l,
804 Matrix::new(vec![
805 vec![F::from(1), F::from(0)],
806 vec![F::from(3), F::from(1)]
807 ])
808 );
809 assert_eq!(
810 u,
811 Matrix::new(vec![
812 vec![F::from(1), F::from(2)],
813 vec![F::from(0), -F::from(2)]
814 ])
815 );
816 }
817
818 #[test]
819 fn test_matrix_utils_matrix_addition() {
820 #[derive(ark_ff::MontConfig)]
821 #[modulus = "127"]
822 #[generator = "6"]
823 pub struct F127Config;
824 type F = Fp64<MontBackend<F127Config, 1>>;
825 let a: Matrix<F> = Matrix::new(vec![
826 vec![F::from(1), F::from(2)],
827 vec![F::from(3), F::from(4)],
828 ]);
829 let b: Matrix<F> = Matrix::new(vec![
830 vec![F::from(1), F::from(2)],
831 vec![F::from(3), F::from(4)],
832 ]);
833
834 let c = a + b;
835
836 assert_eq!(
837 c,
838 Matrix::new(vec![
839 vec![F::from(2), F::from(4)],
840 vec![F::from(6), F::from(8)]
841 ])
842 );
843 }
844
845 #[test]
846 fn test_matrix_utils_matrix_substraction() {
847 #[derive(ark_ff::MontConfig)]
848 #[modulus = "127"]
849 #[generator = "6"]
850 pub struct F127Config;
851 type F = Fp64<MontBackend<F127Config, 1>>;
852 let a: Matrix<F> = Matrix::new(vec![
853 vec![F::from(1), F::from(2)],
854 vec![F::from(3), F::from(4)],
855 ]);
856 let b: Matrix<F> = Matrix::new(vec![
857 vec![F::from(1), F::from(2)],
858 vec![F::from(3), F::from(4)],
859 ]);
860
861 let c = a - b;
862
863 assert_eq!(
864 c,
865 Matrix::new(vec![
866 vec![F::from(0), F::from(0)],
867 vec![F::from(0), F::from(0)]
868 ])
869 );
870 }
871
872 #[test]
873 fn test_matrix_utils_matrix_multiplication() {
874 #[derive(ark_ff::MontConfig)]
875 #[modulus = "127"]
876 #[generator = "6"]
877 pub struct F127Config;
878 type F = Fp64<MontBackend<F127Config, 1>>;
879 let a: Matrix<F> = Matrix::new(vec![
880 vec![F::from(1), F::from(2)],
881 vec![F::from(3), F::from(4)],
882 ]);
883 let b: Matrix<F> = Matrix::new(vec![
884 vec![F::from(2), F::from(3)],
885 vec![F::from(4), F::from(5)],
886 ]);
887
888 let c = a.clone() * (b.clone());
889 let d = b * (a.clone());
890
891 assert_eq!(
892 c,
893 Matrix::new(vec![
894 vec![F::from(10), F::from(13)],
895 vec![F::from(22), F::from(29)]
896 ])
897 );
898
899 assert_eq!(
900 d,
901 Matrix::new(vec![
902 vec![F::from(11), F::from(16)],
903 vec![F::from(19), F::from(28)]
904 ])
905 );
906 }
907
908 #[test]
909 fn test_matrix_utils_matrix_scalar_multiplication() {
910 #[derive(ark_ff::MontConfig)]
911 #[modulus = "127"]
912 #[generator = "6"]
913 pub struct F127Config;
914 type F = Fp64<MontBackend<F127Config, 1>>;
915 let a: Matrix<F> = Matrix::new(vec![
916 vec![F::from(1), F::from(2)],
917 vec![F::from(3), F::from(4)],
918 ]);
919
920 let b = a.scalar_mul(F::from(2));
921
922 assert_eq!(
923 b,
924 Matrix::new(vec![
925 vec![F::from(2), F::from(4)],
926 vec![F::from(6), F::from(8)]
927 ])
928 );
929 }
930
931 #[test]
933 fn test_matrix_utils_multiply_vec() {
934 #[derive(ark_ff::MontConfig)]
935 #[modulus = "127"]
936 #[generator = "6"]
937 pub struct F127Config;
938 type F = Fp64<MontBackend<F127Config, 1>>;
939 let a: Matrix<F> = Matrix::new(vec![
940 vec![F::from(1), F::from(2)],
941 vec![F::from(3), F::from(4)],
942 ]);
943 let b = vec![F::from(1), F::from(2)];
944
945 let c = a.mul_vec(b);
946
947 assert_eq!(c, vec![F::from(7), F::from(10)]);
948 }
949
950 #[test]
951 fn test_matrix_utils_sum_of_matrix() {
952 #[derive(ark_ff::MontConfig)]
953 #[modulus = "127"]
954 #[generator = "6"]
955 pub struct F127Config;
956 type F = Fp64<MontBackend<F127Config, 1>>;
957 let a: Matrix<F> = Matrix::new(vec![
958 vec![F::from(1), F::from(2)],
959 vec![F::from(3), F::from(4)],
960 ]);
961
962 let b = a.sum_of_matrix();
963
964 assert_eq!(b, F::from(10));
965 }
966
967 #[test]
968 fn test_matrix_utils_set_element() {
969 #[derive(ark_ff::MontConfig)]
970 #[modulus = "127"]
971 #[generator = "6"]
972 pub struct F127Config;
973 type F = Fp64<MontBackend<F127Config, 1>>;
974 let mut a: Matrix<F> = Matrix::new(vec![
975 vec![F::from(1), F::from(2)],
976 vec![F::from(3), F::from(4)],
977 ]);
978
979 a.set_element(0, 0, F::from(5));
980 a.set_element(1, 1, F::from(6));
981
982 assert_eq!(
983 a,
984 Matrix::new(vec![
985 vec![F::from(5), F::from(2)],
986 vec![F::from(3), F::from(6)]
987 ])
988 );
989 }
990
991 #[test]
992 fn test_matrix_utils_get_element() {
993 #[derive(ark_ff::MontConfig)]
994 #[modulus = "127"]
995 #[generator = "6"]
996 pub struct F127Config;
997 type F = Fp64<MontBackend<F127Config, 1>>;
998 let a: Matrix<F> = Matrix::new(vec![
999 vec![F::from(1), F::from(2)],
1000 vec![F::from(3), F::from(4)],
1001 ]);
1002
1003 assert_eq!(a.get_element(0, 0), F::from(1));
1004 assert_eq!(a.get_element(0, 1), F::from(2));
1005 assert_eq!(a.get_element(1, 0), F::from(3));
1006 assert_eq!(a.get_element(1, 1), F::from(4));
1007 }
1008}