faer/linalg/lu/partial_pivoting/
factor.rs

1use crate::internal_prelude::*;
2use crate::perm::swap_rows_idx;
3use crate::{assert, debug_assert};
4
5#[math]
6#[inline]
7fn swap_elems<T: ComplexField>(col: ColMut<'_, T>, i: usize, j: usize) {
8	debug_assert!(all(i < col.nrows(), j < col.nrows()));
9	let rs = col.row_stride();
10	let col = col.as_ptr_mut();
11	unsafe {
12		let a = col.offset(i as isize * rs);
13		let b = col.offset(j as isize * rs);
14		core::ptr::swap(a, b);
15	}
16}
17
18#[math]
19fn lu_in_place_unblocked<I: Index, T: ComplexField>(matrix: MatMut<'_, T>, start: usize, end: usize, trans: &mut [I]) -> usize {
20	let mut matrix = matrix;
21	let m = matrix.nrows();
22
23	if start == end {
24		return 0;
25	}
26
27	let mut n_trans = 0;
28
29	for j in start..end {
30		let col = j;
31		let row = j - start;
32
33		let t = &mut trans[row];
34		let mut imax = row;
35		let mut max = zero();
36
37		for i in imax..m {
38			let abs = abs1(matrix[(i, col)]);
39			if abs > max {
40				max = abs;
41				imax = i;
42			}
43		}
44
45		*t = I::truncate(imax - row);
46
47		if imax != row {
48			swap_rows_idx(matrix.rb_mut(), row, imax);
49			n_trans += 1;
50		}
51
52		let mut matrix = matrix.rb_mut().get_mut(.., start..end);
53
54		let inv = recip(matrix[(row, row)]);
55		for i in row + 1..m {
56			matrix[(i, row)] = matrix[(i, row)] * inv;
57		}
58
59		let (_, A01, A10, A11) = matrix.rb_mut().split_at_mut(row + 1, row + 1);
60		let A01 = A01.row(row);
61		let A10 = A10.col(row);
62		linalg::matmul::matmul(A11, Accum::Add, A10.as_mat(), A01.as_mat(), -one::<T>(), Par::Seq);
63	}
64
65	n_trans
66}
67
68#[math]
69pub(crate) fn lu_in_place_recursion<I: Index, T: ComplexField>(
70	A: MatMut<'_, T>,
71	start: usize,
72	end: usize,
73	trans: &mut [I],
74	par: Par,
75	params: Spec<PartialPivLuParams, T>,
76) -> usize {
77	let params = params.config;
78	let mut A = A;
79	let m = A.nrows();
80	let ncols = A.ncols();
81	let n = end - start;
82
83	if n <= params.recursion_threshold {
84		return lu_in_place_unblocked(A, start, end, trans);
85	}
86
87	let half = n / 2;
88	let pow = Ord::min(16, half.next_power_of_two());
89
90	let blocksize = half.next_multiple_of(pow);
91
92	let mut n_trans = 0;
93
94	assert!(n <= m);
95
96	n_trans += lu_in_place_recursion(
97		A.rb_mut().get_mut(.., start..end),
98		0,
99		blocksize,
100		&mut trans[..blocksize],
101		par,
102		params.into(),
103	);
104
105	{
106		let mut A = A.rb_mut().get_mut(.., start..end);
107		let (A00, mut A01, A10, mut A11) = A.rb_mut().split_at_mut(blocksize, blocksize);
108
109		let A00 = A00.rb();
110		let A10 = A10.rb();
111		{
112			linalg::triangular_solve::solve_unit_lower_triangular_in_place(A00.rb(), A01.rb_mut(), par);
113		}
114
115		linalg::matmul::matmul(A11.rb_mut(), Accum::Add, A10.rb(), A01.rb(), -one::<T>(), par);
116
117		n_trans += lu_in_place_recursion(
118			A.rb_mut().get_mut(blocksize..m, ..),
119			blocksize,
120			n,
121			&mut trans[blocksize..n],
122			par,
123			params.into(),
124		);
125	}
126
127	let swap = |mat: MatMut<'_, T>| {
128		let mut mat = mat;
129		for j in 0..mat.ncols() {
130			let mut col = mat.rb_mut().col_mut(j);
131
132			if col.row_stride() == 1 {
133				for (j, &t) in trans[..n].iter().enumerate() {
134					swap_elems(col.rb_mut(), j, t.zx() + j);
135				}
136			} else {
137				for (j, &t) in trans[..n].iter().enumerate() {
138					swap_elems(col.rb_mut(), j, t.zx() + j);
139				}
140			}
141		}
142	};
143
144	let (A_left, A_right) = A.rb_mut().split_at_col_mut(start);
145	let A_right = A_right.get_mut(.., end - start..ncols - start);
146
147	let par = if m * (ncols - n) > params.par_threshold { par } else { Par::Seq };
148
149	match par {
150		Par::Seq => {
151			swap(A_left);
152			swap(A_right);
153		},
154		#[cfg(feature = "rayon")]
155		Par::Rayon(nthreads) => {
156			let nthreads = nthreads.get();
157			let len = (A_left.ncols() + A_right.ncols()) as f64;
158			let left_threads = Ord::min((nthreads as f64 * (A_left.ncols() as f64 / len)) as usize, nthreads);
159			let right_threads = nthreads - left_threads;
160
161			use rayon::prelude::*;
162			rayon::join(
163				|| {
164					if A_left.ncols() > 0 {
165						A_left.par_col_partition_mut(left_threads).for_each(|A| swap(A))
166					}
167				},
168				|| {
169					if A_right.ncols() > 0 {
170						A_right.par_col_partition_mut(right_threads).for_each(|A| swap(A))
171					}
172				},
173			);
174		},
175	}
176
177	n_trans
178}
179
180/// $LU$ factorization tuning parameters
181#[derive(Copy, Clone, Debug)]
182pub struct PartialPivLuParams {
183	/// threshold at which the implementation should stop recursing
184	pub recursion_threshold: usize,
185	/// blocking variant step size
186	pub blocksize: usize,
187	/// threshold at which size parallelism should be disabled
188	pub par_threshold: usize,
189
190	#[doc(hidden)]
191	pub non_exhaustive: NonExhaustive,
192}
193
194/// information about the resulting $LU$ factorization
195#[derive(Copy, Clone, Debug)]
196pub struct PartialPivLuInfo {
197	/// number of transpositions that were performed, can be used to compute the determinant of
198	/// $P$
199	pub transposition_count: usize,
200}
201
202/// error in the $LU$ factorization
203#[derive(Copy, Clone, Debug)]
204pub enum LdltError {
205	ZeroPivot { index: usize },
206}
207
208impl<T: ComplexField> Auto<T> for PartialPivLuParams {
209	#[inline]
210	fn auto() -> Self {
211		Self {
212			recursion_threshold: 16,
213			blocksize: 64,
214			par_threshold: 128 * 128,
215			non_exhaustive: NonExhaustive(()),
216		}
217	}
218}
219
220#[inline]
221pub fn lu_in_place_scratch<I: Index, T: ComplexField>(nrows: usize, ncols: usize, par: Par, params: Spec<PartialPivLuParams, T>) -> StackReq {
222	_ = par;
223	_ = params;
224	StackReq::new::<I>(Ord::min(nrows, ncols))
225}
226
227pub fn lu_in_place<'out, I: Index, T: ComplexField>(
228	A: MatMut<'_, T>,
229	perm: &'out mut [I],
230	perm_inv: &'out mut [I],
231	par: Par,
232	stack: &mut MemStack,
233	params: Spec<PartialPivLuParams, T>,
234) -> (PartialPivLuInfo, PermRef<'out, I>) {
235	let _ = &params;
236	let truncate = I::truncate;
237
238	#[cfg(feature = "perf-warn")]
239	if (A.col_stride().unsigned_abs() == 1 || A.row_stride().unsigned_abs() != 1) && crate::__perf_warn!(LU_WARN) {
240		log::warn!(target: "faer_perf", "LU with partial pivoting prefers column-major or row-major matrix. Found matrix with generic strides.");
241	}
242
243	let mut matrix = A;
244	let mut stack = stack;
245	let m = matrix.nrows();
246	let n = matrix.ncols();
247
248	let size = Ord::min(n, m);
249
250	for i in 0..m {
251		let p = &mut perm[i];
252		*p = truncate(i);
253	}
254
255	let (mut transpositions, _) = stack.rb_mut().make_with(size, |_| truncate(0));
256	let transpositions = transpositions.as_mut();
257
258	let n_transpositions = lu_in_place_recursion(matrix.rb_mut(), 0, size, transpositions.as_mut(), par, params);
259
260	for idx in 0..size {
261		let t = transpositions[idx];
262		perm.as_mut().swap(idx, idx + t.zx());
263	}
264
265	if m < n {
266		let (left, right) = matrix.split_at_col_mut(size);
267		linalg::triangular_solve::solve_unit_lower_triangular_in_place(left.rb(), right, par);
268	}
269
270	for i in 0..m {
271		perm_inv[perm[i].zx()] = truncate(i);
272	}
273
274	(
275		PartialPivLuInfo {
276			transposition_count: n_transpositions,
277		},
278		unsafe { PermRef::new_unchecked(perm, perm_inv, m) },
279	)
280}
281
282#[cfg(test)]
283mod tests {
284	use dyn_stack::MemBuffer;
285
286	use super::*;
287	use crate::stats::prelude::*;
288	use crate::utils::approx::*;
289	use crate::{Mat, assert};
290
291	#[test]
292	fn test_plu() {
293		let rng = &mut StdRng::seed_from_u64(0);
294
295		let approx_eq = CwiseMat(ApproxEq {
296			abs_tol: 1e-13,
297			rel_tol: 1e-13,
298		});
299
300		for n in [1, 2, 3, 128, 255, 256, 257] {
301			let A = CwiseMatDistribution {
302				nrows: n,
303				ncols: n,
304				dist: StandardNormal,
305			}
306			.rand::<Mat<f64>>(rng);
307			let A = A.as_ref();
308
309			let mut LU = A.cloned();
310			let perm = &mut *vec![0usize; n];
311			let perm_inv = &mut *vec![0usize; n];
312
313			let params = PartialPivLuParams {
314				recursion_threshold: 2,
315				blocksize: 2,
316				..auto!(f64)
317			};
318			let p = lu_in_place(
319				LU.as_mut(),
320				perm,
321				perm_inv,
322				Par::Seq,
323				MemStack::new(&mut MemBuffer::new(lu_in_place_scratch::<usize, f64>(n, n, Par::Seq, params.into()))),
324				params.into(),
325			)
326			.1;
327
328			let mut L = LU.as_ref().cloned();
329			let mut U = LU.as_ref().cloned();
330
331			for j in 0..n {
332				for i in 0..j {
333					L[(i, j)] = 0.0;
334				}
335				L[(j, j)] = 1.0;
336			}
337			for j in 0..n {
338				for i in j + 1..n {
339					U[(i, j)] = 0.0;
340				}
341			}
342			let L = L.as_ref();
343			let U = U.as_ref();
344
345			assert!(p.inverse() * L * U ~ A);
346		}
347
348		for m in [8, 128, 255, 256, 257] {
349			let n = 8;
350
351			let A = CwiseMatDistribution {
352				nrows: m,
353				ncols: n,
354				dist: StandardNormal,
355			}
356			.rand::<Mat<f64>>(rng);
357			let A = A.as_ref();
358
359			let mut LU = A.cloned();
360			let perm = &mut *vec![0usize; m];
361			let perm_inv = &mut *vec![0usize; m];
362
363			let p = lu_in_place(
364				LU.as_mut(),
365				perm,
366				perm_inv,
367				Par::Seq,
368				MemStack::new(&mut MemBuffer::new(lu_in_place_scratch::<usize, f64>(n, n, Par::Seq, default()))),
369				default(),
370			)
371			.1;
372
373			let mut L = LU.as_ref().cloned();
374			let mut U = LU.as_ref().cloned();
375
376			for j in 0..n {
377				for i in 0..j {
378					L[(i, j)] = 0.0;
379				}
380				L[(j, j)] = 1.0;
381			}
382			for j in 0..n {
383				for i in j + 1..m {
384					U[(i, j)] = 0.0;
385				}
386			}
387			let L = L.as_ref();
388			let U = U.as_ref();
389
390			let U = U.subrows(0, n);
391
392			assert!(p.inverse() * L * U ~ A);
393		}
394	}
395}