faer/sparse/linalg/
lu.rs

1//! computes the $LU$ decomposition of a given sparse matrix. see
2//! [`faer::linalg::lu`](crate::linalg::lu) for more info
3//!
4//! the entry point in this module is [`SymbolicLu`] and [`factorize_symbolic_lu`]
5//!
6//! # note
7//! the functions in this module accept unsorted inputs, and may produce unsorted decomposition
8//! factors.
9
10use crate::assert;
11use crate::internal_prelude_sp::*;
12use crate::sparse::utils;
13use linalg::lu::partial_pivoting::factor::PartialPivLuParams;
14use linalg_sp::cholesky::simplicial::EliminationTreeRef;
15use linalg_sp::{LuError, SupernodalThreshold, SymbolicSupernodalParams, colamd};
16
17#[inline(never)]
18fn resize_vec<T: Clone>(v: &mut alloc::vec::Vec<T>, n: usize, exact: bool, reserve_only: bool, value: T) -> Result<(), FaerError> {
19	let reserve = if exact {
20		alloc::vec::Vec::try_reserve_exact
21	} else {
22		alloc::vec::Vec::try_reserve
23	};
24	reserve(v, n.saturating_sub(v.len())).map_err(|_| FaerError::OutOfMemory)?;
25	if !reserve_only {
26		v.resize(Ord::max(n, v.len()), value);
27	}
28	Ok(())
29}
30
31/// supernodal factorization module
32///
33/// a supernodal factorization is one that processes the elements of the $LU$ factors of the
34/// input matrix by blocks, rather than by single elements. this is more efficient if the lu
35/// factors are somewhat dense
36pub mod supernodal {
37	use super::*;
38	use crate::assert;
39
40	/// $LU$ factor structure containing the symbolic structure
41	#[derive(Debug, Clone)]
42	pub struct SymbolicSupernodalLu<I> {
43		pub(super) supernode_ptr: alloc::vec::Vec<I>,
44		pub(super) super_etree: alloc::vec::Vec<I>,
45		pub(super) supernode_postorder: alloc::vec::Vec<I>,
46		pub(super) supernode_postorder_inv: alloc::vec::Vec<I>,
47		pub(super) descendant_count: alloc::vec::Vec<I>,
48		pub(super) nrows: usize,
49		pub(super) ncols: usize,
50	}
51
52	/// $LU$ factor structure containing the symbolic and numerical representations
53	#[derive(Debug, Clone)]
54	pub struct SupernodalLu<I, T> {
55		nrows: usize,
56		ncols: usize,
57		nsupernodes: usize,
58
59		supernode_ptr: alloc::vec::Vec<I>,
60
61		l_col_ptr_for_row_idx: alloc::vec::Vec<I>,
62		l_col_ptr_for_val: alloc::vec::Vec<I>,
63		l_row_idx: alloc::vec::Vec<I>,
64		l_val: alloc::vec::Vec<T>,
65
66		ut_col_ptr_for_row_idx: alloc::vec::Vec<I>,
67		ut_col_ptr_for_val: alloc::vec::Vec<I>,
68		ut_row_idx: alloc::vec::Vec<I>,
69		ut_val: alloc::vec::Vec<T>,
70	}
71
72	impl<I: Index, T> Default for SupernodalLu<I, T> {
73		fn default() -> Self {
74			Self::new()
75		}
76	}
77
78	impl<I: Index, T> SupernodalLu<I, T> {
79		/// creates a new supernodal $LU$ of a $0 \times 0$ matrix
80		#[inline]
81		pub fn new() -> Self {
82			Self {
83				nrows: 0,
84				ncols: 0,
85				nsupernodes: 0,
86
87				supernode_ptr: alloc::vec::Vec::new(),
88
89				l_col_ptr_for_row_idx: alloc::vec::Vec::new(),
90				ut_col_ptr_for_row_idx: alloc::vec::Vec::new(),
91
92				l_col_ptr_for_val: alloc::vec::Vec::new(),
93				ut_col_ptr_for_val: alloc::vec::Vec::new(),
94
95				l_row_idx: alloc::vec::Vec::new(),
96				ut_row_idx: alloc::vec::Vec::new(),
97
98				l_val: alloc::vec::Vec::new(),
99				ut_val: alloc::vec::Vec::new(),
100			}
101		}
102
103		/// returns the number of rows of $A$
104		#[inline]
105		pub fn nrows(&self) -> usize {
106			self.nrows
107		}
108
109		/// returns the number of columns of $A$
110		#[inline]
111		pub fn ncols(&self) -> usize {
112			self.ncols
113		}
114
115		/// returns the number of supernodes
116		#[inline]
117		pub fn n_supernodes(&self) -> usize {
118			self.nsupernodes
119		}
120
121		/// solves the equation $A x = \text{rhs}$ and stores the result in `rhs`, implicitly
122		/// conjugating $A$ if needed
123		///
124		/// # panics
125		/// - panics if `self.nrows() != self.ncols()`
126		/// - panics if `rhs.nrows() != self.nrows()`
127		#[track_caller]
128		pub fn solve_in_place_with_conj(
129			&self,
130			row_perm: PermRef<'_, I>,
131			col_perm: PermRef<'_, I>,
132			conj_lhs: Conj,
133			rhs: MatMut<'_, T>,
134			par: Par,
135			work: MatMut<'_, T>,
136		) where
137			T: ComplexField,
138		{
139			assert!(all(self.nrows() == self.ncols(), self.nrows() == rhs.nrows()));
140			let mut X = rhs;
141			let mut temp = work;
142
143			crate::perm::permute_rows(temp.rb_mut(), X.rb(), row_perm);
144			self.l_solve_in_place_with_conj(conj_lhs, temp.rb_mut(), X.rb_mut(), par);
145			self.u_solve_in_place_with_conj(conj_lhs, temp.rb_mut(), X.rb_mut(), par);
146			crate::perm::permute_rows(X.rb_mut(), temp.rb(), col_perm.inverse());
147		}
148
149		/// solves the equation $A^\top x = \text{rhs}$ and stores the result in `rhs`, implicitly
150		/// conjugating $A$ if needed
151		///
152		/// # panics
153		/// - panics if `self.nrows() != self.ncols()`
154		/// - panics if `rhs.nrows() != self.nrows()`
155		#[track_caller]
156		pub fn solve_transpose_in_place_with_conj(
157			&self,
158			row_perm: PermRef<'_, I>,
159			col_perm: PermRef<'_, I>,
160			conj_lhs: Conj,
161			rhs: MatMut<'_, T>,
162			par: Par,
163			work: MatMut<'_, T>,
164		) where
165			T: ComplexField,
166		{
167			assert!(all(self.nrows() == self.ncols(), self.nrows() == rhs.nrows()));
168			let mut X = rhs;
169			let mut temp = work;
170			crate::perm::permute_rows(temp.rb_mut(), X.rb(), col_perm);
171			self.u_solve_transpose_in_place_with_conj(conj_lhs, temp.rb_mut(), X.rb_mut(), par);
172			self.l_solve_transpose_in_place_with_conj(conj_lhs, temp.rb_mut(), X.rb_mut(), par);
173			crate::perm::permute_rows(X.rb_mut(), temp.rb(), row_perm.inverse());
174		}
175
176		#[track_caller]
177		#[math]
178		pub(crate) fn l_solve_in_place_with_conj(&self, conj_lhs: Conj, rhs: MatMut<'_, T>, mut work: MatMut<'_, T>, par: Par)
179		where
180			T: ComplexField,
181		{
182			let lu = self;
183
184			assert!(lu.nrows() == lu.ncols());
185			assert!(lu.nrows() == rhs.nrows());
186
187			let mut X = rhs;
188			let nrhs = X.ncols();
189
190			let supernode_ptr = &*lu.supernode_ptr;
191
192			for s in 0..lu.nsupernodes {
193				let s_begin = supernode_ptr[s].zx();
194				let s_end = supernode_ptr[s + 1].zx();
195				let s_size = s_end - s_begin;
196				let s_row_idx_count = lu.l_col_ptr_for_row_idx[s + 1].zx() - lu.l_col_ptr_for_row_idx[s].zx();
197
198				let L = &lu.l_val[lu.l_col_ptr_for_val[s].zx()..lu.l_col_ptr_for_val[s + 1].zx()];
199				let L = MatRef::from_column_major_slice(L, s_row_idx_count, s_size);
200				let (L_top, L_bot) = L.split_at_row(s_size);
201				linalg::triangular_solve::solve_unit_lower_triangular_in_place_with_conj(
202					L_top,
203					conj_lhs,
204					X.rb_mut().subrows_mut(s_begin, s_size),
205					par,
206				);
207				linalg::matmul::matmul_with_conj(
208					work.rb_mut().subrows_mut(0, s_row_idx_count - s_size),
209					Accum::Replace,
210					L_bot,
211					conj_lhs,
212					X.rb().subrows(s_begin, s_size),
213					Conj::No,
214					one::<T>(),
215					par,
216				);
217
218				for j in 0..nrhs {
219					for (idx, &i) in lu.l_row_idx[lu.l_col_ptr_for_row_idx[s].zx()..lu.l_col_ptr_for_row_idx[s + 1].zx()][s_size..]
220						.iter()
221						.enumerate()
222					{
223						let i = i.zx();
224						X[(i, j)] = X[(i, j)] - work[(idx, j)];
225					}
226				}
227			}
228		}
229
230		#[track_caller]
231		#[math]
232		pub(crate) fn l_solve_transpose_in_place_with_conj(&self, conj_lhs: Conj, rhs: MatMut<'_, T>, mut work: MatMut<'_, T>, par: Par)
233		where
234			T: ComplexField,
235		{
236			let lu = self;
237
238			assert!(lu.nrows() == lu.ncols());
239			assert!(lu.nrows() == rhs.nrows());
240
241			let mut X = rhs;
242			let nrhs = X.ncols();
243
244			let supernode_ptr = &*lu.supernode_ptr;
245
246			for s in (0..lu.nsupernodes).rev() {
247				let s_begin = supernode_ptr[s].zx();
248				let s_end = supernode_ptr[s + 1].zx();
249				let s_size = s_end - s_begin;
250				let s_row_idx_count = lu.l_col_ptr_for_row_idx[s + 1].zx() - lu.l_col_ptr_for_row_idx[s].zx();
251
252				let L = &lu.l_val[lu.l_col_ptr_for_val[s].zx()..lu.l_col_ptr_for_val[s + 1].zx()];
253				let L = MatRef::from_column_major_slice(L, s_row_idx_count, s_size);
254
255				let (L_top, L_bot) = L.split_at_row(s_size);
256
257				for j in 0..nrhs {
258					for (idx, &i) in lu.l_row_idx[lu.l_col_ptr_for_row_idx[s].zx()..lu.l_col_ptr_for_row_idx[s + 1].zx()][s_size..]
259						.iter()
260						.enumerate()
261					{
262						let i = i.zx();
263						work[(idx, j)] = copy(X[(i, j)]);
264					}
265				}
266
267				linalg::matmul::matmul_with_conj(
268					X.rb_mut().subrows_mut(s_begin, s_size),
269					Accum::Add,
270					L_bot.transpose(),
271					conj_lhs,
272					work.rb().subrows(0, s_row_idx_count - s_size),
273					Conj::No,
274					-one::<T>(),
275					par,
276				);
277				linalg::triangular_solve::solve_unit_upper_triangular_in_place_with_conj(
278					L_top.transpose(),
279					conj_lhs,
280					X.rb_mut().subrows_mut(s_begin, s_size),
281					par,
282				);
283			}
284		}
285
286		#[track_caller]
287		#[math]
288		pub(crate) fn u_solve_in_place_with_conj(&self, conj_lhs: Conj, rhs: MatMut<'_, T>, mut work: MatMut<'_, T>, par: Par)
289		where
290			T: ComplexField,
291		{
292			let lu = self;
293
294			assert!(lu.nrows() == lu.ncols());
295			assert!(lu.nrows() == rhs.nrows());
296
297			let mut X = rhs;
298			let nrhs = X.ncols();
299
300			let supernode_ptr = &*lu.supernode_ptr;
301
302			for s in (0..lu.nsupernodes).rev() {
303				let s_begin = supernode_ptr[s].zx();
304				let s_end = supernode_ptr[s + 1].zx();
305				let s_size = s_end - s_begin;
306				let s_row_idx_count = lu.l_col_ptr_for_row_idx[s + 1].zx() - lu.l_col_ptr_for_row_idx[s].zx();
307				let s_col_index_count = lu.ut_col_ptr_for_row_idx[s + 1].zx() - lu.ut_col_ptr_for_row_idx[s].zx();
308
309				let L = &lu.l_val[lu.l_col_ptr_for_val[s].zx()..lu.l_col_ptr_for_val[s + 1].zx()];
310				let L = MatRef::from_column_major_slice(L, s_row_idx_count, s_size);
311				let U = &lu.ut_val[lu.ut_col_ptr_for_val[s].zx()..lu.ut_col_ptr_for_val[s + 1].zx()];
312				let U_right = MatRef::from_column_major_slice(U, s_col_index_count, s_size).transpose();
313
314				for j in 0..nrhs {
315					for (idx, &i) in lu.ut_row_idx[lu.ut_col_ptr_for_row_idx[s].zx()..lu.ut_col_ptr_for_row_idx[s + 1].zx()]
316						.iter()
317						.enumerate()
318					{
319						let i = i.zx();
320						work[(idx, j)] = copy(X[(i, j)]);
321					}
322				}
323
324				let (U_left, _) = L.split_at_row(s_size);
325				linalg::matmul::matmul_with_conj(
326					X.rb_mut().subrows_mut(s_begin, s_size),
327					Accum::Add,
328					U_right,
329					conj_lhs,
330					work.rb().subrows(0, s_col_index_count),
331					Conj::No,
332					-one::<T>(),
333					par,
334				);
335				linalg::triangular_solve::solve_upper_triangular_in_place_with_conj(U_left, conj_lhs, X.rb_mut().subrows_mut(s_begin, s_size), par);
336			}
337		}
338
339		#[track_caller]
340		#[math]
341		pub(crate) fn u_solve_transpose_in_place_with_conj(&self, conj_lhs: Conj, rhs: MatMut<'_, T>, mut work: MatMut<'_, T>, par: Par)
342		where
343			T: ComplexField,
344		{
345			let lu = self;
346
347			assert!(lu.nrows() == lu.ncols());
348			assert!(lu.nrows() == rhs.nrows());
349
350			let mut X = rhs;
351			let nrhs = X.ncols();
352
353			let supernode_ptr = &*lu.supernode_ptr;
354
355			for s in 0..lu.nsupernodes {
356				let s_begin = supernode_ptr[s].zx();
357				let s_end = supernode_ptr[s + 1].zx();
358				let s_size = s_end - s_begin;
359				let s_row_idx_count = lu.l_col_ptr_for_row_idx[s + 1].zx() - lu.l_col_ptr_for_row_idx[s].zx();
360				let s_col_index_count = lu.ut_col_ptr_for_row_idx[s + 1].zx() - lu.ut_col_ptr_for_row_idx[s].zx();
361
362				let L = &lu.l_val[lu.l_col_ptr_for_val[s].zx()..lu.l_col_ptr_for_val[s + 1].zx()];
363				let L = MatRef::from_column_major_slice(L, s_row_idx_count, s_size);
364				let U = &lu.ut_val[lu.ut_col_ptr_for_val[s].zx()..lu.ut_col_ptr_for_val[s + 1].zx()];
365				let U_right = MatRef::from_column_major_slice(U, s_col_index_count, s_size).transpose();
366
367				let (U_left, _) = L.split_at_row(s_size);
368				linalg::triangular_solve::solve_lower_triangular_in_place_with_conj(
369					U_left.transpose(),
370					conj_lhs,
371					X.rb_mut().subrows_mut(s_begin, s_size),
372					par,
373				);
374				linalg::matmul::matmul_with_conj(
375					work.rb_mut().subrows_mut(0, s_col_index_count),
376					Accum::Replace,
377					U_right.transpose(),
378					conj_lhs,
379					X.rb().subrows(s_begin, s_size),
380					Conj::No,
381					one::<T>(),
382					par,
383				);
384
385				for j in 0..nrhs {
386					for (idx, &i) in lu.ut_row_idx[lu.ut_col_ptr_for_row_idx[s].zx()..lu.ut_col_ptr_for_row_idx[s + 1].zx()]
387						.iter()
388						.enumerate()
389					{
390						let i = i.zx();
391						X[(i, j)] = X[(i, j)] - work[(idx, j)];
392					}
393				}
394			}
395		}
396	}
397
398	/// computes the size and alignment of the workspace required to compute the symbolic
399	/// $LU$ factorization of a square matrix with size `n`.
400	pub fn factorize_supernodal_symbolic_lu_scratch<I: Index>(nrows: usize, ncols: usize) -> StackReq {
401		let _ = nrows;
402		linalg_sp::cholesky::supernodal::factorize_supernodal_symbolic_cholesky_scratch::<I>(ncols)
403	}
404
405	/// computes the symbolic structure of the $LU$ factors of the matrix $A$
406	#[track_caller]
407	pub fn factorize_supernodal_symbolic_lu<I: Index>(
408		A: SymbolicSparseColMatRef<'_, I>,
409		col_perm: Option<PermRef<'_, I>>,
410		min_col: &[I],
411		etree: EliminationTreeRef<'_, I>,
412		col_counts: &[I],
413		stack: &mut MemStack,
414		params: SymbolicSupernodalParams<'_>,
415	) -> Result<SymbolicSupernodalLu<I>, FaerError> {
416		let m = A.nrows();
417		let n = A.ncols();
418
419		with_dim!(M, m);
420		with_dim!(N, n);
421
422		let I = I::truncate;
423		let A = A.as_shape(M, N);
424		let min_col = Array::from_ref(MaybeIdx::from_slice_ref_checked(bytemuck::cast_slice(min_col), N), M);
425		let etree = etree.as_bound(N);
426
427		let L = linalg_sp::cholesky::supernodal::ghost_factorize_supernodal_symbolic(
428			A,
429			col_perm.map(|perm| perm.as_shape(N)),
430			Some(min_col),
431			linalg_sp::cholesky::supernodal::CholeskyInput::ATA,
432			etree,
433			Array::from_ref(col_counts, N),
434			stack,
435			params,
436		)?;
437		let n_supernodes = L.n_supernodes();
438		let mut super_etree = try_zeroed::<I>(n_supernodes)?;
439
440		let (index_to_super, _) = unsafe { stack.make_raw::<I>(*N) };
441
442		for s in 0..n_supernodes {
443			index_to_super[L.supernode_begin[s].zx()..L.supernode_begin[s + 1].zx()].fill(I(s));
444		}
445		for s in 0..n_supernodes {
446			let last = L.supernode_begin[s + 1].zx() - 1;
447			if let Some(parent) = etree[N.check(last)].idx() {
448				super_etree[s] = index_to_super[*parent.zx()];
449			} else {
450				super_etree[s] = I(NONE);
451			}
452		}
453
454		Ok(SymbolicSupernodalLu {
455			supernode_ptr: L.supernode_begin,
456			super_etree,
457			supernode_postorder: L.supernode_postorder,
458			supernode_postorder_inv: L.supernode_postorder_inv,
459			descendant_count: L.descendant_count,
460			nrows: *A.nrows(),
461			ncols: *A.ncols(),
462		})
463	}
464
465	struct MatU8 {
466		data: alloc::vec::Vec<u8>,
467		nrows: usize,
468	}
469	impl MatU8 {
470		fn new() -> Self {
471			Self {
472				data: alloc::vec::Vec::new(),
473				nrows: 0,
474			}
475		}
476
477		fn with_dims(nrows: usize, ncols: usize) -> Result<Self, FaerError> {
478			Ok(Self {
479				data: try_collect((0..(nrows * ncols)).map(|_| 1u8))?,
480				nrows,
481			})
482		}
483	}
484	impl core::ops::Index<(usize, usize)> for MatU8 {
485		type Output = u8;
486
487		#[inline(always)]
488		fn index(&self, (row, col): (usize, usize)) -> &Self::Output {
489			&self.data[row + col * self.nrows]
490		}
491	}
492	impl core::ops::IndexMut<(usize, usize)> for MatU8 {
493		#[inline(always)]
494		fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut Self::Output {
495			&mut self.data[row + col * self.nrows]
496		}
497	}
498
499	struct Front;
500	struct LPanel;
501	struct UPanel;
502
503	#[inline(never)]
504	fn noinline<T, R>(_: T, f: impl FnOnce() -> R) -> R {
505		f()
506	}
507
508	/// computes the size and alignment of the workspace required to perform a numeric $LU$
509	/// factorization
510	pub fn factorize_supernodal_numeric_lu_scratch<I: Index, T: ComplexField>(
511		symbolic: &SymbolicSupernodalLu<I>,
512		params: Spec<PartialPivLuParams, T>,
513	) -> StackReq {
514		let m = StackReq::new::<I>(symbolic.nrows);
515		let n = StackReq::new::<I>(symbolic.ncols);
516		_ = params;
517		StackReq::and(n, m.array(5))
518	}
519
520	/// computes the numeric values of the $LU$ factors of the matrix $A$ as well as the row
521	/// pivoting permutation, and stores them in `lu` and `row_perm`/`row_perm_inv`
522	#[math]
523	pub fn factorize_supernodal_numeric_lu<I: Index, T: ComplexField>(
524		row_perm: &mut [I],
525		row_perm_inv: &mut [I],
526		lu: &mut SupernodalLu<I, T>,
527
528		A: SparseColMatRef<'_, I, T>,
529		AT: SparseColMatRef<'_, I, T>,
530		col_perm: PermRef<'_, I>,
531		symbolic: &SymbolicSupernodalLu<I>,
532
533		par: Par,
534		stack: &mut MemStack,
535		params: Spec<PartialPivLuParams, T>,
536	) -> Result<(), LuError> {
537		use linalg_sp::cholesky::supernodal::partition_fn;
538		let SymbolicSupernodalLu {
539			supernode_ptr,
540			super_etree,
541			supernode_postorder,
542			supernode_postorder_inv,
543			descendant_count,
544			nrows: _,
545			ncols: _,
546		} = symbolic;
547
548		let I = I::truncate;
549		let I_checked = |x: usize| -> Result<I, FaerError> {
550			if x > I::Signed::MAX.zx() {
551				Err(FaerError::IndexOverflow)
552			} else {
553				Ok(I(x))
554			}
555		};
556		let to_wide = |x: I| -> u128 { x.zx() as _ };
557		let from_wide_checked = |x: u128| -> Result<I, FaerError> {
558			if x > I::Signed::MAX.zx() as u128 {
559				Err(FaerError::IndexOverflow)
560			} else {
561				Ok(I(x as _))
562			}
563		};
564
565		let m = A.nrows();
566		let n = A.ncols();
567		assert!(m >= n);
568		assert!(all(AT.nrows() == n, AT.ncols() == m));
569		assert!(all(row_perm.len() == m, row_perm_inv.len() == m));
570		let n_supernodes = super_etree.len();
571		assert!(supernode_postorder.len() == n_supernodes);
572		assert!(supernode_postorder_inv.len() == n_supernodes);
573		assert!(supernode_ptr.len() == n_supernodes + 1);
574		assert!(supernode_ptr[n_supernodes].zx() == n);
575
576		lu.nrows = 0;
577		lu.ncols = 0;
578		lu.nsupernodes = 0;
579		lu.supernode_ptr.clear();
580
581		let (col_global_to_local, stack) = unsafe { stack.make_raw::<I>(n) };
582		let (row_global_to_local, stack) = unsafe { stack.make_raw::<I>(m) };
583		let (marked, stack) = unsafe { stack.make_raw::<I>(m) };
584		let (indices, stack) = unsafe { stack.make_raw::<I>(m) };
585		let (transpositions, stack) = unsafe { stack.make_raw::<I>(m) };
586		let (d_active_rows, _) = unsafe { stack.make_raw::<I>(m) };
587
588		col_global_to_local.fill(I(NONE));
589		row_global_to_local.fill(I(NONE));
590
591		marked.fill(I(0));
592
593		resize_vec(&mut lu.l_col_ptr_for_row_idx, n_supernodes + 1, true, false, I(0))?;
594		resize_vec(&mut lu.ut_col_ptr_for_row_idx, n_supernodes + 1, true, false, I(0))?;
595		resize_vec(&mut lu.l_col_ptr_for_val, n_supernodes + 1, true, false, I(0))?;
596		resize_vec(&mut lu.ut_col_ptr_for_val, n_supernodes + 1, true, false, I(0))?;
597
598		lu.l_col_ptr_for_row_idx[0] = I(0);
599		lu.ut_col_ptr_for_row_idx[0] = I(0);
600		lu.l_col_ptr_for_val[0] = I(0);
601		lu.ut_col_ptr_for_val[0] = I(0);
602
603		for i in 0..m {
604			row_perm[i] = I(i);
605		}
606		for i in 0..m {
607			row_perm_inv[i] = I(i);
608		}
609
610		let (col_perm, col_perm_inv) = col_perm.arrays();
611
612		let mut contrib_work =
613			try_collect((0..n_supernodes).map(|_| (alloc::vec::Vec::<T>::new(), alloc::vec::Vec::<I>::new(), 0usize, MatU8::new())))?;
614
615		let work_to_mat_mut = |v: &mut alloc::vec::Vec<T>, nrows: usize, ncols: usize| unsafe {
616			MatMut::from_raw_parts_mut(v.as_mut_ptr(), nrows, ncols, 1, nrows as isize)
617		};
618
619		let mut A_leftover = A.compute_nnz();
620		for s in 0..n_supernodes {
621			let s_begin = supernode_ptr[s].zx();
622			let s_end = supernode_ptr[s + 1].zx();
623			let s_size = s_end - s_begin;
624
625			let s_postordered = supernode_postorder_inv[s].zx();
626			let desc_count = descendant_count[s].zx();
627			let mut s_row_idx_count = 0usize;
628			let (left_contrib, right_contrib) = contrib_work.split_at_mut(s);
629
630			let s_row_idxices = &mut *indices;
631			// add the rows from A[s_end:, s_begin:s_end]
632			for j in s_begin..s_end {
633				let pj = col_perm[j].zx();
634				let row_idx = A.row_idx_of_col_raw(pj);
635				for i in row_idx {
636					let i = i.zx();
637					let pi = row_perm_inv[i].zx();
638					if pi < s_begin {
639						continue;
640					}
641					if marked[i] < I(2 * s + 1) {
642						s_row_idxices[s_row_idx_count] = I(i);
643						s_row_idx_count += 1;
644						marked[i] = I(2 * s + 1);
645					}
646				}
647			}
648
649			// add the rows from child[s_begin:]
650			for d in &supernode_postorder[s_postordered - desc_count..s_postordered] {
651				let d = d.zx();
652				let d_begin = supernode_ptr[d].zx();
653				let d_end = supernode_ptr[d + 1].zx();
654				let d_size = d_end - d_begin;
655				let d_row_idx = &lu.l_row_idx[lu.l_col_ptr_for_row_idx[d].zx()..lu.l_col_ptr_for_row_idx[d + 1].zx()][d_size..];
656				let d_col_ind = &lu.ut_row_idx[lu.ut_col_ptr_for_row_idx[d].zx()..lu.ut_col_ptr_for_row_idx[d + 1].zx()];
657				let d_col_start = d_col_ind.partition_point(partition_fn(s_begin));
658
659				if d_col_start < d_col_ind.len() && d_col_ind[d_col_start].zx() < s_end {
660					for i in d_row_idx.iter() {
661						let i = i.zx();
662						let pi = row_perm_inv[i].zx();
663
664						if pi < s_begin {
665							continue;
666						}
667
668						if marked[i] < I(2 * s + 1) {
669							s_row_idxices[s_row_idx_count] = I(i);
670							s_row_idx_count += 1;
671							marked[i] = I(2 * s + 1);
672						}
673					}
674				}
675			}
676
677			lu.l_col_ptr_for_row_idx[s + 1] = I_checked(lu.l_col_ptr_for_row_idx[s].zx() + s_row_idx_count)?;
678			lu.l_col_ptr_for_val[s + 1] = from_wide_checked(to_wide(lu.l_col_ptr_for_val[s]) + ((s_row_idx_count) as u128 * s_size as u128))?;
679			resize_vec(&mut lu.l_row_idx, lu.l_col_ptr_for_row_idx[s + 1].zx(), false, false, I(0))?;
680			resize_vec::<T>(&mut lu.l_val, lu.l_col_ptr_for_val[s + 1].zx(), false, false, zero::<T>())?;
681			lu.l_row_idx[lu.l_col_ptr_for_row_idx[s].zx()..lu.l_col_ptr_for_row_idx[s + 1].zx()].copy_from_slice(&s_row_idxices[..s_row_idx_count]);
682			lu.l_row_idx[lu.l_col_ptr_for_row_idx[s].zx()..lu.l_col_ptr_for_row_idx[s + 1].zx()].sort_unstable();
683
684			let (left_row_idxices, right_row_idxices) = lu.l_row_idx.split_at_mut(lu.l_col_ptr_for_row_idx[s].zx());
685
686			let s_row_idxices = &mut right_row_idxices[0..lu.l_col_ptr_for_row_idx[s + 1].zx() - lu.l_col_ptr_for_row_idx[s].zx()];
687			for (idx, i) in s_row_idxices.iter().enumerate() {
688				row_global_to_local[i.zx()] = I(idx);
689			}
690			let s_L = &mut lu.l_val[lu.l_col_ptr_for_val[s].zx()..lu.l_col_ptr_for_val[s + 1].zx()];
691			let mut s_L = MatMut::from_column_major_slice_mut(s_L, s_row_idx_count, s_size);
692			s_L.fill(zero());
693
694			for j in s_begin..s_end {
695				let pj = col_perm[j].zx();
696				let row_idx = A.row_idx_of_col(pj);
697				let val = A.val_of_col(pj);
698
699				for (i, val) in iter::zip(row_idx, val) {
700					let pi = row_perm_inv[i].zx();
701					if pi < s_begin {
702						continue;
703					}
704					assert!(A_leftover > 0);
705					A_leftover -= 1;
706					let ix = row_global_to_local[i].zx();
707					let iy = j - s_begin;
708					s_L[(ix, iy)] = s_L[(ix, iy)] + *val;
709				}
710			}
711
712			noinline(LPanel, || {
713				for d in &supernode_postorder[s_postordered - desc_count..s_postordered] {
714					let d = d.zx();
715					if left_contrib[d].0.is_empty() {
716						continue;
717					}
718
719					let d_begin = supernode_ptr[d].zx();
720					let d_end = supernode_ptr[d + 1].zx();
721					let d_size = d_end - d_begin;
722					let d_row_idx = &left_row_idxices[lu.l_col_ptr_for_row_idx[d].zx()..lu.l_col_ptr_for_row_idx[d + 1].zx()][d_size..];
723					let d_col_ind = &lu.ut_row_idx[lu.ut_col_ptr_for_row_idx[d].zx()..lu.ut_col_ptr_for_row_idx[d + 1].zx()];
724					let d_col_start = d_col_ind.partition_point(partition_fn(s_begin));
725
726					if d_col_start < d_col_ind.len() && d_col_ind[d_col_start].zx() < s_end {
727						let d_col_mid = d_col_start + d_col_ind[d_col_start..].partition_point(partition_fn(s_end));
728
729						let mut d_LU_cols = work_to_mat_mut(&mut left_contrib[d].0, d_row_idx.len(), d_col_ind.len())
730							.subcols_mut(d_col_start, d_col_mid - d_col_start);
731						let left_contrib = &mut left_contrib[d];
732						let d_active = &mut left_contrib.1[d_col_start..];
733						let d_active_count = &mut left_contrib.2;
734						let d_active_mat = &mut left_contrib.3;
735
736						for (d_j, j) in d_col_ind[d_col_start..d_col_mid].iter().enumerate() {
737							if d_active[d_j] > I(0) {
738								let mut taken_rows = 0usize;
739								let j = j.zx();
740								let s_j = j - s_begin;
741								for (d_i, i) in d_row_idx.iter().enumerate() {
742									let i = i.zx();
743									let pi = row_perm_inv[i].zx();
744									if pi < s_begin {
745										continue;
746									}
747									let s_i = row_global_to_local[i].zx();
748
749									s_L[(s_i, s_j)] = s_L[(s_i, s_j)] - d_LU_cols[(d_i, d_j)];
750									d_LU_cols[(d_i, d_j)] = zero::<T>();
751									taken_rows += d_active_mat[(d_i, d_j + d_col_start)] as usize;
752									d_active_mat[(d_i, d_j + d_col_start)] = 0;
753								}
754								assert!(d_active[d_j] >= I(taken_rows));
755								d_active[d_j] -= I(taken_rows);
756								if d_active[d_j] == I(0) {
757									assert!(*d_active_count > 0);
758									*d_active_count -= 1;
759								}
760							}
761						}
762						if *d_active_count == 0 {
763							left_contrib.0.clear();
764							left_contrib.1 = alloc::vec::Vec::new();
765							left_contrib.2 = 0;
766							left_contrib.3 = MatU8::new();
767						}
768					}
769				}
770			});
771
772			if s_L.nrows() < s_L.ncols() {
773				return Err(LuError::SymbolicSingular {
774					index: s_begin + s_L.nrows(),
775				});
776			}
777			let transpositions = &mut transpositions[s_begin..s_end];
778			crate::linalg::lu::partial_pivoting::factor::lu_in_place_recursion(s_L.rb_mut(), 0, s_size, transpositions, par, params);
779
780			for (idx, t) in transpositions.iter().enumerate() {
781				let i_t = s_row_idxices[idx + t.zx()].zx();
782				let kk = row_perm_inv[i_t].zx();
783				row_perm.swap(s_begin + idx, row_perm_inv[i_t].zx());
784				row_perm_inv.swap(row_perm[s_begin + idx].zx(), row_perm[kk].zx());
785				s_row_idxices.swap(idx, idx + t.zx());
786			}
787			for (idx, t) in transpositions.iter().enumerate().rev() {
788				row_global_to_local.swap(s_row_idxices[idx].zx(), s_row_idxices[idx + t.zx()].zx());
789			}
790			for (idx, i) in s_row_idxices.iter().enumerate() {
791				assert!(row_global_to_local[i.zx()] == I(idx));
792			}
793
794			let s_col_indices = &mut indices[..n];
795			let mut s_col_index_count = 0usize;
796			for i in s_begin..s_end {
797				let pi = row_perm[i].zx();
798				for j in AT.row_idx_of_col(pi) {
799					let pj = col_perm_inv[j].zx();
800					if pj < s_end {
801						continue;
802					}
803					if marked[pj] < I(2 * s + 2) {
804						s_col_indices[s_col_index_count] = I(pj);
805						s_col_index_count += 1;
806						marked[pj] = I(2 * s + 2);
807					}
808				}
809			}
810
811			for d in &supernode_postorder[s_postordered - desc_count..s_postordered] {
812				let d = d.zx();
813
814				let d_begin = supernode_ptr[d].zx();
815				let d_end = supernode_ptr[d + 1].zx();
816				let d_size = d_end - d_begin;
817
818				let d_row_idx = &left_row_idxices[lu.l_col_ptr_for_row_idx[d].zx()..lu.l_col_ptr_for_row_idx[d + 1].zx()][d_size..];
819				let d_col_ind = &lu.ut_row_idx[lu.ut_col_ptr_for_row_idx[d].zx()..lu.ut_col_ptr_for_row_idx[d + 1].zx()];
820
821				let contributes_to_u = d_row_idx
822					.iter()
823					.any(|&i| row_perm_inv[i.zx()].zx() >= s_begin && row_perm_inv[i.zx()].zx() < s_end);
824
825				if contributes_to_u {
826					let d_col_start = d_col_ind.partition_point(partition_fn(s_end));
827					for j in &d_col_ind[d_col_start..] {
828						let j = j.zx();
829						if marked[j] < I(2 * s + 2) {
830							s_col_indices[s_col_index_count] = I(j);
831							s_col_index_count += 1;
832							marked[j] = I(2 * s + 2);
833						}
834					}
835				}
836			}
837
838			lu.ut_col_ptr_for_row_idx[s + 1] = I_checked(lu.ut_col_ptr_for_row_idx[s].zx() + s_col_index_count)?;
839			lu.ut_col_ptr_for_val[s + 1] = from_wide_checked(to_wide(lu.ut_col_ptr_for_val[s]) + (s_col_index_count as u128 * s_size as u128))?;
840			resize_vec(&mut lu.ut_row_idx, lu.ut_col_ptr_for_row_idx[s + 1].zx(), false, false, I(0))?;
841			resize_vec::<T>(&mut lu.ut_val, lu.ut_col_ptr_for_val[s + 1].zx(), false, false, zero::<T>())?;
842			lu.ut_row_idx[lu.ut_col_ptr_for_row_idx[s].zx()..lu.ut_col_ptr_for_row_idx[s + 1].zx()]
843				.copy_from_slice(&s_col_indices[..s_col_index_count]);
844			lu.ut_row_idx[lu.ut_col_ptr_for_row_idx[s].zx()..lu.ut_col_ptr_for_row_idx[s + 1].zx()].sort_unstable();
845
846			let s_col_indices = &lu.ut_row_idx[lu.ut_col_ptr_for_row_idx[s].zx()..lu.ut_col_ptr_for_row_idx[s + 1].zx()];
847			for (idx, j) in s_col_indices.iter().enumerate() {
848				col_global_to_local[j.zx()] = I(idx);
849			}
850
851			let s_U = &mut lu.ut_val[lu.ut_col_ptr_for_val[s].zx()..lu.ut_col_ptr_for_val[s + 1].zx()];
852			let mut s_U = MatMut::from_column_major_slice_mut(s_U, s_col_index_count, s_size).transpose_mut();
853			s_U.fill(zero());
854
855			for i in s_begin..s_end {
856				let pi = row_perm[i].zx();
857				for (j, val) in iter::zip(AT.row_idx_of_col(pi), AT.val_of_col(pi)) {
858					let pj = col_perm_inv[j].zx();
859					if pj < s_end {
860						continue;
861					}
862					assert!(A_leftover > 0);
863					A_leftover -= 1;
864					let ix = i - s_begin;
865					let iy = col_global_to_local[pj].zx();
866					s_U[(ix, iy)] = s_U[(ix, iy)] + *val;
867				}
868			}
869
870			noinline(UPanel, || {
871				for d in &supernode_postorder[s_postordered - desc_count..s_postordered] {
872					let d = d.zx();
873					if left_contrib[d].0.is_empty() {
874						continue;
875					}
876
877					let d_begin = supernode_ptr[d].zx();
878					let d_end = supernode_ptr[d + 1].zx();
879					let d_size = d_end - d_begin;
880
881					let d_row_idx = &left_row_idxices[lu.l_col_ptr_for_row_idx[d].zx()..lu.l_col_ptr_for_row_idx[d + 1].zx()][d_size..];
882					let d_col_ind = &lu.ut_row_idx[lu.ut_col_ptr_for_row_idx[d].zx()..lu.ut_col_ptr_for_row_idx[d + 1].zx()];
883
884					let contributes_to_u = d_row_idx
885						.iter()
886						.any(|&i| row_perm_inv[i.zx()].zx() >= s_begin && row_perm_inv[i.zx()].zx() < s_end);
887
888					if contributes_to_u {
889						let d_col_start = d_col_ind.partition_point(partition_fn(s_end));
890						let d_LU = work_to_mat_mut(&mut left_contrib[d].0, d_row_idx.len(), d_col_ind.len());
891						let mut d_LU = d_LU.get_mut(.., d_col_start..);
892						let left_contrib = &mut left_contrib[d];
893						let d_active = &mut left_contrib.1[d_col_start..];
894						let d_active_count = &mut left_contrib.2;
895						let d_active_mat = &mut left_contrib.3;
896
897						for (d_j, j) in d_col_ind[d_col_start..].iter().enumerate() {
898							if d_active[d_j] > I(0) {
899								let mut taken_rows = 0usize;
900								let j = j.zx();
901								let s_j = col_global_to_local[j].zx();
902								for (d_i, i) in d_row_idx.iter().enumerate() {
903									let i = i.zx();
904									let pi = row_perm_inv[i].zx();
905
906									if pi >= s_begin && pi < s_end {
907										let s_i = row_global_to_local[i].zx();
908										s_U[(s_i, s_j)] = s_U[(s_i, s_j)] - (d_LU[(d_i, d_j)]);
909										d_LU[(d_i, d_j)] = zero::<T>();
910										taken_rows += d_active_mat[(d_i, d_j + d_col_start)] as usize;
911										d_active_mat[(d_i, d_j + d_col_start)] = 0;
912									}
913								}
914								assert!(d_active[d_j] >= I(taken_rows));
915								d_active[d_j] -= I(taken_rows);
916								if d_active[d_j] == I(0) {
917									assert!(*d_active_count > 0);
918									*d_active_count -= 1;
919								}
920							}
921						}
922						if *d_active_count == 0 {
923							left_contrib.0.clear();
924							left_contrib.1 = alloc::vec::Vec::new();
925							left_contrib.2 = 0;
926							left_contrib.3 = MatU8::new();
927						}
928					}
929				}
930			});
931			linalg::triangular_solve::solve_unit_lower_triangular_in_place(s_L.rb().subrows(0, s_size), s_U.rb_mut(), par);
932
933			if s_row_idx_count > s_size && s_col_index_count > 0 {
934				resize_vec::<T>(
935					&mut right_contrib[0].0,
936					from_wide_checked(to_wide(I(s_row_idx_count - s_size)) * to_wide(I(s_col_index_count)))?.zx(),
937					false,
938					false,
939					zero::<T>(),
940				)?;
941				right_contrib[0]
942					.1
943					.try_reserve_exact(s_col_index_count)
944					.ok()
945					.ok_or(FaerError::OutOfMemory)?;
946				right_contrib[0].1.resize(s_col_index_count, I(s_row_idx_count - s_size));
947				right_contrib[0].2 = s_col_index_count;
948				right_contrib[0].3 = MatU8::with_dims(s_row_idx_count - s_size, s_col_index_count)?;
949
950				let mut s_LU = work_to_mat_mut(&mut right_contrib[0].0, s_row_idx_count - s_size, s_col_index_count);
951				linalg::matmul::matmul(s_LU.rb_mut(), Accum::Replace, s_L.rb().get(s_size.., ..), s_U.rb(), one::<T>(), par);
952
953				noinline(Front, || {
954					for d in &supernode_postorder[s_postordered - desc_count..s_postordered] {
955						let d = d.zx();
956						if left_contrib[d].0.is_empty() {
957							continue;
958						}
959
960						let d_begin = supernode_ptr[d].zx();
961						let d_end = supernode_ptr[d + 1].zx();
962						let d_size = d_end - d_begin;
963
964						let d_row_idx = &left_row_idxices[lu.l_col_ptr_for_row_idx[d].zx()..lu.l_col_ptr_for_row_idx[d + 1].zx()][d_size..];
965						let d_col_ind = &lu.ut_row_idx[lu.ut_col_ptr_for_row_idx[d].zx()..lu.ut_col_ptr_for_row_idx[d + 1].zx()];
966
967						let contributes_to_front = d_row_idx.iter().any(|&i| row_perm_inv[i.zx()].zx() >= s_end);
968
969						if contributes_to_front {
970							let d_col_start = d_col_ind.partition_point(partition_fn(s_end));
971							let d_LU = work_to_mat_mut(&mut left_contrib[d].0, d_row_idx.len(), d_col_ind.len());
972							let mut d_LU = d_LU.get_mut(.., d_col_start..);
973							let left_contrib = &mut left_contrib[d];
974							let d_active = &mut left_contrib.1[d_col_start..];
975							let d_active_count = &mut left_contrib.2;
976							let d_active_mat = &mut left_contrib.3;
977
978							let mut d_active_row_count = 0usize;
979							let mut first_iter = true;
980
981							for (d_j, j) in d_col_ind[d_col_start..].iter().enumerate() {
982								if d_active[d_j] > I(0) {
983									if first_iter {
984										first_iter = false;
985										for (d_i, i) in d_row_idx.iter().enumerate() {
986											let i = i.zx();
987											let pi = row_perm_inv[i].zx();
988											if (pi < s_end) || (row_global_to_local[i] == I(NONE)) {
989												continue;
990											}
991
992											d_active_rows[d_active_row_count] = I(d_i);
993											d_active_row_count += 1;
994										}
995									}
996
997									let j = j.zx();
998									let mut taken_rows = 0usize;
999
1000									let s_j = col_global_to_local[j];
1001									if s_j == I(NONE) {
1002										continue;
1003									}
1004									let s_j = s_j.zx();
1005									let mut dst = s_LU.rb_mut().col_mut(s_j);
1006									let mut src = d_LU.rb_mut().col_mut(d_j);
1007									assert!(dst.row_stride() == 1);
1008									assert!(src.row_stride() == 1);
1009
1010									for d_i in &d_active_rows[..d_active_row_count] {
1011										let d_i = d_i.zx();
1012										let i = d_row_idx[d_i].zx();
1013										let d_active_mat = &mut d_active_mat[(d_i, d_j + d_col_start)];
1014										if *d_active_mat == 0 {
1015											continue;
1016										}
1017										let s_i = row_global_to_local[i].zx() - s_size;
1018
1019										dst[s_i] = dst[s_i] + (src[d_i]);
1020										src[d_i] = zero::<T>();
1021
1022										taken_rows += 1;
1023										*d_active_mat = 0;
1024									}
1025
1026									d_active[d_j] -= I(taken_rows);
1027									if d_active[d_j] == I(0) {
1028										*d_active_count -= 1;
1029									}
1030								}
1031							}
1032							if *d_active_count == 0 {
1033								left_contrib.0.clear();
1034								left_contrib.1 = alloc::vec::Vec::new();
1035								left_contrib.2 = 0;
1036								left_contrib.3 = MatU8::new();
1037							}
1038						}
1039					}
1040				})
1041			}
1042
1043			for i in s_row_idxices.iter() {
1044				row_global_to_local[i.zx()] = I(NONE);
1045			}
1046			for j in s_col_indices.iter() {
1047				col_global_to_local[j.zx()] = I(NONE);
1048			}
1049		}
1050		assert!(A_leftover == 0);
1051
1052		for idx in &mut lu.l_row_idx[..lu.l_col_ptr_for_row_idx[n_supernodes].zx()] {
1053			*idx = row_perm_inv[idx.zx()];
1054		}
1055
1056		lu.nrows = m;
1057		lu.ncols = n;
1058		lu.nsupernodes = n_supernodes;
1059		lu.supernode_ptr.clone_from(supernode_ptr);
1060
1061		Ok(())
1062	}
1063}
1064
1065/// simplicial factorization module
1066///
1067/// a supernodal factorization is one that processes the elements of the $LU$ factors of the
1068/// input matrix by single elements, rather than by blocks. this is more efficient if the lu
1069/// factors are very sparse
1070pub mod simplicial {
1071	use super::*;
1072	use crate::assert;
1073
1074	/// $LU$ factor structure containing the symbolic and numerical representations
1075	#[derive(Debug, Clone)]
1076	pub struct SimplicialLu<I, T> {
1077		nrows: usize,
1078		ncols: usize,
1079
1080		l_col_ptr: alloc::vec::Vec<I>,
1081		l_row_idx: alloc::vec::Vec<I>,
1082		l_val: alloc::vec::Vec<T>,
1083
1084		u_col_ptr: alloc::vec::Vec<I>,
1085		u_row_idx: alloc::vec::Vec<I>,
1086		u_val: alloc::vec::Vec<T>,
1087	}
1088
1089	impl<I: Index, T> Default for SimplicialLu<I, T> {
1090		fn default() -> Self {
1091			Self::new()
1092		}
1093	}
1094
1095	impl<I: Index, T> SimplicialLu<I, T> {
1096		/// creates a new simplicial $LU$ of a $0 \times 0$ matrix
1097		#[inline]
1098		pub fn new() -> Self {
1099			Self {
1100				nrows: 0,
1101				ncols: 0,
1102
1103				l_col_ptr: alloc::vec::Vec::new(),
1104				u_col_ptr: alloc::vec::Vec::new(),
1105
1106				l_row_idx: alloc::vec::Vec::new(),
1107				u_row_idx: alloc::vec::Vec::new(),
1108
1109				l_val: alloc::vec::Vec::new(),
1110				u_val: alloc::vec::Vec::new(),
1111			}
1112		}
1113
1114		/// returns the number of rows of $A$
1115		#[inline]
1116		pub fn nrows(&self) -> usize {
1117			self.nrows
1118		}
1119
1120		/// returns the number of columns of $A$
1121		#[inline]
1122		pub fn ncols(&self) -> usize {
1123			self.ncols
1124		}
1125
1126		/// returns the $L$ factor of the $LU$ factorization. the row indices may be unsorted
1127		#[inline]
1128		pub fn l_factor_unsorted(&self) -> SparseColMatRef<'_, I, T> {
1129			SparseColMatRef::<'_, I, T>::new(
1130				unsafe { SymbolicSparseColMatRef::new_unchecked(self.nrows(), self.ncols(), &self.l_col_ptr, None, &self.l_row_idx) },
1131				&self.l_val,
1132			)
1133		}
1134
1135		/// returns the $U$ factor of the $LU$ factorization. the row indices may be unsorted
1136		#[inline]
1137		pub fn u_factor_unsorted(&self) -> SparseColMatRef<'_, I, T> {
1138			SparseColMatRef::<'_, I, T>::new(
1139				unsafe { SymbolicSparseColMatRef::new_unchecked(self.ncols(), self.ncols(), &self.u_col_ptr, None, &self.u_row_idx) },
1140				&self.u_val,
1141			)
1142		}
1143
1144		/// solves the equation $A x = \text{rhs}$ and stores the result in `rhs`, implicitly
1145		/// conjugating $A$ if needed
1146		///
1147		/// # panics
1148		/// - panics if `self.nrows() != self.ncols()`
1149		/// - panics if `rhs.nrows() != self.nrows()`
1150		#[track_caller]
1151		pub fn solve_in_place_with_conj(
1152			&self,
1153			row_perm: PermRef<'_, I>,
1154			col_perm: PermRef<'_, I>,
1155			conj_lhs: Conj,
1156			rhs: MatMut<'_, T>,
1157			par: Par,
1158			work: MatMut<'_, T>,
1159		) where
1160			T: ComplexField,
1161		{
1162			assert!(self.nrows() == self.ncols());
1163			assert!(self.nrows() == rhs.nrows());
1164			let mut X = rhs;
1165			let mut temp = work;
1166
1167			let l = self.l_factor_unsorted();
1168			let u = self.u_factor_unsorted();
1169
1170			crate::perm::permute_rows(temp.rb_mut(), X.rb(), row_perm);
1171			linalg_sp::triangular_solve::solve_unit_lower_triangular_in_place(l, conj_lhs, temp.rb_mut(), par);
1172			linalg_sp::triangular_solve::solve_upper_triangular_in_place(u, conj_lhs, temp.rb_mut(), par);
1173			crate::perm::permute_rows(X.rb_mut(), temp.rb(), col_perm.inverse());
1174		}
1175
1176		/// solves the equation $A^\top x = \text{rhs}$ and stores the result in `rhs`,
1177		/// implicitly conjugating $A$ if needed
1178		///
1179		/// # panics
1180		/// - panics if `self.nrows() != self.ncols()`
1181		/// - panics if `rhs.nrows() != self.nrows()`
1182		#[track_caller]
1183		pub fn solve_transpose_in_place_with_conj(
1184			&self,
1185			row_perm: PermRef<'_, I>,
1186			col_perm: PermRef<'_, I>,
1187			conj_lhs: Conj,
1188			rhs: MatMut<'_, T>,
1189			par: Par,
1190			work: MatMut<'_, T>,
1191		) where
1192			T: ComplexField,
1193		{
1194			assert!(all(self.nrows() == self.ncols(), self.nrows() == rhs.nrows()));
1195			let mut X = rhs;
1196			let mut temp = work;
1197
1198			let l = self.l_factor_unsorted();
1199			let u = self.u_factor_unsorted();
1200
1201			crate::perm::permute_rows(temp.rb_mut(), X.rb(), col_perm);
1202			linalg_sp::triangular_solve::solve_upper_triangular_transpose_in_place(u, conj_lhs, temp.rb_mut(), par);
1203			linalg_sp::triangular_solve::solve_unit_lower_triangular_transpose_in_place(l, conj_lhs, temp.rb_mut(), par);
1204			crate::perm::permute_rows(X.rb_mut(), temp.rb(), row_perm.inverse());
1205		}
1206	}
1207
1208	fn depth_first_search<I: Index>(
1209		marked: &mut [I],
1210		mark: I,
1211
1212		xi: &mut [I],
1213		l: SymbolicSparseColMatRef<'_, I>,
1214		row_perm_inv: &[I],
1215		b: usize,
1216		stack: &mut [I],
1217	) -> usize {
1218		let I = I::truncate;
1219
1220		let mut tail_start = xi.len();
1221		let mut head_len = 1usize;
1222		xi[0] = I(b);
1223
1224		let li = l.row_idx();
1225
1226		'dfs_loop: while head_len > 0 {
1227			let b = xi[head_len - 1].zx().zx();
1228			let pb = row_perm_inv[b].zx();
1229
1230			let range = if pb < l.ncols() { l.col_range(pb) } else { 0..0 };
1231			if marked[b] < mark {
1232				marked[b] = mark;
1233				stack[head_len - 1] = I(range.start);
1234			}
1235
1236			let start = stack[head_len - 1].zx();
1237			let end = range.end;
1238			for ptr in start..end {
1239				let i = li[ptr].zx();
1240				if marked[i] == mark {
1241					continue;
1242				}
1243				stack[head_len - 1] = I(ptr);
1244				xi[head_len] = I(i);
1245				head_len += 1;
1246				continue 'dfs_loop;
1247			}
1248
1249			head_len -= 1;
1250			tail_start -= 1;
1251			xi[tail_start] = I(b);
1252		}
1253
1254		tail_start
1255	}
1256
1257	fn reach<I: Index>(
1258		marked: &mut [I],
1259		mark: I,
1260
1261		xi: &mut [I],
1262		l: SymbolicSparseColMatRef<'_, I>,
1263		row_perm_inv: &[I],
1264		bi: &[I],
1265		stack: &mut [I],
1266	) -> usize {
1267		let n = l.nrows();
1268		let mut tail_start = n;
1269
1270		for b in bi {
1271			let b = b.zx();
1272			if marked[b] < mark {
1273				tail_start = depth_first_search(marked, mark, &mut xi[..tail_start], l, row_perm_inv, b, stack);
1274			}
1275		}
1276
1277		tail_start
1278	}
1279
1280	#[math]
1281	fn l_incomplete_solve_sparse<I: Index, T: ComplexField>(
1282		marked: &mut [I],
1283		mark: I,
1284
1285		xi: &mut [I],
1286		x: &mut [T],
1287		l: SparseColMatRef<'_, I, T>,
1288		row_perm_inv: &[I],
1289		bi: &[I],
1290		bx: &[T],
1291		stack: &mut [I],
1292	) -> usize {
1293		let tail_start = reach(marked, mark, xi, l.symbolic(), row_perm_inv, bi, stack);
1294
1295		let xi = &xi[tail_start..];
1296		for (i, b) in iter::zip(bi, bx) {
1297			let i = i.zx();
1298			x[i] = x[i] + *b;
1299		}
1300
1301		for i in xi {
1302			let i = i.zx();
1303			let pi = row_perm_inv[i].zx();
1304			if pi >= l.ncols() {
1305				continue;
1306			}
1307
1308			let li = l.row_idx_of_col_raw(pi);
1309			let lx = l.val_of_col(pi);
1310			let len = li.len();
1311
1312			let xi = copy(x[i]);
1313			for (li, lx) in iter::zip(&li[1..], &lx[1..len]) {
1314				let li = li.zx();
1315				x[li] = x[li] - *lx * xi;
1316			}
1317		}
1318
1319		tail_start
1320	}
1321
1322	/// computes the size and alignment of the workspace required to perform a numeric $LU$
1323	/// factorization
1324	pub fn factorize_simplicial_numeric_lu_scratch<I: Index, T: ComplexField>(nrows: usize, ncols: usize) -> StackReq {
1325		let idx = StackReq::new::<I>(nrows);
1326		let val = temp_mat_scratch::<T>(nrows, 1);
1327		let _ = ncols;
1328		StackReq::all_of(&[val, idx, idx, idx])
1329	}
1330
1331	/// computes the numeric values of the $LU$ factors of the matrix $A$ as well as the row
1332	/// pivoting permutation, and stores them in `lu` and `row_perm`/`row_perm_inv`
1333	#[math]
1334	pub fn factorize_simplicial_numeric_lu<I: Index, T: ComplexField>(
1335		row_perm: &mut [I],
1336		row_perm_inv: &mut [I],
1337		lu: &mut SimplicialLu<I, T>,
1338
1339		A: SparseColMatRef<'_, I, T>,
1340		col_perm: PermRef<'_, I>,
1341		stack: &mut MemStack,
1342	) -> Result<(), LuError> {
1343		let I = I::truncate;
1344
1345		assert!(all(
1346			A.nrows() == row_perm.len(),
1347			A.nrows() == row_perm_inv.len(),
1348			A.ncols() == col_perm.len(),
1349			A.nrows() == A.ncols()
1350		));
1351
1352		lu.nrows = 0;
1353		lu.ncols = 0;
1354
1355		let m = A.nrows();
1356		let n = A.ncols();
1357
1358		resize_vec(&mut lu.l_col_ptr, n + 1, true, false, I(0))?;
1359		resize_vec(&mut lu.u_col_ptr, n + 1, true, false, I(0))?;
1360
1361		let (mut x, stack) = temp_mat_zeroed::<T, _, _>(m, 1, stack);
1362		let x = x.as_mat_mut().col_mut(0).try_as_col_major_mut().unwrap().as_slice_mut();
1363
1364		let (marked, stack) = unsafe { stack.make_raw::<I>(m) };
1365		let (xj, stack) = unsafe { stack.make_raw::<I>(m) };
1366		let (stack, _) = unsafe { stack.make_raw::<I>(m) };
1367
1368		marked.fill(I(0));
1369		row_perm_inv.fill(I(n));
1370
1371		let mut l_pos = 0usize;
1372		let mut u_pos = 0usize;
1373		lu.l_col_ptr[0] = I(0);
1374		lu.u_col_ptr[0] = I(0);
1375		for j in 0..n {
1376			let l = SparseColMatRef::<'_, I, T>::new(
1377				unsafe { SymbolicSparseColMatRef::new_unchecked(m, j, &lu.l_col_ptr[..j + 1], None, &lu.l_row_idx) },
1378				&lu.l_val,
1379			);
1380
1381			let pj = col_perm.arrays().0[j].zx();
1382			let tail_start = l_incomplete_solve_sparse(
1383				marked,
1384				I(j + 1),
1385				xj,
1386				x,
1387				l,
1388				row_perm_inv,
1389				A.row_idx_of_col_raw(pj),
1390				A.val_of_col(pj),
1391				stack,
1392			);
1393			let xj = &xj[tail_start..];
1394
1395			resize_vec::<T>(&mut lu.l_val, l_pos + xj.len() + 1, false, false, zero::<T>())?;
1396			resize_vec(&mut lu.l_row_idx, l_pos + xj.len() + 1, false, false, I(0))?;
1397			resize_vec::<T>(&mut lu.u_val, u_pos + xj.len() + 1, false, false, zero::<T>())?;
1398			resize_vec(&mut lu.u_row_idx, u_pos + xj.len() + 1, false, false, I(0))?;
1399
1400			let l_val = &mut *lu.l_val;
1401			let u_val = &mut *lu.u_val;
1402
1403			let mut pivot_idx = n;
1404			let mut pivot_val = -one::<T::Real>();
1405			for i in xj {
1406				let i = i.zx();
1407				let xi = copy(x[i]);
1408				if row_perm_inv[i] == I(n) {
1409					let val = abs(xi);
1410					if matches!(val.partial_cmp(&pivot_val), None | Some(core::cmp::Ordering::Greater)) {
1411						pivot_idx = i;
1412						pivot_val = val;
1413					}
1414				} else {
1415					lu.u_row_idx[u_pos] = row_perm_inv[i];
1416					u_val[u_pos] = xi;
1417					u_pos += 1;
1418				}
1419			}
1420			if pivot_idx == n {
1421				return Err(LuError::SymbolicSingular { index: j });
1422			}
1423
1424			let x_piv = copy(x[pivot_idx]);
1425			if x_piv == zero::<T>() {
1426				panic!();
1427			}
1428			let x_piv_inv = recip(x_piv);
1429
1430			row_perm_inv[pivot_idx] = I(j);
1431
1432			lu.u_row_idx[u_pos] = I(j);
1433			u_val[u_pos] = x_piv;
1434			u_pos += 1;
1435			lu.u_col_ptr[j + 1] = I(u_pos);
1436
1437			lu.l_row_idx[l_pos] = I(pivot_idx);
1438			l_val[l_pos] = one::<T>();
1439			l_pos += 1;
1440
1441			for i in xj {
1442				let i = i.zx();
1443				let xi = copy(x[i]);
1444				if row_perm_inv[i] == I(n) {
1445					lu.l_row_idx[l_pos] = I(i);
1446					l_val[l_pos] = xi * x_piv_inv;
1447					l_pos += 1;
1448				}
1449				x[i] = zero::<T>();
1450			}
1451			lu.l_col_ptr[j + 1] = I(l_pos);
1452		}
1453
1454		for i in &mut lu.l_row_idx[..l_pos] {
1455			*i = row_perm_inv[(*i).zx()];
1456		}
1457
1458		for (idx, p) in row_perm_inv.iter().enumerate() {
1459			row_perm[p.zx()] = I(idx);
1460		}
1461
1462		lu.nrows = m;
1463		lu.ncols = n;
1464
1465		Ok(())
1466	}
1467}
1468
1469/// tuning parameters for the $LU$ symbolic factorization
1470#[derive(Copy, Clone, Debug, Default)]
1471pub struct LuSymbolicParams<'a> {
1472	/// parameters for the fill reducing column permutation
1473	pub colamd_params: colamd::Control,
1474	/// threshold for selecting the supernodal factorization
1475	pub supernodal_flop_ratio_threshold: SupernodalThreshold,
1476	/// supernodal factorization parameters
1477	pub supernodal_params: SymbolicSupernodalParams<'a>,
1478}
1479
1480/// the inner factorization used for the symbolic $LU$, either simplicial or symbolic
1481#[derive(Debug, Clone)]
1482pub enum SymbolicLuRaw<I> {
1483	/// simplicial structure
1484	Simplicial {
1485		/// number of rows of $A$
1486		nrows: usize,
1487		/// number of columns of $A$
1488		ncols: usize,
1489	},
1490	/// supernodal structure
1491	Supernodal(supernodal::SymbolicSupernodalLu<I>),
1492}
1493
1494/// the symbolic structure of a sparse $LU$ decomposition
1495#[derive(Debug, Clone)]
1496pub struct SymbolicLu<I> {
1497	raw: SymbolicLuRaw<I>,
1498	col_perm_fwd: alloc::vec::Vec<I>,
1499	col_perm_inv: alloc::vec::Vec<I>,
1500	A_nnz: usize,
1501}
1502
1503#[derive(Debug, Clone)]
1504enum NumericLuRaw<I, T> {
1505	None,
1506	Supernodal(supernodal::SupernodalLu<I, T>),
1507	Simplicial(simplicial::SimplicialLu<I, T>),
1508}
1509
1510/// structure that contains the numerical values and row pivoting permutation of the lu
1511/// decomposition
1512#[derive(Debug, Clone)]
1513pub struct NumericLu<I, T> {
1514	raw: NumericLuRaw<I, T>,
1515	row_perm_fwd: alloc::vec::Vec<I>,
1516	row_perm_inv: alloc::vec::Vec<I>,
1517}
1518
1519impl<I: Index, T> Default for NumericLu<I, T> {
1520	fn default() -> Self {
1521		Self::new()
1522	}
1523}
1524
1525impl<I: Index, T> NumericLu<I, T> {
1526	/// creates a new $LU$ of a $0\times 0$ matrix
1527	#[inline]
1528	pub fn new() -> Self {
1529		Self {
1530			raw: NumericLuRaw::None,
1531			row_perm_fwd: alloc::vec::Vec::new(),
1532			row_perm_inv: alloc::vec::Vec::new(),
1533		}
1534	}
1535}
1536
1537/// sparse $LU$ factorization wrapper
1538#[derive(Debug)]
1539pub struct LuRef<'a, I: Index, T> {
1540	symbolic: &'a SymbolicLu<I>,
1541	numeric: &'a NumericLu<I, T>,
1542}
1543impl<I: Index, T> Copy for LuRef<'_, I, T> {}
1544impl<I: Index, T> Clone for LuRef<'_, I, T> {
1545	fn clone(&self) -> Self {
1546		*self
1547	}
1548}
1549
1550impl<'a, I: Index, T> LuRef<'a, I, T> {
1551	/// creates $LU$ factors from their components
1552	///
1553	/// # safety
1554	/// the numeric part must be the output of [`SymbolicLu::factorize_numeric_lu`], called with a
1555	/// matrix having the same symbolic structure as the one used to create `symbolic`
1556	#[inline]
1557	pub unsafe fn new_unchecked(symbolic: &'a SymbolicLu<I>, numeric: &'a NumericLu<I, T>) -> Self {
1558		match (&symbolic.raw, &numeric.raw) {
1559			(SymbolicLuRaw::Simplicial { .. }, NumericLuRaw::Simplicial(_)) => {},
1560			(SymbolicLuRaw::Supernodal { .. }, NumericLuRaw::Supernodal(_)) => {},
1561			_ => panic!("incompatible symbolic and numeric variants"),
1562		}
1563		Self { symbolic, numeric }
1564	}
1565
1566	/// returns the symbolic structure of the $LU$ factorization
1567	#[inline]
1568	pub fn symbolic(self) -> &'a SymbolicLu<I> {
1569		self.symbolic
1570	}
1571
1572	/// returns the row pivoting permutation
1573	#[inline]
1574	pub fn row_perm(self) -> PermRef<'a, I> {
1575		unsafe { PermRef::new_unchecked(&self.numeric.row_perm_fwd, &self.numeric.row_perm_inv, self.symbolic.nrows()) }
1576	}
1577
1578	/// returns the fill reducing column permutation
1579	#[inline]
1580	pub fn col_perm(self) -> PermRef<'a, I> {
1581		self.symbolic.col_perm()
1582	}
1583
1584	/// solves the equation $A x = \text{rhs}$ and stores the result in `rhs`, implicitly
1585	/// conjugating $A$ if needed
1586	///
1587	/// # panics
1588	/// - panics if `self.nrows() != self.ncols()`
1589	/// - panics if `rhs.nrows() != self.nrows()`
1590	#[track_caller]
1591	pub fn solve_in_place_with_conj(self, conj: Conj, rhs: MatMut<'_, T>, par: Par, stack: &mut MemStack)
1592	where
1593		T: ComplexField,
1594	{
1595		let (mut work, _) = unsafe { temp_mat_uninit(rhs.nrows(), rhs.ncols(), stack) };
1596		let work = work.as_mat_mut();
1597		match (&self.symbolic.raw, &self.numeric.raw) {
1598			(SymbolicLuRaw::Simplicial { .. }, NumericLuRaw::Simplicial(numeric)) => {
1599				numeric.solve_in_place_with_conj(self.row_perm(), self.col_perm(), conj, rhs, par, work)
1600			},
1601			(SymbolicLuRaw::Supernodal(_), NumericLuRaw::Supernodal(numeric)) => {
1602				numeric.solve_in_place_with_conj(self.row_perm(), self.col_perm(), conj, rhs, par, work)
1603			},
1604			_ => unreachable!(),
1605		}
1606	}
1607
1608	/// solves the equation $A^\top x = \text{rhs}$ and stores the result in `rhs`,
1609	/// implicitly conjugating $A$ if needed
1610	///
1611	/// # panics
1612	/// - panics if `self.nrows() != self.ncols()`
1613	/// - panics if `rhs.nrows() != self.nrows()`
1614	#[track_caller]
1615	pub fn solve_transpose_in_place_with_conj(self, conj: Conj, rhs: MatMut<'_, T>, par: Par, stack: &mut MemStack)
1616	where
1617		T: ComplexField,
1618	{
1619		let (mut work, _) = unsafe { temp_mat_uninit(rhs.nrows(), rhs.ncols(), stack) };
1620		let work = work.as_mat_mut();
1621		match (&self.symbolic.raw, &self.numeric.raw) {
1622			(SymbolicLuRaw::Simplicial { .. }, NumericLuRaw::Simplicial(numeric)) => {
1623				numeric.solve_transpose_in_place_with_conj(self.row_perm(), self.col_perm(), conj, rhs, par, work)
1624			},
1625			(SymbolicLuRaw::Supernodal(_), NumericLuRaw::Supernodal(numeric)) => {
1626				numeric.solve_transpose_in_place_with_conj(self.row_perm(), self.col_perm(), conj, rhs, par, work)
1627			},
1628			_ => unreachable!(),
1629		}
1630	}
1631}
1632
1633impl<I: Index> SymbolicLu<I> {
1634	/// returns the number of rows of $A$
1635	#[inline]
1636	pub fn nrows(&self) -> usize {
1637		match &self.raw {
1638			SymbolicLuRaw::Simplicial { nrows, .. } => *nrows,
1639			SymbolicLuRaw::Supernodal(this) => this.nrows,
1640		}
1641	}
1642
1643	/// returns the number of columns of $A$
1644	#[inline]
1645	pub fn ncols(&self) -> usize {
1646		match &self.raw {
1647			SymbolicLuRaw::Simplicial { ncols, .. } => *ncols,
1648			SymbolicLuRaw::Supernodal(this) => this.ncols,
1649		}
1650	}
1651
1652	/// returns the fill-reducing column permutation that was computed during symbolic analysis
1653	#[inline]
1654	pub fn col_perm(&self) -> PermRef<'_, I> {
1655		unsafe { PermRef::new_unchecked(&self.col_perm_fwd, &self.col_perm_inv, self.ncols()) }
1656	}
1657
1658	/// computes the size and alignment of the workspace required to compute the numerical $LU$
1659	/// factorization
1660	pub fn factorize_numeric_lu_scratch<T>(&self, par: Par, params: Spec<PartialPivLuParams, T>) -> StackReq
1661	where
1662		T: ComplexField,
1663	{
1664		match &self.raw {
1665			SymbolicLuRaw::Simplicial { nrows, ncols } => simplicial::factorize_simplicial_numeric_lu_scratch::<I, T>(*nrows, *ncols),
1666			SymbolicLuRaw::Supernodal(symbolic) => {
1667				let _ = par;
1668				let m = symbolic.nrows;
1669
1670				let A_nnz = self.A_nnz;
1671				let AT_scratch = StackReq::all_of(&[temp_mat_scratch::<T>(A_nnz, 1), StackReq::new::<I>(m + 1), StackReq::new::<I>(A_nnz)]);
1672				StackReq::and(AT_scratch, supernodal::factorize_supernodal_numeric_lu_scratch::<I, T>(symbolic, params))
1673			},
1674		}
1675	}
1676
1677	/// computes the size and alignment of the workspace required to solve the equation $A x = b$
1678	pub fn solve_in_place_scratch<T>(&self, rhs_ncols: usize, par: Par) -> StackReq
1679	where
1680		T: ComplexField,
1681	{
1682		let _ = par;
1683		temp_mat_scratch::<T>(self.nrows(), rhs_ncols)
1684	}
1685
1686	/// computes the size and alignment of the workspace required to solve the equation
1687	/// $A^\top x = b$
1688	pub fn solve_transpose_in_place_scratch<T>(&self, rhs_ncols: usize, par: Par) -> StackReq
1689	where
1690		T: ComplexField,
1691	{
1692		let _ = par;
1693		temp_mat_scratch::<T>(self.nrows(), rhs_ncols)
1694	}
1695
1696	/// computes a numerical $LU$ factorization of $A$
1697	#[track_caller]
1698	pub fn factorize_numeric_lu<'out, T: ComplexField>(
1699		&'out self,
1700		numeric: &'out mut NumericLu<I, T>,
1701		A: SparseColMatRef<'_, I, T>,
1702		par: Par,
1703		stack: &mut MemStack,
1704		params: Spec<PartialPivLuParams, T>,
1705	) -> Result<LuRef<'out, I, T>, LuError> {
1706		if matches!(self.raw, SymbolicLuRaw::Simplicial { .. }) && !matches!(numeric.raw, NumericLuRaw::Simplicial(_)) {
1707			numeric.raw = NumericLuRaw::Simplicial(simplicial::SimplicialLu::new());
1708		}
1709		if matches!(self.raw, SymbolicLuRaw::Supernodal(_)) && !matches!(numeric.raw, NumericLuRaw::Supernodal(_)) {
1710			numeric.raw = NumericLuRaw::Supernodal(supernodal::SupernodalLu::new());
1711		}
1712
1713		let nrows = self.nrows();
1714
1715		numeric
1716			.row_perm_fwd
1717			.try_reserve_exact(nrows.saturating_sub(numeric.row_perm_fwd.len()))
1718			.ok()
1719			.ok_or(FaerError::OutOfMemory)?;
1720		numeric
1721			.row_perm_inv
1722			.try_reserve_exact(nrows.saturating_sub(numeric.row_perm_inv.len()))
1723			.ok()
1724			.ok_or(FaerError::OutOfMemory)?;
1725		numeric.row_perm_fwd.resize(nrows, I::truncate(0));
1726		numeric.row_perm_inv.resize(nrows, I::truncate(0));
1727
1728		match (&self.raw, &mut numeric.raw) {
1729			(SymbolicLuRaw::Simplicial { nrows, ncols }, NumericLuRaw::Simplicial(lu)) => {
1730				assert!(all(A.nrows() == *nrows, A.ncols() == *ncols));
1731
1732				simplicial::factorize_simplicial_numeric_lu(&mut numeric.row_perm_fwd, &mut numeric.row_perm_inv, lu, A, self.col_perm(), stack)?;
1733			},
1734			(SymbolicLuRaw::Supernodal(symbolic), NumericLuRaw::Supernodal(lu)) => {
1735				let m = symbolic.nrows;
1736				let (new_col_ptr, stack) = unsafe { stack.make_raw::<I>(m + 1) };
1737				let (new_row_idx, stack) = unsafe { stack.make_raw::<I>(self.A_nnz) };
1738				let (mut new_values, stack) = unsafe { temp_mat_uninit::<T, _, _>(self.A_nnz, 1, stack) };
1739				let new_values = new_values.as_mat_mut().col_mut(0).try_as_col_major_mut().unwrap().as_slice_mut();
1740				let AT = utils::transpose(new_values, new_col_ptr, new_row_idx, A, stack).into_const();
1741
1742				supernodal::factorize_supernodal_numeric_lu(
1743					&mut numeric.row_perm_fwd,
1744					&mut numeric.row_perm_inv,
1745					lu,
1746					A,
1747					AT,
1748					self.col_perm(),
1749					symbolic,
1750					par,
1751					stack,
1752					params,
1753				)?;
1754			},
1755			_ => unreachable!(),
1756		}
1757
1758		Ok(unsafe { LuRef::new_unchecked(self, numeric) })
1759	}
1760}
1761
1762/// computes the symbolic $LU$ factorization of the matrix $A$, or returns an error if the
1763/// operation could not be completed
1764#[track_caller]
1765pub fn factorize_symbolic_lu<I: Index>(A: SymbolicSparseColMatRef<'_, I>, params: LuSymbolicParams<'_>) -> Result<SymbolicLu<I>, FaerError> {
1766	assert!(A.nrows() == A.ncols());
1767	let m = A.nrows();
1768	let n = A.ncols();
1769	let A_nnz = A.compute_nnz();
1770
1771	with_dim!(M, m);
1772	with_dim!(N, n);
1773
1774	let A = A.as_shape(M, N);
1775
1776	let req = {
1777		let n_scratch = StackReq::new::<I>(n);
1778		let m_scratch = StackReq::new::<I>(m);
1779		let AT_scratch = StackReq::and(
1780			// new_col_ptr
1781			StackReq::new::<I>(m + 1),
1782			// new_row_idx
1783			StackReq::new::<I>(A_nnz),
1784		);
1785
1786		StackReq::or(
1787			linalg_sp::colamd::order_scratch::<I>(m, n, A_nnz),
1788			StackReq::all_of(&[
1789				n_scratch,
1790				n_scratch,
1791				n_scratch,
1792				n_scratch,
1793				AT_scratch,
1794				StackReq::any_of(&[
1795					StackReq::and(n_scratch, m_scratch),
1796					StackReq::all_of(&[n_scratch; 3]),
1797					StackReq::all_of(&[n_scratch, n_scratch, n_scratch, n_scratch, n_scratch, m_scratch]),
1798					supernodal::factorize_supernodal_symbolic_lu_scratch::<I>(m, n),
1799				]),
1800			]),
1801		)
1802	};
1803
1804	let mut mem = dyn_stack::MemBuffer::try_new(req).ok().ok_or(FaerError::OutOfMemory)?;
1805	let stack = MemStack::new(&mut mem);
1806
1807	let mut col_perm_fwd = try_zeroed::<I>(n)?;
1808	let mut col_perm_inv = try_zeroed::<I>(n)?;
1809	let mut min_row = try_zeroed::<I>(m)?;
1810
1811	linalg_sp::colamd::order(&mut col_perm_fwd, &mut col_perm_inv, A.as_dyn(), params.colamd_params, stack)?;
1812
1813	let col_perm = PermRef::new_checked(&col_perm_fwd, &col_perm_inv, n).as_shape(N);
1814
1815	let (new_col_ptr, stack) = unsafe { stack.make_raw::<I>(m + 1) };
1816	let (new_row_idx, stack) = unsafe { stack.make_raw::<I>(A_nnz) };
1817	let AT = utils::adjoint(
1818		Symbolic::materialize(new_row_idx.len()),
1819		new_col_ptr,
1820		new_row_idx,
1821		SparseColMatRef::new(A, Symbolic::materialize(A.row_idx().len())),
1822		stack,
1823	)
1824	.symbolic();
1825
1826	let (etree, stack) = unsafe { stack.make_raw::<I::Signed>(n) };
1827	let (post, stack) = unsafe { stack.make_raw::<I>(n) };
1828	let (col_counts, stack) = unsafe { stack.make_raw::<I>(n) };
1829	let (h_col_counts, stack) = unsafe { stack.make_raw::<I>(n) };
1830
1831	linalg_sp::qr::ghost_col_etree(A, Some(col_perm), Array::from_mut(etree, N), stack);
1832	let etree_ = Array::from_ref(MaybeIdx::<'_, I>::from_slice_ref_checked(etree, N), N);
1833	linalg_sp::cholesky::ghost_postorder(Array::from_mut(post, N), etree_, stack);
1834
1835	linalg_sp::qr::ghost_column_counts_aat(
1836		Array::from_mut(col_counts, N),
1837		Array::from_mut(bytemuck::cast_slice_mut(&mut min_row), M),
1838		AT,
1839		Some(col_perm),
1840		etree_,
1841		Array::from_ref(Idx::from_slice_ref_checked(post, N), N),
1842		stack,
1843	);
1844	let min_col = min_row;
1845
1846	let mut threshold = params.supernodal_flop_ratio_threshold;
1847	if threshold != SupernodalThreshold::FORCE_SIMPLICIAL && threshold != SupernodalThreshold::FORCE_SUPERNODAL {
1848		h_col_counts.fill(I::truncate(0));
1849		for i in 0..m {
1850			let min_col = min_col[i];
1851			if min_col.to_signed() < I::Signed::truncate(0) {
1852				continue;
1853			}
1854			h_col_counts[min_col.zx()] += I::truncate(1);
1855		}
1856		for j in 0..n {
1857			let parent = etree[j];
1858			if parent < I::Signed::truncate(0) {
1859				continue;
1860			}
1861			h_col_counts[parent.zx()] += h_col_counts[j] - I::truncate(1);
1862		}
1863
1864		let mut nnz = 0.0f64;
1865		let mut flops = 0.0f64;
1866		for j in 0..n {
1867			let hj = h_col_counts[j].zx() as f64;
1868			let rj = col_counts[j].zx() as f64;
1869			flops += hj + hj * rj;
1870			nnz += hj + rj;
1871		}
1872
1873		if flops / nnz > threshold.0 * linalg_sp::LU_SUPERNODAL_RATIO_FACTOR {
1874			threshold = SupernodalThreshold::FORCE_SUPERNODAL;
1875		} else {
1876			threshold = SupernodalThreshold::FORCE_SIMPLICIAL;
1877		}
1878	}
1879
1880	if threshold == SupernodalThreshold::FORCE_SUPERNODAL {
1881		let symbolic = supernodal::factorize_supernodal_symbolic_lu::<I>(
1882			A.as_dyn(),
1883			Some(col_perm.as_shape(n)),
1884			&min_col,
1885			EliminationTreeRef::<'_, I> { inner: etree },
1886			col_counts,
1887			stack,
1888			params.supernodal_params,
1889		)?;
1890		Ok(SymbolicLu {
1891			raw: SymbolicLuRaw::Supernodal(symbolic),
1892			col_perm_fwd,
1893			col_perm_inv,
1894			A_nnz,
1895		})
1896	} else {
1897		Ok(SymbolicLu {
1898			raw: SymbolicLuRaw::Simplicial { nrows: m, ncols: n },
1899			col_perm_fwd,
1900			col_perm_inv,
1901			A_nnz,
1902		})
1903	}
1904}
1905
1906#[cfg(test)]
1907mod tests {
1908	use super::*;
1909	use crate::assert;
1910	use crate::stats::prelude::*;
1911	use dyn_stack::MemBuffer;
1912	use linalg_sp::cholesky::tests::load_mtx;
1913	use matrix_market_rs::MtxData;
1914	use std::path::PathBuf;
1915
1916	#[test]
1917	fn test_numeric_lu_multifrontal() {
1918		type T = c64;
1919
1920		let (m, n, col_ptr, row_idx, val) =
1921			load_mtx::<usize>(MtxData::from_file(PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test_data/sparse_lu/YAO.mtx")).unwrap());
1922
1923		let mut rng = StdRng::seed_from_u64(0);
1924		let mut gen = || T::new(rng.gen::<f64>(), rng.gen::<f64>());
1925
1926		let val = val.iter().map(|_| gen()).collect::<alloc::vec::Vec<_>>();
1927		let A = SparseColMatRef::<'_, usize, T>::new(SymbolicSparseColMatRef::new_checked(m, n, &col_ptr, None, &row_idx), &val);
1928		let mut row_perm = vec![0usize; n];
1929		let mut row_perm_inv = vec![0usize; n];
1930		let mut col_perm = vec![0usize; n];
1931		let mut col_perm_inv = vec![0usize; n];
1932		for i in 0..n {
1933			col_perm[i] = i;
1934			col_perm_inv[i] = i;
1935		}
1936		let col_perm = PermRef::<'_, usize>::new_checked(&col_perm, &col_perm_inv, n);
1937
1938		let mut etree = vec![0usize; n];
1939		let mut min_col = vec![0usize; m];
1940		let mut col_counts = vec![0usize; n];
1941
1942		let nnz = A.compute_nnz();
1943		let mut new_col_ptr = vec![0usize; m + 1];
1944		let mut new_row_idx = vec![0usize; nnz];
1945		let mut new_values = vec![zero::<T>(); nnz];
1946		let AT = utils::transpose(
1947			&mut *new_values,
1948			&mut new_col_ptr,
1949			&mut new_row_idx,
1950			A,
1951			MemStack::new(&mut MemBuffer::new(StackReq::new::<usize>(m))),
1952		)
1953		.into_const();
1954
1955		let etree = {
1956			let mut post = vec![0usize; n];
1957
1958			let etree = linalg_sp::qr::col_etree(
1959				A.symbolic(),
1960				Some(col_perm),
1961				&mut etree,
1962				MemStack::new(&mut MemBuffer::new(StackReq::new::<usize>(m + n))),
1963			);
1964			linalg_sp::qr::postorder(&mut post, etree, MemStack::new(&mut MemBuffer::new(StackReq::new::<usize>(3 * n))));
1965			linalg_sp::qr::column_counts_ata(
1966				&mut col_counts,
1967				&mut min_col,
1968				AT.symbolic(),
1969				Some(col_perm),
1970				etree,
1971				&post,
1972				MemStack::new(&mut MemBuffer::new(StackReq::new::<usize>(5 * n + m))),
1973			);
1974			etree
1975		};
1976
1977		let symbolic = linalg_sp::lu::supernodal::factorize_supernodal_symbolic_lu::<usize>(
1978			A.symbolic(),
1979			Some(col_perm),
1980			&min_col,
1981			etree,
1982			&col_counts,
1983			MemStack::new(&mut MemBuffer::new(super::supernodal::factorize_supernodal_symbolic_lu_scratch::<usize>(
1984				m, n,
1985			))),
1986			linalg_sp::SymbolicSupernodalParams {
1987				relax: Some(&[(4, 1.0), (16, 0.8), (48, 0.1), (usize::MAX, 0.05)]),
1988			},
1989		)
1990		.unwrap();
1991
1992		let mut lu = supernodal::SupernodalLu::<usize, T>::new();
1993		supernodal::factorize_supernodal_numeric_lu(
1994			&mut row_perm,
1995			&mut row_perm_inv,
1996			&mut lu,
1997			A,
1998			AT,
1999			col_perm,
2000			&symbolic,
2001			Par::Seq,
2002			MemStack::new(&mut MemBuffer::new(supernodal::factorize_supernodal_numeric_lu_scratch::<usize, T>(
2003				&symbolic,
2004				Default::default(),
2005			))),
2006			Default::default(),
2007		)
2008		.unwrap();
2009
2010		let k = 2;
2011		let rhs = Mat::from_fn(n, k, |_, _| gen());
2012
2013		let mut work = rhs.clone();
2014		let A_dense = A.to_dense();
2015		let row_perm = PermRef::<'_, _>::new_checked(&row_perm, &row_perm_inv, m);
2016
2017		{
2018			let mut x = rhs.clone();
2019
2020			lu.solve_in_place_with_conj(row_perm, col_perm, Conj::No, x.as_mut(), Par::Seq, work.as_mut());
2021			assert!((&A_dense * &x - &rhs).norm_max() < 1e-10);
2022		}
2023		{
2024			let mut x = rhs.clone();
2025
2026			lu.solve_in_place_with_conj(row_perm, col_perm, Conj::Yes, x.as_mut(), Par::Seq, work.as_mut());
2027			assert!((A_dense.conjugate() * &x - &rhs).norm_max() < 1e-10);
2028		}
2029		{
2030			let mut x = rhs.clone();
2031
2032			lu.solve_transpose_in_place_with_conj(row_perm, col_perm, Conj::No, x.as_mut(), Par::Seq, work.as_mut());
2033			assert!((A_dense.transpose() * &x - &rhs).norm_max() < 1e-10);
2034		}
2035		{
2036			let mut x = rhs.clone();
2037
2038			lu.solve_transpose_in_place_with_conj(row_perm, col_perm, Conj::Yes, x.as_mut(), Par::Seq, work.as_mut());
2039			assert!((A_dense.adjoint() * &x - &rhs).norm_max() < 1e-10);
2040		}
2041	}
2042
2043	#[test]
2044	fn test_numeric_lu_simplicial() {
2045		type T = c64;
2046
2047		let (m, n, col_ptr, row_idx, val) =
2048			load_mtx::<usize>(MtxData::from_file(PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test_data/sparse_lu/YAO.mtx")).unwrap());
2049
2050		let mut rng = StdRng::seed_from_u64(0);
2051		let mut gen = || T::new(rng.gen::<f64>(), rng.gen::<f64>());
2052
2053		let val = val.iter().map(|_| gen()).collect::<alloc::vec::Vec<_>>();
2054		let A = SparseColMatRef::<'_, usize, T>::new(SymbolicSparseColMatRef::new_checked(m, n, &col_ptr, None, &row_idx), &val);
2055		let mut row_perm = vec![0usize; n];
2056		let mut row_perm_inv = vec![0usize; n];
2057		let mut col_perm = vec![0usize; n];
2058		let mut col_perm_inv = vec![0usize; n];
2059		for i in 0..n {
2060			col_perm[i] = i;
2061			col_perm_inv[i] = i;
2062		}
2063		let col_perm = PermRef::<'_, usize>::new_checked(&col_perm, &col_perm_inv, n);
2064
2065		let mut lu = simplicial::SimplicialLu::<usize, T>::new();
2066		simplicial::factorize_simplicial_numeric_lu(
2067			&mut row_perm,
2068			&mut row_perm_inv,
2069			&mut lu,
2070			A,
2071			col_perm,
2072			MemStack::new(&mut MemBuffer::new(simplicial::factorize_simplicial_numeric_lu_scratch::<usize, T>(m, n))),
2073		)
2074		.unwrap();
2075
2076		let k = 1;
2077		let rhs = Mat::from_fn(n, k, |_, _| gen());
2078
2079		let mut work = rhs.clone();
2080		let A_dense = A.to_dense();
2081		let row_perm = PermRef::<'_, _>::new_checked(&row_perm, &row_perm_inv, m);
2082
2083		{
2084			let mut x = rhs.clone();
2085
2086			lu.solve_in_place_with_conj(row_perm, col_perm, Conj::No, x.as_mut(), Par::Seq, work.as_mut());
2087			assert!((&A_dense * &x - &rhs).norm_max() < 1e-10);
2088		}
2089		{
2090			let mut x = rhs.clone();
2091
2092			lu.solve_in_place_with_conj(row_perm, col_perm, Conj::Yes, x.as_mut(), Par::Seq, work.as_mut());
2093			assert!((A_dense.conjugate() * &x - &rhs).norm_max() < 1e-10);
2094		}
2095
2096		{
2097			let mut x = rhs.clone();
2098
2099			lu.solve_transpose_in_place_with_conj(row_perm, col_perm, Conj::No, x.as_mut(), Par::Seq, work.as_mut());
2100			assert!((A_dense.transpose() * &x - &rhs).norm_max() < 1e-10);
2101		}
2102		{
2103			let mut x = rhs.clone();
2104
2105			lu.solve_transpose_in_place_with_conj(row_perm, col_perm, Conj::Yes, x.as_mut(), Par::Seq, work.as_mut());
2106			assert!((A_dense.adjoint() * &x - &rhs).norm_max() < 1e-10);
2107		}
2108	}
2109
2110	#[test]
2111	fn test_solver_lu_simplicial() {
2112		type T = c64;
2113
2114		let (m, n, col_ptr, row_idx, val) =
2115			load_mtx::<usize>(MtxData::from_file(PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test_data/sparse_lu/YAO.mtx")).unwrap());
2116
2117		let mut rng = StdRng::seed_from_u64(0);
2118		let mut gen = || T::new(rng.gen::<f64>(), rng.gen::<f64>());
2119
2120		let val = val.iter().map(|_| gen()).collect::<alloc::vec::Vec<_>>();
2121		let A = SparseColMatRef::<'_, usize, T>::new(SymbolicSparseColMatRef::new_checked(m, n, &col_ptr, None, &row_idx), &val);
2122
2123		let rhs = Mat::<T>::from_fn(m, 6, |_, _| gen());
2124
2125		for supernodal_flop_ratio_threshold in [
2126			SupernodalThreshold::AUTO,
2127			SupernodalThreshold::FORCE_SUPERNODAL,
2128			SupernodalThreshold::FORCE_SIMPLICIAL,
2129		] {
2130			let symbolic = factorize_symbolic_lu(
2131				A.symbolic(),
2132				LuSymbolicParams {
2133					supernodal_flop_ratio_threshold,
2134					..Default::default()
2135				},
2136			)
2137			.unwrap();
2138			let mut numeric = NumericLu::<usize, T>::new();
2139			let lu = symbolic
2140				.factorize_numeric_lu(
2141					&mut numeric,
2142					A,
2143					Par::Seq,
2144					MemStack::new(&mut MemBuffer::new(
2145						symbolic.factorize_numeric_lu_scratch::<T>(Par::Seq, Default::default()),
2146					)),
2147					Default::default(),
2148				)
2149				.unwrap();
2150
2151			{
2152				let mut x = rhs.clone();
2153				lu.solve_in_place_with_conj(
2154					crate::Conj::No,
2155					x.as_mut(),
2156					Par::Seq,
2157					MemStack::new(&mut MemBuffer::new(symbolic.solve_in_place_scratch::<T>(rhs.ncols(), Par::Seq))),
2158				);
2159
2160				let linsolve_diff = A * &x - &rhs;
2161				assert!(linsolve_diff.norm_max() <= 1e-10);
2162			}
2163			{
2164				let mut x = rhs.clone();
2165				lu.solve_in_place_with_conj(
2166					crate::Conj::Yes,
2167					x.as_mut(),
2168					Par::Seq,
2169					MemStack::new(&mut MemBuffer::new(symbolic.solve_in_place_scratch::<T>(rhs.ncols(), Par::Seq))),
2170				);
2171
2172				let linsolve_diff = A.conjugate() * &x - &rhs;
2173				assert!(linsolve_diff.norm_max() <= 1e-10);
2174			}
2175
2176			{
2177				let mut x = rhs.clone();
2178				lu.solve_transpose_in_place_with_conj(
2179					crate::Conj::No,
2180					x.as_mut(),
2181					Par::Seq,
2182					MemStack::new(&mut MemBuffer::new(symbolic.solve_transpose_in_place_scratch::<T>(rhs.ncols(), Par::Seq))),
2183				);
2184
2185				let linsolve_diff = A.transpose() * &x - &rhs;
2186				assert!(linsolve_diff.norm_max() <= 1e-10);
2187			}
2188			{
2189				let mut x = rhs.clone();
2190				lu.solve_transpose_in_place_with_conj(
2191					crate::Conj::Yes,
2192					x.as_mut(),
2193					Par::Seq,
2194					MemStack::new(&mut MemBuffer::new(symbolic.solve_transpose_in_place_scratch::<T>(rhs.ncols(), Par::Seq))),
2195				);
2196
2197				let linsolve_diff = A.adjoint() * &x - &rhs;
2198				assert!(linsolve_diff.norm_max() <= 1e-10);
2199			}
2200		}
2201	}
2202}