mdarray_linalg_lapack/lu/
context.rs

1//! LU Decomposition:
2//!     P * A = L * U
3//! where:
4//!     - A is m × n (input matrix)
5//!     - P is m × m (permutation matrix, represented by pivot vector)
6//!     - L is m × min(m,n) (lower triangular matrix with unit diagonal)
7//!     - U is min(m,n) × n (upper triangular matrix)
8//! This decomposition is used to solve linear systems, compute matrix determinants, and matrix inversion.
9//! The function `getrf` (LAPACK) computes the LU factorization of a general m-by-n matrix A using partial pivoting.
10//! The matrix L is lower triangular with unit diagonal, and U is upper triangular.
11use super::simple::{getrf, getri, potrf};
12use mdarray_linalg::{get_dims, ipiv_to_perm_mat, transpose_in_place};
13
14use super::scalar::{LapackScalar, Workspace};
15use mdarray::{DSlice, DTensor, Dense, Layout, tensor};
16use mdarray_linalg::into_i32;
17use mdarray_linalg::lu::{InvError, InvResult, LU};
18use num_complex::ComplexFloat;
19
20use crate::Lapack;
21
22impl<T> LU<T> for Lapack
23where
24    T: ComplexFloat + Default + LapackScalar + Workspace,
25    T::Real: Into<T>,
26{
27    fn lu_overwrite<L: Layout, Ll: Layout, Lu: Layout, Lp: Layout>(
28        &self,
29        a: &mut DSlice<T, 2, L>,
30        l: &mut DSlice<T, 2, Ll>,
31        u: &mut DSlice<T, 2, Lu>,
32        p: &mut DSlice<T, 2, Lp>,
33    ) {
34        let (m, _) = get_dims!(a);
35        let ipiv = getrf(a, l, u);
36
37        let p_matrix = ipiv_to_perm_mat::<T>(&ipiv, m as usize);
38
39        for i in 0..(m as usize) {
40            for j in 0..(m as usize) {
41                p[[i, j]] = p_matrix[[i, j]];
42            }
43        }
44    }
45
46    fn lu<L: Layout>(
47        &self,
48        a: &mut DSlice<T, 2, L>,
49    ) -> (DTensor<T, 2>, DTensor<T, 2>, DTensor<T, 2>) {
50        let (m, n) = get_dims!(a);
51        let min_mn = m.min(n);
52        let mut l = tensor![[T::default(); min_mn as usize]; m as usize];
53        let mut u = tensor![[T::default(); n as usize]; min_mn as usize];
54        let ipiv = getrf::<_, Dense, Dense, T>(a, &mut l, &mut u);
55
56        let p_matrix = ipiv_to_perm_mat::<T>(&ipiv, m as usize);
57
58        (l, u, p_matrix)
59    }
60
61    fn inv_overwrite<L: Layout>(&self, a: &mut DSlice<T, 2, L>) -> Result<(), InvError> {
62        let (m, n) = get_dims!(a);
63        if m != n {
64            return Err(InvError::NotSquare { rows: m, cols: n });
65        }
66
67        let min_mn = m.min(n);
68        let mut l = DTensor::<T, 2>::zeros([m as usize, min_mn as usize]);
69        let mut u = DTensor::<T, 2>::zeros([min_mn as usize, n as usize]);
70        let mut ipiv = getrf::<_, Dense, Dense, T>(a, &mut l, &mut u);
71
72        match getri::<_, T>(a, &mut ipiv) {
73            0 => Ok(()),
74            i if i > 0 => Err(InvError::Singular { pivot: i }),
75            i => Err(InvError::BackendError(i)),
76        }
77    }
78
79    fn inv<L: Layout>(&self, a: &mut DSlice<T, 2, L>) -> InvResult<T> {
80        let (m, n) = get_dims!(a);
81        if m != n {
82            return Err(InvError::NotSquare { rows: m, cols: n });
83        }
84
85        let mut a_inv = DTensor::<T, 2>::zeros([n as usize, n as usize]);
86        for i in 0..n as usize {
87            for j in 0..m as usize {
88                a_inv[[i, j]] = a[[i, j]];
89            }
90        }
91
92        let min_mn = m.min(n);
93        let mut l = DTensor::<T, 2>::zeros([m as usize, min_mn as usize]);
94        let mut u = DTensor::<T, 2>::zeros([min_mn as usize, n as usize]);
95        let mut ipiv = getrf::<_, Dense, Dense, T>(&mut a_inv, &mut l, &mut u);
96
97        match getri::<_, T>(&mut a_inv, &mut ipiv) {
98            0 => Ok(a_inv),
99            i if i > 0 => Err(InvError::Singular { pivot: i }),
100            i => Err(InvError::BackendError(i)),
101        }
102    }
103
104    fn det<L: Layout>(&self, a: &mut DSlice<T, 2, L>) -> T {
105        let (m, n) = get_dims!(a);
106        assert_eq!(m, n, "determinant is only defined for square matrices");
107
108        let mut l = tensor![[T::default(); n as usize]; n as usize];
109        let mut u = tensor![[T::default(); n as usize]; n as usize];
110
111        let ipiv = getrf::<_, Dense, Dense, T>(a, &mut l, &mut u);
112
113        let mut det = T::one();
114        for i in 0..n as usize {
115            det = det * u[[i, i]];
116        }
117
118        let mut sign = T::one();
119        for (i, &pivot) in ipiv.iter().enumerate() {
120            if (i as i32) != (pivot - 1) {
121                sign = sign * (-T::one());
122            }
123        }
124        det * sign
125    }
126
127    /// Computes the Cholesky decomposition, returning a lower-triangular matrix
128    fn choleski<L: Layout>(&self, a: &mut DSlice<T, 2, L>) -> InvResult<T> {
129        let (m, n) = get_dims!(a);
130        assert_eq!(m, n, "Matrix must be square for Cholesky decomposition");
131
132        let mut l = DTensor::<T, 2>::zeros([m as usize, n as usize]);
133
134        match potrf::<_, T>(a, 'L') {
135            0 => {
136                for i in 0..(m as usize) {
137                    for j in 0..(n as usize) {
138                        if i >= j {
139                            l[[i, j]] = a[[j, i]];
140                        } else {
141                            l[[i, j]] = T::zero();
142                        }
143                    }
144                }
145                Ok(l)
146            }
147            i if i > 0 => Err(InvError::NotPositiveDefinite { lpm: i }),
148            i => Err(InvError::BackendError(i)),
149        }
150    }
151
152    /// Computes the Cholesky decomposition in-place, overwriting the input matrix
153    fn choleski_overwrite<L: Layout>(&self, a: &mut DSlice<T, 2, L>) -> Result<(), InvError> {
154        let (m, n) = get_dims!(a);
155        assert_eq!(m, n, "Matrix must be square for Cholesky decomposition");
156
157        match potrf::<_, T>(a, 'L') {
158            0 => {
159                transpose_in_place(a);
160                Ok(())
161            }
162            i if i > 0 => Err(InvError::NotPositiveDefinite { lpm: i }),
163            i => Err(InvError::BackendError(i)),
164        }
165    }
166}