feanor_math/matrix/
mod.rs

1use std::fmt::Display;
2
3use crate::ring::*;
4
5mod submatrix;
6mod transpose;
7mod owned;
8
9pub use submatrix::*;
10#[allow(unused_imports)]
11pub use transpose::*;
12pub use owned::*;
13
14pub mod transform;
15
16#[stability::unstable(feature = "enable")]
17pub fn format_matrix<'a, M, R>(row_count: usize, col_count: usize, matrix: M, ring: R) -> impl 'a + Display
18    where R: 'a + RingStore, 
19        El<R>: 'a,
20        M: 'a + Fn(usize, usize) -> &'a El<R>
21{
22    struct DisplayWrapper<'a, R: 'a + RingStore, M: Fn(usize, usize) -> &'a El<R>> {
23        matrix: M,
24        ring: R,
25        row_count: usize,
26        col_count: usize
27    }
28
29    impl<'a, R: 'a + RingStore, M: Fn(usize, usize) -> &'a El<R>> Display for DisplayWrapper<'a, R, M> {
30
31        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32            let strings = (0..self.row_count).flat_map(|i| (0..self.col_count).map(move |j| (i, j)))
33                .map(|(i, j)| format!("{}", self.ring.format((self.matrix)(i, j))))
34                .collect::<Vec<_>>();
35            let max_len = strings.iter().map(|s| s.chars().count()).chain([2].into_iter()).max().unwrap();
36            let mut strings = strings.into_iter();
37            for i in 0..self.row_count {
38                write!(f, "|")?;
39                if self.col_count > 0 {
40                    write!(f, "{:>width$}", strings.next().unwrap(), width = max_len)?;
41                }
42                for _ in 1..self.col_count {
43                    write!(f, ",{:>width$}", strings.next().unwrap(), width = max_len)?;
44                }
45                if i + 1 != self.row_count {
46                    writeln!(f, "|")?;
47                } else {
48                    write!(f, "|")?;
49                }
50            }
51            return Ok(());
52        }
53    }
54
55    DisplayWrapper { matrix, ring, col_count, row_count }
56}
57
58///
59/// Defines a very general way of comparing anything that can be interpreted as a
60/// matrix in [`matrix_compare::MatrixCompare`]. Used solely in [`crate::assert_matrix_eq`],
61/// with tests being the primary use case.
62/// 
63pub mod matrix_compare {
64    use super::*;
65
66    ///
67    /// Used by [`crate::assert_matrix_eq`] to compare objects that are matrices in a
68    /// very general sense.
69    /// 
70    pub trait MatrixCompare<T> {
71        fn row_count(&self) -> usize;
72        fn col_count(&self) -> usize;
73        fn at(&self, i: usize, j: usize) -> &T;
74    }
75
76    impl<T, const ROWS: usize, const COLS: usize> MatrixCompare<T> for [[T; COLS]; ROWS] {
77
78        fn col_count(&self) -> usize { COLS }
79        fn row_count(&self) -> usize { ROWS }
80        fn at(&self, i: usize, j: usize) -> &T { &self[i][j] }
81    }
82
83    impl<T, const ROWS: usize, const COLS: usize> MatrixCompare<T> for [DerefArray<T, COLS>; ROWS] {
84
85        fn col_count(&self) -> usize { COLS }
86        fn row_count(&self) -> usize { ROWS }
87        fn at(&self, i: usize, j: usize) -> &T { &self[i][j] }
88    }
89
90    impl<'a, V: AsPointerToSlice<T>, T> MatrixCompare<T> for Submatrix<'a, V, T> {
91
92        fn col_count(&self) -> usize { Submatrix::col_count(self) }
93        fn row_count(&self) -> usize { Submatrix::row_count(self) }
94        fn at(&self, i: usize, j: usize) -> &T { Submatrix::at(self, i, j) }
95    }
96
97    impl<'a, V: AsPointerToSlice<T>, T> MatrixCompare<T> for SubmatrixMut<'a, V, T> {
98
99        fn col_count(&self) -> usize { SubmatrixMut::col_count(self) }
100        fn row_count(&self) -> usize { SubmatrixMut::row_count(self) }
101        fn at(&self, i: usize, j: usize) -> &T { self.as_const().into_at(i, j) }
102    }
103
104    impl<'a, V: AsPointerToSlice<T>, T, const TRANSPOSED: bool> MatrixCompare<T> for TransposableSubmatrix<'a, V, T, TRANSPOSED> {
105
106        fn col_count(&self) -> usize { TransposableSubmatrix::col_count(self) }
107        fn row_count(&self) -> usize { TransposableSubmatrix::row_count(self) }
108        fn at(&self, i: usize, j: usize) -> &T { TransposableSubmatrix::at(self, i, j) }
109    }
110
111    impl<'a, V: AsPointerToSlice<T>, T, const TRANSPOSED: bool> MatrixCompare<T> for TransposableSubmatrixMut<'a, V, T, TRANSPOSED> {
112
113        fn col_count(&self) -> usize { TransposableSubmatrixMut::col_count(self) }
114        fn row_count(&self) -> usize { TransposableSubmatrixMut::row_count(self) }
115        fn at(&self, i: usize, j: usize) -> &T { self.as_const().into_at(i, j) }
116    }
117
118    impl<T> MatrixCompare<T> for OwnedMatrix<T> {
119
120        fn col_count(&self) -> usize { OwnedMatrix::col_count(self) }
121        fn row_count(&self) -> usize { OwnedMatrix::row_count(self) }
122        fn at(&self, i: usize, j: usize) -> &T { OwnedMatrix::at(self, i, j) }
123    }
124
125    impl<'a, T, M: MatrixCompare<T>> MatrixCompare<T> for &'a M {
126        
127        fn col_count(&self) -> usize { (**self).col_count() }
128        fn row_count(&self) -> usize { (**self).row_count() }
129        fn at(&self, i: usize, j: usize) -> &T { (**self).at(i, j) }
130    }
131
132    pub fn is_matrix_eq<R: ?Sized + RingBase, M1: MatrixCompare<R::Element>, M2: MatrixCompare<R::Element>>(ring: &R, lhs: &M1, rhs: &M2) -> bool {
133        if lhs.row_count() != rhs.row_count() || lhs.col_count() != rhs.col_count() {
134            return false;
135        }
136        for i in 0..lhs.row_count() {
137            for j in 0..lhs.col_count() {
138                if !ring.eq_el(lhs.at(i, j), rhs.at(i, j)) {
139                    return false;
140                }
141            }
142        }
143        return true;
144    }
145}
146
147///
148/// Variant of `assert_eq!` for matrices elements, i.e. assert that two ring matrices are equal.
149/// Frequently used in tests.
150/// 
151/// ```
152/// # use feanor_math::ring::*;
153/// # use feanor_math::primitive_int::*;
154/// # use feanor_math::matrix::*;
155/// # use feanor_math::assert_matrix_eq;
156/// let lhs = [[0, 0, 0], [0, 0, 0], [0, 0, 0]];
157/// let rhs = OwnedMatrix::zero(3, 3, StaticRing::<i64>::RING);
158/// assert_matrix_eq!(StaticRing::<i64>::RING, lhs, rhs);
159///  ```
160/// 
161#[macro_export]
162macro_rules! assert_matrix_eq {
163    ($ring:expr, $lhs:expr, $rhs:expr) => {
164        match (&$ring, &$lhs, &$rhs) {
165            (ring_val, lhs_val, rhs_val) => {
166                assert!(
167                    $crate::matrix::matrix_compare::is_matrix_eq(ring_val.get_ring(), lhs_val, rhs_val), 
168                    "Assertion failed: Expected\n{}\nto be\n{}", 
169                    $crate::matrix::format_matrix(<_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::row_count(lhs_val), <_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::col_count(lhs_val), |i, j| <_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::at(lhs_val, i, j), ring_val),
170                    $crate::matrix::format_matrix(<_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::row_count(rhs_val), <_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::col_count(rhs_val), |i, j| <_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::at(rhs_val, i, j), ring_val)
171                );
172            }
173        }
174    }
175}