Skip to main content

feanor_math/matrix/
transform.rs

1use std::marker::PhantomData;
2
3use crate::ring::*;
4
5use super::{AsPointerToSlice, OwnedMatrix, SubmatrixMut};
6
7///
8/// A trait for a "target" that can "consume" elementary operations on matrices.
9///  
10/// This is mainly used during algorithms that work on matrices, since in many cases
11/// they transform matrices using elementary row or column operations, and have to
12/// accumulate data depending on these operations.
13/// 
14pub trait TransformTarget<R>
15    where R: ?Sized + RingBase
16{
17    ///
18    /// The transformation given by the matrix `A` with `A[k, l]` being
19    ///  - `1` if `k = l` and `k != i, j`
20    ///  - `transform[0]` if `(k, l) = (i, i)`
21    ///  - `transform[1]` if `(k, l) = (i, j)`
22    ///  - `transform[2]` if `(k, l) = (j, i)`
23    ///  - `transform[3]` if `(k, l) = (j, j)`
24    ///  - `0` otherwise
25    /// 
26    /// In other words, the matrix looks like
27    /// ```text
28    /// | 1  ...  0                       |
29    /// | ⋮        ⋮                       |
30    /// | 0  ...  1                       |
31    /// |    A             B              | <- i-th row
32    /// |            1  ...  0            |
33    /// |            ⋮        ⋮            |
34    /// |            0  ...  1            |
35    /// |    C             D              | <- j-th row
36    /// |                       1  ...  0 |
37    /// |                       ⋮        ⋮ |
38    /// |                       0  ...  1 |
39    ///      ^ i-th col    ^ j-th col
40    /// ```
41    /// where `transform = [A, B, C, D]`.
42    /// 
43    fn transform<S: Copy + RingStore<Type = R>>(&mut self, ring: S, i: usize, j: usize, transform: &[R::Element; 4]);
44
45    ///
46    /// The transformation corresponding to subtracting `factor` times the `src`-th row
47    /// resp. col from the `dst`-th row resp. col.
48    /// 
49    /// More precisely, the `(k, l)`-th entry of the transform matrix is defined to be
50    ///  - `1` if `k == l`
51    ///  - `-factor` if `k == dst, l == src`
52    ///  - `0` otherwise
53    /// 
54    fn subtract<S: Copy + RingStore<Type = R>>(&mut self, ring: S, src: usize, dst: usize, factor: &R::Element) {
55        self.transform(ring, src, dst, &[ring.one(), ring.zero(), ring.negate(ring.clone_el(factor)), ring.one()])
56    }
57
58    ///
59    /// The transformation corresponding to the permutation matrix swapping `i`-th and `j`-th row
60    /// resp. column.
61    /// 
62    /// More precisely, the `(k, l)`-th entry of the transform matrix is defined to be
63    ///  - `1` if `k == l, k != i, k != j`
64    ///  - `1` if `k == i, l == j`
65    ///  - `1` if `k == j, l == i`
66    ///  - `0` otherwise
67    /// 
68    fn swap<S: Copy + RingStore<Type = R>>(&mut self, ring: S, i: usize, j: usize) {
69        self.transform(ring, i, j, &[ring.zero(), ring.one(), ring.one(), ring.zero()])
70    }
71}
72
73impl<'a, T, R> TransformTarget<R> for &'a mut T
74    where R: ?Sized + RingBase,
75        T: TransformTarget<R>
76{
77    fn transform<S: Copy + RingStore<Type = R>>(&mut self, ring: S, i: usize, j: usize, transform: &[R::Element; 4]) {
78        <T as TransformTarget<R>>::transform(*self, ring, i, j, transform)
79    }
80
81    fn subtract<S: Copy + RingStore<Type = R>>(&mut self, ring: S, src: usize, dst: usize, factor: &R::Element) {
82        <T as TransformTarget<R>>::subtract(*self, ring, src, dst, factor)
83    }
84
85    fn swap<S: Copy + RingStore<Type = R>>(&mut self, ring: S, i: usize, j: usize) {
86        <T as TransformTarget<R>>::swap(*self, ring, i, j)
87    }
88}
89
90///
91/// Wraps a [`SubmatrixMut`] to get a [`TransformTarget`]. Every transform is multiplied to
92/// the wrapped matrix from the left, i.e. applied to the rows of the matrix.
93/// 
94/// TODO: at next breaking release, remove the reference to the ring
95/// 
96pub struct TransformRows<'a, V, R>(pub SubmatrixMut<'a, V, R::Element>, pub &'a R)
97    where V: AsPointerToSlice<R::Element>, R: ?Sized + RingBase;
98
99///
100/// Wraps a [`SubmatrixMut`] to get a [`TransformTarget`]. Every transform is multiplied to
101/// the wrapped matrix from the right, i.e. applied to the cols of the matrix.
102/// 
103/// TODO: at next breaking release, remove the reference to the ring
104/// 
105pub struct TransformCols<'a, V, R>(pub SubmatrixMut<'a, V, R::Element>, pub &'a R)
106    where V: AsPointerToSlice<R::Element>, R: ?Sized + RingBase;
107
108impl<'a, V, R> TransformTarget<R> for TransformRows<'a, V, R>
109    where V: AsPointerToSlice<R::Element>, R: ?Sized + RingBase
110{
111    fn transform<S: Copy + RingStore<Type = R>>(&mut self, ring: S, i: usize, j: usize, transform: &[<R as RingBase>::Element; 4]) {
112        assert!(ring.get_ring() == self.1);
113        let A = &mut self.0;
114        for l in 0..A.col_count() {
115            let (new_i, new_j) = (
116                ring.add(ring.mul_ref(A.at(i, l), &transform[0]), ring.mul_ref(A.at(j, l), &transform[1])),
117                ring.add(ring.mul_ref(A.at(i, l), &transform[2]), ring.mul_ref(A.at(j, l), &transform[3]))
118            );
119            *A.at_mut(i, l) = new_i;
120            *A.at_mut(j, l) = new_j;
121        }
122    }
123
124    fn subtract<S: Copy + RingStore<Type = R>>(&mut self, ring: S, src: usize, dst: usize, factor: &<R as RingBase>::Element) {
125        assert!(ring.get_ring() == self.1);
126        let A = &mut self.0;
127        for j in 0..A.col_count() {
128            let to_sub = ring.mul_ref(factor, A.at(src, j));
129            ring.sub_assign(A.at_mut(dst, j), to_sub);
130        }
131    }
132}
133
134impl<'a, V, R> TransformTarget<R> for TransformCols<'a, V, R>
135    where V: AsPointerToSlice<R::Element>, R: ?Sized + RingBase
136{
137    fn transform<S: Copy + RingStore<Type = R>>(&mut self, ring: S, i: usize, j: usize, transform: &[<R as RingBase>::Element; 4]) {
138        assert!(ring.get_ring() == self.1);
139        let A = &mut self.0;
140        for l in 0..A.row_count() {
141            let (new_i, new_j) = (
142                ring.add(ring.mul_ref(A.at(l, i), &transform[0]), ring.mul_ref(A.at(l, j), &transform[1])),
143                ring.add(ring.mul_ref(A.at(l, i), &transform[2]), ring.mul_ref(A.at(l, j), &transform[3]))
144            );
145            *A.at_mut(l, i) = new_i;
146            *A.at_mut(l, j) = new_j;
147        }
148    }
149
150    fn subtract<S: Copy + RingStore<Type = R>>(&mut self, ring: S, src: usize, dst: usize, factor: &<R as RingBase>::Element) {
151        assert!(ring.get_ring() == self.1);
152        let A = &mut self.0;
153        for i in 0..A.row_count() {
154            let to_sub = ring.mul_ref(factor, A.at(i, src));
155            ring.sub_assign(A.at_mut(i, dst), to_sub);
156        }
157    }
158}
159
160enum Transform<R>
161    where R: ?Sized + RingBase
162{
163    General(usize, usize, [R::Element; 4]),
164    Subtract(usize, usize, R::Element),
165    Swap(usize, usize)
166}
167
168#[stability::unstable(feature = "enable")]
169pub struct TransformList<R>
170    where R: ?Sized + RingBase
171{
172    transforms: Vec<Transform<R>>,
173    row_count: usize
174}
175
176impl<R> TransformList<R>
177    where R: ?Sized + RingBase
178{    
179    #[stability::unstable(feature = "enable")]
180    pub fn new(row_count: usize) -> Self {
181        Self {
182            row_count: row_count,
183            transforms: Vec::new()
184        }
185    }
186
187    #[stability::unstable(feature = "enable")]
188    pub fn replay<S: Copy + RingStore<Type = R>, T: TransformTarget<R>>(&self, ring: S, mut target: T) {
189        for transform in &self.transforms {
190            match transform {
191                Transform::General(i, j, matrix) => target.transform(ring, *i, *j, matrix),
192                Transform::Subtract(src, dst, factor) => target.subtract(ring, *src, *dst, factor),
193                Transform::Swap(i, j) => target.swap(ring, *i, *j)
194            }
195        }
196    }
197
198    #[stability::unstable(feature = "enable")]
199    pub fn replay_transposed<S: Copy + RingStore<Type = R>, T: TransformTarget<R>>(&self, ring: S, mut target: T) {
200        for transform in self.transforms.iter().rev() {
201            match transform {
202                Transform::General(i, j, matrix) => {
203                    target.transform(ring, *i, *j, &[
204                        ring.clone_el(&matrix[0]),
205                        ring.clone_el(&matrix[2]),
206                        ring.clone_el(&matrix[1]),
207                        ring.clone_el(&matrix[3])
208                    ])
209                },
210                Transform::Subtract(src, dst, factor) => target.subtract(ring, *dst, *src, factor),
211                Transform::Swap(i, j) => target.swap(ring, *i, *j)
212            }
213        }
214    }
215
216    #[stability::unstable(feature = "enable")]
217    pub fn to_matrix<S: Copy + RingStore<Type = R>>(&self, ring: S) -> OwnedMatrix<R::Element> {
218        let mut result = OwnedMatrix::identity(self.row_count, self.row_count, ring);
219        self.replay(ring, TransformRows(result.data_mut(), ring.get_ring()));
220        return result;
221    }
222}
223
224impl<R> TransformTarget<R> for TransformList<R>
225    where R: ?Sized + RingBase
226{
227    fn transform<S: Copy + RingStore<Type = R>>(&mut self, ring: S, i: usize, j: usize, transform: &[<R as RingBase>::Element; 4]) {
228        debug_assert!(i < self.row_count);
229        debug_assert!(j < self.row_count);
230        self.transforms.push(Transform::General(i, j, std::array::from_fn(|k| ring.clone_el(&transform[k]))))
231    }
232
233    fn subtract<S: Copy + RingStore<Type = R>>(&mut self, ring: S, src: usize, dst: usize, factor: &<R as RingBase>::Element) {
234        debug_assert!(src < self.row_count);
235        debug_assert!(dst < self.row_count);
236        self.transforms.push(Transform::Subtract(src, dst, ring.clone_el(factor)))
237    }
238
239    fn swap<S: Copy + RingStore<Type = R>>(&mut self, _ring: S, i: usize, j: usize) {
240        debug_assert!(i < self.row_count);
241        debug_assert!(j < self.row_count);
242        self.transforms.push(Transform::Swap(i, j))
243    }
244}
245
246impl<R> TransformTarget<R> for ()
247    where R: ?Sized + RingBase
248{
249    fn transform<S: Copy + RingStore<Type = R>>(&mut self, _: S, _: usize, _: usize, _: &[R::Element; 4]) {}
250
251    fn subtract<S: Copy + RingStore<Type = R>>(&mut self, _: S, _: usize, _: usize, _: &R::Element) {}
252
253    fn swap<S: Copy + RingStore<Type = R>>(&mut self, _: S, _: usize, _: usize) {}
254}
255
256///
257/// A [`TransformTarget`] that forwards all transforms to a fixed
258/// delegate, but offsets every row/column index by a given value.
259/// 
260pub struct OffsetTransformIndex<R, T>
261    where R: ?Sized + RingBase,
262        T: TransformTarget<R>
263{
264    delegate: T,
265    index_offset: usize,
266    ring: PhantomData<R>
267}
268
269impl<R, T> OffsetTransformIndex<R, T>
270    where R: ?Sized + RingBase,
271        T: TransformTarget<R>
272{
273    ///
274    /// Creates a new [`OffsetTransformIndex`] that forwards all transforms to `delegate`.
275    /// 
276    pub fn new(delegate: T, offset: usize) -> Self {
277        Self {
278            delegate: delegate, 
279            index_offset: offset, 
280            ring: PhantomData
281        }
282    }
283}
284
285impl<R, T> TransformTarget<R> for OffsetTransformIndex<R, T>
286    where R: ?Sized + RingBase,
287        T: TransformTarget<R>
288{
289    fn transform<S: Copy + RingStore<Type = R>>(&mut self, ring: S, i: usize, j: usize, transform: &[R::Element; 4]) {
290        <T as TransformTarget<R>>::transform(&mut self.delegate, ring, i + self.index_offset, j + self.index_offset, transform);
291    }
292
293    fn subtract<S: Copy + RingStore<Type = R>>(&mut self, ring: S, src: usize, dst: usize, factor: &R::Element) {
294        <T as TransformTarget<R>>::subtract(&mut self.delegate, ring, src + self.index_offset, dst + self.index_offset, factor);
295    }
296
297    fn swap<S: Copy + RingStore<Type = R>>(&mut self, ring: S, i: usize, j: usize) {
298        <T as TransformTarget<R>>::swap(&mut self.delegate, ring, i + self.index_offset, j + self.index_offset);
299    }
300}
301
302///
303/// A [`TransformTarget`] that forwards all transforms to
304/// two fixed delegates.
305/// 
306pub struct DuplicateTransforms<R, T1, T2>
307    where R: ?Sized + RingBase,
308        T1: TransformTarget<R>,
309        T2: TransformTarget<R>
310{
311    delegate1: T1,
312    delegate2: T2,
313    ring: PhantomData<R>
314}
315
316impl<R, T1, T2> DuplicateTransforms<R, T1, T2>
317    where R: ?Sized + RingBase,
318        T1: TransformTarget<R>,
319        T2: TransformTarget<R>
320{
321    ///
322    /// Creates a new [`DuplicateTransforms`] that forwards all transforms to `first` and `second`.
323    /// 
324    pub fn new(first: T1, second: T2) -> Self {
325        Self {
326            delegate1: first,
327            delegate2: second, 
328            ring: PhantomData
329        }
330    }
331}
332
333impl<R, T1, T2> TransformTarget<R> for DuplicateTransforms<R, T1, T2>
334    where R: ?Sized + RingBase,
335        T1: TransformTarget<R>,
336        T2: TransformTarget<R>
337{
338    fn transform<S: Copy + RingStore<Type = R>>(&mut self, ring: S, i: usize, j: usize, transform: &[R::Element; 4]) {
339        <T1 as TransformTarget<R>>::transform(&mut self.delegate1, ring, i, j, transform);
340        <T2 as TransformTarget<R>>::transform(&mut self.delegate2, ring, i, j, transform);
341    }
342
343    fn subtract<S: Copy + RingStore<Type = R>>(&mut self, ring: S, src: usize, dst: usize, factor: &R::Element) {
344        <T1 as TransformTarget<R>>::subtract(&mut self.delegate1, ring, src, dst, factor);
345        <T2 as TransformTarget<R>>::subtract(&mut self.delegate2, ring, src, dst, factor);
346    }
347
348    fn swap<S: Copy + RingStore<Type = R>>(&mut self, ring: S, i: usize, j: usize) {
349        <T1 as TransformTarget<R>>::swap(&mut self.delegate1, ring, i, j);
350        <T2 as TransformTarget<R>>::swap(&mut self.delegate2, ring, i, j);
351    }
352}