1use crate::{DiagonalKind, LapackErrorCode, Side, Transposition, TriangularStructure, qr::QrReal};
2use na::{
3 Dim, DimMin, DimMinimum, IsContiguous, Matrix, RawStorage, RawStorageMut, RealField, Vector,
4};
5use num::{ConstOne, Zero};
6
7#[derive(Debug, PartialEq, thiserror::Error)]
9pub enum Error {
10 #[error("incorrect matrix dimensions")]
12 Dimensions,
13 #[error("Lapack returned with error: {0}")]
15 Lapack(#[from] LapackErrorCode),
16 #[error("QR decomposition for underdetermined systems not supported")]
18 Underdetermined,
19 #[error("Matrix has rank zero")]
21 ZeroRank,
22}
23
24pub(crate) fn q_mul_mut<T, R1, C1, S1, C2, S2, S3>(
34 qr: &Matrix<T, R1, C1, S1>,
35 tau: &Vector<T, DimMinimum<R1, C1>, S3>,
36 b: &mut Matrix<T, R1, C2, S2>,
37) -> Result<(), Error>
38where
39 T: QrReal + Zero + RealField,
40 R1: DimMin<C1>,
41 C1: Dim,
42 S1: RawStorage<T, R1, C1> + IsContiguous,
43 C2: Dim,
44 S2: RawStorageMut<T, R1, C2> + IsContiguous,
45 S3: RawStorage<T, <R1 as DimMin<C1>>::Output> + IsContiguous,
46{
47 if b.nrows() != qr.nrows() {
48 return Err(Error::Dimensions);
49 }
50 if qr.ncols().min(qr.nrows()) != tau.len() {
51 return Err(Error::Dimensions);
52 }
53 unsafe { multiply_q_mut(qr, tau, b, Side::Left, Transposition::No)? };
55 Ok(())
56}
57
58pub(crate) fn q_tr_mul_mut<T, R1, C1, S1, C2, S2, S3>(
68 qr: &Matrix<T, R1, C1, S1>,
69 tau: &Vector<T, DimMinimum<R1, C1>, S3>,
70 b: &mut Matrix<T, R1, C2, S2>,
71) -> Result<(), Error>
72where
73 T: QrReal + Zero + RealField,
74 R1: DimMin<C1>,
75 C1: Dim,
76 S1: RawStorage<T, R1, C1> + IsContiguous,
77 C2: Dim,
78 C2: Dim,
79 S2: RawStorageMut<T, R1, C2> + IsContiguous,
80 S3: RawStorage<T, <R1 as DimMin<C1>>::Output> + IsContiguous,
81{
82 if b.nrows() != qr.nrows() {
83 return Err(Error::Dimensions);
84 }
85 if qr.ncols().min(qr.nrows()) != tau.len() {
86 return Err(Error::Dimensions);
87 }
88 unsafe { multiply_q_mut(qr, tau, b, Side::Left, Transposition::Transpose)? };
90 Ok(())
91}
92
93pub(crate) fn mul_q_mut<T, R1, C1, S1, R2, S2, S3>(
103 qr: &Matrix<T, R1, C1, S1>,
104 tau: &Vector<T, DimMinimum<R1, C1>, S3>,
105 b: &mut Matrix<T, R2, R1, S2>,
106) -> Result<(), Error>
107where
108 T: QrReal + Zero + RealField,
109 R1: DimMin<C1>,
110 C1: Dim,
111 S1: RawStorage<T, R1, C1> + IsContiguous,
112 R2: Dim,
113 S2: RawStorageMut<T, R2, R1> + IsContiguous,
114 S3: RawStorage<T, <R1 as DimMin<C1>>::Output> + IsContiguous,
115{
116 if b.ncols() != qr.nrows() {
117 return Err(Error::Dimensions);
118 }
119 if qr.ncols().min(qr.nrows()) != tau.len() {
120 return Err(Error::Dimensions);
121 }
122 unsafe { multiply_q_mut(qr, tau, b, Side::Right, Transposition::No)? };
124 Ok(())
125}
126
127pub(crate) fn mul_q_tr_mut<T, R1, C1, S1, R2, S2, S3>(
137 qr: &Matrix<T, R1, C1, S1>,
138 tau: &Vector<T, DimMinimum<R1, C1>, S3>,
139 b: &mut Matrix<T, R2, R1, S2>,
140) -> Result<(), Error>
141where
142 T: QrReal + Zero + RealField,
143 R1: DimMin<C1>,
144 C1: Dim,
145 S1: RawStorage<T, R1, C1> + IsContiguous,
146 R2: Dim,
147 S2: RawStorageMut<T, R2, R1> + IsContiguous,
148 S3: RawStorage<T, <R1 as DimMin<C1>>::Output> + IsContiguous,
149{
150 if b.ncols() != qr.nrows() {
151 return Err(Error::Dimensions);
152 }
153 if qr.ncols().min(qr.nrows()) != tau.len() {
154 return Err(Error::Dimensions);
155 }
156 unsafe { multiply_q_mut(qr, tau, b, Side::Right, Transposition::Transpose)? }
158 Ok(())
159}
160
161pub(crate) fn qr_solve_mut_with_rank_unpermuted<T, R1, C1, S1, C2: Dim, S3, S2, S4>(
168 qr: &Matrix<T, R1, C1, S1>,
169 tau: &Vector<T, DimMinimum<R1, C1>, S4>,
170 rank: u16,
171 x: &mut Matrix<T, C1, C2, S2>,
172 mut b: Matrix<T, R1, C2, S3>,
173) -> Result<(), Error>
174where
175 T: QrReal + Zero + RealField,
176 R1: DimMin<C1>,
177 C1: Dim,
178 S1: RawStorage<T, R1, C1> + IsContiguous,
179 S3: RawStorageMut<T, R1, C2> + IsContiguous,
180 S2: RawStorageMut<T, C1, C2> + IsContiguous,
181 S4: RawStorage<T, <R1 as DimMin<C1>>::Output> + IsContiguous,
182{
183 if b.nrows() != qr.nrows() {
184 return Err(Error::Dimensions);
185 }
186
187 if qr.nrows() < qr.ncols() || qr.nrows() == 0 || qr.ncols() == 0 {
188 return Err(Error::Underdetermined);
189 }
190
191 if x.ncols() != b.ncols() || x.nrows() != qr.ncols() {
192 return Err(Error::Dimensions);
193 }
194
195 q_tr_mul_mut(qr, tau, &mut b)?;
196
197 if rank == 0 {
198 return Err(Error::ZeroRank);
199 }
200
201 debug_assert!(rank as usize <= qr.ncols().min(qr.nrows()));
202
203 if (rank as usize) < qr.ncols() {
204 x.view_mut((rank as usize, 0), (x.nrows() - rank as usize, x.ncols()))
205 .iter_mut()
206 .for_each(|val| val.set_zero());
207 }
208
209 let x_cols = x.ncols();
210 x.view_mut((0, 0), (rank as usize, x_cols))
211 .copy_from(&b.view((0, 0), (rank as usize, x_cols)));
212
213 let ldb: i32 = x
214 .nrows()
215 .try_into()
216 .expect("integer dimensions out of bounds");
217
218 unsafe {
221 T::xtrtrs(
222 TriangularStructure::Upper,
223 Transposition::No,
224 DiagonalKind::NonUnit,
225 rank.try_into().expect("rank out of bounds"),
226 x.ncols()
227 .try_into()
228 .expect("integer dimensions out of bounds"),
229 qr.as_slice(),
230 qr.nrows()
231 .try_into()
232 .expect("integer dimensions out of bounds"),
233 x.as_mut_slice(),
234 ldb,
235 )?;
236 }
237
238 Ok(())
239}
240
241#[inline]
252unsafe fn multiply_q_mut<T, R1, C1, S1, R2, C2, S2, S3>(
253 qr: &Matrix<T, R1, C1, S1>,
254 tau: &Vector<T, DimMinimum<R1, C1>, S3>,
255 mat: &mut Matrix<T, R2, C2, S2>,
256 side: Side,
257 transpose: Transposition,
258) -> Result<(), Error>
259where
260 T: QrReal,
261 R1: DimMin<C1>,
262 C1: Dim,
263 S2: RawStorageMut<T, R2, C2> + IsContiguous,
264 R2: Dim,
265 C2: Dim,
266 S1: IsContiguous + RawStorage<T, R1, C1>,
267 S3: RawStorage<T, <R1 as DimMin<C1>>::Output> + IsContiguous,
268{
269 let a = qr.as_slice();
270 let lda = qr
271 .nrows()
272 .try_into()
273 .expect("integer dimension out of range");
274 let m = mat
275 .nrows()
276 .try_into()
277 .expect("integer dimension out of range");
278 let n = mat
279 .ncols()
280 .try_into()
281 .expect("integer dimension out of range");
282 let k = tau
283 .len()
284 .try_into()
285 .expect("integer dimension out of range");
286 let ldc = mat
287 .nrows()
288 .try_into()
289 .expect("integer dimension out of range");
290 let c = mat.as_mut_slice();
291 let trans = transpose;
292 let tau = tau.as_slice();
293
294 if k as usize != qr.ncols() {
295 return Err(Error::Dimensions);
296 }
297
298 match side {
301 Side::Left => {
302 if m < k {
303 return Err(Error::Dimensions);
304 }
305
306 if lda < m {
307 return Err(Error::Dimensions);
308 }
309 }
310 Side::Right => {
311 if n < k {
312 return Err(Error::Dimensions);
313 }
314
315 if lda < n {
316 return Err(Error::Dimensions);
317 }
318 }
319 }
320
321 if ldc < m {
322 return Err(Error::Dimensions);
323 }
324
325 let lwork = unsafe { T::xormqr_work_size(side, transpose, m, n, k, a, lda, tau, c, ldc)? };
330 let mut work = vec![T::zero(); lwork as usize];
331
332 unsafe {
335 T::xormqr(side, trans, m, n, k, a, lda, tau, c, ldc, &mut work, lwork)?;
336 }
337 Ok(())
338}
339
340pub fn r_xx_mul_mut<T, R1, C1, S1, C2, S2>(
343 qr: &Matrix<T, R1, C1, S1>,
344 transpose: Transposition,
345 b: &mut Matrix<T, C1, C2, S2>,
346) -> Result<(), Error>
347where
348 T: QrReal + ConstOne,
349 R1: Dim,
350 C1: Dim,
351 C2: Dim,
352 S1: RawStorage<T, R1, C1> + IsContiguous,
353 S2: RawStorageMut<T, C1, C2> + IsContiguous,
354{
355 if qr.nrows() < qr.ncols() {
360 return Err(Error::Underdetermined);
361 }
362
363 if qr.ncols() != b.nrows() {
364 return Err(Error::Dimensions);
365 }
366
367 multiply_r_mut(qr, transpose, Side::Left, b)?;
368 Ok(())
369}
370
371pub fn mul_r_xx_mut<T, R1, C1, S1, R2, S2>(
374 qr: &Matrix<T, R1, C1, S1>,
375 transpose: Transposition,
376 b: &mut Matrix<T, R2, C1, S2>,
377) -> Result<(), Error>
378where
379 T: QrReal + ConstOne,
380 R1: Dim,
381 C1: Dim,
382 R2: Dim,
383 S1: RawStorage<T, R1, C1> + IsContiguous,
384 S2: RawStorageMut<T, R2, C1> + IsContiguous,
385{
386 if qr.nrows() < qr.ncols() {
391 return Err(Error::Underdetermined);
392 }
393
394 if b.ncols() != qr.ncols() {
395 return Err(Error::Dimensions);
396 }
397
398 multiply_r_mut(qr, transpose, Side::Right, b)?;
399 Ok(())
400}
401#[inline]
416fn multiply_r_mut<T, R1, C1, S1, R2, C2, S2>(
417 qr: &Matrix<T, R1, C1, S1>,
418 transpose: Transposition,
419 side: Side,
420 mat: &mut Matrix<T, R2, C2, S2>,
421) -> Result<(), Error>
422where
423 T: QrReal + ConstOne,
424 R1: Dim,
425 C1: Dim,
426 S2: RawStorageMut<T, R2, C2> + IsContiguous,
427 R2: Dim,
428 C2: Dim,
429 S1: IsContiguous + RawStorage<T, R1, C1>,
430{
431 let m: i32 = mat
432 .nrows()
433 .try_into()
434 .expect("integer dimensions out of bounds");
435 let n: i32 = mat
436 .ncols()
437 .try_into()
438 .expect("integer dimensions out of bounds");
439 let lda: i32 = qr
440 .nrows()
441 .try_into()
442 .expect("integer dimensions out of bounds");
443 let ldb: i32 = mat
444 .nrows()
445 .try_into()
446 .expect("integer dimensions out of bounds");
447
448 match side {
451 Side::Left => {
452 if lda == 0 || lda < m {
453 return Err(Error::Dimensions);
454 }
455 if qr.ncols() != m as usize {
456 return Err(Error::Dimensions);
457 }
458 }
459 Side::Right => {
460 if lda == 0 || lda < n {
461 return Err(Error::Dimensions);
462 }
463 if qr.ncols() != n as usize {
464 return Err(Error::Dimensions);
465 }
466 }
467 }
468
469 unsafe {
472 T::xtrmm(
473 side,
474 TriangularStructure::Upper,
475 transpose,
476 DiagonalKind::NonUnit,
477 m,
478 n,
479 T::ONE,
480 qr.as_slice(),
481 lda,
482 mat.as_mut_slice(),
483 ldb,
484 );
485 }
486 Ok(())
487}