Skip to main content

feanor_math/matrix/
transform.rs

1use std::marker::PhantomData;
2
3use super::{AsPointerToSlice, OwnedMatrix, SubmatrixMut};
4use crate::ring::*;
5
6/// A trait for a "target" that can "consume" elementary operations on matrices.
7///  
8/// This is mainly used during algorithms that work on matrices, since in many cases
9/// they transform matrices using elementary row or column operations, and have to
10/// accumulate data depending on these operations.
11pub trait TransformTarget<R>
12where
13    R: ?Sized + RingBase,
14{
15    /// The transformation given by the matrix `A` with `A[k, l]` being
16    ///  - `1` if `k = l` and `k != i, j`
17    ///  - `transform[0]` if `(k, l) = (i, i)`
18    ///  - `transform[1]` if `(k, l) = (i, j)`
19    ///  - `transform[2]` if `(k, l) = (j, i)`
20    ///  - `transform[3]` if `(k, l) = (j, j)`
21    ///  - `0` otherwise
22    ///
23    /// In other words, the matrix looks like
24    /// ```text
25    /// | 1  ...  0                       |
26    /// | ⋮        ⋮                       |
27    /// | 0  ...  1                       |
28    /// |    A             B              | <- i-th row
29    /// |            1  ...  0            |
30    /// |            ⋮        ⋮            |
31    /// |            0  ...  1            |
32    /// |    C             D              | <- j-th row
33    /// |                       1  ...  0 |
34    /// |                       ⋮        ⋮ |
35    /// |                       0  ...  1 |
36    ///      ^ i-th col    ^ j-th col
37    /// ```
38    /// where `transform = [A, B, C, D]`.
39    fn transform<S: Copy + RingStore<Type = R>>(&mut self, ring: S, i: usize, j: usize, transform: &[R::Element; 4]);
40
41    /// The transformation corresponding to subtracting `factor` times the `src`-th row
42    /// resp. col from the `dst`-th row resp. col.
43    ///
44    /// More precisely, the `(k, l)`-th entry of the transform matrix is defined to be
45    ///  - `1` if `k == l`
46    ///  - `-factor` if `k == dst, l == src`
47    ///  - `0` otherwise
48    fn subtract<S: Copy + RingStore<Type = R>>(&mut self, ring: S, src: usize, dst: usize, factor: &R::Element) {
49        self.transform(
50            ring,
51            src,
52            dst,
53            &[ring.one(), ring.zero(), ring.negate(ring.clone_el(factor)), ring.one()],
54        )
55    }
56
57    /// The transformation corresponding to the permutation matrix swapping `i`-th and `j`-th row
58    /// resp. column.
59    ///
60    /// More precisely, the `(k, l)`-th entry of the transform matrix is defined to be
61    ///  - `1` if `k == l, k != i, k != j`
62    ///  - `1` if `k == i, l == j`
63    ///  - `1` if `k == j, l == i`
64    ///  - `0` otherwise
65    fn swap<S: Copy + RingStore<Type = R>>(&mut self, ring: S, i: usize, j: usize) {
66        self.transform(ring, i, j, &[ring.zero(), ring.one(), ring.one(), ring.zero()])
67    }
68}
69
70impl<T, R> TransformTarget<R> for &mut T
71where
72    R: ?Sized + RingBase,
73    T: TransformTarget<R>,
74{
75    fn transform<S: Copy + RingStore<Type = R>>(&mut self, ring: S, i: usize, j: usize, transform: &[R::Element; 4]) {
76        <T as TransformTarget<R>>::transform(*self, ring, i, j, transform)
77    }
78
79    fn subtract<S: Copy + RingStore<Type = R>>(&mut self, ring: S, src: usize, dst: usize, factor: &R::Element) {
80        <T as TransformTarget<R>>::subtract(*self, ring, src, dst, factor)
81    }
82
83    fn swap<S: Copy + RingStore<Type = R>>(&mut self, ring: S, i: usize, j: usize) {
84        <T as TransformTarget<R>>::swap(*self, ring, i, j)
85    }
86}
87
88/// Wraps a [`SubmatrixMut`] to get a [`TransformTarget`]. Every transform is multiplied to
89/// the wrapped matrix from the left, i.e. applied to the rows of the matrix.
90///
91/// TODO: at next breaking release, remove the reference to the ring
92pub struct TransformRows<'a, V, R>(pub SubmatrixMut<'a, V, R::Element>, pub &'a R)
93where
94    V: AsPointerToSlice<R::Element>,
95    R: ?Sized + RingBase;
96
97/// Wraps a [`SubmatrixMut`] to get a [`TransformTarget`]. Every transform is multiplied to
98/// the wrapped matrix from the right, i.e. applied to the cols of the matrix.
99///
100/// TODO: at next breaking release, remove the reference to the ring
101pub struct TransformCols<'a, V, R>(pub SubmatrixMut<'a, V, R::Element>, pub &'a R)
102where
103    V: AsPointerToSlice<R::Element>,
104    R: ?Sized + RingBase;
105
106impl<'a, V, R> TransformTarget<R> for TransformRows<'a, V, R>
107where
108    V: AsPointerToSlice<R::Element>,
109    R: ?Sized + RingBase,
110{
111    fn transform<S: Copy + RingStore<Type = R>>(
112        &mut self,
113        ring: S,
114        i: usize,
115        j: usize,
116        transform: &[<R as RingBase>::Element; 4],
117    ) {
118        assert!(ring.get_ring() == self.1);
119        let A = &mut self.0;
120        for l in 0..A.col_count() {
121            let (new_i, new_j) = (
122                ring.add(
123                    ring.mul_ref(A.at(i, l), &transform[0]),
124                    ring.mul_ref(A.at(j, l), &transform[1]),
125                ),
126                ring.add(
127                    ring.mul_ref(A.at(i, l), &transform[2]),
128                    ring.mul_ref(A.at(j, l), &transform[3]),
129                ),
130            );
131            *A.at_mut(i, l) = new_i;
132            *A.at_mut(j, l) = new_j;
133        }
134    }
135
136    fn subtract<S: Copy + RingStore<Type = R>>(
137        &mut self,
138        ring: S,
139        src: usize,
140        dst: usize,
141        factor: &<R as RingBase>::Element,
142    ) {
143        assert!(ring.get_ring() == self.1);
144        let A = &mut self.0;
145        for j in 0..A.col_count() {
146            let to_sub = ring.mul_ref(factor, A.at(src, j));
147            ring.sub_assign(A.at_mut(dst, j), to_sub);
148        }
149    }
150}
151
152impl<'a, V, R> TransformTarget<R> for TransformCols<'a, V, R>
153where
154    V: AsPointerToSlice<R::Element>,
155    R: ?Sized + RingBase,
156{
157    fn transform<S: Copy + RingStore<Type = R>>(
158        &mut self,
159        ring: S,
160        i: usize,
161        j: usize,
162        transform: &[<R as RingBase>::Element; 4],
163    ) {
164        assert!(ring.get_ring() == self.1);
165        let A = &mut self.0;
166        for l in 0..A.row_count() {
167            let (new_i, new_j) = (
168                ring.add(
169                    ring.mul_ref(A.at(l, i), &transform[0]),
170                    ring.mul_ref(A.at(l, j), &transform[1]),
171                ),
172                ring.add(
173                    ring.mul_ref(A.at(l, i), &transform[2]),
174                    ring.mul_ref(A.at(l, j), &transform[3]),
175                ),
176            );
177            *A.at_mut(l, i) = new_i;
178            *A.at_mut(l, j) = new_j;
179        }
180    }
181
182    fn subtract<S: Copy + RingStore<Type = R>>(
183        &mut self,
184        ring: S,
185        src: usize,
186        dst: usize,
187        factor: &<R as RingBase>::Element,
188    ) {
189        assert!(ring.get_ring() == self.1);
190        let A = &mut self.0;
191        for i in 0..A.row_count() {
192            let to_sub = ring.mul_ref(factor, A.at(i, src));
193            ring.sub_assign(A.at_mut(i, dst), to_sub);
194        }
195    }
196}
197
198enum Transform<R>
199where
200    R: ?Sized + RingBase,
201{
202    General(usize, usize, [R::Element; 4]),
203    Subtract(usize, usize, R::Element),
204    Swap(usize, usize),
205}
206
207#[stability::unstable(feature = "enable")]
208pub struct TransformList<R>
209where
210    R: ?Sized + RingBase,
211{
212    transforms: Vec<Transform<R>>,
213    row_count: usize,
214}
215
216impl<R> TransformList<R>
217where
218    R: ?Sized + RingBase,
219{
220    #[stability::unstable(feature = "enable")]
221    pub fn new(row_count: usize) -> Self {
222        Self {
223            row_count,
224            transforms: Vec::new(),
225        }
226    }
227
228    #[stability::unstable(feature = "enable")]
229    pub fn replay<S: Copy + RingStore<Type = R>, T: TransformTarget<R>>(&self, ring: S, mut target: T) {
230        for transform in &self.transforms {
231            match transform {
232                Transform::General(i, j, matrix) => target.transform(ring, *i, *j, matrix),
233                Transform::Subtract(src, dst, factor) => target.subtract(ring, *src, *dst, factor),
234                Transform::Swap(i, j) => target.swap(ring, *i, *j),
235            }
236        }
237    }
238
239    #[stability::unstable(feature = "enable")]
240    pub fn replay_transposed<S: Copy + RingStore<Type = R>, T: TransformTarget<R>>(&self, ring: S, mut target: T) {
241        for transform in self.transforms.iter().rev() {
242            match transform {
243                Transform::General(i, j, matrix) => target.transform(
244                    ring,
245                    *i,
246                    *j,
247                    &[
248                        ring.clone_el(&matrix[0]),
249                        ring.clone_el(&matrix[2]),
250                        ring.clone_el(&matrix[1]),
251                        ring.clone_el(&matrix[3]),
252                    ],
253                ),
254                Transform::Subtract(src, dst, factor) => target.subtract(ring, *dst, *src, factor),
255                Transform::Swap(i, j) => target.swap(ring, *i, *j),
256            }
257        }
258    }
259
260    #[stability::unstable(feature = "enable")]
261    pub fn to_matrix<S: Copy + RingStore<Type = R>>(&self, ring: S) -> OwnedMatrix<R::Element> {
262        let mut result = OwnedMatrix::identity(self.row_count, self.row_count, ring);
263        self.replay(ring, TransformRows(result.data_mut(), ring.get_ring()));
264        return result;
265    }
266}
267
268impl<R> TransformTarget<R> for TransformList<R>
269where
270    R: ?Sized + RingBase,
271{
272    fn transform<S: Copy + RingStore<Type = R>>(
273        &mut self,
274        ring: S,
275        i: usize,
276        j: usize,
277        transform: &[<R as RingBase>::Element; 4],
278    ) {
279        debug_assert!(i < self.row_count);
280        debug_assert!(j < self.row_count);
281        self.transforms.push(Transform::General(
282            i,
283            j,
284            std::array::from_fn(|k| ring.clone_el(&transform[k])),
285        ))
286    }
287
288    fn subtract<S: Copy + RingStore<Type = R>>(
289        &mut self,
290        ring: S,
291        src: usize,
292        dst: usize,
293        factor: &<R as RingBase>::Element,
294    ) {
295        debug_assert!(src < self.row_count);
296        debug_assert!(dst < self.row_count);
297        self.transforms
298            .push(Transform::Subtract(src, dst, ring.clone_el(factor)))
299    }
300
301    fn swap<S: Copy + RingStore<Type = R>>(&mut self, _ring: S, i: usize, j: usize) {
302        debug_assert!(i < self.row_count);
303        debug_assert!(j < self.row_count);
304        self.transforms.push(Transform::Swap(i, j))
305    }
306}
307
308impl<R> TransformTarget<R> for ()
309where
310    R: ?Sized + RingBase,
311{
312    fn transform<S: Copy + RingStore<Type = R>>(&mut self, _: S, _: usize, _: usize, _: &[R::Element; 4]) {}
313
314    fn subtract<S: Copy + RingStore<Type = R>>(&mut self, _: S, _: usize, _: usize, _: &R::Element) {}
315
316    fn swap<S: Copy + RingStore<Type = R>>(&mut self, _: S, _: usize, _: usize) {}
317}
318
319/// A [`TransformTarget`] that forwards all transforms to a fixed
320/// delegate, but offsets every row/column index by a given value.
321pub struct OffsetTransformIndex<R, T>
322where
323    R: ?Sized + RingBase,
324    T: TransformTarget<R>,
325{
326    delegate: T,
327    index_offset: usize,
328    ring: PhantomData<R>,
329}
330
331impl<R, T> OffsetTransformIndex<R, T>
332where
333    R: ?Sized + RingBase,
334    T: TransformTarget<R>,
335{
336    /// Creates a new [`OffsetTransformIndex`] that forwards all transforms to `delegate`.
337    pub fn new(delegate: T, offset: usize) -> Self {
338        Self {
339            delegate,
340            index_offset: offset,
341            ring: PhantomData,
342        }
343    }
344}
345
346impl<R, T> TransformTarget<R> for OffsetTransformIndex<R, T>
347where
348    R: ?Sized + RingBase,
349    T: TransformTarget<R>,
350{
351    fn transform<S: Copy + RingStore<Type = R>>(&mut self, ring: S, i: usize, j: usize, transform: &[R::Element; 4]) {
352        <T as TransformTarget<R>>::transform(
353            &mut self.delegate,
354            ring,
355            i + self.index_offset,
356            j + self.index_offset,
357            transform,
358        );
359    }
360
361    fn subtract<S: Copy + RingStore<Type = R>>(&mut self, ring: S, src: usize, dst: usize, factor: &R::Element) {
362        <T as TransformTarget<R>>::subtract(
363            &mut self.delegate,
364            ring,
365            src + self.index_offset,
366            dst + self.index_offset,
367            factor,
368        );
369    }
370
371    fn swap<S: Copy + RingStore<Type = R>>(&mut self, ring: S, i: usize, j: usize) {
372        <T as TransformTarget<R>>::swap(&mut self.delegate, ring, i + self.index_offset, j + self.index_offset);
373    }
374}
375
376/// A [`TransformTarget`] that forwards all transforms to
377/// two fixed delegates.
378pub struct DuplicateTransforms<R, T1, T2>
379where
380    R: ?Sized + RingBase,
381    T1: TransformTarget<R>,
382    T2: TransformTarget<R>,
383{
384    delegate1: T1,
385    delegate2: T2,
386    ring: PhantomData<R>,
387}
388
389impl<R, T1, T2> DuplicateTransforms<R, T1, T2>
390where
391    R: ?Sized + RingBase,
392    T1: TransformTarget<R>,
393    T2: TransformTarget<R>,
394{
395    /// Creates a new [`DuplicateTransforms`] that forwards all transforms to `first` and `second`.
396    pub fn new(first: T1, second: T2) -> Self {
397        Self {
398            delegate1: first,
399            delegate2: second,
400            ring: PhantomData,
401        }
402    }
403}
404
405impl<R, T1, T2> TransformTarget<R> for DuplicateTransforms<R, T1, T2>
406where
407    R: ?Sized + RingBase,
408    T1: TransformTarget<R>,
409    T2: TransformTarget<R>,
410{
411    fn transform<S: Copy + RingStore<Type = R>>(&mut self, ring: S, i: usize, j: usize, transform: &[R::Element; 4]) {
412        <T1 as TransformTarget<R>>::transform(&mut self.delegate1, ring, i, j, transform);
413        <T2 as TransformTarget<R>>::transform(&mut self.delegate2, ring, i, j, transform);
414    }
415
416    fn subtract<S: Copy + RingStore<Type = R>>(&mut self, ring: S, src: usize, dst: usize, factor: &R::Element) {
417        <T1 as TransformTarget<R>>::subtract(&mut self.delegate1, ring, src, dst, factor);
418        <T2 as TransformTarget<R>>::subtract(&mut self.delegate2, ring, src, dst, factor);
419    }
420
421    fn swap<S: Copy + RingStore<Type = R>>(&mut self, ring: S, i: usize, j: usize) {
422        <T1 as TransformTarget<R>>::swap(&mut self.delegate1, ring, i, j);
423        <T2 as TransformTarget<R>>::swap(&mut self.delegate2, ring, i, j);
424    }
425}