1#[cfg(feature = "serde-serialize")]
2use serde::{Deserialize, Serialize};
3
4use num::{One, Zero};
5use num_complex::Complex;
6
7use crate::ComplexHelper;
8use na::allocator::Allocator;
9use na::dimension::{Const, Dim, DimMin, DimMinimum};
10use na::storage::Storage;
11use na::{DefaultAllocator, Matrix, OMatrix, OVector, Scalar};
12
13use lapack;
14
15#[cfg_attr(feature = "serde-serialize", derive(Serialize, Deserialize))]
24#[cfg_attr(
25 feature = "serde-serialize",
26 serde(bound(serialize = "DefaultAllocator: Allocator<R, C> +
27 Allocator<DimMinimum<R, C>>,
28 OMatrix<T, R, C>: Serialize,
29 OVector<i32, DimMinimum<R, C>>: Serialize"))
30)]
31#[cfg_attr(
32 feature = "serde-serialize",
33 serde(bound(deserialize = "DefaultAllocator: Allocator<R, C> +
34 Allocator<DimMinimum<R, C>>,
35 OMatrix<T, R, C>: Deserialize<'de>,
36 OVector<i32, DimMinimum<R, C>>: Deserialize<'de>"))
37)]
38#[derive(Clone, Debug)]
39pub struct LU<T: Scalar, R: DimMin<C>, C: Dim>
40where
41 DefaultAllocator: Allocator<DimMinimum<R, C>> + Allocator<R, C>,
42{
43 lu: OMatrix<T, R, C>,
44 p: OVector<i32, DimMinimum<R, C>>,
45 singular: bool,
48}
49
50impl<T: Scalar + Copy, R: DimMin<C>, C: Dim> Copy for LU<T, R, C>
51where
52 DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
53 OMatrix<T, R, C>: Copy,
54 OVector<i32, DimMinimum<R, C>>: Copy,
55{
56}
57
58impl<T: LUScalar, R: Dim, C: Dim> LU<T, R, C>
59where
60 T: Zero + One,
61 R: DimMin<C>,
62 DefaultAllocator: Allocator<R, C>
63 + Allocator<R, R>
64 + Allocator<R, DimMinimum<R, C>>
65 + Allocator<DimMinimum<R, C>, C>
66 + Allocator<DimMinimum<R, C>>,
67{
68 pub fn new(mut m: OMatrix<T, R, C>) -> Self {
70 let (nrows, ncols) = m.shape_generic();
71 let min_nrows_ncols = nrows.min(ncols);
72 let nrows = nrows.value() as i32;
73 let ncols = ncols.value() as i32;
74
75 let mut ipiv: OVector<i32, _> = Matrix::zeros_generic(min_nrows_ncols, Const::<1>);
76
77 let mut info = 0;
78
79 T::xgetrf(
80 nrows,
81 ncols,
82 m.as_mut_slice(),
83 nrows,
84 ipiv.as_mut_slice(),
85 &mut info,
86 );
87
88 if info < 0 {
103 lapack_panic!(info);
104 }
105
106 Self {
107 lu: m,
108 p: ipiv,
109 singular: info > 0,
110 }
111 }
112
113 #[inline]
115 #[must_use]
116 pub fn l(&self) -> OMatrix<T, R, DimMinimum<R, C>> {
117 let (nrows, ncols) = self.lu.shape_generic();
118 let mut res = self.lu.columns_generic(0, nrows.min(ncols)).into_owned();
119
120 res.fill_upper_triangle(Zero::zero(), 1);
121 res.fill_diagonal(One::one());
122
123 res
124 }
125
126 #[inline]
128 #[must_use]
129 pub fn u(&self) -> OMatrix<T, DimMinimum<R, C>, C> {
130 let (nrows, ncols) = self.lu.shape_generic();
131 let mut res = self.lu.rows_generic(0, nrows.min(ncols)).into_owned();
132
133 res.fill_lower_triangle(Zero::zero(), 1);
134
135 res
136 }
137
138 #[inline]
143 #[must_use]
144 pub fn p(&self) -> OMatrix<T, R, R> {
145 let (dim, _) = self.lu.shape_generic();
146 let mut id = Matrix::identity_generic(dim, dim);
147 self.permute(&mut id);
148
149 id
150 }
151
152 #[inline]
157 #[must_use]
158 pub fn permutation_indices(&self) -> &OVector<i32, DimMinimum<R, C>> {
159 &self.p
160 }
161
162 #[inline]
164 pub fn permute<C2: Dim>(&self, rhs: &mut OMatrix<T, R, C2>)
165 where
166 DefaultAllocator: Allocator<R, C2>,
167 {
168 let (nrows, ncols) = rhs.shape();
169
170 T::xlaswp(
171 ncols as i32,
172 rhs.as_mut_slice(),
173 nrows as i32,
174 1,
175 self.p.len() as i32,
176 self.p.as_slice(),
177 -1,
178 );
179 }
180
181 fn generic_solve_mut<R2: Dim, C2: Dim>(&self, trans: u8, b: &mut OMatrix<T, R2, C2>) -> bool
182 where
183 DefaultAllocator: Allocator<R2, C2> + Allocator<R2>,
184 {
185 if self.singular {
186 return false;
187 }
188
189 let dim = self.lu.nrows();
190
191 assert!(
192 self.lu.is_square(),
193 "Unable to solve a set of under/over-determined equations."
194 );
195 assert!(
196 b.nrows() == dim,
197 "The number of rows of `b` must be equal to the dimension of the matrix `a`."
198 );
199
200 let nrhs = b.ncols() as i32;
201 let lda = dim as i32;
202 let ldb = dim as i32;
203 let mut info = 0;
204
205 T::xgetrs(
206 trans,
207 dim as i32,
208 nrhs,
209 self.lu.as_slice(),
210 lda,
211 self.p.as_slice(),
212 b.as_mut_slice(),
213 ldb,
214 &mut info,
215 );
216 lapack_test!(info)
217 }
218
219 pub fn solve<R2: Dim, C2: Dim, S2>(
221 &self,
222 b: &Matrix<T, R2, C2, S2>,
223 ) -> Option<OMatrix<T, R2, C2>>
224 where
225 S2: Storage<T, R2, C2>,
226 DefaultAllocator: Allocator<R2, C2> + Allocator<R2>,
227 {
228 let mut res = b.clone_owned();
229 if self.generic_solve_mut(b'N', &mut res) {
230 Some(res)
231 } else {
232 None
233 }
234 }
235
236 pub fn solve_transpose<R2: Dim, C2: Dim, S2>(
239 &self,
240 b: &Matrix<T, R2, C2, S2>,
241 ) -> Option<OMatrix<T, R2, C2>>
242 where
243 S2: Storage<T, R2, C2>,
244 DefaultAllocator: Allocator<R2, C2> + Allocator<R2>,
245 {
246 let mut res = b.clone_owned();
247 if self.generic_solve_mut(b'T', &mut res) {
248 Some(res)
249 } else {
250 None
251 }
252 }
253
254 pub fn solve_conjugate_transpose<R2: Dim, C2: Dim, S2>(
257 &self,
258 b: &Matrix<T, R2, C2, S2>,
259 ) -> Option<OMatrix<T, R2, C2>>
260 where
261 S2: Storage<T, R2, C2>,
262 DefaultAllocator: Allocator<R2, C2> + Allocator<R2>,
263 {
264 let mut res = b.clone_owned();
265 if self.generic_solve_mut(b'C', &mut res) {
266 Some(res)
267 } else {
268 None
269 }
270 }
271
272 pub fn solve_mut<R2: Dim, C2: Dim>(&self, b: &mut OMatrix<T, R2, C2>) -> bool
276 where
277 DefaultAllocator: Allocator<R2, C2> + Allocator<R2>,
278 {
279 self.generic_solve_mut(b'N', b)
280 }
281
282 pub fn solve_transpose_mut<R2: Dim, C2: Dim>(&self, b: &mut OMatrix<T, R2, C2>) -> bool
287 where
288 DefaultAllocator: Allocator<R2, C2> + Allocator<R2>,
289 {
290 self.generic_solve_mut(b'T', b)
291 }
292
293 pub fn solve_adjoint_mut<R2: Dim, C2: Dim>(&self, b: &mut OMatrix<T, R2, C2>) -> bool
298 where
299 DefaultAllocator: Allocator<R2, C2> + Allocator<R2>,
300 {
301 self.generic_solve_mut(b'C', b)
302 }
303}
304
305impl<T: LUScalar, D: Dim> LU<T, D, D>
306where
307 T: Zero + One,
308 D: DimMin<D, Output = D>,
309 DefaultAllocator: Allocator<D, D> + Allocator<D>,
310{
311 pub fn inverse(mut self) -> Option<OMatrix<T, D, D>> {
313 if self.singular {
314 return None;
315 }
316
317 let dim = self.lu.nrows() as i32;
318 let mut info = 0;
319 let lwork = T::xgetri_work_size(
320 dim,
321 self.lu.as_mut_slice(),
322 dim,
323 self.p.as_mut_slice(),
324 &mut info,
325 );
326 lapack_check!(info);
327
328 let mut work = vec![T::zero(); lwork as usize];
329
330 T::xgetri(
331 dim,
332 self.lu.as_mut_slice(),
333 dim,
334 self.p.as_mut_slice(),
335 &mut work,
336 lwork,
337 &mut info,
338 );
339 lapack_check!(info);
340
341 Some(self.lu)
342 }
343}
344
345pub trait LUScalar: Scalar + Copy {
352 #[allow(missing_docs)]
353 fn xgetrf(m: i32, n: i32, a: &mut [Self], lda: i32, ipiv: &mut [i32], info: &mut i32);
354 #[allow(missing_docs)]
355 fn xlaswp(n: i32, a: &mut [Self], lda: i32, k1: i32, k2: i32, ipiv: &[i32], incx: i32);
356 #[allow(missing_docs)]
357 fn xgetrs(
358 trans: u8,
359 n: i32,
360 nrhs: i32,
361 a: &[Self],
362 lda: i32,
363 ipiv: &[i32],
364 b: &mut [Self],
365 ldb: i32,
366 info: &mut i32,
367 );
368 #[allow(missing_docs)]
369 fn xgetri(
370 n: i32,
371 a: &mut [Self],
372 lda: i32,
373 ipiv: &[i32],
374 work: &mut [Self],
375 lwork: i32,
376 info: &mut i32,
377 );
378 #[allow(missing_docs)]
379 fn xgetri_work_size(n: i32, a: &mut [Self], lda: i32, ipiv: &[i32], info: &mut i32) -> i32;
380}
381
382macro_rules! lup_scalar_impl(
383 ($N: ty, $xgetrf: path, $xlaswp: path, $xgetrs: path, $xgetri: path) => (
384 impl LUScalar for $N {
385 #[inline]
386 fn xgetrf(m: i32, n: i32, a: &mut [Self], lda: i32, ipiv: &mut [i32], info: &mut i32) {
387 unsafe { $xgetrf(m, n, a, lda, ipiv, info) }
388 }
389
390 #[inline]
391 fn xlaswp(n: i32, a: &mut [Self], lda: i32, k1: i32, k2: i32, ipiv: &[i32], incx: i32) {
392 unsafe { $xlaswp(n, a, lda, k1, k2, ipiv, incx) }
393 }
394
395 #[inline]
396 fn xgetrs(trans: u8, n: i32, nrhs: i32, a: &[Self], lda: i32, ipiv: &[i32],
397 b: &mut [Self], ldb: i32, info: &mut i32) {
398 unsafe { $xgetrs(trans, n, nrhs, a, lda, ipiv, b, ldb, info) }
399 }
400
401 #[inline]
402 fn xgetri(n: i32, a: &mut [Self], lda: i32, ipiv: &[i32],
403 work: &mut [Self], lwork: i32, info: &mut i32) {
404 unsafe { $xgetri(n, a, lda, ipiv, work, lwork, info) }
405 }
406
407 #[inline]
408 fn xgetri_work_size(n: i32, a: &mut [Self], lda: i32, ipiv: &[i32], info: &mut i32) -> i32 {
409 let mut work = [ Zero::zero() ];
410 let lwork = -1 as i32;
411
412 unsafe { $xgetri(n, a, lda, ipiv, &mut work, lwork, info); }
413 ComplexHelper::real_part(work[0]) as i32
414 }
415 }
416 )
417);
418
419lup_scalar_impl!(
420 f32,
421 lapack::sgetrf,
422 lapack::slaswp,
423 lapack::sgetrs,
424 lapack::sgetri
425);
426lup_scalar_impl!(
427 f64,
428 lapack::dgetrf,
429 lapack::dlaswp,
430 lapack::dgetrs,
431 lapack::dgetri
432);
433lup_scalar_impl!(
434 Complex<f32>,
435 lapack::cgetrf,
436 lapack::claswp,
437 lapack::cgetrs,
438 lapack::cgetri
439);
440lup_scalar_impl!(
441 Complex<f64>,
442 lapack::zgetrf,
443 lapack::zlaswp,
444 lapack::zgetrs,
445 lapack::zgetri
446);