1use crate::sealed::Sealed;
2use crate::{ComplexHelper, DiagonalKind, Side, Transposition, TriangularStructure, qr_util};
3use crate::{LapackErrorCode, lapack_error::check_lapack_info};
4use lapack;
5use na::allocator::Allocator;
6use na::dimension::{Const, Dim, DimMin, DimMinimum};
7use na::{
8 ComplexField, DefaultAllocator, IsContiguous, Matrix, OMatrix, OVector, RawStorageMut,
9 RealField, Scalar,
10};
11use num::Zero;
12#[cfg(feature = "serde-serialize")]
13use serde::{Deserialize, Serialize};
14
15pub use crate::qr_util::Error;
16pub(crate) mod abstraction;
17pub use abstraction::QrDecomposition;
18
19#[cfg_attr(feature = "serde-serialize", derive(Serialize, Deserialize))]
21#[cfg_attr(
22 feature = "serde-serialize",
23 serde(bound(serialize = "DefaultAllocator: Allocator<R, C> +
24 Allocator<DimMinimum<R, C>>,
25 OMatrix<T, R, C>: Serialize,
26 OVector<T, DimMinimum<R, C>>: Serialize"))
27)]
28#[cfg_attr(
29 feature = "serde-serialize",
30 serde(bound(deserialize = "DefaultAllocator: Allocator<R, C> +
31 Allocator<DimMinimum<R, C>>,
32 OMatrix<T, R, C>: Deserialize<'de>,
33 OVector<T, DimMinimum<R, C>>: Deserialize<'de>"))
34)]
35#[derive(Clone, Debug)]
36pub struct QR<T, R, C>
37where
38 DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
39 T: Scalar,
40 R: DimMin<C>,
41 C: Dim,
42{
43 qr: OMatrix<T, R, C>,
44 tau: OVector<T, DimMinimum<R, C>>,
45}
46
47impl<T: Scalar + Copy, R: DimMin<C>, C: Dim> Copy for QR<T, R, C>
48where
49 DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
50 OMatrix<T, R, C>: Copy,
51 OVector<T, DimMinimum<R, C>>: Copy,
52{
53}
54
55impl<T, R, C> QR<T, R, C>
56where
57 DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
58 T: Scalar,
59 R: DimMin<C>,
60 C: Dim,
61{
62 pub fn new(mut m: OMatrix<T, R, C>) -> Result<Self, Error>
64 where
65 T: QrScalar + Zero,
66 {
67 let (nrows, ncols) = m.shape_generic();
68
69 let mut tau = Matrix::zeros_generic(nrows.min(ncols), Const::<1>);
70
71 if nrows.value() < ncols.value() {
72 return Err(Error::Underdetermined);
73 }
74
75 if nrows.value() == 0 || ncols.value() == 0 {
76 return Ok(Self { qr: m, tau });
77 }
78
79 let lwork = unsafe {
80 T::xgeqrf_work_size(
81 nrows.value() as i32,
82 ncols.value() as i32,
83 m.as_mut_slice(),
84 nrows.value() as i32,
85 tau.as_mut_slice(),
86 )?
87 };
88
89 let mut work = vec![T::zero(); lwork as usize];
90
91 unsafe {
92 T::xgeqrf(
93 nrows.value() as i32,
94 ncols.value() as i32,
95 m.as_mut_slice(),
96 nrows.value() as i32,
97 tau.as_mut_slice(),
98 &mut work,
99 lwork,
100 )?;
101 }
102
103 Ok(Self { qr: m, tau })
104 }
105}
106
107impl<T, R, C> Sealed for QR<T, R, C>
108where
109 DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>> + Allocator<C>,
110 T: Scalar,
111 R: DimMin<C>,
112 C: Dim,
113{
114}
115
116impl<T, R, C> QrDecomposition<T, R, C> for QR<T, R, C>
117where
118 DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>> + Allocator<C>,
119 R: DimMin<C, Output = C>,
120 C: Dim,
121 T: Scalar + RealField + QrReal,
122{
123 fn __lapack_qr_ref(&self) -> &OMatrix<T, R, C> {
124 &self.qr
125 }
126
127 fn __lapack_tau_ref(&self) -> &OVector<T, DimMinimum<R, C>> {
128 &self.tau
129 }
130
131 fn solve_mut<C2: Dim, S, S2>(
132 &self,
133 x: &mut Matrix<T, C, C2, S2>,
134 b: Matrix<T, R, C2, S>,
135 ) -> Result<(), Error>
136 where
137 S: RawStorageMut<T, R, C2> + IsContiguous,
138 S2: RawStorageMut<T, C, C2> + IsContiguous,
139 T: Zero,
140 {
141 if self.nrows() < self.ncols() {
143 return Err(Error::Underdetermined);
144 }
145
146 let rank = self
149 .nrows()
150 .min(self.ncols())
151 .try_into()
152 .expect("integer dimensions out of bounds");
153 qr_util::qr_solve_mut_with_rank_unpermuted(&self.qr, &self.tau, rank, x, b)?;
154 Ok(())
155 }
156}
157
158#[allow(missing_docs)]
166pub trait QrScalar: ComplexField + Scalar + Copy + Sealed {
167 unsafe fn xgeqrf(
168 m: i32,
169 n: i32,
170 a: &mut [Self],
171 lda: i32,
172 tau: &mut [Self],
173 work: &mut [Self],
174 lwork: i32,
175 ) -> Result<(), LapackErrorCode>;
176
177 unsafe fn xgeqrf_work_size(
178 m: i32,
179 n: i32,
180 a: &mut [Self],
181 lda: i32,
182 tau: &mut [Self],
183 ) -> Result<i32, LapackErrorCode>;
184
185 unsafe fn xgeqp3(
189 m: i32,
190 n: i32,
191 a: &mut [Self],
192 lda: i32,
193 jpvt: &mut [i32],
194 tau: &mut [Self],
195 work: &mut [Self],
196 lwork: i32,
197 ) -> Result<(), LapackErrorCode>;
198
199 unsafe fn xgeqp3_work_size(
200 m: i32,
201 n: i32,
202 a: &mut [Self],
203 lda: i32,
204 jpvt: &mut [i32],
205 tau: &mut [Self],
206 ) -> Result<i32, LapackErrorCode>;
207
208 unsafe fn xtrtrs(
209 uplo: TriangularStructure,
210 trans: Transposition,
211 diag: DiagonalKind,
212 n: i32,
213 nrhs: i32,
214 a: &[Self],
215 lda: i32,
216 b: &mut [Self],
217 ldb: i32,
218 ) -> Result<(), LapackErrorCode>;
219
220 unsafe fn xlapmt(
221 forwrd: bool,
222 m: i32,
223 n: i32,
224 x: &mut [Self],
225 ldx: i32,
226 k: &mut [i32],
227 ) -> Result<(), LapackErrorCode>;
228
229 unsafe fn xlapmr(
230 forwrd: bool,
231 m: i32,
232 n: i32,
233 x: &mut [Self],
234 ldx: i32,
235 k: &mut [i32],
236 ) -> Result<(), LapackErrorCode>;
237}
238
239macro_rules! qr_scalar_impl(
240 ($type:ty,
241 xgeqrf = $xgeqrf: path,
242 xgeqp3=$xgeqp3:path,
243 xtrtrs=$xtrtrs:path,
244 xlapmt=$xlapmt:path,
245 xlapmr=$xlapmr:path $(,)?) => (
246 impl QrScalar for $type {
247 #[inline]
248 unsafe fn xgeqrf(m: i32, n: i32, a: &mut [Self], lda: i32, tau: &mut [Self],
249 work: &mut [Self], lwork: i32) -> Result<(),LapackErrorCode> {
250 let mut info = 0;
251 unsafe { $xgeqrf(m, n, a, lda, tau, work, lwork, &mut info) }
252 check_lapack_info(info)
253 }
254
255 #[inline]
256 unsafe fn xgeqrf_work_size(m: i32, n: i32, a: &mut [Self], lda: i32, tau: &mut [Self]) -> Result<i32, LapackErrorCode> {
257 let mut info = 0;
258 let mut work = [ Zero::zero() ];
259 let lwork = -1 as i32;
260
261 unsafe { $xgeqrf(m, n, a, lda, tau, &mut work, lwork, &mut info); }
262 check_lapack_info(info)?;
263 Ok(ComplexHelper::real_part(work[0]) as i32)
264 }
265
266 unsafe fn xgeqp3(
267 m: i32,
268 n: i32,
269 a: &mut [Self],
270 lda: i32,
271 jpvt: &mut [i32],
272 tau: &mut [Self],
273 work: &mut [Self],
274 lwork: i32,
275 ) -> Result<(), LapackErrorCode> {
276 let mut info = 0;
277 unsafe { $xgeqp3(m, n, a, lda, jpvt, tau, work, lwork, &mut info) };
278 check_lapack_info(info)
279 }
280
281 unsafe fn xgeqp3_work_size(
282 m: i32,
283 n: i32,
284 a: &mut [Self],
285 lda: i32,
286 jpvt: &mut [i32],
287 tau: &mut [Self],
288 ) -> Result<i32, LapackErrorCode> {
289 let mut work = [Zero::zero()];
290 let lwork = -1 as i32;
291 let mut info = 0;
292 unsafe { $xgeqp3(m, n, a, lda, jpvt, tau, &mut work, lwork, &mut info) };
293 check_lapack_info(info)?;
294 Ok(work[0] as i32)
295 }
296
297 unsafe fn xtrtrs(
298 uplo: TriangularStructure,
299 trans: Transposition,
300 diag: DiagonalKind,
301 n: i32,
302 nrhs: i32,
303 a: &[Self],
304 lda: i32,
305 b: &mut [Self],
306 ldb: i32,
307 ) -> Result<(), LapackErrorCode> {
308 let mut info = 0;
309 let trans = match trans {
310 Transposition::No => b'N',
311 Transposition::Transpose => b'T',
312 };
313
314 unsafe {
315 $xtrtrs(
316 uplo.into_lapack_uplo_character(),
317 trans,
318 diag.into_lapack_diag_character(),
319 n,
320 nrhs,
321 a,
322 lda,
323 b,
324 ldb,
325 &mut info,
326 );
327 }
328
329 check_lapack_info(info)
330 }
331
332 unsafe fn xlapmt(
333 forwrd: bool,
334 m: i32,
335 n: i32,
336 x: &mut [Self],
337 ldx: i32,
338 k: &mut [i32],
339 ) -> Result<(), LapackErrorCode> {
340 debug_assert_eq!(k.len(), n as usize);
341
342 let forward: [i32; 1] = [forwrd.then_some(1).unwrap_or(0)];
343 unsafe { $xlapmt(forward.as_slice(), m, n, x, ldx, k) }
344 Ok(())
345 }
346
347 unsafe fn xlapmr(
348 forwrd: bool,
349 m: i32,
350 n: i32,
351 x: &mut [Self],
352 ldx: i32,
353 k: &mut [i32],
354 ) -> Result<(), LapackErrorCode> {
355 debug_assert_eq!(k.len(), m as usize);
356
357 let forward: [i32; 1] = [forwrd.then_some(1).unwrap_or(0)];
358 unsafe { $xlapmr(forward.as_slice(), m, n, x, ldx, k) }
359 Ok(())
360 }
361
362
363 }
364 )
365);
366
367pub trait QrReal: QrScalar {
370 #[allow(missing_docs)]
371 unsafe fn xorgqr(
372 m: i32,
373 n: i32,
374 k: i32,
375 a: &mut [Self],
376 lda: i32,
377 tau: &[Self],
378 work: &mut [Self],
379 lwork: i32,
380 ) -> Result<(), LapackErrorCode>;
381
382 #[allow(missing_docs)]
383 unsafe fn xorgqr_work_size(
384 m: i32,
385 n: i32,
386 k: i32,
387 a: &mut [Self],
388 lda: i32,
389 tau: &[Self],
390 ) -> Result<i32, LapackErrorCode>;
391
392 #[allow(missing_docs)]
393 unsafe fn xormqr(
394 side: Side,
395 trans: Transposition,
396 m: i32,
397 n: i32,
398 k: i32,
399 a: &[Self],
400 lda: i32,
401 tau: &[Self],
402 c: &mut [Self],
403 ldc: i32,
404 work: &mut [Self],
405 lwork: i32,
406 ) -> Result<(), LapackErrorCode>;
407
408 #[allow(missing_docs)]
409 unsafe fn xormqr_work_size(
410 side: Side,
411 trans: Transposition,
412 m: i32,
413 n: i32,
414 k: i32,
415 a: &[Self],
416 lda: i32,
417 tau: &[Self],
418 c: &mut [Self],
419 ldc: i32,
420 ) -> Result<i32, LapackErrorCode>;
421
422 unsafe fn xtrmm(
424 side: Side,
425 uplo: TriangularStructure,
426 transa: Transposition,
427 diag: DiagonalKind,
428 m: i32,
429 n: i32,
430 alpha: Self,
431 a: &[Self],
432 lda: i32,
433 b: &mut [Self],
434 ldb: i32,
435 );
436}
437
438macro_rules! qr_real_impl(
439 ($type:ty, xorgqr = $xorgqr:path, xormqr = $xormqr:path, xtrmm = $xtrmm:path) => (
440 impl QrReal for $type {
441 #[inline]
442 unsafe fn xorgqr(m: i32, n: i32, k: i32, a: &mut [Self], lda: i32, tau: &[Self],
443 work: &mut [Self], lwork: i32) -> Result<(),LapackErrorCode> {
444 let mut info = 0;
445 unsafe { $xorgqr(m, n, k, a, lda, tau, work, lwork, &mut info) }
446 check_lapack_info(info)
447 }
448
449 #[inline]
450 unsafe fn xorgqr_work_size(m: i32, n: i32, k: i32, a: &mut [Self], lda: i32, tau: &[Self]) -> Result<i32,LapackErrorCode> {
451 let mut info = 0;
452 let mut work = [ Zero::zero() ];
453 let lwork = -1 as i32;
454
455 unsafe { $xorgqr(m, n, k, a, lda, tau, &mut work, lwork, &mut info); }
456 check_lapack_info(info)?;
457 Ok(ComplexHelper::real_part(work[0]) as i32)
458 }
459
460 unsafe fn xormqr(
461 side: Side,
462 trans: Transposition,
463 m: i32,
464 n: i32,
465 k: i32,
466 a: &[Self],
467 lda: i32,
468 tau: &[Self],
469 c: &mut [Self],
470 ldc: i32,
471 work: &mut [Self],
472 lwork: i32,
473 ) -> Result<(), LapackErrorCode> {
474 let mut info = 0;
475 let side = side.into_lapack_side_character();
476
477 let trans = match trans {
479 Transposition::No => b'N',
480 Transposition::Transpose => b'T',
481 };
482
483 unsafe {
484 $xormqr(
485 side, trans, m, n, k, a, lda, tau, c, ldc, work, lwork, &mut info,
486 );
487 }
488 check_lapack_info(info)
489 }
490
491 unsafe fn xormqr_work_size(
492 side: Side,
493 trans: Transposition,
494 m: i32,
495 n: i32,
496 k: i32,
497 a: &[Self],
498 lda: i32,
499 tau: &[Self],
500 c: &mut [Self],
501 ldc: i32,
502 ) -> Result<i32, LapackErrorCode> {
503 let mut info = 0;
504 let side = side.into_lapack_side_character();
505
506 let trans = match trans {
508 Transposition::No => b'N',
509 Transposition::Transpose => b'T',
510 };
511
512 let mut work = [Zero::zero()];
513 let lwork = -1 as i32;
514 unsafe {
515 $xormqr(
516 side, trans, m, n, k, a, lda, tau, c, ldc, &mut work, lwork, &mut info,
517 );
518 }
519 check_lapack_info(info)?;
520 Ok(ComplexHelper::real_part(work[0]) as i32)
522 }
523
524 unsafe fn xtrmm(
525 side: Side,
526 uplo: TriangularStructure,
527 transa: Transposition,
528 diag: DiagonalKind,
529 m: i32,
530 n: i32,
531 alpha: Self,
532 a: &[Self],
533 lda: i32,
534 b: &mut [Self],
535 ldb: i32,
536 ) {
537 let transa = match transa {
539 Transposition::No => b'N',
540 Transposition::Transpose => b'T',
541 };
542
543 unsafe {$xtrmm(
544 side.into_lapack_side_character(),
545 uplo.into_lapack_uplo_character(),
546 transa,
547 diag.into_lapack_diag_character(),
548 m,
549 n,
550 alpha,
551 a,
552 lda,
553 b,
554 ldb
555 )}
556 }
557 }
558 )
559);
560
561qr_scalar_impl!(
562 f32,
563 xgeqrf = lapack::sgeqrf,
564 xgeqp3 = lapack::sgeqp3,
565 xtrtrs = lapack::strtrs,
566 xlapmt = lapack::slapmt,
567 xlapmr = lapack::slapmr
568);
569
570qr_scalar_impl!(
571 f64,
572 xgeqrf = lapack::dgeqrf,
573 xgeqp3 = lapack::dgeqp3,
574 xtrtrs = lapack::dtrtrs,
575 xlapmt = lapack::dlapmt,
576 xlapmr = lapack::dlapmr
577);
578
579qr_real_impl!(
580 f32,
581 xorgqr = lapack::sorgqr,
582 xormqr = lapack::sormqr,
583 xtrmm = blas::strmm
584);
585qr_real_impl!(
586 f64,
587 xorgqr = lapack::dorgqr,
588 xormqr = lapack::dormqr,
589 xtrmm = blas::dtrmm
590);