1use std::fmt::Display;
2
3use crate::ring::*;
4
5mod owned;
6mod submatrix;
7mod transpose;
8
9pub use owned::*;
10pub use submatrix::*;
11#[allow(unused_imports)]
12pub use transpose::*;
13
14pub mod transform;
17
18#[stability::unstable(feature = "enable")]
19pub fn format_matrix<'a, M, R>(row_count: usize, col_count: usize, matrix: M, ring: R) -> impl 'a + Display
20where
21 R: 'a + RingStore,
22 El<R>: 'a,
23 M: 'a + Fn(usize, usize) -> &'a El<R>,
24{
25 struct DisplayWrapper<'a, R: 'a + RingStore, M: Fn(usize, usize) -> &'a El<R>> {
26 matrix: M,
27 ring: R,
28 row_count: usize,
29 col_count: usize,
30 }
31
32 impl<'a, R: 'a + RingStore, M: Fn(usize, usize) -> &'a El<R>> Display for DisplayWrapper<'a, R, M> {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 let strings = (0..self.row_count)
35 .flat_map(|i| (0..self.col_count).map(move |j| (i, j)))
36 .map(|(i, j)| format!("{}", self.ring.format((self.matrix)(i, j))))
37 .collect::<Vec<_>>();
38 let max_len = strings.iter().map(|s| s.chars().count()).chain([2]).max().unwrap();
39 let mut strings = strings.into_iter();
40 for i in 0..self.row_count {
41 write!(f, "|")?;
42 if self.col_count > 0 {
43 write!(f, "{:>width$}", strings.next().unwrap(), width = max_len)?;
44 }
45 for _ in 1..self.col_count {
46 write!(f, ",{:>width$}", strings.next().unwrap(), width = max_len)?;
47 }
48 if i + 1 != self.row_count {
49 writeln!(f, "|")?;
50 } else {
51 write!(f, "|")?;
52 }
53 }
54 return Ok(());
55 }
56 }
57
58 DisplayWrapper {
59 matrix,
60 ring,
61 col_count,
62 row_count,
63 }
64}
65
66pub mod matrix_compare {
70 use super::*;
71
72 pub trait MatrixCompare<T> {
75 fn row_count(&self) -> usize;
77
78 fn col_count(&self) -> usize;
80
81 fn at(&self, i: usize, j: usize) -> &T;
83 }
84
85 impl<T, const ROWS: usize, const COLS: usize> MatrixCompare<T> for [[T; COLS]; ROWS] {
86 fn col_count(&self) -> usize { COLS }
87 fn row_count(&self) -> usize { ROWS }
88 fn at(&self, i: usize, j: usize) -> &T { &self[i][j] }
89 }
90
91 impl<T, const ROWS: usize, const COLS: usize> MatrixCompare<T> for [DerefArray<T, COLS>; ROWS] {
92 fn col_count(&self) -> usize { COLS }
93 fn row_count(&self) -> usize { ROWS }
94 fn at(&self, i: usize, j: usize) -> &T { &self[i][j] }
95 }
96
97 impl<'a, V: AsPointerToSlice<T>, T> MatrixCompare<T> for Submatrix<'a, V, T> {
98 fn col_count(&self) -> usize { Submatrix::col_count(self) }
99 fn row_count(&self) -> usize { Submatrix::row_count(self) }
100 fn at(&self, i: usize, j: usize) -> &T { Submatrix::at(self, i, j) }
101 }
102
103 impl<'a, V: AsPointerToSlice<T>, T> MatrixCompare<T> for SubmatrixMut<'a, V, T> {
104 fn col_count(&self) -> usize { SubmatrixMut::col_count(self) }
105 fn row_count(&self) -> usize { SubmatrixMut::row_count(self) }
106 fn at(&self, i: usize, j: usize) -> &T { self.as_const().into_at(i, j) }
107 }
108
109 impl<'a, V: AsPointerToSlice<T>, T, const TRANSPOSED: bool> MatrixCompare<T>
110 for TransposableSubmatrix<'a, V, T, TRANSPOSED>
111 {
112 fn col_count(&self) -> usize { TransposableSubmatrix::col_count(self) }
113 fn row_count(&self) -> usize { TransposableSubmatrix::row_count(self) }
114 fn at(&self, i: usize, j: usize) -> &T { TransposableSubmatrix::at(self, i, j) }
115 }
116
117 impl<'a, V: AsPointerToSlice<T>, T, const TRANSPOSED: bool> MatrixCompare<T>
118 for TransposableSubmatrixMut<'a, V, T, TRANSPOSED>
119 {
120 fn col_count(&self) -> usize { TransposableSubmatrixMut::col_count(self) }
121 fn row_count(&self) -> usize { TransposableSubmatrixMut::row_count(self) }
122 fn at(&self, i: usize, j: usize) -> &T { self.as_const().into_at(i, j) }
123 }
124
125 impl<T> MatrixCompare<T> for OwnedMatrix<T> {
126 fn col_count(&self) -> usize { OwnedMatrix::col_count(self) }
127 fn row_count(&self) -> usize { OwnedMatrix::row_count(self) }
128 fn at(&self, i: usize, j: usize) -> &T { OwnedMatrix::at(self, i, j) }
129 }
130
131 impl<T, M: MatrixCompare<T>> MatrixCompare<T> for &M {
132 fn col_count(&self) -> usize { (**self).col_count() }
133 fn row_count(&self) -> usize { (**self).row_count() }
134 fn at(&self, i: usize, j: usize) -> &T { (**self).at(i, j) }
135 }
136
137 pub fn is_matrix_eq<R: ?Sized + RingBase, M1: MatrixCompare<R::Element>, M2: MatrixCompare<R::Element>>(
142 ring: &R,
143 lhs: &M1,
144 rhs: &M2,
145 ) -> bool {
146 if lhs.row_count() != rhs.row_count() || lhs.col_count() != rhs.col_count() {
147 return false;
148 }
149 for i in 0..lhs.row_count() {
150 for j in 0..lhs.col_count() {
151 if !ring.eq_el(lhs.at(i, j), rhs.at(i, j)) {
152 return false;
153 }
154 }
155 }
156 return true;
157 }
158}
159
160#[macro_export]
178macro_rules! assert_matrix_eq {
179 ($ring:expr,$lhs:expr,$rhs:expr) => {
180 match (&$ring, &$lhs, &$rhs) {
181 (ring_val, lhs_val, rhs_val) => {
182 assert!(
183 $crate::matrix::matrix_compare::is_matrix_eq(ring_val.get_ring(), lhs_val, rhs_val),
184 "Assertion failed: Expected\n{}\nto be\n{}",
185 $crate::matrix::format_matrix(
186 <_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::row_count(lhs_val),
187 <_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::col_count(lhs_val),
188 |i, j| <_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::at(lhs_val, i, j),
189 ring_val
190 ),
191 $crate::matrix::format_matrix(
192 <_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::row_count(rhs_val),
193 <_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::col_count(rhs_val),
194 |i, j| <_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::at(rhs_val, i, j),
195 ring_val
196 )
197 );
198 }
199 }
200 };
201}