feanor_math/matrix/
transform.rs

1use crate::ring::*;
2
3use super::{AsPointerToSlice, OwnedMatrix, SubmatrixMut};
4
5///
6/// A trait for a "target" that can "consume" elementary operations on matrices.
7///  
8/// This is mainly used during algorithms that work on matrices, since in many cases
9/// they transform matrices using elementary row or column operations, and have to
10/// accumulate data depending on these operations.
11/// 
12pub trait TransformTarget<R>
13    where R: ?Sized + RingBase
14{
15    ///
16    /// The transformation given by the matrix `A` with `A[k, l]` being
17    ///  - `1` if `k = l` and `k != i, j`
18    ///  - `transform[0]` if `(k, l) = (i, i)`
19    ///  - `transform[1]` if `(k, l) = (i, j)`
20    ///  - `transform[2]` if `(k, l) = (j, i)`
21    ///  - `transform[3]` if `(k, l) = (j, j)`
22    ///  - `0` otherwise
23    /// 
24    /// In other words, the matrix looks like
25    /// ```text
26    /// | 1  ...  0                       |
27    /// | ⋮        ⋮                       |
28    /// | 0  ...  1                       |
29    /// |    A             B              | <- i-th row
30    /// |            1  ...  0            |
31    /// |            ⋮        ⋮            |
32    /// |            0  ...  1            |
33    /// |    C             D              | <- j-th row
34    /// |                       1  ...  0 |
35    /// |                       ⋮        ⋮ |
36    /// |                       0  ...  1 |
37    ///      ^ i-th col    ^ j-th col
38    /// ```
39    /// where `transform = [A, B, C, D]`.
40    /// 
41    fn transform<S: Copy + RingStore<Type = R>>(&mut self, ring: S, i: usize, j: usize, transform: &[R::Element; 4]);
42
43    ///
44    /// The transformation corresponding to subtracting `factor` times the `src`-th row
45    /// resp. col from the `dst`-th row resp. col.
46    /// 
47    /// More precisely, the `(k, l)`-th entry of the transform matrix is defined to be
48    ///  - `1` if `k == l`
49    ///  - `-factor` if `k == dst, l == src`
50    ///  - `0` otherwise
51    /// 
52    fn subtract<S: Copy + RingStore<Type = R>>(&mut self, ring: S, src: usize, dst: usize, factor: &R::Element) {
53        self.transform(ring, src, dst, &[ring.one(), ring.zero(), ring.negate(ring.clone_el(factor)), ring.one()])
54    }
55
56    ///
57    /// The transformation corresponding to the permutation matrix swapping `i`-th and `j`-th row
58    /// resp. column.
59    /// 
60    /// More precisely, the `(k, l)`-th entry of the transform matrix is defined to be
61    ///  - `1` if `k == l, k != i, k != j`
62    ///  - `1` if `k == i, l == j`
63    ///  - `1` if `k == j, l == i`
64    ///  - `0` otherwise
65    /// 
66    fn swap<S: Copy + RingStore<Type = R>>(&mut self, ring: S, i: usize, j: usize) {
67        self.transform(ring, i, j, &[ring.zero(), ring.one(), ring.one(), ring.zero()])
68    }
69}
70
71///
72/// Wraps a [`SubmatrixMut`] to get a [`TransformTarget`]. Every transform is multiplied to
73/// the wrapped matrix from the left, i.e. applied to the rows of the matrix.
74/// 
75pub struct TransformRows<'a, V, R>(pub SubmatrixMut<'a, V, R::Element>, pub &'a R)
76    where V: AsPointerToSlice<R::Element>, R: ?Sized + RingBase;
77
78///
79/// Wraps a [`SubmatrixMut`] to get a [`TransformTarget`]. Every transform is multiplied to
80/// the wrapped matrix from the right, i.e. applied to the cols of the matrix.
81/// 
82pub struct TransformCols<'a, V, R>(pub SubmatrixMut<'a, V, R::Element>, pub &'a R)
83    where V: AsPointerToSlice<R::Element>, R: ?Sized + RingBase;
84
85impl<'a, V, R> TransformTarget<R> for TransformRows<'a, V, R>
86    where V: AsPointerToSlice<R::Element>, R: ?Sized + RingBase
87{
88    fn transform<S: Copy + RingStore<Type = R>>(&mut self, ring: S, i: usize, j: usize, transform: &[<R as RingBase>::Element; 4]) {
89        assert!(ring.get_ring() == self.1);
90        let A = &mut self.0;
91        for l in 0..A.col_count() {
92            let (new_i, new_j) = (
93                ring.add(ring.mul_ref(A.at(i, l), &transform[0]), ring.mul_ref(A.at(j, l), &transform[1])),
94                ring.add(ring.mul_ref(A.at(i, l), &transform[2]), ring.mul_ref(A.at(j, l), &transform[3]))
95            );
96            *A.at_mut(i, l) = new_i;
97            *A.at_mut(j, l) = new_j;
98        }
99    }
100
101    fn subtract<S: Copy + RingStore<Type = R>>(&mut self, ring: S, src: usize, dst: usize, factor: &<R as RingBase>::Element) {
102        assert!(ring.get_ring() == self.1);
103        let A = &mut self.0;
104        for j in 0..A.col_count() {
105            let to_sub = ring.mul_ref(factor, A.at(src, j));
106            ring.sub_assign(A.at_mut(dst, j), to_sub);
107        }
108    }
109}
110
111impl<'a, V, R> TransformTarget<R> for TransformCols<'a, V, R>
112    where V: AsPointerToSlice<R::Element>, R: ?Sized + RingBase
113{
114    fn transform<S: Copy + RingStore<Type = R>>(&mut self, ring: S, i: usize, j: usize, transform: &[<R as RingBase>::Element; 4]) {
115        assert!(ring.get_ring() == self.1);
116        let A = &mut self.0;
117        for l in 0..A.row_count() {
118            let (new_i, new_j) = (
119                ring.add(ring.mul_ref(A.at(l, i), &transform[0]), ring.mul_ref(A.at(l, j), &transform[1])),
120                ring.add(ring.mul_ref(A.at(l, i), &transform[2]), ring.mul_ref(A.at(l, j), &transform[3]))
121            );
122            *A.at_mut(l, i) = new_i;
123            *A.at_mut(l, j) = new_j;
124        }
125    }
126
127    fn subtract<S: Copy + RingStore<Type = R>>(&mut self, ring: S, src: usize, dst: usize, factor: &<R as RingBase>::Element) {
128        assert!(ring.get_ring() == self.1);
129        let A = &mut self.0;
130        for i in 0..A.row_count() {
131            let to_sub = ring.mul_ref(factor, A.at(i, src));
132            ring.sub_assign(A.at_mut(i, dst), to_sub);
133        }
134    }
135}
136
137enum Transform<R>
138    where R: ?Sized + RingBase
139{
140    General(usize, usize, [R::Element; 4]),
141    Subtract(usize, usize, R::Element),
142    Swap(usize, usize)
143}
144
145#[stability::unstable(feature = "enable")]
146pub struct TransformList<R>
147    where R: ?Sized + RingBase
148{
149    transforms: Vec<Transform<R>>,
150    row_count: usize
151}
152
153impl<R> TransformList<R>
154    where R: ?Sized + RingBase
155{    
156    #[stability::unstable(feature = "enable")]
157    pub fn new(row_count: usize) -> Self {
158        Self {
159            row_count: row_count,
160            transforms: Vec::new()
161        }
162    }
163
164    #[stability::unstable(feature = "enable")]
165    pub fn replay<S: Copy + RingStore<Type = R>, T: TransformTarget<R>>(&self, ring: S, mut target: T) {
166        for transform in &self.transforms {
167            match transform {
168                Transform::General(i, j, matrix) => target.transform(ring, *i, *j, matrix),
169                Transform::Subtract(src, dst, factor) => target.subtract(ring, *src, *dst, factor),
170                Transform::Swap(i, j) => target.swap(ring, *i, *j)
171            }
172        }
173    }
174
175    #[stability::unstable(feature = "enable")]
176    pub fn replay_transposed<S: Copy + RingStore<Type = R>, T: TransformTarget<R>>(&self, ring: S, mut target: T) {
177        for transform in self.transforms.iter().rev() {
178            match transform {
179                Transform::General(i, j, matrix) => {
180                    target.transform(ring, *i, *j, &[
181                        ring.clone_el(&matrix[0]),
182                        ring.clone_el(&matrix[2]),
183                        ring.clone_el(&matrix[1]),
184                        ring.clone_el(&matrix[3])
185                    ])
186                },
187                Transform::Subtract(src, dst, factor) => target.subtract(ring, *dst, *src, factor),
188                Transform::Swap(i, j) => target.swap(ring, *i, *j)
189            }
190        }
191    }
192
193    #[stability::unstable(feature = "enable")]
194    pub fn to_matrix<S: Copy + RingStore<Type = R>>(&self, ring: S) -> OwnedMatrix<R::Element> {
195        let mut result = OwnedMatrix::identity(self.row_count, self.row_count, ring);
196        self.replay(ring, TransformRows(result.data_mut(), ring.get_ring()));
197        return result;
198    }
199}
200
201impl<R> TransformTarget<R> for TransformList<R>
202    where R: ?Sized + RingBase
203{
204    fn transform<S: Copy + RingStore<Type = R>>(&mut self, ring: S, i: usize, j: usize, transform: &[<R as RingBase>::Element; 4]) {
205        debug_assert!(i < self.row_count);
206        debug_assert!(j < self.row_count);
207        self.transforms.push(Transform::General(i, j, std::array::from_fn(|k| ring.clone_el(&transform[k]))))
208    }
209
210    fn subtract<S: Copy + RingStore<Type = R>>(&mut self, ring: S, src: usize, dst: usize, factor: &<R as RingBase>::Element) {
211        debug_assert!(src < self.row_count);
212        debug_assert!(dst < self.row_count);
213        self.transforms.push(Transform::Subtract(src, dst, ring.clone_el(factor)))
214    }
215
216    fn swap<S: Copy + RingStore<Type = R>>(&mut self, _ring: S, i: usize, j: usize) {
217        debug_assert!(i < self.row_count);
218        debug_assert!(j < self.row_count);
219        self.transforms.push(Transform::Swap(i, j))
220    }
221}