feanor_math/matrix/
mod.rs

1use std::fmt::Display;
2
3use crate::ring::*;
4
5mod submatrix;
6mod transpose;
7mod owned;
8
9pub use submatrix::*;
10#[allow(unused_imports)]
11pub use transpose::*;
12pub use owned::*;
13
14///
15/// Contains the trait [`transform::TransformTarget`], for "consumers" of elementary
16/// matrix operations.
17/// 
18pub mod transform;
19
20#[stability::unstable(feature = "enable")]
21pub fn format_matrix<'a, M, R>(row_count: usize, col_count: usize, matrix: M, ring: R) -> impl 'a + Display
22    where R: 'a + RingStore, 
23        El<R>: 'a,
24        M: 'a + Fn(usize, usize) -> &'a El<R>
25{
26    struct DisplayWrapper<'a, R: 'a + RingStore, M: Fn(usize, usize) -> &'a El<R>> {
27        matrix: M,
28        ring: R,
29        row_count: usize,
30        col_count: usize
31    }
32
33    impl<'a, R: 'a + RingStore, M: Fn(usize, usize) -> &'a El<R>> Display for DisplayWrapper<'a, R, M> {
34
35        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36            let strings = (0..self.row_count).flat_map(|i| (0..self.col_count).map(move |j| (i, j)))
37                .map(|(i, j)| format!("{}", self.ring.format((self.matrix)(i, j))))
38                .collect::<Vec<_>>();
39            let max_len = strings.iter().map(|s| s.chars().count()).chain([2].into_iter()).max().unwrap();
40            let mut strings = strings.into_iter();
41            for i in 0..self.row_count {
42                write!(f, "|")?;
43                if self.col_count > 0 {
44                    write!(f, "{:>width$}", strings.next().unwrap(), width = max_len)?;
45                }
46                for _ in 1..self.col_count {
47                    write!(f, ",{:>width$}", strings.next().unwrap(), width = max_len)?;
48                }
49                if i + 1 != self.row_count {
50                    writeln!(f, "|")?;
51                } else {
52                    write!(f, "|")?;
53                }
54            }
55            return Ok(());
56        }
57    }
58
59    DisplayWrapper { matrix, ring, col_count, row_count }
60}
61
62///
63/// Defines a very general way of comparing anything that can be interpreted as a
64/// matrix in [`matrix_compare::MatrixCompare`]. Used solely in [`crate::assert_matrix_eq`],
65/// with tests being the primary use case.
66/// 
67pub mod matrix_compare {
68    use super::*;
69
70    ///
71    /// Used by [`crate::assert_matrix_eq`] to compare objects that are matrices in a
72    /// very general sense.
73    /// 
74    pub trait MatrixCompare<T> {
75
76        ///
77        /// Returns the number of rows of the matrix.
78        /// 
79        fn row_count(&self) -> usize;
80
81        ///
82        /// Returns the number of columns of the matrix.
83        /// 
84        fn col_count(&self) -> usize;
85
86        ///
87        /// Returns a reference to the element at the given position.
88        /// 
89        fn at(&self, i: usize, j: usize) -> &T;
90    }
91
92    impl<T, const ROWS: usize, const COLS: usize> MatrixCompare<T> for [[T; COLS]; ROWS] {
93
94        fn col_count(&self) -> usize { COLS }
95        fn row_count(&self) -> usize { ROWS }
96        fn at(&self, i: usize, j: usize) -> &T { &self[i][j] }
97    }
98
99    impl<T, const ROWS: usize, const COLS: usize> MatrixCompare<T> for [DerefArray<T, COLS>; ROWS] {
100
101        fn col_count(&self) -> usize { COLS }
102        fn row_count(&self) -> usize { ROWS }
103        fn at(&self, i: usize, j: usize) -> &T { &self[i][j] }
104    }
105
106    impl<'a, V: AsPointerToSlice<T>, T> MatrixCompare<T> for Submatrix<'a, V, T> {
107
108        fn col_count(&self) -> usize { Submatrix::col_count(self) }
109        fn row_count(&self) -> usize { Submatrix::row_count(self) }
110        fn at(&self, i: usize, j: usize) -> &T { Submatrix::at(self, i, j) }
111    }
112
113    impl<'a, V: AsPointerToSlice<T>, T> MatrixCompare<T> for SubmatrixMut<'a, V, T> {
114
115        fn col_count(&self) -> usize { SubmatrixMut::col_count(self) }
116        fn row_count(&self) -> usize { SubmatrixMut::row_count(self) }
117        fn at(&self, i: usize, j: usize) -> &T { self.as_const().into_at(i, j) }
118    }
119
120    impl<'a, V: AsPointerToSlice<T>, T, const TRANSPOSED: bool> MatrixCompare<T> for TransposableSubmatrix<'a, V, T, TRANSPOSED> {
121
122        fn col_count(&self) -> usize { TransposableSubmatrix::col_count(self) }
123        fn row_count(&self) -> usize { TransposableSubmatrix::row_count(self) }
124        fn at(&self, i: usize, j: usize) -> &T { TransposableSubmatrix::at(self, i, j) }
125    }
126
127    impl<'a, V: AsPointerToSlice<T>, T, const TRANSPOSED: bool> MatrixCompare<T> for TransposableSubmatrixMut<'a, V, T, TRANSPOSED> {
128
129        fn col_count(&self) -> usize { TransposableSubmatrixMut::col_count(self) }
130        fn row_count(&self) -> usize { TransposableSubmatrixMut::row_count(self) }
131        fn at(&self, i: usize, j: usize) -> &T { self.as_const().into_at(i, j) }
132    }
133
134    impl<T> MatrixCompare<T> for OwnedMatrix<T> {
135
136        fn col_count(&self) -> usize { OwnedMatrix::col_count(self) }
137        fn row_count(&self) -> usize { OwnedMatrix::row_count(self) }
138        fn at(&self, i: usize, j: usize) -> &T { OwnedMatrix::at(self, i, j) }
139    }
140
141    impl<'a, T, M: MatrixCompare<T>> MatrixCompare<T> for &'a M {
142        
143        fn col_count(&self) -> usize { (**self).col_count() }
144        fn row_count(&self) -> usize { (**self).row_count() }
145        fn at(&self, i: usize, j: usize) -> &T { (**self).at(i, j) }
146    }
147
148    ///
149    /// Checks whether two matrices are equal.
150    /// 
151    /// The prime use case is [`crate::assert_matrix_eq!`], which should be applicable
152    /// to all types that are, in some sense, interpretable as matrices.
153    /// 
154    pub fn is_matrix_eq<R: ?Sized + RingBase, M1: MatrixCompare<R::Element>, M2: MatrixCompare<R::Element>>(ring: &R, lhs: &M1, rhs: &M2) -> bool {
155        if lhs.row_count() != rhs.row_count() || lhs.col_count() != rhs.col_count() {
156            return false;
157        }
158        for i in 0..lhs.row_count() {
159            for j in 0..lhs.col_count() {
160                if !ring.eq_el(lhs.at(i, j), rhs.at(i, j)) {
161                    return false;
162                }
163            }
164        }
165        return true;
166    }
167}
168
169///
170/// Variant of `assert_eq!` for matrices elements, i.e. assert that two ring matrices are equal.
171/// Frequently used in tests.
172/// 
173/// This takes any arguments that implement [`matrix_compare::MatrixCompare`], a very minimal trait
174/// for matrices that is implemented for all types that are, in some sense, interpretable as matrices.
175/// 
176/// # Example
177/// ```rust
178/// # use feanor_math::ring::*;
179/// # use feanor_math::primitive_int::*;
180/// # use feanor_math::matrix::*;
181/// # use feanor_math::assert_matrix_eq;
182/// let lhs = [[0, 0, 0], [0, 0, 0], [0, 0, 0]];
183/// let rhs = OwnedMatrix::zero(3, 3, StaticRing::<i64>::RING);
184/// assert_matrix_eq!(StaticRing::<i64>::RING, lhs, rhs);
185///  ```
186/// 
187#[macro_export]
188macro_rules! assert_matrix_eq {
189    ($ring:expr, $lhs:expr, $rhs:expr) => {
190        match (&$ring, &$lhs, &$rhs) {
191            (ring_val, lhs_val, rhs_val) => {
192                assert!(
193                    $crate::matrix::matrix_compare::is_matrix_eq(ring_val.get_ring(), lhs_val, rhs_val), 
194                    "Assertion failed: Expected\n{}\nto be\n{}", 
195                    $crate::matrix::format_matrix(<_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::row_count(lhs_val), <_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::col_count(lhs_val), |i, j| <_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::at(lhs_val, i, j), ring_val),
196                    $crate::matrix::format_matrix(<_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::row_count(rhs_val), <_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::col_count(rhs_val), |i, j| <_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::at(rhs_val, i, j), ring_val)
197                );
198            }
199        }
200    }
201}