1use crate::no_pivoting;
2use dyn_stack::{PodStack, SizeOverflow, StackReq};
3use faer_core::{
4 permutation::{
5 permute_rows, permute_rows_in_place, permute_rows_in_place_req, Index, PermutationRef,
6 },
7 ComplexField, Conj, Entity, MatMut, MatRef, Parallelism,
8};
9use reborrow::*;
10
11#[inline]
14pub fn solve_in_place_req<I: Index, E: Entity>(
15 qr_size: usize,
16 qr_blocksize: usize,
17 rhs_ncols: usize,
18) -> Result<StackReq, SizeOverflow> {
19 StackReq::try_any_of([
20 no_pivoting::solve::solve_in_place_req::<E>(qr_size, qr_blocksize, rhs_ncols)?,
21 permute_rows_in_place_req::<I, E>(qr_size, rhs_ncols)?,
22 ])
23}
24
25#[inline]
28pub fn solve_transpose_in_place_req<I: Index, E: Entity>(
29 qr_size: usize,
30 qr_blocksize: usize,
31 rhs_ncols: usize,
32) -> Result<StackReq, SizeOverflow> {
33 StackReq::try_any_of([
34 no_pivoting::solve::solve_transpose_in_place_req::<E>(qr_size, qr_blocksize, rhs_ncols)?,
35 permute_rows_in_place_req::<I, E>(qr_size, rhs_ncols)?,
36 ])
37}
38
39#[inline]
42pub fn solve_req<I: Index, E: Entity>(
43 qr_size: usize,
44 qr_blocksize: usize,
45 rhs_ncols: usize,
46) -> Result<StackReq, SizeOverflow> {
47 StackReq::try_any_of([
48 no_pivoting::solve::solve_req::<E>(qr_size, qr_blocksize, rhs_ncols)?,
49 permute_rows_in_place_req::<I, E>(qr_size, rhs_ncols)?,
50 ])
51}
52
53#[inline]
56pub fn solve_transpose_req<I: Index, E: Entity>(
57 qr_size: usize,
58 qr_blocksize: usize,
59 rhs_ncols: usize,
60) -> Result<StackReq, SizeOverflow> {
61 StackReq::try_any_of([
62 no_pivoting::solve::solve_transpose_req::<E>(qr_size, qr_blocksize, rhs_ncols)?,
63 permute_rows_in_place_req::<I, E>(qr_size, rhs_ncols)?,
64 ])
65}
66
67#[track_caller]
85pub fn solve_in_place<I: Index, E: ComplexField>(
86 qr_factors: MatRef<'_, E>,
87 householder_factor: MatRef<'_, E>,
88 col_perm: PermutationRef<'_, I, E>,
89 conj_lhs: Conj,
90 rhs: MatMut<'_, E>,
91 parallelism: Parallelism,
92 stack: PodStack<'_>,
93) {
94 let mut rhs = rhs;
95 let mut stack = stack;
96 no_pivoting::solve::solve_in_place(
97 qr_factors,
98 householder_factor,
99 conj_lhs,
100 rhs.rb_mut(),
101 parallelism,
102 stack.rb_mut(),
103 );
104 let size = qr_factors.ncols();
105 permute_rows_in_place(rhs.subrows_mut(0, size), col_perm.inverse(), stack);
106}
107
108#[track_caller]
127pub fn solve_transpose_in_place<I: Index, E: ComplexField>(
128 qr_factors: MatRef<'_, E>,
129 householder_factor: MatRef<'_, E>,
130 col_perm: PermutationRef<'_, I, E>,
131 conj_lhs: Conj,
132 rhs: MatMut<'_, E>,
133 parallelism: Parallelism,
134 stack: PodStack<'_>,
135) {
136 let mut rhs = rhs;
137 let mut stack = stack;
138 permute_rows_in_place(rhs.rb_mut(), col_perm, stack.rb_mut());
139 no_pivoting::solve::solve_transpose_in_place(
140 qr_factors,
141 householder_factor,
142 conj_lhs,
143 rhs.rb_mut(),
144 parallelism,
145 stack.rb_mut(),
146 );
147}
148
149#[track_caller]
168pub fn solve<I: Index, E: ComplexField>(
169 dst: MatMut<'_, E>,
170 qr_factors: MatRef<'_, E>,
171 householder_factor: MatRef<'_, E>,
172 col_perm: PermutationRef<'_, I, E>,
173 conj_lhs: Conj,
174 rhs: MatRef<'_, E>,
175 parallelism: Parallelism,
176 stack: PodStack<'_>,
177) {
178 let mut dst = dst;
179 let mut stack = stack;
180 no_pivoting::solve::solve(
181 dst.rb_mut(),
182 qr_factors,
183 householder_factor,
184 conj_lhs,
185 rhs,
186 parallelism,
187 stack.rb_mut(),
188 );
189 permute_rows_in_place(dst, col_perm.inverse(), stack);
190}
191
192#[track_caller]
211pub fn solve_transpose<I: Index, E: ComplexField>(
212 dst: MatMut<'_, E>,
213 qr_factors: MatRef<'_, E>,
214 householder_factor: MatRef<'_, E>,
215 col_perm: PermutationRef<'_, I, E>,
216 conj_lhs: Conj,
217 rhs: MatRef<'_, E>,
218 parallelism: Parallelism,
219 stack: PodStack<'_>,
220) {
221 let mut dst = dst;
222 let mut stack = stack;
223 permute_rows(dst.rb_mut(), rhs, col_perm);
224 no_pivoting::solve::solve_transpose_in_place(
225 qr_factors,
226 householder_factor,
227 conj_lhs,
228 dst.rb_mut(),
229 parallelism,
230 stack.rb_mut(),
231 );
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use crate::col_pivoting::compute::{qr_in_place, qr_in_place_req, recommended_blocksize};
238 use faer_core::{assert, c32, c64, mul::matmul_with_conj, Mat};
239 use rand::random;
240
241 macro_rules! make_stack {
242 ($req: expr) => {
243 ::dyn_stack::PodStack::new(&mut ::dyn_stack::GlobalPodBuffer::new($req.unwrap()))
244 };
245 }
246
247 fn test_solve_in_place<E: ComplexField>(mut random: impl FnMut() -> E, epsilon: E::Real) {
248 let n = 32;
249 let k = 6;
250
251 let a = Mat::from_fn(n, n, |_, _| random());
252 let rhs = Mat::from_fn(n, k, |_, _| random());
253
254 let mut qr = a.clone();
255 let blocksize = recommended_blocksize::<f64>(n, n);
256 let mut householder = Mat::from_fn(blocksize, n, |_, _| E::faer_zero());
257 let mut perm = vec![0usize; n];
258 let mut perm_inv = vec![0usize; n];
259
260 let (_, perm) = qr_in_place(
261 qr.as_mut(),
262 householder.as_mut(),
263 &mut perm,
264 &mut perm_inv,
265 Parallelism::None,
266 make_stack!(qr_in_place_req::<usize, E>(
267 n,
268 n,
269 blocksize,
270 Parallelism::None,
271 Default::default(),
272 )),
273 Default::default(),
274 );
275
276 let qr = qr.as_ref();
277
278 for conj_lhs in [Conj::No, Conj::Yes] {
279 let mut sol = rhs.clone();
280 solve_in_place(
281 qr,
282 householder.as_ref(),
283 perm.rb(),
284 conj_lhs,
285 sol.as_mut(),
286 Parallelism::None,
287 make_stack!(solve_in_place_req::<usize, E>(n, blocksize, k)),
288 );
289
290 let mut rhs_reconstructed = rhs.clone();
291 matmul_with_conj(
292 rhs_reconstructed.as_mut(),
293 a.as_ref(),
294 conj_lhs,
295 sol.as_ref(),
296 Conj::No,
297 None,
298 E::faer_one(),
299 Parallelism::None,
300 );
301
302 for j in 0..k {
303 for i in 0..n {
304 assert!(
305 (rhs_reconstructed.read(i, j).faer_sub(rhs.read(i, j))).faer_abs()
306 < epsilon
307 )
308 }
309 }
310 }
311 }
312
313 fn test_solve_transpose_in_place<E: ComplexField>(
314 mut random: impl FnMut() -> E,
315 epsilon: E::Real,
316 ) {
317 let n = 32;
318 let k = 6;
319
320 let a = Mat::from_fn(n, n, |_, _| random());
321 let rhs = Mat::from_fn(n, k, |_, _| random());
322
323 let mut qr = a.clone();
324 let blocksize = recommended_blocksize::<f64>(n, n);
325 let mut householder = Mat::from_fn(blocksize, n, |_, _| E::faer_zero());
326 let mut perm = vec![0usize; n];
327 let mut perm_inv = vec![0; n];
328
329 let (_, perm) = qr_in_place(
330 qr.as_mut(),
331 householder.as_mut(),
332 &mut perm,
333 &mut perm_inv,
334 Parallelism::None,
335 make_stack!(qr_in_place_req::<usize, E>(
336 n,
337 n,
338 blocksize,
339 Parallelism::None,
340 Default::default(),
341 )),
342 Default::default(),
343 );
344
345 let qr = qr.as_ref();
346
347 for conj_lhs in [Conj::No, Conj::Yes] {
348 let mut sol = rhs.clone();
349 solve_transpose_in_place(
350 qr,
351 householder.as_ref(),
352 perm.rb(),
353 conj_lhs,
354 sol.as_mut(),
355 Parallelism::None,
356 make_stack!(solve_transpose_in_place_req::<usize, E>(n, blocksize, k)),
357 );
358
359 let mut rhs_reconstructed = rhs.clone();
360 matmul_with_conj(
361 rhs_reconstructed.as_mut(),
362 a.as_ref().transpose(),
363 conj_lhs,
364 sol.as_ref(),
365 Conj::No,
366 None,
367 E::faer_one(),
368 Parallelism::None,
369 );
370
371 for j in 0..k {
372 for i in 0..n {
373 assert!(
374 (rhs_reconstructed.read(i, j).faer_sub(rhs.read(i, j))).faer_abs()
375 < epsilon
376 )
377 }
378 }
379 }
380 }
381
382 #[test]
383 fn test_solve_in_place_f64() {
384 test_solve_in_place(random::<f64>, 1e-6);
385 }
386 #[test]
387 fn test_solve_in_place_f32() {
388 test_solve_in_place(random::<f32>, 1e-1);
389 }
390
391 #[test]
392 fn test_solve_in_place_c64() {
393 test_solve_in_place(|| c64::new(random(), random()), 1e-6);
394 }
395
396 #[test]
397 fn test_solve_in_place_c32() {
398 test_solve_in_place(|| c32::new(random(), random()), 1e-1);
399 }
400
401 #[test]
402 fn test_solve_transpose_in_place_f64() {
403 test_solve_transpose_in_place(random::<f64>, 1e-6);
404 }
405
406 #[test]
407 fn test_solve_transpose_in_place_f32() {
408 test_solve_transpose_in_place(random::<f32>, 1e-1);
409 }
410
411 #[test]
412 fn test_solve_transpose_in_place_c64() {
413 test_solve_transpose_in_place(|| c64::new(random(), random()), 1e-6);
414 }
415
416 #[test]
417 fn test_solve_transpose_in_place_c32() {
418 test_solve_transpose_in_place(|| c32::new(random(), random()), 1e-1);
419 }
420}