faer_qr/col_pivoting/
solve.rs

1use crate::no_pivoting;
2use dyn_stack::{PodStack, SizeOverflow, StackReq};
3use faer_core::{
4    permutation::{
5        permute_rows, permute_rows_in_place, permute_rows_in_place_req, Index, PermutationRef,
6    },
7    ComplexField, Conj, Entity, MatMut, MatRef, Parallelism,
8};
9use reborrow::*;
10
11/// Computes the size and alignment of required workspace for solving a linear system defined by a
12/// matrix in place, given its QR decomposition with column pivoting.
13#[inline]
14pub fn solve_in_place_req<I: Index, E: Entity>(
15    qr_size: usize,
16    qr_blocksize: usize,
17    rhs_ncols: usize,
18) -> Result<StackReq, SizeOverflow> {
19    StackReq::try_any_of([
20        no_pivoting::solve::solve_in_place_req::<E>(qr_size, qr_blocksize, rhs_ncols)?,
21        permute_rows_in_place_req::<I, E>(qr_size, rhs_ncols)?,
22    ])
23}
24
25/// Computes the size and alignment of required workspace for solving a linear system defined by
26/// the transpose of a matrix in place, given its QR decomposition with column pivoting.
27#[inline]
28pub fn solve_transpose_in_place_req<I: Index, E: Entity>(
29    qr_size: usize,
30    qr_blocksize: usize,
31    rhs_ncols: usize,
32) -> Result<StackReq, SizeOverflow> {
33    StackReq::try_any_of([
34        no_pivoting::solve::solve_transpose_in_place_req::<E>(qr_size, qr_blocksize, rhs_ncols)?,
35        permute_rows_in_place_req::<I, E>(qr_size, rhs_ncols)?,
36    ])
37}
38
39/// Computes the size and alignment of required workspace for solving a linear system defined by a
40/// matrix out of place, given its QR decomposition with column pivoting.
41#[inline]
42pub fn solve_req<I: Index, E: Entity>(
43    qr_size: usize,
44    qr_blocksize: usize,
45    rhs_ncols: usize,
46) -> Result<StackReq, SizeOverflow> {
47    StackReq::try_any_of([
48        no_pivoting::solve::solve_req::<E>(qr_size, qr_blocksize, rhs_ncols)?,
49        permute_rows_in_place_req::<I, E>(qr_size, rhs_ncols)?,
50    ])
51}
52
53/// Computes the size and alignment of required workspace for solving a linear system defined by
54/// the transpose of a matrix ouf of place, given its QR decomposition with column pivoting.
55#[inline]
56pub fn solve_transpose_req<I: Index, E: Entity>(
57    qr_size: usize,
58    qr_blocksize: usize,
59    rhs_ncols: usize,
60) -> Result<StackReq, SizeOverflow> {
61    StackReq::try_any_of([
62        no_pivoting::solve::solve_transpose_req::<E>(qr_size, qr_blocksize, rhs_ncols)?,
63        permute_rows_in_place_req::<I, E>(qr_size, rhs_ncols)?,
64    ])
65}
66
67/// Given the QR factors with column pivoting of a matrix $A$ and a matrix $B$ stored in `rhs`,
68/// this function computes the solution of the linear system in the sense of least squares:
69/// $$\text{Op}_A(A)X = B.$$
70///
71/// $\text{Op}_A$ is either the identity or the conjugation depending on the value of `conj_lhs`.  
72///
73/// The solution of the linear system is stored in `rhs`.
74///
75/// # Panics
76///
77/// - Panics if `qr_factors` is not a tall matrix.
78/// - Panics if the number of columns of `householder_factor` isn't the same as the minimum of the
79/// number of rows and the number of columns of `qr_factors`.
80/// - Panics if the block size is zero.
81/// - Panics if `col_perm` doesn't have the same dimension as `qr_factors`.
82/// - Panics if `rhs` doesn't have the same number of rows as the number of columns of `qr_factors`.
83/// - Panics if the provided memory in `stack` is insufficient (see [`solve_in_place_req`]).
84#[track_caller]
85pub fn solve_in_place<I: Index, E: ComplexField>(
86    qr_factors: MatRef<'_, E>,
87    householder_factor: MatRef<'_, E>,
88    col_perm: PermutationRef<'_, I, E>,
89    conj_lhs: Conj,
90    rhs: MatMut<'_, E>,
91    parallelism: Parallelism,
92    stack: PodStack<'_>,
93) {
94    let mut rhs = rhs;
95    let mut stack = stack;
96    no_pivoting::solve::solve_in_place(
97        qr_factors,
98        householder_factor,
99        conj_lhs,
100        rhs.rb_mut(),
101        parallelism,
102        stack.rb_mut(),
103    );
104    let size = qr_factors.ncols();
105    permute_rows_in_place(rhs.subrows_mut(0, size), col_perm.inverse(), stack);
106}
107
108/// Given the QR factors with column pivoting of a matrix $A$ and a matrix $B$ stored in `rhs`,
109/// this function computes the solution of the linear system in the sense of least squares:
110/// $$\text{Op}_A(A)\top X = B.$$
111///
112/// $\text{Op}_A$ is either the identity or the conjugation depending on the value of `conj_lhs`.  
113///
114/// The solution of the linear system is stored in `rhs`.
115///
116/// # Panics
117///
118/// - Panics if `qr_factors` is not a square matrix.
119/// - Panics if the number of columns of `householder_factor` isn't the same as the minimum of the
120/// number of rows and the number of columns of `qr_factors`.
121/// - Panics if the block size is zero.
122/// - Panics if `col_perm` doesn't have the same dimension as `qr_factors`.
123/// - Panics if `rhs` doesn't have the same number of rows as the dimension of `qr_factors`.
124/// - Panics if the provided memory in `stack` is insufficient (see
125///   [`solve_transpose_in_place_req`]).
126#[track_caller]
127pub fn solve_transpose_in_place<I: Index, E: ComplexField>(
128    qr_factors: MatRef<'_, E>,
129    householder_factor: MatRef<'_, E>,
130    col_perm: PermutationRef<'_, I, E>,
131    conj_lhs: Conj,
132    rhs: MatMut<'_, E>,
133    parallelism: Parallelism,
134    stack: PodStack<'_>,
135) {
136    let mut rhs = rhs;
137    let mut stack = stack;
138    permute_rows_in_place(rhs.rb_mut(), col_perm, stack.rb_mut());
139    no_pivoting::solve::solve_transpose_in_place(
140        qr_factors,
141        householder_factor,
142        conj_lhs,
143        rhs.rb_mut(),
144        parallelism,
145        stack.rb_mut(),
146    );
147}
148
149/// Given the QR factors with column pivoting of a matrix $A$ and a matrix $B$ stored in `rhs`,
150/// this function computes the solution of the linear system:
151/// $$\text{Op}_A(A)X = B.$$
152///
153/// $\text{Op}_A$ is either the identity or the conjugation depending on the value of `conj_lhs`.  
154///
155/// The solution of the linear system is stored in `dst`.
156///
157/// # Panics
158///
159/// - Panics if `qr_factors` is not a square matrix.
160/// - Panics if the number of columns of `householder_factor` isn't the same as the minimum of the
161/// number of rows and the number of columns of `qr_factors`.
162/// - Panics if the block size is zero.
163/// - Panics if `col_perm` doesn't have the same dimension as `qr_factors`.
164/// - Panics if `rhs` doesn't have the same number of rows as the dimension of `qr_factors`.
165/// - Panics if `rhs` and `dst` don't have the same shape.
166/// - Panics if the provided memory in `stack` is insufficient (see [`solve_req`]).
167#[track_caller]
168pub fn solve<I: Index, E: ComplexField>(
169    dst: MatMut<'_, E>,
170    qr_factors: MatRef<'_, E>,
171    householder_factor: MatRef<'_, E>,
172    col_perm: PermutationRef<'_, I, E>,
173    conj_lhs: Conj,
174    rhs: MatRef<'_, E>,
175    parallelism: Parallelism,
176    stack: PodStack<'_>,
177) {
178    let mut dst = dst;
179    let mut stack = stack;
180    no_pivoting::solve::solve(
181        dst.rb_mut(),
182        qr_factors,
183        householder_factor,
184        conj_lhs,
185        rhs,
186        parallelism,
187        stack.rb_mut(),
188    );
189    permute_rows_in_place(dst, col_perm.inverse(), stack);
190}
191
192/// Given the QR factors with column pivoting of a matrix $A$ and a matrix $B$ stored in `rhs`,
193/// this function computes the solution of the linear system:
194/// $$\text{Op}_A(A)^\top X = B.$$
195///
196/// $\text{Op}_A$ is either the identity or the conjugation depending on the value of `conj_lhs`.  
197///
198/// The solution of the linear system is stored in `dst`.
199///
200/// # Panics
201///
202/// - Panics if `qr_factors` is not a square matrix.
203/// - Panics if the number of columns of `householder_factor` isn't the same as the minimum of the
204/// number of rows and the number of columns of `qr_factors`.
205/// - Panics if the block size is zero.
206/// - Panics if `col_perm` doesn't have the same dimension as `qr_factors`.
207/// - Panics if `rhs` doesn't have the same number of rows as the dimension of `qr_factors`.
208/// - Panics if `rhs` and `dst` don't have the same shape.
209/// - Panics if the provided memory in `stack` is insufficient (see [`solve_transpose_req`]).
210#[track_caller]
211pub fn solve_transpose<I: Index, E: ComplexField>(
212    dst: MatMut<'_, E>,
213    qr_factors: MatRef<'_, E>,
214    householder_factor: MatRef<'_, E>,
215    col_perm: PermutationRef<'_, I, E>,
216    conj_lhs: Conj,
217    rhs: MatRef<'_, E>,
218    parallelism: Parallelism,
219    stack: PodStack<'_>,
220) {
221    let mut dst = dst;
222    let mut stack = stack;
223    permute_rows(dst.rb_mut(), rhs, col_perm);
224    no_pivoting::solve::solve_transpose_in_place(
225        qr_factors,
226        householder_factor,
227        conj_lhs,
228        dst.rb_mut(),
229        parallelism,
230        stack.rb_mut(),
231    );
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use crate::col_pivoting::compute::{qr_in_place, qr_in_place_req, recommended_blocksize};
238    use faer_core::{assert, c32, c64, mul::matmul_with_conj, Mat};
239    use rand::random;
240
241    macro_rules! make_stack {
242        ($req: expr) => {
243            ::dyn_stack::PodStack::new(&mut ::dyn_stack::GlobalPodBuffer::new($req.unwrap()))
244        };
245    }
246
247    fn test_solve_in_place<E: ComplexField>(mut random: impl FnMut() -> E, epsilon: E::Real) {
248        let n = 32;
249        let k = 6;
250
251        let a = Mat::from_fn(n, n, |_, _| random());
252        let rhs = Mat::from_fn(n, k, |_, _| random());
253
254        let mut qr = a.clone();
255        let blocksize = recommended_blocksize::<f64>(n, n);
256        let mut householder = Mat::from_fn(blocksize, n, |_, _| E::faer_zero());
257        let mut perm = vec![0usize; n];
258        let mut perm_inv = vec![0usize; n];
259
260        let (_, perm) = qr_in_place(
261            qr.as_mut(),
262            householder.as_mut(),
263            &mut perm,
264            &mut perm_inv,
265            Parallelism::None,
266            make_stack!(qr_in_place_req::<usize, E>(
267                n,
268                n,
269                blocksize,
270                Parallelism::None,
271                Default::default(),
272            )),
273            Default::default(),
274        );
275
276        let qr = qr.as_ref();
277
278        for conj_lhs in [Conj::No, Conj::Yes] {
279            let mut sol = rhs.clone();
280            solve_in_place(
281                qr,
282                householder.as_ref(),
283                perm.rb(),
284                conj_lhs,
285                sol.as_mut(),
286                Parallelism::None,
287                make_stack!(solve_in_place_req::<usize, E>(n, blocksize, k)),
288            );
289
290            let mut rhs_reconstructed = rhs.clone();
291            matmul_with_conj(
292                rhs_reconstructed.as_mut(),
293                a.as_ref(),
294                conj_lhs,
295                sol.as_ref(),
296                Conj::No,
297                None,
298                E::faer_one(),
299                Parallelism::None,
300            );
301
302            for j in 0..k {
303                for i in 0..n {
304                    assert!(
305                        (rhs_reconstructed.read(i, j).faer_sub(rhs.read(i, j))).faer_abs()
306                            < epsilon
307                    )
308                }
309            }
310        }
311    }
312
313    fn test_solve_transpose_in_place<E: ComplexField>(
314        mut random: impl FnMut() -> E,
315        epsilon: E::Real,
316    ) {
317        let n = 32;
318        let k = 6;
319
320        let a = Mat::from_fn(n, n, |_, _| random());
321        let rhs = Mat::from_fn(n, k, |_, _| random());
322
323        let mut qr = a.clone();
324        let blocksize = recommended_blocksize::<f64>(n, n);
325        let mut householder = Mat::from_fn(blocksize, n, |_, _| E::faer_zero());
326        let mut perm = vec![0usize; n];
327        let mut perm_inv = vec![0; n];
328
329        let (_, perm) = qr_in_place(
330            qr.as_mut(),
331            householder.as_mut(),
332            &mut perm,
333            &mut perm_inv,
334            Parallelism::None,
335            make_stack!(qr_in_place_req::<usize, E>(
336                n,
337                n,
338                blocksize,
339                Parallelism::None,
340                Default::default(),
341            )),
342            Default::default(),
343        );
344
345        let qr = qr.as_ref();
346
347        for conj_lhs in [Conj::No, Conj::Yes] {
348            let mut sol = rhs.clone();
349            solve_transpose_in_place(
350                qr,
351                householder.as_ref(),
352                perm.rb(),
353                conj_lhs,
354                sol.as_mut(),
355                Parallelism::None,
356                make_stack!(solve_transpose_in_place_req::<usize, E>(n, blocksize, k)),
357            );
358
359            let mut rhs_reconstructed = rhs.clone();
360            matmul_with_conj(
361                rhs_reconstructed.as_mut(),
362                a.as_ref().transpose(),
363                conj_lhs,
364                sol.as_ref(),
365                Conj::No,
366                None,
367                E::faer_one(),
368                Parallelism::None,
369            );
370
371            for j in 0..k {
372                for i in 0..n {
373                    assert!(
374                        (rhs_reconstructed.read(i, j).faer_sub(rhs.read(i, j))).faer_abs()
375                            < epsilon
376                    )
377                }
378            }
379        }
380    }
381
382    #[test]
383    fn test_solve_in_place_f64() {
384        test_solve_in_place(random::<f64>, 1e-6);
385    }
386    #[test]
387    fn test_solve_in_place_f32() {
388        test_solve_in_place(random::<f32>, 1e-1);
389    }
390
391    #[test]
392    fn test_solve_in_place_c64() {
393        test_solve_in_place(|| c64::new(random(), random()), 1e-6);
394    }
395
396    #[test]
397    fn test_solve_in_place_c32() {
398        test_solve_in_place(|| c32::new(random(), random()), 1e-1);
399    }
400
401    #[test]
402    fn test_solve_transpose_in_place_f64() {
403        test_solve_transpose_in_place(random::<f64>, 1e-6);
404    }
405
406    #[test]
407    fn test_solve_transpose_in_place_f32() {
408        test_solve_transpose_in_place(random::<f32>, 1e-1);
409    }
410
411    #[test]
412    fn test_solve_transpose_in_place_c64() {
413        test_solve_transpose_in_place(|| c64::new(random(), random()), 1e-6);
414    }
415
416    #[test]
417    fn test_solve_transpose_in_place_c32() {
418        test_solve_transpose_in_place(|| c32::new(random(), random()), 1e-1);
419    }
420}