mdarray_linalg_lapack/lu/
context.rs1use super::simple::{getrf, getri, potrf};
12use mdarray_linalg::{get_dims, ipiv_to_perm_mat, transpose_in_place};
13
14use super::scalar::{LapackScalar, Workspace};
15use mdarray::{DSlice, DTensor, Dense, Layout, tensor};
16use mdarray_linalg::into_i32;
17use mdarray_linalg::lu::{InvError, InvResult, LU};
18use num_complex::ComplexFloat;
19
20use crate::Lapack;
21
22impl<T> LU<T> for Lapack
23where
24 T: ComplexFloat + Default + LapackScalar + Workspace,
25 T::Real: Into<T>,
26{
27 fn lu_overwrite<L: Layout, Ll: Layout, Lu: Layout, Lp: Layout>(
28 &self,
29 a: &mut DSlice<T, 2, L>,
30 l: &mut DSlice<T, 2, Ll>,
31 u: &mut DSlice<T, 2, Lu>,
32 p: &mut DSlice<T, 2, Lp>,
33 ) {
34 let (m, _) = get_dims!(a);
35 let ipiv = getrf(a, l, u);
36
37 let p_matrix = ipiv_to_perm_mat::<T>(&ipiv, m as usize);
38
39 for i in 0..(m as usize) {
40 for j in 0..(m as usize) {
41 p[[i, j]] = p_matrix[[i, j]];
42 }
43 }
44 }
45
46 fn lu<L: Layout>(
47 &self,
48 a: &mut DSlice<T, 2, L>,
49 ) -> (DTensor<T, 2>, DTensor<T, 2>, DTensor<T, 2>) {
50 let (m, n) = get_dims!(a);
51 let min_mn = m.min(n);
52 let mut l = tensor![[T::default(); min_mn as usize]; m as usize];
53 let mut u = tensor![[T::default(); n as usize]; min_mn as usize];
54 let ipiv = getrf::<_, Dense, Dense, T>(a, &mut l, &mut u);
55
56 let p_matrix = ipiv_to_perm_mat::<T>(&ipiv, m as usize);
57
58 (l, u, p_matrix)
59 }
60
61 fn inv_overwrite<L: Layout>(&self, a: &mut DSlice<T, 2, L>) -> Result<(), InvError> {
62 let (m, n) = get_dims!(a);
63 if m != n {
64 return Err(InvError::NotSquare { rows: m, cols: n });
65 }
66
67 let min_mn = m.min(n);
68 let mut l = DTensor::<T, 2>::zeros([m as usize, min_mn as usize]);
69 let mut u = DTensor::<T, 2>::zeros([min_mn as usize, n as usize]);
70 let mut ipiv = getrf::<_, Dense, Dense, T>(a, &mut l, &mut u);
71
72 match getri::<_, T>(a, &mut ipiv) {
73 0 => Ok(()),
74 i if i > 0 => Err(InvError::Singular { pivot: i }),
75 i => Err(InvError::BackendError(i)),
76 }
77 }
78
79 fn inv<L: Layout>(&self, a: &mut DSlice<T, 2, L>) -> InvResult<T> {
80 let (m, n) = get_dims!(a);
81 if m != n {
82 return Err(InvError::NotSquare { rows: m, cols: n });
83 }
84
85 let mut a_inv = DTensor::<T, 2>::zeros([n as usize, n as usize]);
86 for i in 0..n as usize {
87 for j in 0..m as usize {
88 a_inv[[i, j]] = a[[i, j]];
89 }
90 }
91
92 let min_mn = m.min(n);
93 let mut l = DTensor::<T, 2>::zeros([m as usize, min_mn as usize]);
94 let mut u = DTensor::<T, 2>::zeros([min_mn as usize, n as usize]);
95 let mut ipiv = getrf::<_, Dense, Dense, T>(&mut a_inv, &mut l, &mut u);
96
97 match getri::<_, T>(&mut a_inv, &mut ipiv) {
98 0 => Ok(a_inv),
99 i if i > 0 => Err(InvError::Singular { pivot: i }),
100 i => Err(InvError::BackendError(i)),
101 }
102 }
103
104 fn det<L: Layout>(&self, a: &mut DSlice<T, 2, L>) -> T {
105 let (m, n) = get_dims!(a);
106 assert_eq!(m, n, "determinant is only defined for square matrices");
107
108 let mut l = tensor![[T::default(); n as usize]; n as usize];
109 let mut u = tensor![[T::default(); n as usize]; n as usize];
110
111 let ipiv = getrf::<_, Dense, Dense, T>(a, &mut l, &mut u);
112
113 let mut det = T::one();
114 for i in 0..n as usize {
115 det = det * u[[i, i]];
116 }
117
118 let mut sign = T::one();
119 for (i, &pivot) in ipiv.iter().enumerate() {
120 if (i as i32) != (pivot - 1) {
121 sign = sign * (-T::one());
122 }
123 }
124 det * sign
125 }
126
127 fn choleski<L: Layout>(&self, a: &mut DSlice<T, 2, L>) -> InvResult<T> {
129 let (m, n) = get_dims!(a);
130 assert_eq!(m, n, "Matrix must be square for Cholesky decomposition");
131
132 let mut l = DTensor::<T, 2>::zeros([m as usize, n as usize]);
133
134 match potrf::<_, T>(a, 'L') {
135 0 => {
136 for i in 0..(m as usize) {
137 for j in 0..(n as usize) {
138 if i >= j {
139 l[[i, j]] = a[[j, i]];
140 } else {
141 l[[i, j]] = T::zero();
142 }
143 }
144 }
145 Ok(l)
146 }
147 i if i > 0 => Err(InvError::NotPositiveDefinite { lpm: i }),
148 i => Err(InvError::BackendError(i)),
149 }
150 }
151
152 fn choleski_overwrite<L: Layout>(&self, a: &mut DSlice<T, 2, L>) -> Result<(), InvError> {
154 let (m, n) = get_dims!(a);
155 assert_eq!(m, n, "Matrix must be square for Cholesky decomposition");
156
157 match potrf::<_, T>(a, 'L') {
158 0 => {
159 transpose_in_place(a);
160 Ok(())
161 }
162 i if i > 0 => Err(InvError::NotPositiveDefinite { lpm: i }),
163 i => Err(InvError::BackendError(i)),
164 }
165 }
166}