1use std::fmt::Display;
2
3use crate::ring::*;
4
5mod submatrix;
6mod transpose;
7mod owned;
8
9pub use submatrix::*;
10#[allow(unused_imports)]
11pub use transpose::*;
12pub use owned::*;
13
14pub mod transform;
19
20#[stability::unstable(feature = "enable")]
21pub fn format_matrix<'a, M, R>(row_count: usize, col_count: usize, matrix: M, ring: R) -> impl 'a + Display
22 where R: 'a + RingStore,
23 El<R>: 'a,
24 M: 'a + Fn(usize, usize) -> &'a El<R>
25{
26 struct DisplayWrapper<'a, R: 'a + RingStore, M: Fn(usize, usize) -> &'a El<R>> {
27 matrix: M,
28 ring: R,
29 row_count: usize,
30 col_count: usize
31 }
32
33 impl<'a, R: 'a + RingStore, M: Fn(usize, usize) -> &'a El<R>> Display for DisplayWrapper<'a, R, M> {
34
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 let strings = (0..self.row_count).flat_map(|i| (0..self.col_count).map(move |j| (i, j)))
37 .map(|(i, j)| format!("{}", self.ring.format((self.matrix)(i, j))))
38 .collect::<Vec<_>>();
39 let max_len = strings.iter().map(|s| s.chars().count()).chain([2].into_iter()).max().unwrap();
40 let mut strings = strings.into_iter();
41 for i in 0..self.row_count {
42 write!(f, "|")?;
43 if self.col_count > 0 {
44 write!(f, "{:>width$}", strings.next().unwrap(), width = max_len)?;
45 }
46 for _ in 1..self.col_count {
47 write!(f, ",{:>width$}", strings.next().unwrap(), width = max_len)?;
48 }
49 if i + 1 != self.row_count {
50 writeln!(f, "|")?;
51 } else {
52 write!(f, "|")?;
53 }
54 }
55 return Ok(());
56 }
57 }
58
59 DisplayWrapper { matrix, ring, col_count, row_count }
60}
61
62pub mod matrix_compare {
68 use super::*;
69
70 pub trait MatrixCompare<T> {
75
76 fn row_count(&self) -> usize;
80
81 fn col_count(&self) -> usize;
85
86 fn at(&self, i: usize, j: usize) -> &T;
90 }
91
92 impl<T, const ROWS: usize, const COLS: usize> MatrixCompare<T> for [[T; COLS]; ROWS] {
93
94 fn col_count(&self) -> usize { COLS }
95 fn row_count(&self) -> usize { ROWS }
96 fn at(&self, i: usize, j: usize) -> &T { &self[i][j] }
97 }
98
99 impl<T, const ROWS: usize, const COLS: usize> MatrixCompare<T> for [DerefArray<T, COLS>; ROWS] {
100
101 fn col_count(&self) -> usize { COLS }
102 fn row_count(&self) -> usize { ROWS }
103 fn at(&self, i: usize, j: usize) -> &T { &self[i][j] }
104 }
105
106 impl<'a, V: AsPointerToSlice<T>, T> MatrixCompare<T> for Submatrix<'a, V, T> {
107
108 fn col_count(&self) -> usize { Submatrix::col_count(self) }
109 fn row_count(&self) -> usize { Submatrix::row_count(self) }
110 fn at(&self, i: usize, j: usize) -> &T { Submatrix::at(self, i, j) }
111 }
112
113 impl<'a, V: AsPointerToSlice<T>, T> MatrixCompare<T> for SubmatrixMut<'a, V, T> {
114
115 fn col_count(&self) -> usize { SubmatrixMut::col_count(self) }
116 fn row_count(&self) -> usize { SubmatrixMut::row_count(self) }
117 fn at(&self, i: usize, j: usize) -> &T { self.as_const().into_at(i, j) }
118 }
119
120 impl<'a, V: AsPointerToSlice<T>, T, const TRANSPOSED: bool> MatrixCompare<T> for TransposableSubmatrix<'a, V, T, TRANSPOSED> {
121
122 fn col_count(&self) -> usize { TransposableSubmatrix::col_count(self) }
123 fn row_count(&self) -> usize { TransposableSubmatrix::row_count(self) }
124 fn at(&self, i: usize, j: usize) -> &T { TransposableSubmatrix::at(self, i, j) }
125 }
126
127 impl<'a, V: AsPointerToSlice<T>, T, const TRANSPOSED: bool> MatrixCompare<T> for TransposableSubmatrixMut<'a, V, T, TRANSPOSED> {
128
129 fn col_count(&self) -> usize { TransposableSubmatrixMut::col_count(self) }
130 fn row_count(&self) -> usize { TransposableSubmatrixMut::row_count(self) }
131 fn at(&self, i: usize, j: usize) -> &T { self.as_const().into_at(i, j) }
132 }
133
134 impl<T> MatrixCompare<T> for OwnedMatrix<T> {
135
136 fn col_count(&self) -> usize { OwnedMatrix::col_count(self) }
137 fn row_count(&self) -> usize { OwnedMatrix::row_count(self) }
138 fn at(&self, i: usize, j: usize) -> &T { OwnedMatrix::at(self, i, j) }
139 }
140
141 impl<'a, T, M: MatrixCompare<T>> MatrixCompare<T> for &'a M {
142
143 fn col_count(&self) -> usize { (**self).col_count() }
144 fn row_count(&self) -> usize { (**self).row_count() }
145 fn at(&self, i: usize, j: usize) -> &T { (**self).at(i, j) }
146 }
147
148 pub fn is_matrix_eq<R: ?Sized + RingBase, M1: MatrixCompare<R::Element>, M2: MatrixCompare<R::Element>>(ring: &R, lhs: &M1, rhs: &M2) -> bool {
155 if lhs.row_count() != rhs.row_count() || lhs.col_count() != rhs.col_count() {
156 return false;
157 }
158 for i in 0..lhs.row_count() {
159 for j in 0..lhs.col_count() {
160 if !ring.eq_el(lhs.at(i, j), rhs.at(i, j)) {
161 return false;
162 }
163 }
164 }
165 return true;
166 }
167}
168
169#[macro_export]
188macro_rules! assert_matrix_eq {
189 ($ring:expr, $lhs:expr, $rhs:expr) => {
190 match (&$ring, &$lhs, &$rhs) {
191 (ring_val, lhs_val, rhs_val) => {
192 assert!(
193 $crate::matrix::matrix_compare::is_matrix_eq(ring_val.get_ring(), lhs_val, rhs_val),
194 "Assertion failed: Expected\n{}\nto be\n{}",
195 $crate::matrix::format_matrix(<_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::row_count(lhs_val), <_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::col_count(lhs_val), |i, j| <_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::at(lhs_val, i, j), ring_val),
196 $crate::matrix::format_matrix(<_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::row_count(rhs_val), <_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::col_count(rhs_val), |i, j| <_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::at(rhs_val, i, j), ring_val)
197 );
198 }
199 }
200 }
201}