field_matrix_utils/
lib.rs

1//! Simple matrix library for Rust.
2//! Used for use with Finite fields.
3//!
4//! Not safe for production use.
5//! It was only done for educational purposes.
6//!
7//! # Example
8//! ```
9//! // Arkworks has a macro to generate the modulus and generator for a finite field.
10//! // Type F is field element for use in our matrix.
11//! // You should be able to use any. This is just an example.
12//! use ark_ff::{Fp64, MontBackend};
13//! #[derive(ark_ff::MontConfig)]
14//! #[modulus = "127"]
15//! #[generator = "6"]
16//! pub struct F127Config;
17//! type F = Fp64<MontBackend<F127Config, 1>>;
18//!
19//! // The good stuff starts here.
20//! let a: Matrix<F> = Matrix::new(vec![
21//!     vec![F::from(1), F::from(2)],
22//!     vec![F::from(3), F::from(4)],
23//! ]);
24//! let b: Matrix<F> = a.transpose();
25//! let c: Matrix<F> = a + b;
26//! let d: Matrix<F> = a * b;
27//! let det: F = a.determinant();
28//! ...
29//! ```
30//! # Features:
31//! - Addition
32//! - Subtraction
33//! - Multiplication
34//! - Transpose
35//! - Determinant
36//! - Inverse
37//! - Is square
38//! - Adjoint
39//! - LU decomposition
40//! - Scalar multiplication
41//! - Vector multiplication
42//! - Sumation
43//! - Get element at index
44//! - Set element at index
45//! - Is identity
46//! - Equality
47//! - Display
48//! - Linear equations Ax = b for x solution
49//!
50
51use core::ops::{Add, Mul, Sub};
52use std::fmt;
53
54use ark_ff::Field;
55
56#[derive(Debug, Clone)]
57pub struct Matrix<F: Field> {
58    matrix: Vec<Vec<F>>,
59}
60
61impl<F: Field> fmt::Display for Matrix<F> {
62    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63        let mut s = String::new();
64        for row in self.matrix.iter() {
65            for entry in row {
66                s += &format!("{:20?}", entry);
67            }
68            s += "\n";
69        }
70        write!(f, "{}", s)
71    }
72}
73
74impl<F: Field> PartialEq for Matrix<F> {
75    fn eq(&self, other: &Matrix<F>) -> bool {
76        let num_rows = self.matrix.len();
77        let num_columns = self.matrix.first().unwrap().len();
78
79        assert_eq!(num_rows, num_columns, "Matrix is not square");
80
81        for i in 0..num_rows {
82            for j in 0..num_columns {
83                if self.matrix[i][j] != other.matrix[i][j] {
84                    return false;
85                }
86            }
87        }
88
89        true
90    }
91}
92
93impl<F: Field> Add<Matrix<F>> for Matrix<F> {
94    type Output = Matrix<F>;
95
96    fn add(self, other: Matrix<F>) -> Matrix<F> {
97        assert_eq!(
98            self.matrix.len(),
99            other.matrix.len(),
100            "Matrices have different number of rows"
101        );
102        assert_eq!(
103            self.matrix.first().unwrap().len(),
104            other.matrix.first().unwrap().len(),
105            "Matrices have different number of columns"
106        );
107
108        let mut result = self.clone();
109
110        for i in 0..self.matrix.len() {
111            for j in 0..self.matrix.first().unwrap().len() {
112                result.matrix[i][j] = self.matrix[i][j] + other.matrix[i][j];
113            }
114        }
115
116        result
117    }
118}
119
120impl<F: Field> Sub<Matrix<F>> for Matrix<F> {
121    type Output = Matrix<F>;
122
123    fn sub(self, other: Matrix<F>) -> Matrix<F> {
124        assert_eq!(
125            self.matrix.len(),
126            other.matrix.len(),
127            "Matrices have different number of rows"
128        );
129        assert_eq!(
130            self.matrix.first().unwrap().len(),
131            other.matrix.first().unwrap().len(),
132            "Matrices have different number of columns"
133        );
134
135        let mut result = self.clone();
136
137        for i in 0..self.matrix.len() {
138            for j in 0..self.matrix.first().unwrap().len() {
139                result.matrix[i][j] = self.matrix[i][j] - other.matrix[i][j];
140            }
141        }
142
143        result
144    }
145}
146
147impl<F: Field> Mul<Matrix<F>> for Matrix<F> {
148    type Output = Matrix<F>;
149
150    fn mul(self, other: Matrix<F>) -> Matrix<F> {
151        assert_eq!(
152            self.matrix.first().unwrap().len(),
153            other.matrix.len(),
154            "Matrices cannot be multiplied"
155        );
156
157        let mut result = Matrix::new(vec![
158            vec![F::ZERO; other.matrix.first().unwrap().len()];
159            self.matrix.len()
160        ]);
161
162        for i in 0..self.matrix.len() {
163            for j in 0..other.matrix.first().unwrap().len() {
164                for k in 0..self.matrix.first().unwrap().len() {
165                    result.matrix[i][j] += self.matrix[i][k] * other.matrix[k][j];
166                }
167            }
168        }
169
170        result
171    }
172}
173
174impl<F: Field> Matrix<F> {
175    /// Creates a new matrix from a vector of vectors.
176    /// ## Example
177    /// ```
178    /// let a: Matrix<F> = Matrix::new(vec![
179    ///     vec![F::from(2), F::from(2)],
180    ///     vec![F::from(3), F::from(4)],
181    /// ]);
182    /// ```
183    pub fn new(matrix: Vec<Vec<F>>) -> Matrix<F> {
184        Matrix { matrix }
185    }
186
187    /// Returns whether or not the matrix is square.
188    /// ## Example
189    /// ```
190    /// let is_square: bool = a.is_square();
191    /// assert!(is_square);
192    /// ```
193    pub fn is_square(self) -> bool {
194        let num_rows = self.matrix.len();
195        let num_columns = self.matrix.first().unwrap().len();
196
197        if num_rows == 0 {
198            return false;
199        }
200
201        num_rows == num_columns
202    }
203
204    /// Returns the determinant of the matrix.
205    /// ## Example
206    /// ```
207    /// let det: F = a.determinant();
208    /// ```
209    pub fn determinant(mut self) -> F {
210        assert_eq!(
211            self.matrix.len(),
212            self.matrix[0].len(),
213            "Matrix is not square"
214        );
215
216        let n = self.matrix.len();
217        let mut det = F::ONE;
218
219        for i in 0..n {
220            let mut pivot_row = i;
221            for j in (i + 1)..n {
222                if self.matrix[j][i] != F::ZERO {
223                    pivot_row = j;
224                    break;
225                }
226            }
227
228            if pivot_row != i {
229                self.matrix.swap(i, pivot_row);
230                det = -det;
231            }
232
233            let pivot = self.matrix[i][i];
234
235            if pivot == F::ZERO {
236                return F::ZERO;
237            }
238
239            det *= pivot;
240
241            for j in (i + 1)..n {
242                let factor = self.matrix[j][i] / pivot;
243                for k in (i + 1)..n {
244                    self.matrix[j][k] = self.matrix[j][k] - factor * self.matrix[i][k];
245                }
246            }
247        }
248
249        det
250    }
251
252    /// Returns whether or not the matrix is diagonal.
253    /// ## Example
254    /// ```
255    /// let is_diagonal: bool = a.is_diagonal();
256    /// assert!(is_diagonal);
257    /// ```
258    pub fn is_diagonal(self) -> bool {
259        let num_rows = self.matrix.len();
260        let num_columns = self.matrix.first().unwrap().len();
261
262        if num_rows == 0 {
263            return false;
264        }
265
266        for i in 0..num_rows {
267            for j in 0..num_columns {
268                if i != j && self.matrix[i][j] != F::ZERO {
269                    return false;
270                }
271            }
272        }
273
274        true
275    }
276
277    /// Returns the transpose of the matrix.
278    /// ## Example
279    /// ```
280    /// let b: Matrix<F> = a.transpose();
281    /// ```
282    pub fn transpose(self) -> Matrix<F> {
283        let num_rows = self.matrix.len();
284        let num_columns = self.matrix.first().unwrap().len();
285
286        let mut new_rows = vec![vec![F::ZERO; num_rows]; num_columns];
287
288        for i in 0..num_rows {
289            for j in 0..num_columns {
290                new_rows[j][i] = self.matrix[i][j];
291            }
292        }
293
294        Matrix { matrix: new_rows }
295    }
296
297    /// Returns the adjoint of the matrix.
298    /// ## Example
299    /// ```
300    /// let b: Matrix<F> = a.adjoint();
301    /// ```
302    pub fn adjoint(self) -> Matrix<F> {
303        let num_rows = self.matrix.len();
304        let num_columns = self.matrix.first().unwrap().len();
305
306        let mut new_rows = vec![vec![F::ZERO; num_rows]; num_columns];
307
308        for i in 0..num_rows {
309            for j in 0..num_columns {
310                new_rows[j][i] = self.matrix[i][j];
311            }
312        }
313
314        Matrix { matrix: new_rows }
315    }
316
317    /// Returns the inverse of the matrix.
318    /// ## Example
319    /// ```
320    /// let b: Matrix<F> = a.inverse();
321    /// ```
322    /// ## Panics
323    /// Panics if the matrix is not invertible.
324    /// ## Notes
325    /// This function uses the LU decomposition to compute the inverse.
326    /// The LU decomposition is computed using the Doolittle algorithm.
327    pub fn inverse(self) -> Option<Matrix<F>> {
328        let (l_prima, u_prima) = self.lu_decomposition();
329        let (l, u) = (l_prima.matrix, u_prima.matrix);
330
331        let n = self.matrix.len();
332        let mut x = vec![vec![F::ZERO; n]; n];
333
334        for i in 0..n {
335            let mut b = vec![F::ZERO; n];
336            b[i] = F::ONE;
337
338            let mut y = vec![F::ZERO; n];
339
340            // solve Ly = b for y
341            for j in 0..n {
342                let mut sum = F::ZERO;
343                for k in 0..j {
344                    sum += l[j][k] * y[k];
345                }
346                y[j] = b[j] - sum;
347            }
348
349            // solve Ux = y for x
350            for j in (0..n).rev() {
351                let mut sum = F::ZERO;
352                for k in j + 1..n {
353                    sum += u[j][k] * x[k][i];
354                }
355                x[j][i] = (y[j] - sum) / u[j][j];
356            }
357        }
358
359        Some(Matrix { matrix: x })
360    }
361
362    /// Returns whether or not the matrix is the identity matrix.
363    /// ## Example
364    /// ```
365    /// let is_identity: bool = a.is_identity();
366    /// assert!(is_identity);
367    /// ```
368    /// ## Notes
369    /// This function returns false if the matrix is empty.
370    pub fn is_identity(self) -> bool {
371        let num_rows = self.matrix.len();
372        let num_columns = self.matrix.first().unwrap().len();
373
374        if num_rows == 0 {
375            return false;
376        }
377
378        for i in 0..num_rows {
379            for j in 0..num_columns {
380                if i == j && self.matrix[i][j] != F::ONE {
381                    return false;
382                }
383                if i != j && self.matrix[i][j] != F::ZERO {
384                    return false;
385                }
386            }
387        }
388
389        true
390    }
391
392    /// Solves the system of linear equations Ax = b for x.
393    /// ## Example
394    /// ```
395    /// let x: Vec<F> = a.solve_for_x(b);
396    /// ```
397    /// ## Panics
398    /// Panics if the matrix is the determinant of the matrix is zero.
399    /// ## Notes
400    /// This function uses the LU decomposition to solve the system of linear equations.
401    /// The LU decomposition is computed using the Doolittle algorithm.
402    pub fn ax_b_solve_for_x(self, b: Vec<F>) -> Vec<F> {
403        let (l, u) = self.lu_decomposition();
404        let n = l.matrix.len();
405
406        let mut y = vec![F::ZERO; n];
407        let mut x = vec![F::ZERO; n];
408
409        // Solve for Ly=b
410        for i in 0..n {
411            let mut sum = F::ZERO;
412            for j in 0..i {
413                sum += l.matrix[i][j] * y[j];
414            }
415            y[i] = b[i] - sum;
416        }
417
418        //Solve Ux = y
419        for i in (0..n).rev() {
420            let mut sum = F::ZERO;
421            for j in (i + 1)..n {
422                sum += u.matrix[i][j] * x[i];
423            }
424            x[i] = (y[i] - sum) / u.matrix[i][i];
425        }
426
427        x
428    }
429
430    /// Returns the LU decomposition of the matrix.
431    /// ## Example
432    /// ```
433    /// let (l, u): (Matrix<F>, Matrix<F>) = a.lu_decomposition();
434    /// ```
435    /// ## Panics
436    /// Panics if the matrix is the determinant of the matrix is zero.
437    /// ## Notes
438    /// This function uses the Doolittle algorithm to compute the LU decomposition.
439    /// The Doolittle algorithm is a variant of the Gaussian elimination algorithm.
440    pub fn lu_decomposition(&self) -> (Matrix<F>, Matrix<F>) {
441        assert_ne!(self.clone().determinant(), F::ZERO, "Det(A) = 0");
442
443        let num_rows = self.matrix.len();
444        let num_columns = self.matrix.first().unwrap().len();
445
446        let mut l = vec![vec![F::ZERO; num_rows]; num_columns];
447        let mut u = vec![vec![F::ZERO; num_rows]; num_columns];
448
449        for i in 0..num_rows {
450            for j in 0..num_columns {
451                if i == j {
452                    l[i][j] = F::ONE;
453                }
454            }
455        }
456
457        for i in 0..num_rows {
458            for j in 0..num_columns {
459                let mut sum = F::ZERO;
460                for k in 0..i {
461                    sum += l[i][k] * u[k][j];
462                }
463                u[i][j] = self.matrix[i][j] - sum;
464            }
465
466            for j in 0..num_columns {
467                let mut sum = F::ZERO;
468                for k in 0..i {
469                    sum += l[j][k] * u[k][i];
470                }
471                l[j][i] = (self.matrix[j][i] - sum) / u[i][i];
472            }
473        }
474
475        (Matrix { matrix: l }, Matrix { matrix: u })
476    }
477
478    /// Multiplies the matrix by a scalar.
479    /// ## Example
480    /// ```
481    /// let c: Matrix<F> = a.scalar_mul(b);
482    /// ```
483    /// ## Notes
484    /// This function is equivalent to multiplying each element of the matrix by the scalar.
485    pub fn scalar_mul(self, scalar: F) -> Matrix<F> {
486        let num_rows = self.matrix.len();
487        let num_columns = self.matrix.first().unwrap().len();
488
489        let mut new_rows = vec![vec![F::ZERO; num_rows]; num_columns];
490
491        for i in 0..num_rows {
492            for j in 0..num_columns {
493                new_rows[i][j] = self.matrix[i][j] * scalar;
494            }
495        }
496
497        Matrix { matrix: new_rows }
498    }
499
500    /// Multiplies the matrix by a vector.
501    /// ## Example
502    /// ```
503    /// let c: Vec<F> = a.mul_vec(b);
504    /// ```
505    /// ## Panics
506    /// Panics if the number of rows in the matrix is not equal to the number of elements in the vector.
507    /// ## Notes
508    /// This function is equivalent to multiplying the matrix by a column vector.
509    pub fn mul_vec(self, vec: Vec<F>) -> Vec<F> {
510        assert_eq!(
511            vec.len(),
512            self.matrix.len(),
513            "Vector and matrix can't be multiplied by the other"
514        );
515
516        let mut result = vec![F::ZERO; self.matrix.first().unwrap().len()];
517
518        for i in 0..self.matrix.first().unwrap().len() {
519            for j in 0..self.matrix.len() {
520                result[i] += vec[j] * self.matrix[j][i];
521            }
522        }
523
524        result
525    }
526
527    /// Returns the sum of all the elements in the matrix.
528    /// ## Example
529    /// ```
530    /// let c: F = a.sum_of_matrix();
531    /// ```
532    /// ## Notes
533    /// This function is equivalent to summing all the elements in the matrix.
534    pub fn sum_of_matrix(self) -> F {
535        let num_rows = self.matrix.len();
536        let num_columns = self.matrix.first().unwrap().len();
537
538        let mut sum = F::ZERO;
539
540        for i in 0..num_rows {
541            for j in 0..num_columns {
542                sum += self.matrix[i][j];
543            }
544        }
545
546        sum
547    }
548
549    /// Sets a specific element in the matrix.
550    /// ## Example
551    /// ```
552    /// type F = Field;
553    /// let a: Matrix<F> = Matrix::new(vec![
554    ///     vec![F::from(2), F::from(2)],
555    ///     vec![F::from(3), F::from(4)],
556    /// ]);
557    /// a.set_element(0, 0, F::from(1));
558    /// ```
559    /// ## Panics
560    /// Panics if the row or column is out of bounds.
561    /// ## Notes
562    /// This function is equivalent to setting the element in the matrix.
563    /// With row being the position in the outer vector and column being the position in the inner vector.
564    /// The first element in the outer vector is row 0 and the first element in the inner vector is column 0.
565    pub fn set_element(&mut self, row: usize, column: usize, new_value: F) {
566        let num_rows = self.matrix.len();
567        let num_columns = self.matrix.first().unwrap().len();
568
569        assert!(
570            row < num_rows && column < num_columns,
571            "Index out of bounds"
572        );
573
574        for i in 0..num_rows {
575            for j in 0..num_columns {
576                if i == row && j == column {
577                    self.matrix[i][j] = new_value;
578                }
579            }
580        }
581    }
582
583    /// Gets a specific element in the matrix.
584    /// ## Example
585    /// ```
586    /// let a: Matrix<F> = Matrix::new(vec![
587    ///   vec![F::from(2), F::from(2)],
588    ///   vec![F::from(3), F::from(4)],
589    /// ]);
590    /// let b: F = a.get_element(0, 0);
591    /// assert_eq!(b, F::from(2));
592    /// ```
593    /// ## Panics
594    /// Panics if the row or column is out of bounds.
595    /// ## Notes
596    /// This function is equivalent to getting the element in the matrix.
597    /// With row being the position in the outer vector and column being the position in the inner vector.
598    /// The first element in the outer vector is row 0 and the first element in the inner vector is column 0.
599    pub fn get_element(&self, row: usize, column: usize) -> F {
600        let num_rows = self.matrix.len();
601        let num_columns = self.matrix.first().unwrap().len();
602
603        assert!(
604            row < num_rows && column < num_columns,
605            "Index out of bounds"
606        );
607
608        self.matrix[row][column]
609    }
610}
611
612#[cfg(test)]
613mod tests {
614    use super::Matrix;
615
616    use ark_ff::{Fp64, MontBackend};
617
618    #[test]
619    fn test_matrix_utils_determinant() {
620        #[derive(ark_ff::MontConfig)]
621        #[modulus = "127"]
622        #[generator = "6"]
623        pub struct F127Config;
624        type F = Fp64<MontBackend<F127Config, 1>>;
625        let a: Matrix<F> = Matrix::new(vec![
626            vec![F::from(2), F::from(2)],
627            vec![F::from(3), F::from(4)],
628        ]);
629        let b: Matrix<F> = Matrix::new(vec![
630            vec![F::from(1), F::from(2), F::from(3)],
631            vec![F::from(4), F::from(5), F::from(6)],
632            vec![F::from(7), F::from(8), F::from(9)],
633        ]);
634
635        assert_eq!(a.determinant(), F::from(2));
636        assert_eq!(b.determinant(), F::from(0));
637    }
638
639    #[test]
640    fn test_matrix_utils_is_square() {
641        #[derive(ark_ff::MontConfig)]
642        #[modulus = "127"]
643        #[generator = "6"]
644        pub struct F127Config;
645        type F = Fp64<MontBackend<F127Config, 1>>;
646        let a: Matrix<F> = Matrix::new(vec![
647            vec![F::from(1), F::from(2)],
648            vec![F::from(3), F::from(4)],
649        ]);
650        let b = Matrix::new(vec![
651            vec![F::from(1), F::from(2)],
652            vec![F::from(3), F::from(4)],
653            vec![F::from(5), F::from(6)],
654        ]);
655
656        assert!(a.is_square());
657        assert!(!b.is_square());
658    }
659
660    #[test]
661    fn test_matrix_utils_is_diagonal() {
662        #[derive(ark_ff::MontConfig)]
663        #[modulus = "127"]
664        #[generator = "6"]
665        pub struct F127Config;
666        type F = Fp64<MontBackend<F127Config, 1>>;
667        let a: Matrix<F> = Matrix::new(vec![
668            vec![F::from(1), F::from(0)],
669            vec![F::from(0), F::from(4)],
670        ]);
671        let b = Matrix::new(vec![
672            vec![F::from(1), F::from(2)],
673            vec![F::from(3), F::from(4)],
674        ]);
675
676        assert!(a.is_diagonal());
677        assert!(!b.is_diagonal());
678    }
679
680    #[test]
681    fn test_matrix_utils_transpose() {
682        #[derive(ark_ff::MontConfig)]
683        #[modulus = "127"]
684        #[generator = "6"]
685        pub struct F127Config;
686        type F = Fp64<MontBackend<F127Config, 1>>;
687        let a: Matrix<F> = Matrix::new(vec![
688            vec![F::from(1), F::from(2)],
689            vec![F::from(3), F::from(4)],
690        ]);
691
692        let b = a.transpose();
693
694        assert_eq!(
695            b,
696            Matrix::new(vec![
697                vec![F::from(1), F::from(3)],
698                vec![F::from(2), F::from(4)]
699            ])
700        );
701    }
702
703    #[test]
704    fn test_matrix_utils_adjoint() {
705        #[derive(ark_ff::MontConfig)]
706        #[modulus = "127"]
707        #[generator = "6"]
708        pub struct F127Config;
709        type F = Fp64<MontBackend<F127Config, 1>>;
710        let a: Matrix<F> = Matrix::new(vec![
711            vec![F::from(2), F::from(3)],
712            vec![F::from(4), F::from(3)],
713        ]);
714
715        let b = a.adjoint();
716
717        assert_eq!(
718            b,
719            Matrix::new(vec![
720                vec![F::from(2), F::from(4)],
721                vec![F::from(3), F::from(3)]
722            ])
723        );
724    }
725
726    #[test]
727    fn test_matrix_utils_inverse() {
728        #[derive(ark_ff::MontConfig)]
729        #[modulus = "127"]
730        #[generator = "6"]
731        pub struct F127Config;
732        type F = Fp64<MontBackend<F127Config, 1>>;
733        let a: Matrix<F> = Matrix::new(vec![
734            vec![F::from(1), F::from(2)],
735            vec![F::from(3), F::from(7)],
736        ]);
737
738        let b = a.inverse().unwrap();
739
740        assert_eq!(
741            b,
742            Matrix::new(vec![
743                vec![F::from(7), -F::from(2)],
744                vec![-F::from(3), F::from(1)]
745            ])
746        );
747    }
748
749    #[test]
750    fn test_matrix_utils_is_identity() {
751        #[derive(ark_ff::MontConfig)]
752        #[modulus = "127"]
753        #[generator = "6"]
754        pub struct F127Config;
755        type F = Fp64<MontBackend<F127Config, 1>>;
756        let a: Matrix<F> = Matrix::new(vec![
757            vec![F::from(1), F::from(0)],
758            vec![F::from(0), F::from(1)],
759        ]);
760        let b = Matrix::new(vec![
761            vec![F::from(1), F::from(2)],
762            vec![F::from(3), F::from(4)],
763        ]);
764
765        assert!(a.is_identity());
766        assert!(!b.is_identity());
767    }
768
769    // Double check this test.
770    #[test]
771    fn test_matrix_utils_ax_b_solve_for_x() {
772        #[derive(ark_ff::MontConfig)]
773        #[modulus = "127"]
774        #[generator = "6"]
775        pub struct F127Config;
776        type F = Fp64<MontBackend<F127Config, 1>>;
777        let a: Matrix<F> = Matrix::new(vec![
778            vec![F::from(1), F::from(2)],
779            vec![F::from(3), F::from(4)],
780        ]);
781        let b = vec![F::from(1), F::from(2)];
782
783        let x = a.ax_b_solve_for_x(b);
784
785        assert_eq!(x, vec![F::from(1), F::from(1) / F::from(2)]);
786    }
787
788    #[test]
789    fn test_matrix_utils_lu_decomposition() {
790        #[derive(ark_ff::MontConfig)]
791        #[modulus = "127"]
792        #[generator = "6"]
793        pub struct F127Config;
794        type F = Fp64<MontBackend<F127Config, 1>>;
795        let a: Matrix<F> = Matrix::new(vec![
796            vec![F::from(1), F::from(2)],
797            vec![F::from(3), F::from(4)],
798        ]);
799
800        let (l, u) = a.lu_decomposition();
801
802        assert_eq!(
803            l,
804            Matrix::new(vec![
805                vec![F::from(1), F::from(0)],
806                vec![F::from(3), F::from(1)]
807            ])
808        );
809        assert_eq!(
810            u,
811            Matrix::new(vec![
812                vec![F::from(1), F::from(2)],
813                vec![F::from(0), -F::from(2)]
814            ])
815        );
816    }
817
818    #[test]
819    fn test_matrix_utils_matrix_addition() {
820        #[derive(ark_ff::MontConfig)]
821        #[modulus = "127"]
822        #[generator = "6"]
823        pub struct F127Config;
824        type F = Fp64<MontBackend<F127Config, 1>>;
825        let a: Matrix<F> = Matrix::new(vec![
826            vec![F::from(1), F::from(2)],
827            vec![F::from(3), F::from(4)],
828        ]);
829        let b: Matrix<F> = Matrix::new(vec![
830            vec![F::from(1), F::from(2)],
831            vec![F::from(3), F::from(4)],
832        ]);
833
834        let c = a + b;
835
836        assert_eq!(
837            c,
838            Matrix::new(vec![
839                vec![F::from(2), F::from(4)],
840                vec![F::from(6), F::from(8)]
841            ])
842        );
843    }
844
845    #[test]
846    fn test_matrix_utils_matrix_substraction() {
847        #[derive(ark_ff::MontConfig)]
848        #[modulus = "127"]
849        #[generator = "6"]
850        pub struct F127Config;
851        type F = Fp64<MontBackend<F127Config, 1>>;
852        let a: Matrix<F> = Matrix::new(vec![
853            vec![F::from(1), F::from(2)],
854            vec![F::from(3), F::from(4)],
855        ]);
856        let b: Matrix<F> = Matrix::new(vec![
857            vec![F::from(1), F::from(2)],
858            vec![F::from(3), F::from(4)],
859        ]);
860
861        let c = a - b;
862
863        assert_eq!(
864            c,
865            Matrix::new(vec![
866                vec![F::from(0), F::from(0)],
867                vec![F::from(0), F::from(0)]
868            ])
869        );
870    }
871
872    #[test]
873    fn test_matrix_utils_matrix_multiplication() {
874        #[derive(ark_ff::MontConfig)]
875        #[modulus = "127"]
876        #[generator = "6"]
877        pub struct F127Config;
878        type F = Fp64<MontBackend<F127Config, 1>>;
879        let a: Matrix<F> = Matrix::new(vec![
880            vec![F::from(1), F::from(2)],
881            vec![F::from(3), F::from(4)],
882        ]);
883        let b: Matrix<F> = Matrix::new(vec![
884            vec![F::from(2), F::from(3)],
885            vec![F::from(4), F::from(5)],
886        ]);
887
888        let c = a.clone() * (b.clone());
889        let d = b * (a.clone());
890
891        assert_eq!(
892            c,
893            Matrix::new(vec![
894                vec![F::from(10), F::from(13)],
895                vec![F::from(22), F::from(29)]
896            ])
897        );
898
899        assert_eq!(
900            d,
901            Matrix::new(vec![
902                vec![F::from(11), F::from(16)],
903                vec![F::from(19), F::from(28)]
904            ])
905        );
906    }
907
908    #[test]
909    fn test_matrix_utils_matrix_scalar_multiplication() {
910        #[derive(ark_ff::MontConfig)]
911        #[modulus = "127"]
912        #[generator = "6"]
913        pub struct F127Config;
914        type F = Fp64<MontBackend<F127Config, 1>>;
915        let a: Matrix<F> = Matrix::new(vec![
916            vec![F::from(1), F::from(2)],
917            vec![F::from(3), F::from(4)],
918        ]);
919
920        let b = a.scalar_mul(F::from(2));
921
922        assert_eq!(
923            b,
924            Matrix::new(vec![
925                vec![F::from(2), F::from(4)],
926                vec![F::from(6), F::from(8)]
927            ])
928        );
929    }
930
931    // Double check this test.
932    #[test]
933    fn test_matrix_utils_multiply_vec() {
934        #[derive(ark_ff::MontConfig)]
935        #[modulus = "127"]
936        #[generator = "6"]
937        pub struct F127Config;
938        type F = Fp64<MontBackend<F127Config, 1>>;
939        let a: Matrix<F> = Matrix::new(vec![
940            vec![F::from(1), F::from(2)],
941            vec![F::from(3), F::from(4)],
942        ]);
943        let b = vec![F::from(1), F::from(2)];
944
945        let c = a.mul_vec(b);
946
947        assert_eq!(c, vec![F::from(7), F::from(10)]);
948    }
949
950    #[test]
951    fn test_matrix_utils_sum_of_matrix() {
952        #[derive(ark_ff::MontConfig)]
953        #[modulus = "127"]
954        #[generator = "6"]
955        pub struct F127Config;
956        type F = Fp64<MontBackend<F127Config, 1>>;
957        let a: Matrix<F> = Matrix::new(vec![
958            vec![F::from(1), F::from(2)],
959            vec![F::from(3), F::from(4)],
960        ]);
961
962        let b = a.sum_of_matrix();
963
964        assert_eq!(b, F::from(10));
965    }
966
967    #[test]
968    fn test_matrix_utils_set_element() {
969        #[derive(ark_ff::MontConfig)]
970        #[modulus = "127"]
971        #[generator = "6"]
972        pub struct F127Config;
973        type F = Fp64<MontBackend<F127Config, 1>>;
974        let mut a: Matrix<F> = Matrix::new(vec![
975            vec![F::from(1), F::from(2)],
976            vec![F::from(3), F::from(4)],
977        ]);
978
979        a.set_element(0, 0, F::from(5));
980        a.set_element(1, 1, F::from(6));
981
982        assert_eq!(
983            a,
984            Matrix::new(vec![
985                vec![F::from(5), F::from(2)],
986                vec![F::from(3), F::from(6)]
987            ])
988        );
989    }
990
991    #[test]
992    fn test_matrix_utils_get_element() {
993        #[derive(ark_ff::MontConfig)]
994        #[modulus = "127"]
995        #[generator = "6"]
996        pub struct F127Config;
997        type F = Fp64<MontBackend<F127Config, 1>>;
998        let a: Matrix<F> = Matrix::new(vec![
999            vec![F::from(1), F::from(2)],
1000            vec![F::from(3), F::from(4)],
1001        ]);
1002
1003        assert_eq!(a.get_element(0, 0), F::from(1));
1004        assert_eq!(a.get_element(0, 1), F::from(2));
1005        assert_eq!(a.get_element(1, 0), F::from(3));
1006        assert_eq!(a.get_element(1, 1), F::from(4));
1007    }
1008}