1use crate::internal_prelude::*;
2use crate::perm::swap_rows_idx;
3use crate::{assert, debug_assert};
4
5#[math]
6#[inline]
7fn swap_elems<T: ComplexField>(col: ColMut<'_, T>, i: usize, j: usize) {
8 debug_assert!(all(i < col.nrows(), j < col.nrows()));
9 let rs = col.row_stride();
10 let col = col.as_ptr_mut();
11 unsafe {
12 let a = col.offset(i as isize * rs);
13 let b = col.offset(j as isize * rs);
14 core::ptr::swap(a, b);
15 }
16}
17
18#[math]
19fn lu_in_place_unblocked<I: Index, T: ComplexField>(matrix: MatMut<'_, T>, start: usize, end: usize, trans: &mut [I]) -> usize {
20 let mut matrix = matrix;
21 let m = matrix.nrows();
22
23 if start == end {
24 return 0;
25 }
26
27 let mut n_trans = 0;
28
29 for j in start..end {
30 let col = j;
31 let row = j - start;
32
33 let t = &mut trans[row];
34 let mut imax = row;
35 let mut max = zero();
36
37 for i in imax..m {
38 let abs = abs1(matrix[(i, col)]);
39 if abs > max {
40 max = abs;
41 imax = i;
42 }
43 }
44
45 *t = I::truncate(imax - row);
46
47 if imax != row {
48 swap_rows_idx(matrix.rb_mut(), row, imax);
49 n_trans += 1;
50 }
51
52 let mut matrix = matrix.rb_mut().get_mut(.., start..end);
53
54 let inv = recip(matrix[(row, row)]);
55 for i in row + 1..m {
56 matrix[(i, row)] = matrix[(i, row)] * inv;
57 }
58
59 let (_, A01, A10, A11) = matrix.rb_mut().split_at_mut(row + 1, row + 1);
60 let A01 = A01.row(row);
61 let A10 = A10.col(row);
62 linalg::matmul::matmul(A11, Accum::Add, A10.as_mat(), A01.as_mat(), -one::<T>(), Par::Seq);
63 }
64
65 n_trans
66}
67
68#[math]
69pub(crate) fn lu_in_place_recursion<I: Index, T: ComplexField>(
70 A: MatMut<'_, T>,
71 start: usize,
72 end: usize,
73 trans: &mut [I],
74 par: Par,
75 params: Spec<PartialPivLuParams, T>,
76) -> usize {
77 let params = params.config;
78 let mut A = A;
79 let m = A.nrows();
80 let ncols = A.ncols();
81 let n = end - start;
82
83 if n <= params.recursion_threshold {
84 return lu_in_place_unblocked(A, start, end, trans);
85 }
86
87 let half = n / 2;
88 let pow = Ord::min(16, half.next_power_of_two());
89
90 let blocksize = half.next_multiple_of(pow);
91
92 let mut n_trans = 0;
93
94 assert!(n <= m);
95
96 n_trans += lu_in_place_recursion(
97 A.rb_mut().get_mut(.., start..end),
98 0,
99 blocksize,
100 &mut trans[..blocksize],
101 par,
102 params.into(),
103 );
104
105 {
106 let mut A = A.rb_mut().get_mut(.., start..end);
107 let (A00, mut A01, A10, mut A11) = A.rb_mut().split_at_mut(blocksize, blocksize);
108
109 let A00 = A00.rb();
110 let A10 = A10.rb();
111 {
112 linalg::triangular_solve::solve_unit_lower_triangular_in_place(A00.rb(), A01.rb_mut(), par);
113 }
114
115 linalg::matmul::matmul(A11.rb_mut(), Accum::Add, A10.rb(), A01.rb(), -one::<T>(), par);
116
117 n_trans += lu_in_place_recursion(
118 A.rb_mut().get_mut(blocksize..m, ..),
119 blocksize,
120 n,
121 &mut trans[blocksize..n],
122 par,
123 params.into(),
124 );
125 }
126
127 let swap = |mat: MatMut<'_, T>| {
128 let mut mat = mat;
129 for j in 0..mat.ncols() {
130 let mut col = mat.rb_mut().col_mut(j);
131
132 if col.row_stride() == 1 {
133 for (j, &t) in trans[..n].iter().enumerate() {
134 swap_elems(col.rb_mut(), j, t.zx() + j);
135 }
136 } else {
137 for (j, &t) in trans[..n].iter().enumerate() {
138 swap_elems(col.rb_mut(), j, t.zx() + j);
139 }
140 }
141 }
142 };
143
144 let (A_left, A_right) = A.rb_mut().split_at_col_mut(start);
145 let A_right = A_right.get_mut(.., end - start..ncols - start);
146
147 let par = if m * (ncols - n) > params.par_threshold { par } else { Par::Seq };
148
149 match par {
150 Par::Seq => {
151 swap(A_left);
152 swap(A_right);
153 },
154 #[cfg(feature = "rayon")]
155 Par::Rayon(nthreads) => {
156 let nthreads = nthreads.get();
157 let len = (A_left.ncols() + A_right.ncols()) as f64;
158 let left_threads = Ord::min((nthreads as f64 * (A_left.ncols() as f64 / len)) as usize, nthreads);
159 let right_threads = nthreads - left_threads;
160
161 use rayon::prelude::*;
162 rayon::join(
163 || {
164 if A_left.ncols() > 0 {
165 A_left.par_col_partition_mut(left_threads).for_each(|A| swap(A))
166 }
167 },
168 || {
169 if A_right.ncols() > 0 {
170 A_right.par_col_partition_mut(right_threads).for_each(|A| swap(A))
171 }
172 },
173 );
174 },
175 }
176
177 n_trans
178}
179
180#[derive(Copy, Clone, Debug)]
182pub struct PartialPivLuParams {
183 pub recursion_threshold: usize,
185 pub blocksize: usize,
187 pub par_threshold: usize,
189
190 #[doc(hidden)]
191 pub non_exhaustive: NonExhaustive,
192}
193
194#[derive(Copy, Clone, Debug)]
196pub struct PartialPivLuInfo {
197 pub transposition_count: usize,
200}
201
202#[derive(Copy, Clone, Debug)]
204pub enum LdltError {
205 ZeroPivot { index: usize },
206}
207
208impl<T: ComplexField> Auto<T> for PartialPivLuParams {
209 #[inline]
210 fn auto() -> Self {
211 Self {
212 recursion_threshold: 16,
213 blocksize: 64,
214 par_threshold: 128 * 128,
215 non_exhaustive: NonExhaustive(()),
216 }
217 }
218}
219
220#[inline]
221pub fn lu_in_place_scratch<I: Index, T: ComplexField>(nrows: usize, ncols: usize, par: Par, params: Spec<PartialPivLuParams, T>) -> StackReq {
222 _ = par;
223 _ = params;
224 StackReq::new::<I>(Ord::min(nrows, ncols))
225}
226
227pub fn lu_in_place<'out, I: Index, T: ComplexField>(
228 A: MatMut<'_, T>,
229 perm: &'out mut [I],
230 perm_inv: &'out mut [I],
231 par: Par,
232 stack: &mut MemStack,
233 params: Spec<PartialPivLuParams, T>,
234) -> (PartialPivLuInfo, PermRef<'out, I>) {
235 let _ = ¶ms;
236 let truncate = I::truncate;
237
238 #[cfg(feature = "perf-warn")]
239 if (A.col_stride().unsigned_abs() == 1 || A.row_stride().unsigned_abs() != 1) && crate::__perf_warn!(LU_WARN) {
240 log::warn!(target: "faer_perf", "LU with partial pivoting prefers column-major or row-major matrix. Found matrix with generic strides.");
241 }
242
243 let mut matrix = A;
244 let mut stack = stack;
245 let m = matrix.nrows();
246 let n = matrix.ncols();
247
248 let size = Ord::min(n, m);
249
250 for i in 0..m {
251 let p = &mut perm[i];
252 *p = truncate(i);
253 }
254
255 let (mut transpositions, _) = stack.rb_mut().make_with(size, |_| truncate(0));
256 let transpositions = transpositions.as_mut();
257
258 let n_transpositions = lu_in_place_recursion(matrix.rb_mut(), 0, size, transpositions.as_mut(), par, params);
259
260 for idx in 0..size {
261 let t = transpositions[idx];
262 perm.as_mut().swap(idx, idx + t.zx());
263 }
264
265 if m < n {
266 let (left, right) = matrix.split_at_col_mut(size);
267 linalg::triangular_solve::solve_unit_lower_triangular_in_place(left.rb(), right, par);
268 }
269
270 for i in 0..m {
271 perm_inv[perm[i].zx()] = truncate(i);
272 }
273
274 (
275 PartialPivLuInfo {
276 transposition_count: n_transpositions,
277 },
278 unsafe { PermRef::new_unchecked(perm, perm_inv, m) },
279 )
280}
281
282#[cfg(test)]
283mod tests {
284 use dyn_stack::MemBuffer;
285
286 use super::*;
287 use crate::stats::prelude::*;
288 use crate::utils::approx::*;
289 use crate::{Mat, assert};
290
291 #[test]
292 fn test_plu() {
293 let rng = &mut StdRng::seed_from_u64(0);
294
295 let approx_eq = CwiseMat(ApproxEq {
296 abs_tol: 1e-13,
297 rel_tol: 1e-13,
298 });
299
300 for n in [1, 2, 3, 128, 255, 256, 257] {
301 let A = CwiseMatDistribution {
302 nrows: n,
303 ncols: n,
304 dist: StandardNormal,
305 }
306 .rand::<Mat<f64>>(rng);
307 let A = A.as_ref();
308
309 let mut LU = A.cloned();
310 let perm = &mut *vec![0usize; n];
311 let perm_inv = &mut *vec![0usize; n];
312
313 let params = PartialPivLuParams {
314 recursion_threshold: 2,
315 blocksize: 2,
316 ..auto!(f64)
317 };
318 let p = lu_in_place(
319 LU.as_mut(),
320 perm,
321 perm_inv,
322 Par::Seq,
323 MemStack::new(&mut MemBuffer::new(lu_in_place_scratch::<usize, f64>(n, n, Par::Seq, params.into()))),
324 params.into(),
325 )
326 .1;
327
328 let mut L = LU.as_ref().cloned();
329 let mut U = LU.as_ref().cloned();
330
331 for j in 0..n {
332 for i in 0..j {
333 L[(i, j)] = 0.0;
334 }
335 L[(j, j)] = 1.0;
336 }
337 for j in 0..n {
338 for i in j + 1..n {
339 U[(i, j)] = 0.0;
340 }
341 }
342 let L = L.as_ref();
343 let U = U.as_ref();
344
345 assert!(p.inverse() * L * U ~ A);
346 }
347
348 for m in [8, 128, 255, 256, 257] {
349 let n = 8;
350
351 let A = CwiseMatDistribution {
352 nrows: m,
353 ncols: n,
354 dist: StandardNormal,
355 }
356 .rand::<Mat<f64>>(rng);
357 let A = A.as_ref();
358
359 let mut LU = A.cloned();
360 let perm = &mut *vec![0usize; m];
361 let perm_inv = &mut *vec![0usize; m];
362
363 let p = lu_in_place(
364 LU.as_mut(),
365 perm,
366 perm_inv,
367 Par::Seq,
368 MemStack::new(&mut MemBuffer::new(lu_in_place_scratch::<usize, f64>(n, n, Par::Seq, default()))),
369 default(),
370 )
371 .1;
372
373 let mut L = LU.as_ref().cloned();
374 let mut U = LU.as_ref().cloned();
375
376 for j in 0..n {
377 for i in 0..j {
378 L[(i, j)] = 0.0;
379 }
380 L[(j, j)] = 1.0;
381 }
382 for j in 0..n {
383 for i in j + 1..m {
384 U[(i, j)] = 0.0;
385 }
386 }
387 let L = L.as_ref();
388 let U = U.as_ref();
389
390 let U = U.subrows(0, n);
391
392 assert!(p.inverse() * L * U ~ A);
393 }
394 }
395}