algebraeon_rings/matrix/
matrix.rs

1use crate::structure::*;
2use algebraeon_nzq::Natural;
3use algebraeon_sets::structure::*;
4use std::{borrow::Borrow, marker::PhantomData};
5
6#[derive(Debug)]
7pub enum MatOppErr {
8    DimMismatch,
9    InvalidIndex,
10    NotSquare,
11    Singular,
12}
13
14#[derive(Debug, Clone)]
15pub struct Matrix<Set: Clone> {
16    dim1: usize,
17    dim2: usize,
18    transpose: bool,
19    flip_rows: bool,
20    flip_cols: bool,
21    elems: Vec<Set>, //length self.rows * self.cols. row r and column c is index c + r * self.cols
22}
23
24impl<Set: Clone> Matrix<Set> {
25    #[allow(unused)]
26    fn check_invariants(&self) -> Result<(), &'static str> {
27        if self.elems.len() != self.dim1 * self.dim2 {
28            return Err("matrix entries has the wrong length");
29        }
30        Ok(())
31    }
32
33    pub fn full(rows: usize, cols: usize, elem: &Set) -> Self {
34        let mut elems = Vec::with_capacity(rows * cols);
35        for _i in 0..rows * cols {
36            elems.push(elem.clone());
37        }
38        Self {
39            dim1: rows,
40            dim2: cols,
41            transpose: false,
42            flip_rows: false,
43            flip_cols: false,
44            elems,
45        }
46    }
47
48    /// Construct a matrix from a closure.
49    ///
50    /// ```rust
51    /// use algebraeon_nzq::Integer;
52    /// use algebraeon_rings::matrix::Matrix;
53    /// let a = Matrix::<Integer>::construct(2, 3, |r, c| if (r + c) % 2 == 0 { Integer::ZERO } else { Integer::ONE });
54    /// let b = Matrix::<Integer>::from_rows(
55    ///     vec![
56    ///         vec![Integer::ZERO, Integer::ONE, Integer::ZERO],
57    ///         vec![Integer::ONE, Integer::ZERO, Integer::ONE]
58    ///     ]
59    /// );
60    /// assert_eq!(a, b);
61    /// ```
62    pub fn construct(rows: usize, cols: usize, make_entry: impl Fn(usize, usize) -> Set) -> Self {
63        let mut elems = Vec::with_capacity(rows * cols);
64        for idx in 0..rows * cols {
65            let (r, c) = (idx / cols, idx % cols); //idx_to_rc for transpose=false
66            elems.push(make_entry(r, c).clone());
67        }
68        Self {
69            dim1: rows,
70            dim2: cols,
71            transpose: false,
72            flip_rows: false,
73            flip_cols: false,
74            elems,
75        }
76    }
77
78    /// Construct a matrix from a list of rows.
79    pub fn from_rows(rows_elems: Vec<Vec<impl Into<Set> + Clone>>) -> Self {
80        let rows = rows_elems.len();
81        assert!(rows >= 1);
82        let cols = rows_elems[0].len();
83        #[allow(clippy::needless_range_loop)]
84        for r in 1..rows {
85            assert_eq!(rows_elems[r].len(), cols);
86        }
87        Self::construct(rows, cols, |r, c| rows_elems[r][c].clone().into())
88    }
89
90    /// Construct a matrix from a list of columns.
91    pub fn from_cols(cols_elems: Vec<Vec<impl Into<Set> + Clone>>) -> Self {
92        Self::from_rows(cols_elems).transpose()
93    }
94
95    /// Construct a matrix from a row.
96    pub fn from_row(elems: Vec<impl Into<Set> + Clone>) -> Self {
97        Self::from_rows(vec![elems])
98    }
99
100    /// Construct a matrix from a column.
101    pub fn from_col(elems: Vec<impl Into<Set> + Clone>) -> Self {
102        Self::from_rows(vec![elems]).transpose()
103    }
104
105    fn rc_to_idx(&self, mut r: usize, mut c: usize) -> usize {
106        if self.flip_rows {
107            r = self.rows() - r - 1;
108        }
109        if self.flip_cols {
110            c = self.cols() - c - 1;
111        }
112        if self.transpose {
113            r + c * self.dim2
114        } else {
115            c + r * self.dim2
116        }
117    }
118
119    /// Get a reference to the entry at row `r` and column `c`.
120    pub fn at(&self, r: usize, c: usize) -> Result<&Set, MatOppErr> {
121        if r >= self.rows() || c >= self.cols() {
122            Err(MatOppErr::InvalidIndex)
123        } else {
124            let idx = self.rc_to_idx(r, c);
125            Ok(&self.elems[idx])
126        }
127    }
128
129    /// Get a mutable reference to the entry at row `r` and column `c`.
130    pub fn at_mut(&mut self, r: usize, c: usize) -> Result<&mut Set, MatOppErr> {
131        if r >= self.rows() || c >= self.cols() {
132            Err(MatOppErr::InvalidIndex)
133        } else {
134            let idx = self.rc_to_idx(r, c);
135            Ok(&mut self.elems[idx])
136        }
137    }
138
139    pub fn rows(&self) -> usize {
140        if self.transpose { self.dim2 } else { self.dim1 }
141    }
142
143    pub fn cols(&self) -> usize {
144        if self.transpose { self.dim1 } else { self.dim2 }
145    }
146
147    /// Return the submatrix given by the intersection of the rows defined by `rows` and the columns defined by `cols`.
148    pub fn submatrix(&self, rows: Vec<usize>, cols: Vec<usize>) -> Self {
149        let mut elems = vec![];
150        for r in &rows {
151            for c in &cols {
152                elems.push(self.at(*r, *c).unwrap().clone());
153            }
154        }
155        Matrix {
156            dim1: rows.len(),
157            dim2: cols.len(),
158            transpose: false,
159            flip_rows: false,
160            flip_cols: false,
161            elems,
162        }
163    }
164
165    pub fn get_row_submatrix(&self, row: usize) -> Self {
166        self.submatrix(vec![row], (0..self.cols()).collect())
167    }
168
169    pub fn get_col_submatrix(&self, col: usize) -> Self {
170        self.submatrix((0..self.rows()).collect(), vec![col])
171    }
172
173    pub fn get_row_refs(&self, row: usize) -> Vec<&Set> {
174        assert!(row < self.rows());
175        (0..self.cols()).map(|c| self.at(row, c).unwrap()).collect()
176    }
177
178    pub fn get_col_refs(&self, col: usize) -> Vec<&Set> {
179        assert!(col < self.cols());
180        (0..self.rows()).map(|r| self.at(r, col).unwrap()).collect()
181    }
182
183    pub fn get_row(&self, row: usize) -> Vec<Set> {
184        assert!(row < self.rows());
185        self.get_row_refs(row).into_iter().cloned().collect()
186    }
187
188    pub fn get_col(&self, col: usize) -> Vec<Set> {
189        assert!(col < self.cols());
190        self.get_col_refs(col).into_iter().cloned().collect()
191    }
192
193    /// Apply a function `f` to the entries of this matrix, producing a new matrix.
194    pub fn apply_map<NewSet: Clone>(&self, f: impl Fn(&Set) -> NewSet) -> Matrix<NewSet> {
195        Matrix {
196            dim1: self.dim1,
197            dim2: self.dim2,
198            transpose: self.transpose,
199            flip_rows: self.flip_rows,
200            flip_cols: self.flip_cols,
201            elems: self.elems.iter().map(f).collect(),
202        }
203    }
204
205    pub fn transpose(mut self) -> Self {
206        self.transpose_mut();
207        self
208    }
209    pub fn transpose_ref(&self) -> Self {
210        self.clone().transpose()
211    }
212    pub fn transpose_mut(&mut self) {
213        self.transpose = !self.transpose;
214        (self.flip_rows, self.flip_cols) = (self.flip_cols, self.flip_rows);
215    }
216
217    pub fn flip_rows(mut self) -> Self {
218        self.flip_rows_mut();
219        self
220    }
221    pub fn flip_rows_ref(&self) -> Self {
222        self.clone().flip_rows()
223    }
224    pub fn flip_rows_mut(&mut self) {
225        self.flip_rows = !self.flip_rows;
226    }
227
228    pub fn flip_cols(mut self) -> Self {
229        self.flip_cols_mut();
230        self
231    }
232    pub fn flip_cols_ref(&self) -> Self {
233        self.clone().flip_cols()
234    }
235    pub fn flip_cols_mut(&mut self) {
236        self.flip_cols = !self.flip_cols;
237    }
238
239    /// Concatenate the rows of the matrices in `mats` into a single matrix.
240    ///
241    /// `cols` must match the number of columns of every matrix in `mats`. The purpose of this input is to produce an empty matrix of the correct dimension when `mats` is empty.
242    ///
243    /// # Panics
244    ///
245    /// This function panics if `cols` does not match the number of columns of every matrix in `mats`.
246    pub fn join_rows<MatT: Borrow<Matrix<Set>>>(cols: usize, mats: Vec<MatT>) -> Matrix<Set> {
247        let mut rows = 0;
248        for mat in &mats {
249            assert_eq!(cols, mat.borrow().cols());
250            rows += mat.borrow().rows();
251        }
252        Matrix::construct(rows, cols, |r, c| {
253            //todo use a less cursed method
254            let mut row_offset = 0;
255            for mat in &mats {
256                for mr in 0..mat.borrow().rows() {
257                    for mc in 0..cols {
258                        if r == row_offset + mr && c == mc {
259                            return mat.borrow().at(mr, mc).unwrap().clone();
260                        }
261                    }
262                }
263                row_offset += mat.borrow().rows();
264            }
265            panic!();
266        })
267    }
268
269    /// Concatenate the columns of the matrices in `mats` into a single matrix.
270    ///
271    /// `rows` must match the number of rows of every matrix in `mats`. The purpose of this input is to produce an empty matrix of the correct dimension when `mats` is empty.
272    ///
273    /// # Panics
274    ///
275    /// This function panics if `rows` does not match the number of rows of every matrix in `mats`.
276    pub fn join_cols<MatT: Borrow<Matrix<Set>>>(rows: usize, mats: Vec<MatT>) -> Matrix<Set> {
277        let mut t_mats = vec![];
278        for mat in mats {
279            t_mats.push(mat.borrow().clone().transpose());
280        }
281        let joined = Self::join_rows(rows, t_mats.iter().collect());
282        joined.transpose()
283    }
284
285    /// Return a vector containing the entries of this matrix.
286    ///
287    /// Most useful when this matrix is a row vector or a column vector.
288    pub fn entries_list(&self) -> Vec<&Set> {
289        let mut entries = vec![];
290        for r in 0..self.rows() {
291            for c in 0..self.cols() {
292                entries.push(self.at(r, c).unwrap());
293            }
294        }
295        entries
296    }
297}
298
299#[derive(Debug, Clone, PartialEq, Eq)]
300pub struct MatrixStructure<RS: SetSignature, RSB: BorrowedStructure<RS>> {
301    _ring: PhantomData<RS>,
302    ring: RSB,
303}
304
305impl<RS: SetSignature, RSB: BorrowedStructure<RS>> Signature for MatrixStructure<RS, RSB> {}
306
307impl<RS: SetSignature, RSB: BorrowedStructure<RS>> SetSignature for MatrixStructure<RS, RSB> {
308    type Set = Matrix<RS::Set>;
309
310    fn is_element(&self, _x: &Self::Set) -> Result<(), String> {
311        Ok(())
312    }
313}
314
315impl<RS: SetSignature, RSB: BorrowedStructure<RS>> MatrixStructure<RS, RSB> {
316    pub fn new(ring: RSB) -> Self {
317        Self {
318            _ring: PhantomData,
319            ring,
320        }
321    }
322
323    pub fn ring(&self) -> &RS {
324        self.ring.borrow()
325    }
326}
327
328pub trait RingMatricesSignature: SetSignature {
329    fn matrices(&self) -> MatrixStructure<Self, &Self> {
330        MatrixStructure::new(self)
331    }
332
333    fn into_matrices(self) -> MatrixStructure<Self, Self> {
334        MatrixStructure::new(self)
335    }
336}
337
338impl<RS: SetSignature> RingMatricesSignature for RS {}
339
340impl<RS: EqSignature, RSB: BorrowedStructure<RS>> MatrixStructure<RS, RSB> {
341    pub fn equal(&self, a: &Matrix<RS::Set>, b: &Matrix<RS::Set>) -> bool {
342        let rows = a.rows();
343        let cols = a.cols();
344        if rows != b.rows() || cols != b.cols() {
345            false
346        } else {
347            for c in 0..cols {
348                for r in 0..rows {
349                    if !self.ring().equal(a.at(r, c).unwrap(), b.at(r, c).unwrap()) {
350                        return false;
351                    }
352                }
353            }
354            true
355        }
356    }
357}
358
359impl<RS: ToStringSignature, RSB: BorrowedStructure<RS>> MatrixStructure<RS, RSB> {
360    pub fn pprint(&self, mat: &Matrix<RS::Set>) {
361        let mut str_rows = vec![];
362        for r in 0..mat.rows() {
363            str_rows.push(vec![]);
364            for c in 0..mat.cols() {
365                str_rows[r].push(self.ring().to_string(mat.at(r, c).unwrap()));
366            }
367        }
368        #[allow(clippy::redundant_closure_for_method_calls)]
369        let cols_widths: Vec<usize> = (0..mat.cols())
370            .map(|c| {
371                (0..mat.rows())
372                    .map(|r| str_rows[r][c].chars().count())
373                    .fold(0usize, |a, b| a.max(b))
374            })
375            .collect();
376
377        #[allow(clippy::needless_range_loop)]
378        for r in 0..mat.rows() {
379            for c in 0..mat.cols() {
380                while str_rows[r][c].chars().count() < cols_widths[c] {
381                    str_rows[r][c].push(' ');
382                }
383                debug_assert_eq!(str_rows[r][c].chars().count(), cols_widths[c]);
384            }
385        }
386
387        #[allow(clippy::needless_range_loop)]
388        for r in 0..mat.rows() {
389            if mat.rows() == 1 {
390                print!("( ");
391            } else if r == 0 {
392                print!("/ ");
393            } else if r == mat.rows() - 1 {
394                print!("\\ ");
395            } else {
396                print!("| ");
397            }
398            for c in 0..mat.cols() {
399                if c != 0 {
400                    print!("    ");
401                }
402                print!("{}", str_rows[r][c]);
403            }
404            if mat.rows() == 1 {
405                print!(" )");
406            } else if r == 0 {
407                print!(" \\");
408            } else if r == mat.rows() - 1 {
409                print!(" /");
410            } else {
411                print!(" |");
412            }
413            println!();
414        }
415    }
416}
417
418impl<RS: RingSignature, RSB: BorrowedStructure<RS>> MatrixStructure<RS, RSB> {
419    pub fn zero(&self, rows: usize, cols: usize) -> Matrix<RS::Set> {
420        Matrix::construct(rows, cols, |_r, _c| self.ring().zero())
421    }
422
423    pub fn ident(&self, n: usize) -> Matrix<RS::Set> {
424        Matrix::construct(n, n, |r, c| {
425            if r == c {
426                self.ring().one()
427            } else {
428                self.ring().zero()
429            }
430        })
431    }
432
433    pub fn diag(&self, diag: &[RS::Set]) -> Matrix<RS::Set> {
434        Matrix::construct(diag.len(), diag.len(), |r, c| {
435            if r == c {
436                diag[r].clone()
437            } else {
438                self.ring().zero()
439            }
440        })
441    }
442
443    pub fn join_diag<MatT: Borrow<Matrix<RS::Set>>>(&self, mats: Vec<MatT>) -> Matrix<RS::Set> {
444        if mats.is_empty() {
445            Matrix::construct(0, 0, |_r, _c| unreachable!())
446        } else if mats.len() == 1 {
447            mats[0].borrow().clone()
448        } else {
449            let i = mats.len() / 2;
450            let (first, last) = mats.split_at(i);
451            #[allow(clippy::redundant_closure_for_method_calls)]
452            let first = self.join_diag(first.iter().map(|m| m.borrow()).collect());
453            #[allow(clippy::redundant_closure_for_method_calls)]
454            let last = self.join_diag(last.iter().map(|m| m.borrow()).collect());
455            Matrix::construct(
456                first.rows() + last.rows(),
457                first.cols() + last.cols(),
458                |r, c| {
459                    if r < first.rows() && c < first.cols() {
460                        first.at(r, c).unwrap().clone()
461                    } else if first.rows() <= r && first.cols() <= c {
462                        last.at(r - first.rows(), c - first.cols()).unwrap().clone()
463                    } else {
464                        self.ring().zero()
465                    }
466                },
467            )
468        }
469    }
470
471    pub fn dot(&self, a: &Matrix<RS::Set>, b: &Matrix<RS::Set>) -> RS::Set {
472        let rows = a.rows();
473        let cols = a.cols();
474        assert_eq!(rows, b.rows());
475        assert_eq!(cols, b.cols());
476        let mut tot = self.ring().zero();
477        for r in 0..rows {
478            for c in 0..cols {
479                self.ring().add_mut(
480                    &mut tot,
481                    &self.ring().mul(a.at(r, c).unwrap(), b.at(r, c).unwrap()),
482                );
483            }
484        }
485        tot
486    }
487
488    pub fn add_mut(&self, a: &mut Matrix<RS::Set>, b: &Matrix<RS::Set>) -> Result<(), MatOppErr> {
489        if a.rows() != b.rows() || a.cols() != b.cols() {
490            Err(MatOppErr::DimMismatch)
491        } else {
492            let rows = a.rows();
493            let cols = a.cols();
494            for c in 0..cols {
495                for r in 0..rows {
496                    self.ring()
497                        .add_mut(a.at_mut(r, c).unwrap(), b.at(r, c).unwrap());
498                }
499            }
500            Ok(())
501        }
502    }
503
504    pub fn add(
505        &self,
506        a: &Matrix<RS::Set>,
507        b: &Matrix<RS::Set>,
508    ) -> Result<Matrix<RS::Set>, MatOppErr> {
509        let mut new_a = a.clone();
510        match self.add_mut(&mut new_a, b) {
511            Ok(()) => Ok(new_a),
512            Err(e) => Err(e),
513        }
514    }
515
516    pub fn neg_mut(&self, a: &mut Matrix<RS::Set>) {
517        for r in 0..a.rows() {
518            for c in 0..a.cols() {
519                let neg_elem = self.ring().neg(a.at(r, c).unwrap());
520                *a.at_mut(r, c).unwrap() = neg_elem;
521            }
522        }
523    }
524
525    pub fn neg(&self, mut a: Matrix<RS::Set>) -> Matrix<RS::Set> {
526        self.neg_mut(&mut a);
527        a
528    }
529
530    pub fn mul(
531        &self,
532        a: &Matrix<RS::Set>,
533        b: &Matrix<RS::Set>,
534    ) -> Result<Matrix<RS::Set>, MatOppErr> {
535        let mids = a.cols();
536        if mids != b.rows() {
537            return Err(MatOppErr::DimMismatch);
538        }
539        let rows = a.rows();
540        let cols = b.cols();
541        let mut s = self.zero(rows, cols);
542        for r in 0..rows {
543            for c in 0..cols {
544                for m in 0..mids {
545                    self.ring().add_mut(
546                        s.at_mut(r, c).unwrap(),
547                        &self.ring().mul(a.at(r, m).unwrap(), b.at(m, c).unwrap()),
548                    );
549                }
550            }
551        }
552        Ok(s)
553    }
554
555    pub fn apply_row(&self, mat: &Matrix<RS::Set>, row: &[RS::Set]) -> Vec<RS::Set> {
556        assert_eq!(mat.rows(), row.len());
557        (0..mat.cols())
558            .map(|c| {
559                self.ring().sum(
560                    (0..mat.rows())
561                        .map(|r| self.ring().mul(mat.at(r, c).unwrap(), &row[r]))
562                        .collect(),
563                )
564            })
565            .collect()
566    }
567
568    pub fn apply_col(&self, mat: &Matrix<RS::Set>, col: &[RS::Set]) -> Vec<RS::Set> {
569        assert_eq!(mat.cols(), col.len());
570        (0..mat.rows())
571            .map(|r| {
572                self.ring().sum(
573                    (0..mat.cols())
574                        .map(|c| self.ring().mul(mat.at(r, c).unwrap(), &col[c]))
575                        .collect(),
576                )
577            })
578            .collect()
579    }
580
581    pub fn mul_scalar(&self, mut a: Matrix<RS::Set>, scalar: &RS::Set) -> Matrix<RS::Set> {
582        for r in 0..a.rows() {
583            for c in 0..a.cols() {
584                self.ring().mul_mut(a.at_mut(r, c).unwrap(), scalar);
585            }
586        }
587        a
588    }
589
590    pub fn mul_scalar_ref(&self, a: &Matrix<RS::Set>, scalar: &RS::Set) -> Matrix<RS::Set> {
591        self.mul_scalar(a.clone(), scalar)
592    }
593
594    pub fn det_naive(&self, a: &Matrix<RS::Set>) -> Result<RS::Set, MatOppErr> {
595        let n = a.rows();
596        if n == a.cols() {
597            let mut det = self.ring().zero();
598            for perm in algebraeon_groups::permutation::Permutation::all_permutations(n) {
599                let mut prod = self.ring().one();
600                for k in 0..n {
601                    self.ring()
602                        .mul_mut(&mut prod, a.at(k, perm.call(k)).unwrap());
603                }
604                match perm.sign() {
605                    algebraeon_groups::examples::c2::C2::Identity => {}
606                    algebraeon_groups::examples::c2::C2::Flip => {
607                        prod = self.ring().neg(&prod);
608                    }
609                }
610
611                self.ring().add_mut(&mut det, &prod);
612            }
613            Ok(det)
614        } else {
615            Err(MatOppErr::NotSquare)
616        }
617    }
618
619    pub fn trace(&self, a: &Matrix<RS::Set>) -> Result<RS::Set, MatOppErr> {
620        let n = a.rows();
621        if n == a.cols() {
622            Ok(self
623                .ring()
624                .sum((0..n).map(|i| a.at(i, i).unwrap()).collect()))
625        } else {
626            Err(MatOppErr::NotSquare)
627        }
628    }
629
630    pub fn nat_pow(&self, a: &Matrix<RS::Set>, k: &Natural) -> Result<Matrix<RS::Set>, MatOppErr> {
631        let n = a.rows();
632        if n != a.cols() {
633            Err(MatOppErr::NotSquare)
634        } else if *k == Natural::ZERO {
635            Ok(self.ident(n))
636        } else if *k == Natural::ONE {
637            Ok(a.clone())
638        } else {
639            debug_assert!(*k >= Natural::TWO);
640            let bits: Vec<_> = k.bits().collect();
641            let mut pows = vec![a.clone()];
642            while pows.len() < bits.len() {
643                pows.push(
644                    self.mul(pows.last().unwrap(), pows.last().unwrap())
645                        .unwrap(),
646                );
647            }
648            let count = bits.len();
649            debug_assert_eq!(count, pows.len());
650            let mut ans = self.ident(n);
651            for i in 0..count {
652                if bits[i] {
653                    ans = self.mul(&ans, &pows[i]).unwrap();
654                }
655            }
656            Ok(ans)
657        }
658    }
659}
660
661impl<R: MetaType> MetaType for Matrix<R>
662where
663    R::Signature: SetSignature,
664{
665    type Signature = MatrixStructure<R::Signature, R::Signature>;
666
667    fn structure() -> Self::Signature {
668        MatrixStructure::new(R::structure())
669    }
670}
671
672impl<R: MetaType> Matrix<R>
673where
674    R::Signature: ToStringSignature,
675{
676    pub fn pprint(&self) {
677        Self::structure().pprint(self);
678    }
679}
680
681impl<R: MetaType> PartialEq for Matrix<R>
682where
683    R::Signature: RingSignature + EqSignature,
684{
685    fn eq(&self, other: &Self) -> bool {
686        Self::structure().equal(self, other)
687    }
688}
689
690impl<R: MetaType> Eq for Matrix<R> where R::Signature: RingSignature + EqSignature {}
691
692impl<R: MetaType> Matrix<R>
693where
694    R::Signature: RingSignature,
695{
696    pub fn zero(rows: usize, cols: usize) -> Self {
697        Self::structure().zero(rows, cols)
698    }
699
700    pub fn ident(n: usize) -> Self {
701        Self::structure().ident(n)
702    }
703
704    pub fn diag(diag: &[R]) -> Self {
705        Self::structure().diag(diag)
706    }
707
708    pub fn dot(a: &Self, b: &Self) -> R {
709        Self::structure().dot(a, b)
710    }
711
712    pub fn add_mut(&mut self, b: &Self) -> Result<(), MatOppErr> {
713        Self::structure().add_mut(self, b)
714    }
715
716    pub fn add(a: &Self, b: &Self) -> Result<Self, MatOppErr> {
717        Self::structure().add(a, b)
718    }
719
720    pub fn neg_mut(&mut self) {
721        Self::structure().neg_mut(self);
722    }
723
724    pub fn neg(&self) -> Self {
725        Self::structure().neg(self.clone())
726    }
727
728    pub fn mul(a: &Self, b: &Self) -> Result<Self, MatOppErr> {
729        Self::structure().mul(a, b)
730    }
731
732    pub fn apply_row(&self, row: &[R]) -> Vec<R> {
733        Self::structure().apply_row(self, row)
734    }
735
736    pub fn apply_col(&self, col: &[R]) -> Vec<R> {
737        Self::structure().apply_col(self, col)
738    }
739
740    pub fn mul_scalar(&self, scalar: &R) -> Matrix<R> {
741        Self::structure().mul_scalar(self.clone(), scalar)
742    }
743
744    pub fn mul_scalar_ref(&self, scalar: &R) -> Matrix<R> {
745        Self::structure().mul_scalar_ref(self, scalar)
746    }
747
748    pub fn det_naive(&self) -> Result<R, MatOppErr> {
749        Self::structure().det_naive(self)
750    }
751
752    pub fn trace(&self) -> Result<R, MatOppErr> {
753        Self::structure().trace(self)
754    }
755}
756
757#[cfg(test)]
758mod tests {
759    use algebraeon_nzq::Integer;
760
761    use super::*;
762
763    #[test]
764    fn test_join_rows() {
765        let top = Matrix::<Integer>::from_rows(vec![vec![1, 2, 3], vec![4, 5, 6]]);
766        let bot = Matrix::from_rows(vec![vec![7, 8, 9]]);
767
768        let both = Matrix::from_rows(vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]);
769
770        println!("top");
771        top.pprint();
772        println!("bot");
773        bot.pprint();
774        println!("both");
775        both.pprint();
776
777        let ans = Matrix::join_rows(3, vec![top, bot]);
778        println!("ans");
779        ans.pprint();
780
781        assert_eq!(ans, both);
782    }
783
784    #[test]
785    fn invariants() {
786        let m = Matrix {
787            dim1: 3,
788            dim2: 4,
789            transpose: false,
790            flip_rows: false,
791            flip_cols: false,
792            elems: vec![
793                Integer::from(1),
794                Integer::from(2),
795                Integer::from(3),
796                Integer::from(4),
797                Integer::from(5),
798            ],
799        };
800        if let Ok(()) = m.check_invariants() {
801            panic!();
802        }
803
804        let m = Matrix {
805            dim1: 2,
806            dim2: 3,
807            transpose: true,
808            flip_rows: false,
809            flip_cols: false,
810            elems: vec![
811                Integer::from(1),
812                Integer::from(2),
813                Integer::from(3),
814                Integer::from(4),
815                Integer::from(5),
816                Integer::from(6),
817            ],
818        };
819        m.check_invariants().unwrap();
820    }
821
822    #[test]
823    fn transpose_eq() {
824        let a = Matrix {
825            dim1: 2,
826            dim2: 2,
827            transpose: false,
828            flip_rows: false,
829            flip_cols: false,
830            elems: vec![
831                Integer::from(0),
832                Integer::from(1),
833                Integer::from(2),
834                Integer::from(3),
835            ],
836        };
837        a.check_invariants().unwrap();
838
839        let b = Matrix {
840            dim1: 2,
841            dim2: 2,
842            transpose: true,
843            flip_rows: false,
844            flip_cols: false,
845            elems: vec![
846                Integer::from(0),
847                Integer::from(2),
848                Integer::from(1),
849                Integer::from(3),
850            ],
851        };
852        b.check_invariants().unwrap();
853
854        assert_eq!(a, b);
855    }
856
857    #[test]
858    fn flip_axes_eq() {
859        let mut a = Matrix::<Integer>::from_rows(vec![vec![1, 2], vec![3, 4]]);
860        a.pprint();
861        println!("flip rows");
862        a.flip_rows_mut();
863        a.pprint();
864        assert_eq!(
865            a,
866            Matrix::from_rows(vec![
867                vec![Integer::from(3), Integer::from(4)],
868                vec![Integer::from(1), Integer::from(2)],
869            ])
870        );
871        println!("transpose");
872        a.transpose_mut();
873        a.pprint();
874        assert_eq!(
875            a,
876            Matrix::from_rows(vec![
877                vec![Integer::from(3), Integer::from(1)],
878                vec![Integer::from(4), Integer::from(2)],
879            ])
880        );
881        println!("flip rows");
882        a.flip_rows_mut();
883        a.pprint();
884        assert_eq!(
885            a,
886            Matrix::from_rows(vec![
887                vec![Integer::from(4), Integer::from(2)],
888                vec![Integer::from(3), Integer::from(1)],
889            ])
890        );
891        println!("flip cols");
892        a.flip_cols_mut();
893        a.pprint();
894        assert_eq!(
895            a,
896            Matrix::from_rows(vec![
897                vec![Integer::from(2), Integer::from(4)],
898                vec![Integer::from(1), Integer::from(3)],
899            ])
900        );
901        println!("transpose");
902        a.transpose_mut();
903        a.pprint();
904        assert_eq!(
905            a,
906            Matrix::from_rows(vec![
907                vec![Integer::from(2), Integer::from(1)],
908                vec![Integer::from(4), Integer::from(3)],
909            ])
910        );
911        println!("flip cols");
912        a.flip_cols_mut();
913        a.pprint();
914        assert_eq!(
915            a,
916            Matrix::from_rows(vec![
917                vec![Integer::from(1), Integer::from(2)],
918                vec![Integer::from(3), Integer::from(4)],
919            ])
920        );
921    }
922
923    #[test]
924    fn add() {
925        {
926            let mut a = Matrix {
927                dim1: 2,
928                dim2: 3,
929                transpose: false,
930                flip_rows: false,
931                flip_cols: false,
932                elems: vec![
933                    Integer::from(1),
934                    Integer::from(2),
935                    Integer::from(3),
936                    Integer::from(4),
937                    Integer::from(5),
938                    Integer::from(6),
939                ],
940            };
941            a.check_invariants().unwrap();
942
943            let b = Matrix {
944                dim1: 2,
945                dim2: 3,
946                transpose: false,
947                flip_rows: false,
948                flip_cols: false,
949                elems: vec![
950                    Integer::from(1),
951                    Integer::from(2),
952                    Integer::from(1),
953                    Integer::from(2),
954                    Integer::from(1),
955                    Integer::from(2),
956                ],
957            };
958            b.check_invariants().unwrap();
959
960            let c = Matrix {
961                dim1: 2,
962                dim2: 3,
963                transpose: false,
964                flip_rows: false,
965                flip_cols: false,
966                elems: vec![
967                    Integer::from(2),
968                    Integer::from(4),
969                    Integer::from(4),
970                    Integer::from(6),
971                    Integer::from(6),
972                    Integer::from(8),
973                ],
974            };
975            c.check_invariants().unwrap();
976
977            a.add_mut(&b).unwrap();
978
979            assert_eq!(a, c);
980        }
981
982        {
983            let mut a = Matrix {
984                dim1: 3,
985                dim2: 2,
986                transpose: false,
987                flip_rows: false,
988                flip_cols: false,
989                elems: vec![
990                    Integer::from(1),
991                    Integer::from(2),
992                    Integer::from(3),
993                    Integer::from(4),
994                    Integer::from(5),
995                    Integer::from(6),
996                ],
997            };
998            a.check_invariants().unwrap();
999
1000            let b = Matrix {
1001                dim1: 2,
1002                dim2: 3,
1003                transpose: true,
1004                flip_rows: false,
1005                flip_cols: false,
1006                elems: vec![
1007                    Integer::from(10),
1008                    Integer::from(20),
1009                    Integer::from(30),
1010                    Integer::from(40),
1011                    Integer::from(50),
1012                    Integer::from(60),
1013                ],
1014            };
1015            b.check_invariants().unwrap();
1016
1017            let c = Matrix {
1018                dim1: 3,
1019                dim2: 2,
1020                transpose: false,
1021                flip_rows: false,
1022                flip_cols: false,
1023                elems: vec![
1024                    Integer::from(11),
1025                    Integer::from(42),
1026                    Integer::from(23),
1027                    Integer::from(54),
1028                    Integer::from(35),
1029                    Integer::from(66),
1030                ],
1031            };
1032            c.check_invariants().unwrap();
1033
1034            a.add_mut(&b).unwrap();
1035
1036            assert_eq!(a, c);
1037        }
1038
1039        {
1040            let mut a = Matrix {
1041                dim1: 3,
1042                dim2: 2,
1043                transpose: false,
1044                flip_rows: false,
1045                flip_cols: false,
1046                elems: vec![
1047                    Integer::from(1),
1048                    Integer::from(2),
1049                    Integer::from(3),
1050                    Integer::from(4),
1051                    Integer::from(5),
1052                    Integer::from(6),
1053                ],
1054            };
1055            a.check_invariants().unwrap();
1056
1057            let b = Matrix {
1058                dim1: 2,
1059                dim2: 3,
1060                transpose: false,
1061                flip_rows: false,
1062                flip_cols: false,
1063                elems: vec![
1064                    Integer::from(1),
1065                    Integer::from(2),
1066                    Integer::from(1),
1067                    Integer::from(2),
1068                    Integer::from(1),
1069                    Integer::from(2),
1070                ],
1071            };
1072            b.check_invariants().unwrap();
1073
1074            match a.add_mut(&b) {
1075                Ok(()) => panic!(),
1076                Err(MatOppErr::DimMismatch) => {}
1077                Err(_) => panic!(),
1078            }
1079        }
1080
1081        {
1082            let a = Matrix {
1083                dim1: 2,
1084                dim2: 3,
1085                transpose: false,
1086                flip_rows: false,
1087                flip_cols: false,
1088                elems: vec![
1089                    Integer::from(1),
1090                    Integer::from(2),
1091                    Integer::from(3),
1092                    Integer::from(4),
1093                    Integer::from(5),
1094                    Integer::from(6),
1095                ],
1096            };
1097            a.check_invariants().unwrap();
1098
1099            let b = Matrix {
1100                dim1: 2,
1101                dim2: 3,
1102                transpose: false,
1103                flip_rows: false,
1104                flip_cols: false,
1105                elems: vec![
1106                    Integer::from(1),
1107                    Integer::from(2),
1108                    Integer::from(1),
1109                    Integer::from(2),
1110                    Integer::from(1),
1111                    Integer::from(2),
1112                ],
1113            };
1114            b.check_invariants().unwrap();
1115
1116            let c = Matrix {
1117                dim1: 2,
1118                dim2: 3,
1119                transpose: false,
1120                flip_rows: false,
1121                flip_cols: false,
1122                elems: vec![
1123                    Integer::from(2),
1124                    Integer::from(4),
1125                    Integer::from(4),
1126                    Integer::from(6),
1127                    Integer::from(6),
1128                    Integer::from(8),
1129                ],
1130            };
1131            c.check_invariants().unwrap();
1132
1133            assert_eq!(Matrix::add(&a, &b).unwrap(), c);
1134        }
1135    }
1136
1137    #[test]
1138    fn mul() {
1139        {
1140            let a = Matrix {
1141                dim1: 2,
1142                dim2: 4,
1143                transpose: false,
1144                flip_rows: false,
1145                flip_cols: false,
1146                elems: vec![
1147                    Integer::from(3),
1148                    Integer::from(2),
1149                    Integer::from(1),
1150                    Integer::from(5),
1151                    Integer::from(9),
1152                    Integer::from(1),
1153                    Integer::from(3),
1154                    Integer::from(0),
1155                ],
1156            };
1157            a.check_invariants().unwrap();
1158
1159            let b = Matrix {
1160                dim1: 4,
1161                dim2: 3,
1162                transpose: false,
1163                flip_rows: false,
1164                flip_cols: false,
1165                elems: vec![
1166                    Integer::from(2),
1167                    Integer::from(9),
1168                    Integer::from(0),
1169                    Integer::from(1),
1170                    Integer::from(3),
1171                    Integer::from(5),
1172                    Integer::from(2),
1173                    Integer::from(4),
1174                    Integer::from(7),
1175                    Integer::from(8),
1176                    Integer::from(1),
1177                    Integer::from(5),
1178                ],
1179            };
1180            b.check_invariants().unwrap();
1181
1182            let c = Matrix {
1183                dim1: 2,
1184                dim2: 3,
1185                transpose: false,
1186                flip_rows: false,
1187                flip_cols: false,
1188                elems: vec![
1189                    Integer::from(50),
1190                    Integer::from(42),
1191                    Integer::from(42),
1192                    Integer::from(25),
1193                    Integer::from(96),
1194                    Integer::from(26),
1195                ],
1196            };
1197            c.check_invariants().unwrap();
1198
1199            assert_eq!(Matrix::mul(&a, &b).unwrap(), c);
1200        }
1201    }
1202
1203    #[test]
1204    fn matrix_apply_row_and_col_test() {
1205        let m = Matrix::<Integer>::from_rows(vec![
1206            vec![Integer::from(1), Integer::from(2), Integer::from(3)],
1207            vec![Integer::from(6), Integer::from(5), Integer::from(4)],
1208        ]);
1209
1210        assert_eq!(
1211            m.apply_row(&[Integer::from(1), Integer::from(0)]),
1212            vec![Integer::from(1), Integer::from(2), Integer::from(3)]
1213        );
1214
1215        assert_eq!(
1216            m.apply_row(&[Integer::from(0), Integer::from(1)]),
1217            vec![Integer::from(6), Integer::from(5), Integer::from(4)]
1218        );
1219
1220        assert_eq!(
1221            m.apply_row(&[Integer::from(1), Integer::from(1)]),
1222            vec![Integer::from(7), Integer::from(7), Integer::from(7)]
1223        );
1224
1225        assert_eq!(
1226            m.apply_col(&[Integer::from(1), Integer::from(0), Integer::from(0)]),
1227            vec![Integer::from(1), Integer::from(6)]
1228        );
1229
1230        assert_eq!(
1231            m.apply_col(&[Integer::from(0), Integer::from(1), Integer::from(0)]),
1232            vec![Integer::from(2), Integer::from(5)]
1233        );
1234
1235        assert_eq!(
1236            m.apply_col(&[Integer::from(0), Integer::from(0), Integer::from(1)]),
1237            vec![Integer::from(3), Integer::from(4)]
1238        );
1239
1240        assert_eq!(
1241            m.apply_col(&[Integer::from(1), Integer::from(1), Integer::from(1)]),
1242            vec![Integer::from(6), Integer::from(15)]
1243        );
1244    }
1245
1246    #[test]
1247    fn det_naive() {
1248        let m = Matrix::<Integer>::from_rows(vec![
1249            vec![Integer::from(1), Integer::from(3)],
1250            vec![Integer::from(4), Integer::from(2)],
1251        ]);
1252        println!("{}", m.det_naive().unwrap());
1253        assert_eq!(m.det_naive().unwrap(), Integer::from(-10));
1254
1255        let m = Matrix::<Integer>::from_rows(vec![
1256            vec![Integer::from(1), Integer::from(3), Integer::from(2)],
1257            vec![Integer::from(-3), Integer::from(-1), Integer::from(-3)],
1258            vec![Integer::from(2), Integer::from(3), Integer::from(1)],
1259        ]);
1260        println!("{}", m.det_naive().unwrap());
1261        assert_eq!(m.det_naive().unwrap(), Integer::from(-15));
1262    }
1263}