1use crate::structure::*;
2use algebraeon_nzq::Natural;
3use algebraeon_sets::structure::*;
4use std::{borrow::Borrow, marker::PhantomData};
5
6#[derive(Debug)]
7pub enum MatOppErr {
8 DimMismatch,
9 InvalidIndex,
10 NotSquare,
11 Singular,
12}
13
14#[derive(Debug, Clone)]
15pub struct Matrix<Set: Clone> {
16 dim1: usize,
17 dim2: usize,
18 transpose: bool,
19 flip_rows: bool,
20 flip_cols: bool,
21 elems: Vec<Set>, }
23
24impl<Set: Clone> Matrix<Set> {
25 #[allow(unused)]
26 fn check_invariants(&self) -> Result<(), &'static str> {
27 if self.elems.len() != self.dim1 * self.dim2 {
28 return Err("matrix entries has the wrong length");
29 }
30 Ok(())
31 }
32
33 pub fn full(rows: usize, cols: usize, elem: &Set) -> Self {
34 let mut elems = Vec::with_capacity(rows * cols);
35 for _i in 0..rows * cols {
36 elems.push(elem.clone());
37 }
38 Self {
39 dim1: rows,
40 dim2: cols,
41 transpose: false,
42 flip_rows: false,
43 flip_cols: false,
44 elems,
45 }
46 }
47
48 pub fn construct(rows: usize, cols: usize, make_entry: impl Fn(usize, usize) -> Set) -> Self {
63 let mut elems = Vec::with_capacity(rows * cols);
64 for idx in 0..rows * cols {
65 let (r, c) = (idx / cols, idx % cols); elems.push(make_entry(r, c).clone());
67 }
68 Self {
69 dim1: rows,
70 dim2: cols,
71 transpose: false,
72 flip_rows: false,
73 flip_cols: false,
74 elems,
75 }
76 }
77
78 pub fn from_rows(rows_elems: Vec<Vec<impl Into<Set> + Clone>>) -> Self {
80 let rows = rows_elems.len();
81 assert!(rows >= 1);
82 let cols = rows_elems[0].len();
83 #[allow(clippy::needless_range_loop)]
84 for r in 1..rows {
85 assert_eq!(rows_elems[r].len(), cols);
86 }
87 Self::construct(rows, cols, |r, c| rows_elems[r][c].clone().into())
88 }
89
90 pub fn from_cols(cols_elems: Vec<Vec<impl Into<Set> + Clone>>) -> Self {
92 Self::from_rows(cols_elems).transpose()
93 }
94
95 pub fn from_row(elems: Vec<impl Into<Set> + Clone>) -> Self {
97 Self::from_rows(vec![elems])
98 }
99
100 pub fn from_col(elems: Vec<impl Into<Set> + Clone>) -> Self {
102 Self::from_rows(vec![elems]).transpose()
103 }
104
105 fn rc_to_idx(&self, mut r: usize, mut c: usize) -> usize {
106 if self.flip_rows {
107 r = self.rows() - r - 1;
108 }
109 if self.flip_cols {
110 c = self.cols() - c - 1;
111 }
112 if self.transpose {
113 r + c * self.dim2
114 } else {
115 c + r * self.dim2
116 }
117 }
118
119 pub fn at(&self, r: usize, c: usize) -> Result<&Set, MatOppErr> {
121 if r >= self.rows() || c >= self.cols() {
122 Err(MatOppErr::InvalidIndex)
123 } else {
124 let idx = self.rc_to_idx(r, c);
125 Ok(&self.elems[idx])
126 }
127 }
128
129 pub fn at_mut(&mut self, r: usize, c: usize) -> Result<&mut Set, MatOppErr> {
131 if r >= self.rows() || c >= self.cols() {
132 Err(MatOppErr::InvalidIndex)
133 } else {
134 let idx = self.rc_to_idx(r, c);
135 Ok(&mut self.elems[idx])
136 }
137 }
138
139 pub fn rows(&self) -> usize {
140 if self.transpose { self.dim2 } else { self.dim1 }
141 }
142
143 pub fn cols(&self) -> usize {
144 if self.transpose { self.dim1 } else { self.dim2 }
145 }
146
147 pub fn submatrix(&self, rows: Vec<usize>, cols: Vec<usize>) -> Self {
149 let mut elems = vec![];
150 for r in &rows {
151 for c in &cols {
152 elems.push(self.at(*r, *c).unwrap().clone());
153 }
154 }
155 Matrix {
156 dim1: rows.len(),
157 dim2: cols.len(),
158 transpose: false,
159 flip_rows: false,
160 flip_cols: false,
161 elems,
162 }
163 }
164
165 pub fn get_row_submatrix(&self, row: usize) -> Self {
166 self.submatrix(vec![row], (0..self.cols()).collect())
167 }
168
169 pub fn get_col_submatrix(&self, col: usize) -> Self {
170 self.submatrix((0..self.rows()).collect(), vec![col])
171 }
172
173 pub fn get_row_refs(&self, row: usize) -> Vec<&Set> {
174 assert!(row < self.rows());
175 (0..self.cols()).map(|c| self.at(row, c).unwrap()).collect()
176 }
177
178 pub fn get_col_refs(&self, col: usize) -> Vec<&Set> {
179 assert!(col < self.cols());
180 (0..self.rows()).map(|r| self.at(r, col).unwrap()).collect()
181 }
182
183 pub fn get_row(&self, row: usize) -> Vec<Set> {
184 assert!(row < self.rows());
185 self.get_row_refs(row).into_iter().cloned().collect()
186 }
187
188 pub fn get_col(&self, col: usize) -> Vec<Set> {
189 assert!(col < self.cols());
190 self.get_col_refs(col).into_iter().cloned().collect()
191 }
192
193 pub fn apply_map<NewSet: Clone>(&self, f: impl Fn(&Set) -> NewSet) -> Matrix<NewSet> {
195 Matrix {
196 dim1: self.dim1,
197 dim2: self.dim2,
198 transpose: self.transpose,
199 flip_rows: self.flip_rows,
200 flip_cols: self.flip_cols,
201 elems: self.elems.iter().map(f).collect(),
202 }
203 }
204
205 pub fn transpose(mut self) -> Self {
206 self.transpose_mut();
207 self
208 }
209 pub fn transpose_ref(&self) -> Self {
210 self.clone().transpose()
211 }
212 pub fn transpose_mut(&mut self) {
213 self.transpose = !self.transpose;
214 (self.flip_rows, self.flip_cols) = (self.flip_cols, self.flip_rows);
215 }
216
217 pub fn flip_rows(mut self) -> Self {
218 self.flip_rows_mut();
219 self
220 }
221 pub fn flip_rows_ref(&self) -> Self {
222 self.clone().flip_rows()
223 }
224 pub fn flip_rows_mut(&mut self) {
225 self.flip_rows = !self.flip_rows;
226 }
227
228 pub fn flip_cols(mut self) -> Self {
229 self.flip_cols_mut();
230 self
231 }
232 pub fn flip_cols_ref(&self) -> Self {
233 self.clone().flip_cols()
234 }
235 pub fn flip_cols_mut(&mut self) {
236 self.flip_cols = !self.flip_cols;
237 }
238
239 pub fn join_rows<MatT: Borrow<Matrix<Set>>>(cols: usize, mats: Vec<MatT>) -> Matrix<Set> {
247 let mut rows = 0;
248 for mat in &mats {
249 assert_eq!(cols, mat.borrow().cols());
250 rows += mat.borrow().rows();
251 }
252 Matrix::construct(rows, cols, |r, c| {
253 let mut row_offset = 0;
255 for mat in &mats {
256 for mr in 0..mat.borrow().rows() {
257 for mc in 0..cols {
258 if r == row_offset + mr && c == mc {
259 return mat.borrow().at(mr, mc).unwrap().clone();
260 }
261 }
262 }
263 row_offset += mat.borrow().rows();
264 }
265 panic!();
266 })
267 }
268
269 pub fn join_cols<MatT: Borrow<Matrix<Set>>>(rows: usize, mats: Vec<MatT>) -> Matrix<Set> {
277 let mut t_mats = vec![];
278 for mat in mats {
279 t_mats.push(mat.borrow().clone().transpose());
280 }
281 let joined = Self::join_rows(rows, t_mats.iter().collect());
282 joined.transpose()
283 }
284
285 pub fn entries_list(&self) -> Vec<&Set> {
289 let mut entries = vec![];
290 for r in 0..self.rows() {
291 for c in 0..self.cols() {
292 entries.push(self.at(r, c).unwrap());
293 }
294 }
295 entries
296 }
297}
298
299#[derive(Debug, Clone, PartialEq, Eq)]
300pub struct MatrixStructure<RS: SetSignature, RSB: BorrowedStructure<RS>> {
301 _ring: PhantomData<RS>,
302 ring: RSB,
303}
304
305impl<RS: SetSignature, RSB: BorrowedStructure<RS>> Signature for MatrixStructure<RS, RSB> {}
306
307impl<RS: SetSignature, RSB: BorrowedStructure<RS>> SetSignature for MatrixStructure<RS, RSB> {
308 type Set = Matrix<RS::Set>;
309
310 fn is_element(&self, _x: &Self::Set) -> Result<(), String> {
311 Ok(())
312 }
313}
314
315impl<RS: SetSignature, RSB: BorrowedStructure<RS>> MatrixStructure<RS, RSB> {
316 pub fn new(ring: RSB) -> Self {
317 Self {
318 _ring: PhantomData,
319 ring,
320 }
321 }
322
323 pub fn ring(&self) -> &RS {
324 self.ring.borrow()
325 }
326}
327
328pub trait RingMatricesSignature: SetSignature {
329 fn matrices(&self) -> MatrixStructure<Self, &Self> {
330 MatrixStructure::new(self)
331 }
332
333 fn into_matrices(self) -> MatrixStructure<Self, Self> {
334 MatrixStructure::new(self)
335 }
336}
337
338impl<RS: SetSignature> RingMatricesSignature for RS {}
339
340impl<RS: EqSignature, RSB: BorrowedStructure<RS>> MatrixStructure<RS, RSB> {
341 pub fn equal(&self, a: &Matrix<RS::Set>, b: &Matrix<RS::Set>) -> bool {
342 let rows = a.rows();
343 let cols = a.cols();
344 if rows != b.rows() || cols != b.cols() {
345 false
346 } else {
347 for c in 0..cols {
348 for r in 0..rows {
349 if !self.ring().equal(a.at(r, c).unwrap(), b.at(r, c).unwrap()) {
350 return false;
351 }
352 }
353 }
354 true
355 }
356 }
357}
358
359impl<RS: ToStringSignature, RSB: BorrowedStructure<RS>> MatrixStructure<RS, RSB> {
360 pub fn pprint(&self, mat: &Matrix<RS::Set>) {
361 let mut str_rows = vec![];
362 for r in 0..mat.rows() {
363 str_rows.push(vec![]);
364 for c in 0..mat.cols() {
365 str_rows[r].push(self.ring().to_string(mat.at(r, c).unwrap()));
366 }
367 }
368 #[allow(clippy::redundant_closure_for_method_calls)]
369 let cols_widths: Vec<usize> = (0..mat.cols())
370 .map(|c| {
371 (0..mat.rows())
372 .map(|r| str_rows[r][c].chars().count())
373 .fold(0usize, |a, b| a.max(b))
374 })
375 .collect();
376
377 #[allow(clippy::needless_range_loop)]
378 for r in 0..mat.rows() {
379 for c in 0..mat.cols() {
380 while str_rows[r][c].chars().count() < cols_widths[c] {
381 str_rows[r][c].push(' ');
382 }
383 debug_assert_eq!(str_rows[r][c].chars().count(), cols_widths[c]);
384 }
385 }
386
387 #[allow(clippy::needless_range_loop)]
388 for r in 0..mat.rows() {
389 if mat.rows() == 1 {
390 print!("( ");
391 } else if r == 0 {
392 print!("/ ");
393 } else if r == mat.rows() - 1 {
394 print!("\\ ");
395 } else {
396 print!("| ");
397 }
398 for c in 0..mat.cols() {
399 if c != 0 {
400 print!(" ");
401 }
402 print!("{}", str_rows[r][c]);
403 }
404 if mat.rows() == 1 {
405 print!(" )");
406 } else if r == 0 {
407 print!(" \\");
408 } else if r == mat.rows() - 1 {
409 print!(" /");
410 } else {
411 print!(" |");
412 }
413 println!();
414 }
415 }
416}
417
418impl<RS: RingSignature, RSB: BorrowedStructure<RS>> MatrixStructure<RS, RSB> {
419 pub fn zero(&self, rows: usize, cols: usize) -> Matrix<RS::Set> {
420 Matrix::construct(rows, cols, |_r, _c| self.ring().zero())
421 }
422
423 pub fn ident(&self, n: usize) -> Matrix<RS::Set> {
424 Matrix::construct(n, n, |r, c| {
425 if r == c {
426 self.ring().one()
427 } else {
428 self.ring().zero()
429 }
430 })
431 }
432
433 pub fn diag(&self, diag: &[RS::Set]) -> Matrix<RS::Set> {
434 Matrix::construct(diag.len(), diag.len(), |r, c| {
435 if r == c {
436 diag[r].clone()
437 } else {
438 self.ring().zero()
439 }
440 })
441 }
442
443 pub fn join_diag<MatT: Borrow<Matrix<RS::Set>>>(&self, mats: Vec<MatT>) -> Matrix<RS::Set> {
444 if mats.is_empty() {
445 Matrix::construct(0, 0, |_r, _c| unreachable!())
446 } else if mats.len() == 1 {
447 mats[0].borrow().clone()
448 } else {
449 let i = mats.len() / 2;
450 let (first, last) = mats.split_at(i);
451 #[allow(clippy::redundant_closure_for_method_calls)]
452 let first = self.join_diag(first.iter().map(|m| m.borrow()).collect());
453 #[allow(clippy::redundant_closure_for_method_calls)]
454 let last = self.join_diag(last.iter().map(|m| m.borrow()).collect());
455 Matrix::construct(
456 first.rows() + last.rows(),
457 first.cols() + last.cols(),
458 |r, c| {
459 if r < first.rows() && c < first.cols() {
460 first.at(r, c).unwrap().clone()
461 } else if first.rows() <= r && first.cols() <= c {
462 last.at(r - first.rows(), c - first.cols()).unwrap().clone()
463 } else {
464 self.ring().zero()
465 }
466 },
467 )
468 }
469 }
470
471 pub fn dot(&self, a: &Matrix<RS::Set>, b: &Matrix<RS::Set>) -> RS::Set {
472 let rows = a.rows();
473 let cols = a.cols();
474 assert_eq!(rows, b.rows());
475 assert_eq!(cols, b.cols());
476 let mut tot = self.ring().zero();
477 for r in 0..rows {
478 for c in 0..cols {
479 self.ring().add_mut(
480 &mut tot,
481 &self.ring().mul(a.at(r, c).unwrap(), b.at(r, c).unwrap()),
482 );
483 }
484 }
485 tot
486 }
487
488 pub fn add_mut(&self, a: &mut Matrix<RS::Set>, b: &Matrix<RS::Set>) -> Result<(), MatOppErr> {
489 if a.rows() != b.rows() || a.cols() != b.cols() {
490 Err(MatOppErr::DimMismatch)
491 } else {
492 let rows = a.rows();
493 let cols = a.cols();
494 for c in 0..cols {
495 for r in 0..rows {
496 self.ring()
497 .add_mut(a.at_mut(r, c).unwrap(), b.at(r, c).unwrap());
498 }
499 }
500 Ok(())
501 }
502 }
503
504 pub fn add(
505 &self,
506 a: &Matrix<RS::Set>,
507 b: &Matrix<RS::Set>,
508 ) -> Result<Matrix<RS::Set>, MatOppErr> {
509 let mut new_a = a.clone();
510 match self.add_mut(&mut new_a, b) {
511 Ok(()) => Ok(new_a),
512 Err(e) => Err(e),
513 }
514 }
515
516 pub fn neg_mut(&self, a: &mut Matrix<RS::Set>) {
517 for r in 0..a.rows() {
518 for c in 0..a.cols() {
519 let neg_elem = self.ring().neg(a.at(r, c).unwrap());
520 *a.at_mut(r, c).unwrap() = neg_elem;
521 }
522 }
523 }
524
525 pub fn neg(&self, mut a: Matrix<RS::Set>) -> Matrix<RS::Set> {
526 self.neg_mut(&mut a);
527 a
528 }
529
530 pub fn mul(
531 &self,
532 a: &Matrix<RS::Set>,
533 b: &Matrix<RS::Set>,
534 ) -> Result<Matrix<RS::Set>, MatOppErr> {
535 let mids = a.cols();
536 if mids != b.rows() {
537 return Err(MatOppErr::DimMismatch);
538 }
539 let rows = a.rows();
540 let cols = b.cols();
541 let mut s = self.zero(rows, cols);
542 for r in 0..rows {
543 for c in 0..cols {
544 for m in 0..mids {
545 self.ring().add_mut(
546 s.at_mut(r, c).unwrap(),
547 &self.ring().mul(a.at(r, m).unwrap(), b.at(m, c).unwrap()),
548 );
549 }
550 }
551 }
552 Ok(s)
553 }
554
555 pub fn apply_row(&self, mat: &Matrix<RS::Set>, row: &[RS::Set]) -> Vec<RS::Set> {
556 assert_eq!(mat.rows(), row.len());
557 (0..mat.cols())
558 .map(|c| {
559 self.ring().sum(
560 (0..mat.rows())
561 .map(|r| self.ring().mul(mat.at(r, c).unwrap(), &row[r]))
562 .collect(),
563 )
564 })
565 .collect()
566 }
567
568 pub fn apply_col(&self, mat: &Matrix<RS::Set>, col: &[RS::Set]) -> Vec<RS::Set> {
569 assert_eq!(mat.cols(), col.len());
570 (0..mat.rows())
571 .map(|r| {
572 self.ring().sum(
573 (0..mat.cols())
574 .map(|c| self.ring().mul(mat.at(r, c).unwrap(), &col[c]))
575 .collect(),
576 )
577 })
578 .collect()
579 }
580
581 pub fn mul_scalar(&self, mut a: Matrix<RS::Set>, scalar: &RS::Set) -> Matrix<RS::Set> {
582 for r in 0..a.rows() {
583 for c in 0..a.cols() {
584 self.ring().mul_mut(a.at_mut(r, c).unwrap(), scalar);
585 }
586 }
587 a
588 }
589
590 pub fn mul_scalar_ref(&self, a: &Matrix<RS::Set>, scalar: &RS::Set) -> Matrix<RS::Set> {
591 self.mul_scalar(a.clone(), scalar)
592 }
593
594 pub fn det_naive(&self, a: &Matrix<RS::Set>) -> Result<RS::Set, MatOppErr> {
595 let n = a.rows();
596 if n == a.cols() {
597 let mut det = self.ring().zero();
598 for perm in algebraeon_groups::permutation::Permutation::all_permutations(n) {
599 let mut prod = self.ring().one();
600 for k in 0..n {
601 self.ring()
602 .mul_mut(&mut prod, a.at(k, perm.call(k)).unwrap());
603 }
604 match perm.sign() {
605 algebraeon_groups::examples::c2::C2::Identity => {}
606 algebraeon_groups::examples::c2::C2::Flip => {
607 prod = self.ring().neg(&prod);
608 }
609 }
610
611 self.ring().add_mut(&mut det, &prod);
612 }
613 Ok(det)
614 } else {
615 Err(MatOppErr::NotSquare)
616 }
617 }
618
619 pub fn trace(&self, a: &Matrix<RS::Set>) -> Result<RS::Set, MatOppErr> {
620 let n = a.rows();
621 if n == a.cols() {
622 Ok(self
623 .ring()
624 .sum((0..n).map(|i| a.at(i, i).unwrap()).collect()))
625 } else {
626 Err(MatOppErr::NotSquare)
627 }
628 }
629
630 pub fn nat_pow(&self, a: &Matrix<RS::Set>, k: &Natural) -> Result<Matrix<RS::Set>, MatOppErr> {
631 let n = a.rows();
632 if n != a.cols() {
633 Err(MatOppErr::NotSquare)
634 } else if *k == Natural::ZERO {
635 Ok(self.ident(n))
636 } else if *k == Natural::ONE {
637 Ok(a.clone())
638 } else {
639 debug_assert!(*k >= Natural::TWO);
640 let bits: Vec<_> = k.bits().collect();
641 let mut pows = vec![a.clone()];
642 while pows.len() < bits.len() {
643 pows.push(
644 self.mul(pows.last().unwrap(), pows.last().unwrap())
645 .unwrap(),
646 );
647 }
648 let count = bits.len();
649 debug_assert_eq!(count, pows.len());
650 let mut ans = self.ident(n);
651 for i in 0..count {
652 if bits[i] {
653 ans = self.mul(&ans, &pows[i]).unwrap();
654 }
655 }
656 Ok(ans)
657 }
658 }
659}
660
661impl<R: MetaType> MetaType for Matrix<R>
662where
663 R::Signature: SetSignature,
664{
665 type Signature = MatrixStructure<R::Signature, R::Signature>;
666
667 fn structure() -> Self::Signature {
668 MatrixStructure::new(R::structure())
669 }
670}
671
672impl<R: MetaType> Matrix<R>
673where
674 R::Signature: ToStringSignature,
675{
676 pub fn pprint(&self) {
677 Self::structure().pprint(self);
678 }
679}
680
681impl<R: MetaType> PartialEq for Matrix<R>
682where
683 R::Signature: RingSignature + EqSignature,
684{
685 fn eq(&self, other: &Self) -> bool {
686 Self::structure().equal(self, other)
687 }
688}
689
690impl<R: MetaType> Eq for Matrix<R> where R::Signature: RingSignature + EqSignature {}
691
692impl<R: MetaType> Matrix<R>
693where
694 R::Signature: RingSignature,
695{
696 pub fn zero(rows: usize, cols: usize) -> Self {
697 Self::structure().zero(rows, cols)
698 }
699
700 pub fn ident(n: usize) -> Self {
701 Self::structure().ident(n)
702 }
703
704 pub fn diag(diag: &[R]) -> Self {
705 Self::structure().diag(diag)
706 }
707
708 pub fn dot(a: &Self, b: &Self) -> R {
709 Self::structure().dot(a, b)
710 }
711
712 pub fn add_mut(&mut self, b: &Self) -> Result<(), MatOppErr> {
713 Self::structure().add_mut(self, b)
714 }
715
716 pub fn add(a: &Self, b: &Self) -> Result<Self, MatOppErr> {
717 Self::structure().add(a, b)
718 }
719
720 pub fn neg_mut(&mut self) {
721 Self::structure().neg_mut(self);
722 }
723
724 pub fn neg(&self) -> Self {
725 Self::structure().neg(self.clone())
726 }
727
728 pub fn mul(a: &Self, b: &Self) -> Result<Self, MatOppErr> {
729 Self::structure().mul(a, b)
730 }
731
732 pub fn apply_row(&self, row: &[R]) -> Vec<R> {
733 Self::structure().apply_row(self, row)
734 }
735
736 pub fn apply_col(&self, col: &[R]) -> Vec<R> {
737 Self::structure().apply_col(self, col)
738 }
739
740 pub fn mul_scalar(&self, scalar: &R) -> Matrix<R> {
741 Self::structure().mul_scalar(self.clone(), scalar)
742 }
743
744 pub fn mul_scalar_ref(&self, scalar: &R) -> Matrix<R> {
745 Self::structure().mul_scalar_ref(self, scalar)
746 }
747
748 pub fn det_naive(&self) -> Result<R, MatOppErr> {
749 Self::structure().det_naive(self)
750 }
751
752 pub fn trace(&self) -> Result<R, MatOppErr> {
753 Self::structure().trace(self)
754 }
755}
756
757#[cfg(test)]
758mod tests {
759 use algebraeon_nzq::Integer;
760
761 use super::*;
762
763 #[test]
764 fn test_join_rows() {
765 let top = Matrix::<Integer>::from_rows(vec![vec![1, 2, 3], vec![4, 5, 6]]);
766 let bot = Matrix::from_rows(vec![vec![7, 8, 9]]);
767
768 let both = Matrix::from_rows(vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]);
769
770 println!("top");
771 top.pprint();
772 println!("bot");
773 bot.pprint();
774 println!("both");
775 both.pprint();
776
777 let ans = Matrix::join_rows(3, vec![top, bot]);
778 println!("ans");
779 ans.pprint();
780
781 assert_eq!(ans, both);
782 }
783
784 #[test]
785 fn invariants() {
786 let m = Matrix {
787 dim1: 3,
788 dim2: 4,
789 transpose: false,
790 flip_rows: false,
791 flip_cols: false,
792 elems: vec![
793 Integer::from(1),
794 Integer::from(2),
795 Integer::from(3),
796 Integer::from(4),
797 Integer::from(5),
798 ],
799 };
800 if let Ok(()) = m.check_invariants() {
801 panic!();
802 }
803
804 let m = Matrix {
805 dim1: 2,
806 dim2: 3,
807 transpose: true,
808 flip_rows: false,
809 flip_cols: false,
810 elems: vec![
811 Integer::from(1),
812 Integer::from(2),
813 Integer::from(3),
814 Integer::from(4),
815 Integer::from(5),
816 Integer::from(6),
817 ],
818 };
819 m.check_invariants().unwrap();
820 }
821
822 #[test]
823 fn transpose_eq() {
824 let a = Matrix {
825 dim1: 2,
826 dim2: 2,
827 transpose: false,
828 flip_rows: false,
829 flip_cols: false,
830 elems: vec![
831 Integer::from(0),
832 Integer::from(1),
833 Integer::from(2),
834 Integer::from(3),
835 ],
836 };
837 a.check_invariants().unwrap();
838
839 let b = Matrix {
840 dim1: 2,
841 dim2: 2,
842 transpose: true,
843 flip_rows: false,
844 flip_cols: false,
845 elems: vec![
846 Integer::from(0),
847 Integer::from(2),
848 Integer::from(1),
849 Integer::from(3),
850 ],
851 };
852 b.check_invariants().unwrap();
853
854 assert_eq!(a, b);
855 }
856
857 #[test]
858 fn flip_axes_eq() {
859 let mut a = Matrix::<Integer>::from_rows(vec![vec![1, 2], vec![3, 4]]);
860 a.pprint();
861 println!("flip rows");
862 a.flip_rows_mut();
863 a.pprint();
864 assert_eq!(
865 a,
866 Matrix::from_rows(vec![
867 vec![Integer::from(3), Integer::from(4)],
868 vec![Integer::from(1), Integer::from(2)],
869 ])
870 );
871 println!("transpose");
872 a.transpose_mut();
873 a.pprint();
874 assert_eq!(
875 a,
876 Matrix::from_rows(vec![
877 vec![Integer::from(3), Integer::from(1)],
878 vec![Integer::from(4), Integer::from(2)],
879 ])
880 );
881 println!("flip rows");
882 a.flip_rows_mut();
883 a.pprint();
884 assert_eq!(
885 a,
886 Matrix::from_rows(vec![
887 vec![Integer::from(4), Integer::from(2)],
888 vec![Integer::from(3), Integer::from(1)],
889 ])
890 );
891 println!("flip cols");
892 a.flip_cols_mut();
893 a.pprint();
894 assert_eq!(
895 a,
896 Matrix::from_rows(vec![
897 vec![Integer::from(2), Integer::from(4)],
898 vec![Integer::from(1), Integer::from(3)],
899 ])
900 );
901 println!("transpose");
902 a.transpose_mut();
903 a.pprint();
904 assert_eq!(
905 a,
906 Matrix::from_rows(vec![
907 vec![Integer::from(2), Integer::from(1)],
908 vec![Integer::from(4), Integer::from(3)],
909 ])
910 );
911 println!("flip cols");
912 a.flip_cols_mut();
913 a.pprint();
914 assert_eq!(
915 a,
916 Matrix::from_rows(vec![
917 vec![Integer::from(1), Integer::from(2)],
918 vec![Integer::from(3), Integer::from(4)],
919 ])
920 );
921 }
922
923 #[test]
924 fn add() {
925 {
926 let mut a = Matrix {
927 dim1: 2,
928 dim2: 3,
929 transpose: false,
930 flip_rows: false,
931 flip_cols: false,
932 elems: vec![
933 Integer::from(1),
934 Integer::from(2),
935 Integer::from(3),
936 Integer::from(4),
937 Integer::from(5),
938 Integer::from(6),
939 ],
940 };
941 a.check_invariants().unwrap();
942
943 let b = Matrix {
944 dim1: 2,
945 dim2: 3,
946 transpose: false,
947 flip_rows: false,
948 flip_cols: false,
949 elems: vec![
950 Integer::from(1),
951 Integer::from(2),
952 Integer::from(1),
953 Integer::from(2),
954 Integer::from(1),
955 Integer::from(2),
956 ],
957 };
958 b.check_invariants().unwrap();
959
960 let c = Matrix {
961 dim1: 2,
962 dim2: 3,
963 transpose: false,
964 flip_rows: false,
965 flip_cols: false,
966 elems: vec![
967 Integer::from(2),
968 Integer::from(4),
969 Integer::from(4),
970 Integer::from(6),
971 Integer::from(6),
972 Integer::from(8),
973 ],
974 };
975 c.check_invariants().unwrap();
976
977 a.add_mut(&b).unwrap();
978
979 assert_eq!(a, c);
980 }
981
982 {
983 let mut a = Matrix {
984 dim1: 3,
985 dim2: 2,
986 transpose: false,
987 flip_rows: false,
988 flip_cols: false,
989 elems: vec![
990 Integer::from(1),
991 Integer::from(2),
992 Integer::from(3),
993 Integer::from(4),
994 Integer::from(5),
995 Integer::from(6),
996 ],
997 };
998 a.check_invariants().unwrap();
999
1000 let b = Matrix {
1001 dim1: 2,
1002 dim2: 3,
1003 transpose: true,
1004 flip_rows: false,
1005 flip_cols: false,
1006 elems: vec![
1007 Integer::from(10),
1008 Integer::from(20),
1009 Integer::from(30),
1010 Integer::from(40),
1011 Integer::from(50),
1012 Integer::from(60),
1013 ],
1014 };
1015 b.check_invariants().unwrap();
1016
1017 let c = Matrix {
1018 dim1: 3,
1019 dim2: 2,
1020 transpose: false,
1021 flip_rows: false,
1022 flip_cols: false,
1023 elems: vec![
1024 Integer::from(11),
1025 Integer::from(42),
1026 Integer::from(23),
1027 Integer::from(54),
1028 Integer::from(35),
1029 Integer::from(66),
1030 ],
1031 };
1032 c.check_invariants().unwrap();
1033
1034 a.add_mut(&b).unwrap();
1035
1036 assert_eq!(a, c);
1037 }
1038
1039 {
1040 let mut a = Matrix {
1041 dim1: 3,
1042 dim2: 2,
1043 transpose: false,
1044 flip_rows: false,
1045 flip_cols: false,
1046 elems: vec![
1047 Integer::from(1),
1048 Integer::from(2),
1049 Integer::from(3),
1050 Integer::from(4),
1051 Integer::from(5),
1052 Integer::from(6),
1053 ],
1054 };
1055 a.check_invariants().unwrap();
1056
1057 let b = Matrix {
1058 dim1: 2,
1059 dim2: 3,
1060 transpose: false,
1061 flip_rows: false,
1062 flip_cols: false,
1063 elems: vec![
1064 Integer::from(1),
1065 Integer::from(2),
1066 Integer::from(1),
1067 Integer::from(2),
1068 Integer::from(1),
1069 Integer::from(2),
1070 ],
1071 };
1072 b.check_invariants().unwrap();
1073
1074 match a.add_mut(&b) {
1075 Ok(()) => panic!(),
1076 Err(MatOppErr::DimMismatch) => {}
1077 Err(_) => panic!(),
1078 }
1079 }
1080
1081 {
1082 let a = Matrix {
1083 dim1: 2,
1084 dim2: 3,
1085 transpose: false,
1086 flip_rows: false,
1087 flip_cols: false,
1088 elems: vec![
1089 Integer::from(1),
1090 Integer::from(2),
1091 Integer::from(3),
1092 Integer::from(4),
1093 Integer::from(5),
1094 Integer::from(6),
1095 ],
1096 };
1097 a.check_invariants().unwrap();
1098
1099 let b = Matrix {
1100 dim1: 2,
1101 dim2: 3,
1102 transpose: false,
1103 flip_rows: false,
1104 flip_cols: false,
1105 elems: vec![
1106 Integer::from(1),
1107 Integer::from(2),
1108 Integer::from(1),
1109 Integer::from(2),
1110 Integer::from(1),
1111 Integer::from(2),
1112 ],
1113 };
1114 b.check_invariants().unwrap();
1115
1116 let c = Matrix {
1117 dim1: 2,
1118 dim2: 3,
1119 transpose: false,
1120 flip_rows: false,
1121 flip_cols: false,
1122 elems: vec![
1123 Integer::from(2),
1124 Integer::from(4),
1125 Integer::from(4),
1126 Integer::from(6),
1127 Integer::from(6),
1128 Integer::from(8),
1129 ],
1130 };
1131 c.check_invariants().unwrap();
1132
1133 assert_eq!(Matrix::add(&a, &b).unwrap(), c);
1134 }
1135 }
1136
1137 #[test]
1138 fn mul() {
1139 {
1140 let a = Matrix {
1141 dim1: 2,
1142 dim2: 4,
1143 transpose: false,
1144 flip_rows: false,
1145 flip_cols: false,
1146 elems: vec![
1147 Integer::from(3),
1148 Integer::from(2),
1149 Integer::from(1),
1150 Integer::from(5),
1151 Integer::from(9),
1152 Integer::from(1),
1153 Integer::from(3),
1154 Integer::from(0),
1155 ],
1156 };
1157 a.check_invariants().unwrap();
1158
1159 let b = Matrix {
1160 dim1: 4,
1161 dim2: 3,
1162 transpose: false,
1163 flip_rows: false,
1164 flip_cols: false,
1165 elems: vec![
1166 Integer::from(2),
1167 Integer::from(9),
1168 Integer::from(0),
1169 Integer::from(1),
1170 Integer::from(3),
1171 Integer::from(5),
1172 Integer::from(2),
1173 Integer::from(4),
1174 Integer::from(7),
1175 Integer::from(8),
1176 Integer::from(1),
1177 Integer::from(5),
1178 ],
1179 };
1180 b.check_invariants().unwrap();
1181
1182 let c = Matrix {
1183 dim1: 2,
1184 dim2: 3,
1185 transpose: false,
1186 flip_rows: false,
1187 flip_cols: false,
1188 elems: vec![
1189 Integer::from(50),
1190 Integer::from(42),
1191 Integer::from(42),
1192 Integer::from(25),
1193 Integer::from(96),
1194 Integer::from(26),
1195 ],
1196 };
1197 c.check_invariants().unwrap();
1198
1199 assert_eq!(Matrix::mul(&a, &b).unwrap(), c);
1200 }
1201 }
1202
1203 #[test]
1204 fn matrix_apply_row_and_col_test() {
1205 let m = Matrix::<Integer>::from_rows(vec![
1206 vec![Integer::from(1), Integer::from(2), Integer::from(3)],
1207 vec![Integer::from(6), Integer::from(5), Integer::from(4)],
1208 ]);
1209
1210 assert_eq!(
1211 m.apply_row(&[Integer::from(1), Integer::from(0)]),
1212 vec![Integer::from(1), Integer::from(2), Integer::from(3)]
1213 );
1214
1215 assert_eq!(
1216 m.apply_row(&[Integer::from(0), Integer::from(1)]),
1217 vec![Integer::from(6), Integer::from(5), Integer::from(4)]
1218 );
1219
1220 assert_eq!(
1221 m.apply_row(&[Integer::from(1), Integer::from(1)]),
1222 vec![Integer::from(7), Integer::from(7), Integer::from(7)]
1223 );
1224
1225 assert_eq!(
1226 m.apply_col(&[Integer::from(1), Integer::from(0), Integer::from(0)]),
1227 vec![Integer::from(1), Integer::from(6)]
1228 );
1229
1230 assert_eq!(
1231 m.apply_col(&[Integer::from(0), Integer::from(1), Integer::from(0)]),
1232 vec![Integer::from(2), Integer::from(5)]
1233 );
1234
1235 assert_eq!(
1236 m.apply_col(&[Integer::from(0), Integer::from(0), Integer::from(1)]),
1237 vec![Integer::from(3), Integer::from(4)]
1238 );
1239
1240 assert_eq!(
1241 m.apply_col(&[Integer::from(1), Integer::from(1), Integer::from(1)]),
1242 vec![Integer::from(6), Integer::from(15)]
1243 );
1244 }
1245
1246 #[test]
1247 fn det_naive() {
1248 let m = Matrix::<Integer>::from_rows(vec![
1249 vec![Integer::from(1), Integer::from(3)],
1250 vec![Integer::from(4), Integer::from(2)],
1251 ]);
1252 println!("{}", m.det_naive().unwrap());
1253 assert_eq!(m.det_naive().unwrap(), Integer::from(-10));
1254
1255 let m = Matrix::<Integer>::from_rows(vec![
1256 vec![Integer::from(1), Integer::from(3), Integer::from(2)],
1257 vec![Integer::from(-3), Integer::from(-1), Integer::from(-3)],
1258 vec![Integer::from(2), Integer::from(3), Integer::from(1)],
1259 ]);
1260 println!("{}", m.det_naive().unwrap());
1261 assert_eq!(m.det_naive().unwrap(), Integer::from(-15));
1262 }
1263}