Skip to main content

feanor_math/matrix/
mod.rs

1use std::fmt::Display;
2
3use crate::ring::*;
4
5mod owned;
6mod submatrix;
7mod transpose;
8
9pub use owned::*;
10pub use submatrix::*;
11#[allow(unused_imports)]
12pub use transpose::*;
13
14/// Contains the trait [`transform::TransformTarget`], for "consumers" of elementary
15/// matrix operations.
16pub mod transform;
17
18#[stability::unstable(feature = "enable")]
19pub fn format_matrix<'a, M, R>(row_count: usize, col_count: usize, matrix: M, ring: R) -> impl 'a + Display
20where
21    R: 'a + RingStore,
22    El<R>: 'a,
23    M: 'a + Fn(usize, usize) -> &'a El<R>,
24{
25    struct DisplayWrapper<'a, R: 'a + RingStore, M: Fn(usize, usize) -> &'a El<R>> {
26        matrix: M,
27        ring: R,
28        row_count: usize,
29        col_count: usize,
30    }
31
32    impl<'a, R: 'a + RingStore, M: Fn(usize, usize) -> &'a El<R>> Display for DisplayWrapper<'a, R, M> {
33        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34            let strings = (0..self.row_count)
35                .flat_map(|i| (0..self.col_count).map(move |j| (i, j)))
36                .map(|(i, j)| format!("{}", self.ring.format((self.matrix)(i, j))))
37                .collect::<Vec<_>>();
38            let max_len = strings.iter().map(|s| s.chars().count()).chain([2]).max().unwrap();
39            let mut strings = strings.into_iter();
40            for i in 0..self.row_count {
41                write!(f, "|")?;
42                if self.col_count > 0 {
43                    write!(f, "{:>width$}", strings.next().unwrap(), width = max_len)?;
44                }
45                for _ in 1..self.col_count {
46                    write!(f, ",{:>width$}", strings.next().unwrap(), width = max_len)?;
47                }
48                if i + 1 != self.row_count {
49                    writeln!(f, "|")?;
50                } else {
51                    write!(f, "|")?;
52                }
53            }
54            return Ok(());
55        }
56    }
57
58    DisplayWrapper {
59        matrix,
60        ring,
61        col_count,
62        row_count,
63    }
64}
65
66/// Defines a very general way of comparing anything that can be interpreted as a
67/// matrix in [`matrix_compare::MatrixCompare`]. Used solely in [`crate::assert_matrix_eq`],
68/// with tests being the primary use case.
69pub mod matrix_compare {
70    use super::*;
71
72    /// Used by [`crate::assert_matrix_eq`] to compare objects that are matrices in a
73    /// very general sense.
74    pub trait MatrixCompare<T> {
75        /// Returns the number of rows of the matrix.
76        fn row_count(&self) -> usize;
77
78        /// Returns the number of columns of the matrix.
79        fn col_count(&self) -> usize;
80
81        /// Returns a reference to the element at the given position.
82        fn at(&self, i: usize, j: usize) -> &T;
83    }
84
85    impl<T, const ROWS: usize, const COLS: usize> MatrixCompare<T> for [[T; COLS]; ROWS] {
86        fn col_count(&self) -> usize { COLS }
87        fn row_count(&self) -> usize { ROWS }
88        fn at(&self, i: usize, j: usize) -> &T { &self[i][j] }
89    }
90
91    impl<T, const ROWS: usize, const COLS: usize> MatrixCompare<T> for [DerefArray<T, COLS>; ROWS] {
92        fn col_count(&self) -> usize { COLS }
93        fn row_count(&self) -> usize { ROWS }
94        fn at(&self, i: usize, j: usize) -> &T { &self[i][j] }
95    }
96
97    impl<'a, V: AsPointerToSlice<T>, T> MatrixCompare<T> for Submatrix<'a, V, T> {
98        fn col_count(&self) -> usize { Submatrix::col_count(self) }
99        fn row_count(&self) -> usize { Submatrix::row_count(self) }
100        fn at(&self, i: usize, j: usize) -> &T { Submatrix::at(self, i, j) }
101    }
102
103    impl<'a, V: AsPointerToSlice<T>, T> MatrixCompare<T> for SubmatrixMut<'a, V, T> {
104        fn col_count(&self) -> usize { SubmatrixMut::col_count(self) }
105        fn row_count(&self) -> usize { SubmatrixMut::row_count(self) }
106        fn at(&self, i: usize, j: usize) -> &T { self.as_const().into_at(i, j) }
107    }
108
109    impl<'a, V: AsPointerToSlice<T>, T, const TRANSPOSED: bool> MatrixCompare<T>
110        for TransposableSubmatrix<'a, V, T, TRANSPOSED>
111    {
112        fn col_count(&self) -> usize { TransposableSubmatrix::col_count(self) }
113        fn row_count(&self) -> usize { TransposableSubmatrix::row_count(self) }
114        fn at(&self, i: usize, j: usize) -> &T { TransposableSubmatrix::at(self, i, j) }
115    }
116
117    impl<'a, V: AsPointerToSlice<T>, T, const TRANSPOSED: bool> MatrixCompare<T>
118        for TransposableSubmatrixMut<'a, V, T, TRANSPOSED>
119    {
120        fn col_count(&self) -> usize { TransposableSubmatrixMut::col_count(self) }
121        fn row_count(&self) -> usize { TransposableSubmatrixMut::row_count(self) }
122        fn at(&self, i: usize, j: usize) -> &T { self.as_const().into_at(i, j) }
123    }
124
125    impl<T> MatrixCompare<T> for OwnedMatrix<T> {
126        fn col_count(&self) -> usize { OwnedMatrix::col_count(self) }
127        fn row_count(&self) -> usize { OwnedMatrix::row_count(self) }
128        fn at(&self, i: usize, j: usize) -> &T { OwnedMatrix::at(self, i, j) }
129    }
130
131    impl<T, M: MatrixCompare<T>> MatrixCompare<T> for &M {
132        fn col_count(&self) -> usize { (**self).col_count() }
133        fn row_count(&self) -> usize { (**self).row_count() }
134        fn at(&self, i: usize, j: usize) -> &T { (**self).at(i, j) }
135    }
136
137    /// Checks whether two matrices are equal.
138    ///
139    /// The prime use case is [`crate::assert_matrix_eq!`], which should be applicable
140    /// to all types that are, in some sense, interpretable as matrices.
141    pub fn is_matrix_eq<R: ?Sized + RingBase, M1: MatrixCompare<R::Element>, M2: MatrixCompare<R::Element>>(
142        ring: &R,
143        lhs: &M1,
144        rhs: &M2,
145    ) -> bool {
146        if lhs.row_count() != rhs.row_count() || lhs.col_count() != rhs.col_count() {
147            return false;
148        }
149        for i in 0..lhs.row_count() {
150            for j in 0..lhs.col_count() {
151                if !ring.eq_el(lhs.at(i, j), rhs.at(i, j)) {
152                    return false;
153                }
154            }
155        }
156        return true;
157    }
158}
159
160/// Variant of `assert_eq!` for matrices elements, i.e. assert that two ring matrices are equal.
161/// Frequently used in tests.
162///
163/// This takes any arguments that implement [`matrix_compare::MatrixCompare`], a very minimal trait
164/// for matrices that is implemented for all types that are, in some sense, interpretable as
165/// matrices.
166///
167/// # Example
168/// ```rust
169/// # use feanor_math::ring::*;
170/// # use feanor_math::primitive_int::*;
171/// # use feanor_math::matrix::*;
172/// # use feanor_math::assert_matrix_eq;
173/// let lhs = [[0, 0, 0], [0, 0, 0], [0, 0, 0]];
174/// let rhs = OwnedMatrix::zero(3, 3, StaticRing::<i64>::RING);
175/// assert_matrix_eq!(StaticRing::<i64>::RING, lhs, rhs);
176///  ```
177#[macro_export]
178macro_rules! assert_matrix_eq {
179    ($ring:expr,$lhs:expr,$rhs:expr) => {
180        match (&$ring, &$lhs, &$rhs) {
181            (ring_val, lhs_val, rhs_val) => {
182                assert!(
183                    $crate::matrix::matrix_compare::is_matrix_eq(ring_val.get_ring(), lhs_val, rhs_val),
184                    "Assertion failed: Expected\n{}\nto be\n{}",
185                    $crate::matrix::format_matrix(
186                        <_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::row_count(lhs_val),
187                        <_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::col_count(lhs_val),
188                        |i, j| <_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::at(lhs_val, i, j),
189                        ring_val
190                    ),
191                    $crate::matrix::format_matrix(
192                        <_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::row_count(rhs_val),
193                        <_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::col_count(rhs_val),
194                        |i, j| <_ as $crate::matrix::matrix_compare::MatrixCompare<_>>::at(rhs_val, i, j),
195                        ring_val
196                    )
197                );
198            }
199        }
200    };
201}