1use std::fmt::Display;
2
3use crate::ring::*;
4
5mod submatrix;
6mod transpose;
7mod owned;
8
9pub use submatrix::*;
10#[allow(unused_imports)]
11pub use transpose::*;
12pub use owned::*;
13
14pub mod transform;
15
16#[stability::unstable(feature = "enable")]
17pub fn format_matrix<'a, M, R>(row_count: usize, col_count: usize, matrix: M, ring: R) -> impl 'a + Display
18 where R: 'a + RingStore,
19 El<R>: 'a,
20 M: 'a + Fn(usize, usize) -> &'a El<R>
21{
22 struct DisplayWrapper<'a, R: 'a + RingStore, M: Fn(usize, usize) -> &'a El<R>> {
23 matrix: M,
24 ring: R,
25 row_count: usize,
26 col_count: usize
27 }
28
29 impl<'a, R: 'a + RingStore, M: Fn(usize, usize) -> &'a El<R>> Display for DisplayWrapper<'a, R, M> {
30
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 let strings = (0..self.row_count).flat_map(|i| (0..self.col_count).map(move |j| (i, j)))
33 .map(|(i, j)| format!("{}", self.ring.format((self.matrix)(i, j))))
34 .collect::<Vec<_>>();
35 let max_len = strings.iter().map(|s| s.chars().count()).chain([2].into_iter()).max().unwrap();
36 let mut strings = strings.into_iter();
37 for i in 0..self.row_count {
38 write!(f, "|")?;
39 if self.col_count > 0 {
40 write!(f, "{:>width$}", strings.next().unwrap(), width = max_len)?;
41 }
42 for _ in 1..self.col_count {
43 write!(f, ",{:>width$}", strings.next().unwrap(), width = max_len)?;
44 }
45 if i + 1 != self.row_count {
46 writeln!(f, "|")?;
47 } else {
48 write!(f, "|")?;
49 }
50 }
51 return Ok(());
52 }
53 }
54
55 DisplayWrapper { matrix, ring, col_count, row_count }
56}
57
58pub mod matrix_compare {
64 use super::*;
65
66 pub trait MatrixCompare<T> {
71 fn row_count(&self) -> usize;
72 fn col_count(&self) -> usize;
73 fn at(&self, i: usize, j: usize) -> &T;
74 }
75
76 impl<T, const ROWS: usize, const COLS: usize> MatrixCompare<T> for [[T; COLS]; ROWS] {
77
78 fn col_count(&self) -> usize { COLS }
79 fn row_count(&self) -> usize { ROWS }
80 fn at(&self, i: usize, j: usize) -> &T { &self[i][j] }
81 }
82
83 impl<T, const ROWS: usize, const COLS: usize> MatrixCompare<T> for [DerefArray<T, COLS>; ROWS] {
84
85 fn col_count(&self) -> usize { COLS }
86 fn row_count(&self) -> usize { ROWS }
87 fn at(&self, i: usize, j: usize) -> &T { &self[i][j] }
88 }
89
90 impl<'a, V: AsPointerToSlice<T>, T> MatrixCompare<T> for Submatrix<'a, V, T> {
91
92 fn col_count(&self) -> usize { Submatrix::col_count(self) }
93 fn row_count(&self) -> usize { Submatrix::row_count(self) }
94 fn at(&self, i: usize, j: usize) -> &T { Submatrix::at(self, i, j) }
95 }
96
97 impl<'a, V: AsPointerToSlice<T>, T> MatrixCompare<T> for SubmatrixMut<'a, V, T> {
98
99 fn col_count(&self) -> usize { SubmatrixMut::col_count(self) }
100 fn row_count(&self) -> usize { SubmatrixMut::row_count(self) }
101 fn at(&self, i: usize, j: usize) -> &T { self.as_const().into_at(i, j) }
102 }
103
104 impl<'a, V: AsPointerToSlice<T>, T, const TRANSPOSED: bool> MatrixCompare<T> for TransposableSubmatrix<'a, V, T, TRANSPOSED> {
105
106 fn col_count(&self) -> usize { TransposableSubmatrix::col_count(self) }
107 fn row_count(&self) -> usize { TransposableSubmatrix::row_count(self) }
108 fn at(&self, i: usize, j: usize) -> &T { TransposableSubmatrix::at(self, i, j) }
109 }
110
111 impl<'a, V: AsPointerToSlice<T>, T, const TRANSPOSED: bool> MatrixCompare<T> for TransposableSubmatrixMut<'a, V, T, TRANSPOSED> {
112
113 fn col_count(&self) -> usize { TransposableSubmatrixMut::col_count(self) }
114 fn row_count(&self) -> usize { TransposableSubmatrixMut::row_count(self) }
115 fn at(&self, i: usize, j: usize) -> &T { self.as_const().into_at(i, j) }
116 }
117
118 impl<T> MatrixCompare<T> for OwnedMatrix<T> {
119
120 fn col_count(&self) -> usize { OwnedMatrix::col_count(self) }
121 fn row_count(&self) -> usize { OwnedMatrix::row_count(self) }
122 fn at(&self, i: usize, j: usize) -> &T { OwnedMatrix::at(self, i, j) }
123 }
124
125 impl<'a, T, M: MatrixCompare<T>> MatrixCompare<T> for &'a M {
126
127 fn col_count(&self) -> usize { (**self).col_count() }
128 fn row_count(&self) -> usize { (**self).row_count() }
129 fn at(&self, i: usize, j: usize) -> &T { (**self).at(i, j) }
130 }
131
132 pub fn is_matrix_eq<R: ?Sized + RingBase, M1: MatrixCompare<R::Element>, M2: MatrixCompare<R::Element>>(ring: &R, lhs: &M1, rhs: &M2) -> bool {
133 if lhs.row_count() != rhs.row_count() || lhs.col_count() != rhs.col_count() {
134 return false;
135 }
136 for i in 0..lhs.row_count() {
137 for j in 0..lhs.col_count() {
138 if !ring.eq_el(lhs.at(i, j), rhs.at(i, j)) {
139 return false;
140 }
141 }
142 }
143 return true;
144 }
145}
146
147#[macro_export]
162macro_rules! assert_matrix_eq {
163 ($ring:expr, $lhs:expr, $rhs:expr) => {
164 match (&$ring, &$lhs, &$rhs) {
165 (ring_val, lhs_val, rhs_val) => {
166 assert!(
167 $crate::matrix::matrix_compare::is_matrix_eq(ring_val.get_ring(), lhs_val, rhs_val),
168 "Assertion failed: Expected\n{}\nto be\n{}",
169 $crate::matrix::format_matrix(<_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::row_count(lhs_val), <_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::col_count(lhs_val), |i, j| <_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::at(lhs_val, i, j), ring_val),
170 $crate::matrix::format_matrix(<_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::row_count(rhs_val), <_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::col_count(rhs_val), |i, j| <_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::at(rhs_val, i, j), ring_val)
171 );
172 }
173 }
174 }
175}