1use crate::assert;
11use crate::internal_prelude_sp::*;
12use crate::sparse::utils;
13use linalg::lu::partial_pivoting::factor::PartialPivLuParams;
14use linalg_sp::cholesky::simplicial::EliminationTreeRef;
15use linalg_sp::{LuError, SupernodalThreshold, SymbolicSupernodalParams, colamd};
16
17#[inline(never)]
18fn resize_vec<T: Clone>(v: &mut alloc::vec::Vec<T>, n: usize, exact: bool, reserve_only: bool, value: T) -> Result<(), FaerError> {
19 let reserve = if exact {
20 alloc::vec::Vec::try_reserve_exact
21 } else {
22 alloc::vec::Vec::try_reserve
23 };
24 reserve(v, n.saturating_sub(v.len())).map_err(|_| FaerError::OutOfMemory)?;
25 if !reserve_only {
26 v.resize(Ord::max(n, v.len()), value);
27 }
28 Ok(())
29}
30
31pub mod supernodal {
37 use super::*;
38 use crate::assert;
39
40 #[derive(Debug, Clone)]
42 pub struct SymbolicSupernodalLu<I> {
43 pub(super) supernode_ptr: alloc::vec::Vec<I>,
44 pub(super) super_etree: alloc::vec::Vec<I>,
45 pub(super) supernode_postorder: alloc::vec::Vec<I>,
46 pub(super) supernode_postorder_inv: alloc::vec::Vec<I>,
47 pub(super) descendant_count: alloc::vec::Vec<I>,
48 pub(super) nrows: usize,
49 pub(super) ncols: usize,
50 }
51
52 #[derive(Debug, Clone)]
54 pub struct SupernodalLu<I, T> {
55 nrows: usize,
56 ncols: usize,
57 nsupernodes: usize,
58
59 supernode_ptr: alloc::vec::Vec<I>,
60
61 l_col_ptr_for_row_idx: alloc::vec::Vec<I>,
62 l_col_ptr_for_val: alloc::vec::Vec<I>,
63 l_row_idx: alloc::vec::Vec<I>,
64 l_val: alloc::vec::Vec<T>,
65
66 ut_col_ptr_for_row_idx: alloc::vec::Vec<I>,
67 ut_col_ptr_for_val: alloc::vec::Vec<I>,
68 ut_row_idx: alloc::vec::Vec<I>,
69 ut_val: alloc::vec::Vec<T>,
70 }
71
72 impl<I: Index, T> Default for SupernodalLu<I, T> {
73 fn default() -> Self {
74 Self::new()
75 }
76 }
77
78 impl<I: Index, T> SupernodalLu<I, T> {
79 #[inline]
81 pub fn new() -> Self {
82 Self {
83 nrows: 0,
84 ncols: 0,
85 nsupernodes: 0,
86
87 supernode_ptr: alloc::vec::Vec::new(),
88
89 l_col_ptr_for_row_idx: alloc::vec::Vec::new(),
90 ut_col_ptr_for_row_idx: alloc::vec::Vec::new(),
91
92 l_col_ptr_for_val: alloc::vec::Vec::new(),
93 ut_col_ptr_for_val: alloc::vec::Vec::new(),
94
95 l_row_idx: alloc::vec::Vec::new(),
96 ut_row_idx: alloc::vec::Vec::new(),
97
98 l_val: alloc::vec::Vec::new(),
99 ut_val: alloc::vec::Vec::new(),
100 }
101 }
102
103 #[inline]
105 pub fn nrows(&self) -> usize {
106 self.nrows
107 }
108
109 #[inline]
111 pub fn ncols(&self) -> usize {
112 self.ncols
113 }
114
115 #[inline]
117 pub fn n_supernodes(&self) -> usize {
118 self.nsupernodes
119 }
120
121 #[track_caller]
128 pub fn solve_in_place_with_conj(
129 &self,
130 row_perm: PermRef<'_, I>,
131 col_perm: PermRef<'_, I>,
132 conj_lhs: Conj,
133 rhs: MatMut<'_, T>,
134 par: Par,
135 work: MatMut<'_, T>,
136 ) where
137 T: ComplexField,
138 {
139 assert!(all(self.nrows() == self.ncols(), self.nrows() == rhs.nrows()));
140 let mut X = rhs;
141 let mut temp = work;
142
143 crate::perm::permute_rows(temp.rb_mut(), X.rb(), row_perm);
144 self.l_solve_in_place_with_conj(conj_lhs, temp.rb_mut(), X.rb_mut(), par);
145 self.u_solve_in_place_with_conj(conj_lhs, temp.rb_mut(), X.rb_mut(), par);
146 crate::perm::permute_rows(X.rb_mut(), temp.rb(), col_perm.inverse());
147 }
148
149 #[track_caller]
156 pub fn solve_transpose_in_place_with_conj(
157 &self,
158 row_perm: PermRef<'_, I>,
159 col_perm: PermRef<'_, I>,
160 conj_lhs: Conj,
161 rhs: MatMut<'_, T>,
162 par: Par,
163 work: MatMut<'_, T>,
164 ) where
165 T: ComplexField,
166 {
167 assert!(all(self.nrows() == self.ncols(), self.nrows() == rhs.nrows()));
168 let mut X = rhs;
169 let mut temp = work;
170 crate::perm::permute_rows(temp.rb_mut(), X.rb(), col_perm);
171 self.u_solve_transpose_in_place_with_conj(conj_lhs, temp.rb_mut(), X.rb_mut(), par);
172 self.l_solve_transpose_in_place_with_conj(conj_lhs, temp.rb_mut(), X.rb_mut(), par);
173 crate::perm::permute_rows(X.rb_mut(), temp.rb(), row_perm.inverse());
174 }
175
176 #[track_caller]
177 #[math]
178 pub(crate) fn l_solve_in_place_with_conj(&self, conj_lhs: Conj, rhs: MatMut<'_, T>, mut work: MatMut<'_, T>, par: Par)
179 where
180 T: ComplexField,
181 {
182 let lu = self;
183
184 assert!(lu.nrows() == lu.ncols());
185 assert!(lu.nrows() == rhs.nrows());
186
187 let mut X = rhs;
188 let nrhs = X.ncols();
189
190 let supernode_ptr = &*lu.supernode_ptr;
191
192 for s in 0..lu.nsupernodes {
193 let s_begin = supernode_ptr[s].zx();
194 let s_end = supernode_ptr[s + 1].zx();
195 let s_size = s_end - s_begin;
196 let s_row_idx_count = lu.l_col_ptr_for_row_idx[s + 1].zx() - lu.l_col_ptr_for_row_idx[s].zx();
197
198 let L = &lu.l_val[lu.l_col_ptr_for_val[s].zx()..lu.l_col_ptr_for_val[s + 1].zx()];
199 let L = MatRef::from_column_major_slice(L, s_row_idx_count, s_size);
200 let (L_top, L_bot) = L.split_at_row(s_size);
201 linalg::triangular_solve::solve_unit_lower_triangular_in_place_with_conj(
202 L_top,
203 conj_lhs,
204 X.rb_mut().subrows_mut(s_begin, s_size),
205 par,
206 );
207 linalg::matmul::matmul_with_conj(
208 work.rb_mut().subrows_mut(0, s_row_idx_count - s_size),
209 Accum::Replace,
210 L_bot,
211 conj_lhs,
212 X.rb().subrows(s_begin, s_size),
213 Conj::No,
214 one::<T>(),
215 par,
216 );
217
218 for j in 0..nrhs {
219 for (idx, &i) in lu.l_row_idx[lu.l_col_ptr_for_row_idx[s].zx()..lu.l_col_ptr_for_row_idx[s + 1].zx()][s_size..]
220 .iter()
221 .enumerate()
222 {
223 let i = i.zx();
224 X[(i, j)] = X[(i, j)] - work[(idx, j)];
225 }
226 }
227 }
228 }
229
230 #[track_caller]
231 #[math]
232 pub(crate) fn l_solve_transpose_in_place_with_conj(&self, conj_lhs: Conj, rhs: MatMut<'_, T>, mut work: MatMut<'_, T>, par: Par)
233 where
234 T: ComplexField,
235 {
236 let lu = self;
237
238 assert!(lu.nrows() == lu.ncols());
239 assert!(lu.nrows() == rhs.nrows());
240
241 let mut X = rhs;
242 let nrhs = X.ncols();
243
244 let supernode_ptr = &*lu.supernode_ptr;
245
246 for s in (0..lu.nsupernodes).rev() {
247 let s_begin = supernode_ptr[s].zx();
248 let s_end = supernode_ptr[s + 1].zx();
249 let s_size = s_end - s_begin;
250 let s_row_idx_count = lu.l_col_ptr_for_row_idx[s + 1].zx() - lu.l_col_ptr_for_row_idx[s].zx();
251
252 let L = &lu.l_val[lu.l_col_ptr_for_val[s].zx()..lu.l_col_ptr_for_val[s + 1].zx()];
253 let L = MatRef::from_column_major_slice(L, s_row_idx_count, s_size);
254
255 let (L_top, L_bot) = L.split_at_row(s_size);
256
257 for j in 0..nrhs {
258 for (idx, &i) in lu.l_row_idx[lu.l_col_ptr_for_row_idx[s].zx()..lu.l_col_ptr_for_row_idx[s + 1].zx()][s_size..]
259 .iter()
260 .enumerate()
261 {
262 let i = i.zx();
263 work[(idx, j)] = copy(X[(i, j)]);
264 }
265 }
266
267 linalg::matmul::matmul_with_conj(
268 X.rb_mut().subrows_mut(s_begin, s_size),
269 Accum::Add,
270 L_bot.transpose(),
271 conj_lhs,
272 work.rb().subrows(0, s_row_idx_count - s_size),
273 Conj::No,
274 -one::<T>(),
275 par,
276 );
277 linalg::triangular_solve::solve_unit_upper_triangular_in_place_with_conj(
278 L_top.transpose(),
279 conj_lhs,
280 X.rb_mut().subrows_mut(s_begin, s_size),
281 par,
282 );
283 }
284 }
285
286 #[track_caller]
287 #[math]
288 pub(crate) fn u_solve_in_place_with_conj(&self, conj_lhs: Conj, rhs: MatMut<'_, T>, mut work: MatMut<'_, T>, par: Par)
289 where
290 T: ComplexField,
291 {
292 let lu = self;
293
294 assert!(lu.nrows() == lu.ncols());
295 assert!(lu.nrows() == rhs.nrows());
296
297 let mut X = rhs;
298 let nrhs = X.ncols();
299
300 let supernode_ptr = &*lu.supernode_ptr;
301
302 for s in (0..lu.nsupernodes).rev() {
303 let s_begin = supernode_ptr[s].zx();
304 let s_end = supernode_ptr[s + 1].zx();
305 let s_size = s_end - s_begin;
306 let s_row_idx_count = lu.l_col_ptr_for_row_idx[s + 1].zx() - lu.l_col_ptr_for_row_idx[s].zx();
307 let s_col_index_count = lu.ut_col_ptr_for_row_idx[s + 1].zx() - lu.ut_col_ptr_for_row_idx[s].zx();
308
309 let L = &lu.l_val[lu.l_col_ptr_for_val[s].zx()..lu.l_col_ptr_for_val[s + 1].zx()];
310 let L = MatRef::from_column_major_slice(L, s_row_idx_count, s_size);
311 let U = &lu.ut_val[lu.ut_col_ptr_for_val[s].zx()..lu.ut_col_ptr_for_val[s + 1].zx()];
312 let U_right = MatRef::from_column_major_slice(U, s_col_index_count, s_size).transpose();
313
314 for j in 0..nrhs {
315 for (idx, &i) in lu.ut_row_idx[lu.ut_col_ptr_for_row_idx[s].zx()..lu.ut_col_ptr_for_row_idx[s + 1].zx()]
316 .iter()
317 .enumerate()
318 {
319 let i = i.zx();
320 work[(idx, j)] = copy(X[(i, j)]);
321 }
322 }
323
324 let (U_left, _) = L.split_at_row(s_size);
325 linalg::matmul::matmul_with_conj(
326 X.rb_mut().subrows_mut(s_begin, s_size),
327 Accum::Add,
328 U_right,
329 conj_lhs,
330 work.rb().subrows(0, s_col_index_count),
331 Conj::No,
332 -one::<T>(),
333 par,
334 );
335 linalg::triangular_solve::solve_upper_triangular_in_place_with_conj(U_left, conj_lhs, X.rb_mut().subrows_mut(s_begin, s_size), par);
336 }
337 }
338
339 #[track_caller]
340 #[math]
341 pub(crate) fn u_solve_transpose_in_place_with_conj(&self, conj_lhs: Conj, rhs: MatMut<'_, T>, mut work: MatMut<'_, T>, par: Par)
342 where
343 T: ComplexField,
344 {
345 let lu = self;
346
347 assert!(lu.nrows() == lu.ncols());
348 assert!(lu.nrows() == rhs.nrows());
349
350 let mut X = rhs;
351 let nrhs = X.ncols();
352
353 let supernode_ptr = &*lu.supernode_ptr;
354
355 for s in 0..lu.nsupernodes {
356 let s_begin = supernode_ptr[s].zx();
357 let s_end = supernode_ptr[s + 1].zx();
358 let s_size = s_end - s_begin;
359 let s_row_idx_count = lu.l_col_ptr_for_row_idx[s + 1].zx() - lu.l_col_ptr_for_row_idx[s].zx();
360 let s_col_index_count = lu.ut_col_ptr_for_row_idx[s + 1].zx() - lu.ut_col_ptr_for_row_idx[s].zx();
361
362 let L = &lu.l_val[lu.l_col_ptr_for_val[s].zx()..lu.l_col_ptr_for_val[s + 1].zx()];
363 let L = MatRef::from_column_major_slice(L, s_row_idx_count, s_size);
364 let U = &lu.ut_val[lu.ut_col_ptr_for_val[s].zx()..lu.ut_col_ptr_for_val[s + 1].zx()];
365 let U_right = MatRef::from_column_major_slice(U, s_col_index_count, s_size).transpose();
366
367 let (U_left, _) = L.split_at_row(s_size);
368 linalg::triangular_solve::solve_lower_triangular_in_place_with_conj(
369 U_left.transpose(),
370 conj_lhs,
371 X.rb_mut().subrows_mut(s_begin, s_size),
372 par,
373 );
374 linalg::matmul::matmul_with_conj(
375 work.rb_mut().subrows_mut(0, s_col_index_count),
376 Accum::Replace,
377 U_right.transpose(),
378 conj_lhs,
379 X.rb().subrows(s_begin, s_size),
380 Conj::No,
381 one::<T>(),
382 par,
383 );
384
385 for j in 0..nrhs {
386 for (idx, &i) in lu.ut_row_idx[lu.ut_col_ptr_for_row_idx[s].zx()..lu.ut_col_ptr_for_row_idx[s + 1].zx()]
387 .iter()
388 .enumerate()
389 {
390 let i = i.zx();
391 X[(i, j)] = X[(i, j)] - work[(idx, j)];
392 }
393 }
394 }
395 }
396 }
397
398 pub fn factorize_supernodal_symbolic_lu_scratch<I: Index>(nrows: usize, ncols: usize) -> StackReq {
401 let _ = nrows;
402 linalg_sp::cholesky::supernodal::factorize_supernodal_symbolic_cholesky_scratch::<I>(ncols)
403 }
404
405 #[track_caller]
407 pub fn factorize_supernodal_symbolic_lu<I: Index>(
408 A: SymbolicSparseColMatRef<'_, I>,
409 col_perm: Option<PermRef<'_, I>>,
410 min_col: &[I],
411 etree: EliminationTreeRef<'_, I>,
412 col_counts: &[I],
413 stack: &mut MemStack,
414 params: SymbolicSupernodalParams<'_>,
415 ) -> Result<SymbolicSupernodalLu<I>, FaerError> {
416 let m = A.nrows();
417 let n = A.ncols();
418
419 with_dim!(M, m);
420 with_dim!(N, n);
421
422 let I = I::truncate;
423 let A = A.as_shape(M, N);
424 let min_col = Array::from_ref(MaybeIdx::from_slice_ref_checked(bytemuck::cast_slice(min_col), N), M);
425 let etree = etree.as_bound(N);
426
427 let L = linalg_sp::cholesky::supernodal::ghost_factorize_supernodal_symbolic(
428 A,
429 col_perm.map(|perm| perm.as_shape(N)),
430 Some(min_col),
431 linalg_sp::cholesky::supernodal::CholeskyInput::ATA,
432 etree,
433 Array::from_ref(col_counts, N),
434 stack,
435 params,
436 )?;
437 let n_supernodes = L.n_supernodes();
438 let mut super_etree = try_zeroed::<I>(n_supernodes)?;
439
440 let (index_to_super, _) = unsafe { stack.make_raw::<I>(*N) };
441
442 for s in 0..n_supernodes {
443 index_to_super[L.supernode_begin[s].zx()..L.supernode_begin[s + 1].zx()].fill(I(s));
444 }
445 for s in 0..n_supernodes {
446 let last = L.supernode_begin[s + 1].zx() - 1;
447 if let Some(parent) = etree[N.check(last)].idx() {
448 super_etree[s] = index_to_super[*parent.zx()];
449 } else {
450 super_etree[s] = I(NONE);
451 }
452 }
453
454 Ok(SymbolicSupernodalLu {
455 supernode_ptr: L.supernode_begin,
456 super_etree,
457 supernode_postorder: L.supernode_postorder,
458 supernode_postorder_inv: L.supernode_postorder_inv,
459 descendant_count: L.descendant_count,
460 nrows: *A.nrows(),
461 ncols: *A.ncols(),
462 })
463 }
464
465 struct MatU8 {
466 data: alloc::vec::Vec<u8>,
467 nrows: usize,
468 }
469 impl MatU8 {
470 fn new() -> Self {
471 Self {
472 data: alloc::vec::Vec::new(),
473 nrows: 0,
474 }
475 }
476
477 fn with_dims(nrows: usize, ncols: usize) -> Result<Self, FaerError> {
478 Ok(Self {
479 data: try_collect((0..(nrows * ncols)).map(|_| 1u8))?,
480 nrows,
481 })
482 }
483 }
484 impl core::ops::Index<(usize, usize)> for MatU8 {
485 type Output = u8;
486
487 #[inline(always)]
488 fn index(&self, (row, col): (usize, usize)) -> &Self::Output {
489 &self.data[row + col * self.nrows]
490 }
491 }
492 impl core::ops::IndexMut<(usize, usize)> for MatU8 {
493 #[inline(always)]
494 fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut Self::Output {
495 &mut self.data[row + col * self.nrows]
496 }
497 }
498
499 struct Front;
500 struct LPanel;
501 struct UPanel;
502
503 #[inline(never)]
504 fn noinline<T, R>(_: T, f: impl FnOnce() -> R) -> R {
505 f()
506 }
507
508 pub fn factorize_supernodal_numeric_lu_scratch<I: Index, T: ComplexField>(
511 symbolic: &SymbolicSupernodalLu<I>,
512 params: Spec<PartialPivLuParams, T>,
513 ) -> StackReq {
514 let m = StackReq::new::<I>(symbolic.nrows);
515 let n = StackReq::new::<I>(symbolic.ncols);
516 _ = params;
517 StackReq::and(n, m.array(5))
518 }
519
520 #[math]
523 pub fn factorize_supernodal_numeric_lu<I: Index, T: ComplexField>(
524 row_perm: &mut [I],
525 row_perm_inv: &mut [I],
526 lu: &mut SupernodalLu<I, T>,
527
528 A: SparseColMatRef<'_, I, T>,
529 AT: SparseColMatRef<'_, I, T>,
530 col_perm: PermRef<'_, I>,
531 symbolic: &SymbolicSupernodalLu<I>,
532
533 par: Par,
534 stack: &mut MemStack,
535 params: Spec<PartialPivLuParams, T>,
536 ) -> Result<(), LuError> {
537 use linalg_sp::cholesky::supernodal::partition_fn;
538 let SymbolicSupernodalLu {
539 supernode_ptr,
540 super_etree,
541 supernode_postorder,
542 supernode_postorder_inv,
543 descendant_count,
544 nrows: _,
545 ncols: _,
546 } = symbolic;
547
548 let I = I::truncate;
549 let I_checked = |x: usize| -> Result<I, FaerError> {
550 if x > I::Signed::MAX.zx() {
551 Err(FaerError::IndexOverflow)
552 } else {
553 Ok(I(x))
554 }
555 };
556 let to_wide = |x: I| -> u128 { x.zx() as _ };
557 let from_wide_checked = |x: u128| -> Result<I, FaerError> {
558 if x > I::Signed::MAX.zx() as u128 {
559 Err(FaerError::IndexOverflow)
560 } else {
561 Ok(I(x as _))
562 }
563 };
564
565 let m = A.nrows();
566 let n = A.ncols();
567 assert!(m >= n);
568 assert!(all(AT.nrows() == n, AT.ncols() == m));
569 assert!(all(row_perm.len() == m, row_perm_inv.len() == m));
570 let n_supernodes = super_etree.len();
571 assert!(supernode_postorder.len() == n_supernodes);
572 assert!(supernode_postorder_inv.len() == n_supernodes);
573 assert!(supernode_ptr.len() == n_supernodes + 1);
574 assert!(supernode_ptr[n_supernodes].zx() == n);
575
576 lu.nrows = 0;
577 lu.ncols = 0;
578 lu.nsupernodes = 0;
579 lu.supernode_ptr.clear();
580
581 let (col_global_to_local, stack) = unsafe { stack.make_raw::<I>(n) };
582 let (row_global_to_local, stack) = unsafe { stack.make_raw::<I>(m) };
583 let (marked, stack) = unsafe { stack.make_raw::<I>(m) };
584 let (indices, stack) = unsafe { stack.make_raw::<I>(m) };
585 let (transpositions, stack) = unsafe { stack.make_raw::<I>(m) };
586 let (d_active_rows, _) = unsafe { stack.make_raw::<I>(m) };
587
588 col_global_to_local.fill(I(NONE));
589 row_global_to_local.fill(I(NONE));
590
591 marked.fill(I(0));
592
593 resize_vec(&mut lu.l_col_ptr_for_row_idx, n_supernodes + 1, true, false, I(0))?;
594 resize_vec(&mut lu.ut_col_ptr_for_row_idx, n_supernodes + 1, true, false, I(0))?;
595 resize_vec(&mut lu.l_col_ptr_for_val, n_supernodes + 1, true, false, I(0))?;
596 resize_vec(&mut lu.ut_col_ptr_for_val, n_supernodes + 1, true, false, I(0))?;
597
598 lu.l_col_ptr_for_row_idx[0] = I(0);
599 lu.ut_col_ptr_for_row_idx[0] = I(0);
600 lu.l_col_ptr_for_val[0] = I(0);
601 lu.ut_col_ptr_for_val[0] = I(0);
602
603 for i in 0..m {
604 row_perm[i] = I(i);
605 }
606 for i in 0..m {
607 row_perm_inv[i] = I(i);
608 }
609
610 let (col_perm, col_perm_inv) = col_perm.arrays();
611
612 let mut contrib_work =
613 try_collect((0..n_supernodes).map(|_| (alloc::vec::Vec::<T>::new(), alloc::vec::Vec::<I>::new(), 0usize, MatU8::new())))?;
614
615 let work_to_mat_mut = |v: &mut alloc::vec::Vec<T>, nrows: usize, ncols: usize| unsafe {
616 MatMut::from_raw_parts_mut(v.as_mut_ptr(), nrows, ncols, 1, nrows as isize)
617 };
618
619 let mut A_leftover = A.compute_nnz();
620 for s in 0..n_supernodes {
621 let s_begin = supernode_ptr[s].zx();
622 let s_end = supernode_ptr[s + 1].zx();
623 let s_size = s_end - s_begin;
624
625 let s_postordered = supernode_postorder_inv[s].zx();
626 let desc_count = descendant_count[s].zx();
627 let mut s_row_idx_count = 0usize;
628 let (left_contrib, right_contrib) = contrib_work.split_at_mut(s);
629
630 let s_row_idxices = &mut *indices;
631 for j in s_begin..s_end {
633 let pj = col_perm[j].zx();
634 let row_idx = A.row_idx_of_col_raw(pj);
635 for i in row_idx {
636 let i = i.zx();
637 let pi = row_perm_inv[i].zx();
638 if pi < s_begin {
639 continue;
640 }
641 if marked[i] < I(2 * s + 1) {
642 s_row_idxices[s_row_idx_count] = I(i);
643 s_row_idx_count += 1;
644 marked[i] = I(2 * s + 1);
645 }
646 }
647 }
648
649 for d in &supernode_postorder[s_postordered - desc_count..s_postordered] {
651 let d = d.zx();
652 let d_begin = supernode_ptr[d].zx();
653 let d_end = supernode_ptr[d + 1].zx();
654 let d_size = d_end - d_begin;
655 let d_row_idx = &lu.l_row_idx[lu.l_col_ptr_for_row_idx[d].zx()..lu.l_col_ptr_for_row_idx[d + 1].zx()][d_size..];
656 let d_col_ind = &lu.ut_row_idx[lu.ut_col_ptr_for_row_idx[d].zx()..lu.ut_col_ptr_for_row_idx[d + 1].zx()];
657 let d_col_start = d_col_ind.partition_point(partition_fn(s_begin));
658
659 if d_col_start < d_col_ind.len() && d_col_ind[d_col_start].zx() < s_end {
660 for i in d_row_idx.iter() {
661 let i = i.zx();
662 let pi = row_perm_inv[i].zx();
663
664 if pi < s_begin {
665 continue;
666 }
667
668 if marked[i] < I(2 * s + 1) {
669 s_row_idxices[s_row_idx_count] = I(i);
670 s_row_idx_count += 1;
671 marked[i] = I(2 * s + 1);
672 }
673 }
674 }
675 }
676
677 lu.l_col_ptr_for_row_idx[s + 1] = I_checked(lu.l_col_ptr_for_row_idx[s].zx() + s_row_idx_count)?;
678 lu.l_col_ptr_for_val[s + 1] = from_wide_checked(to_wide(lu.l_col_ptr_for_val[s]) + ((s_row_idx_count) as u128 * s_size as u128))?;
679 resize_vec(&mut lu.l_row_idx, lu.l_col_ptr_for_row_idx[s + 1].zx(), false, false, I(0))?;
680 resize_vec::<T>(&mut lu.l_val, lu.l_col_ptr_for_val[s + 1].zx(), false, false, zero::<T>())?;
681 lu.l_row_idx[lu.l_col_ptr_for_row_idx[s].zx()..lu.l_col_ptr_for_row_idx[s + 1].zx()].copy_from_slice(&s_row_idxices[..s_row_idx_count]);
682 lu.l_row_idx[lu.l_col_ptr_for_row_idx[s].zx()..lu.l_col_ptr_for_row_idx[s + 1].zx()].sort_unstable();
683
684 let (left_row_idxices, right_row_idxices) = lu.l_row_idx.split_at_mut(lu.l_col_ptr_for_row_idx[s].zx());
685
686 let s_row_idxices = &mut right_row_idxices[0..lu.l_col_ptr_for_row_idx[s + 1].zx() - lu.l_col_ptr_for_row_idx[s].zx()];
687 for (idx, i) in s_row_idxices.iter().enumerate() {
688 row_global_to_local[i.zx()] = I(idx);
689 }
690 let s_L = &mut lu.l_val[lu.l_col_ptr_for_val[s].zx()..lu.l_col_ptr_for_val[s + 1].zx()];
691 let mut s_L = MatMut::from_column_major_slice_mut(s_L, s_row_idx_count, s_size);
692 s_L.fill(zero());
693
694 for j in s_begin..s_end {
695 let pj = col_perm[j].zx();
696 let row_idx = A.row_idx_of_col(pj);
697 let val = A.val_of_col(pj);
698
699 for (i, val) in iter::zip(row_idx, val) {
700 let pi = row_perm_inv[i].zx();
701 if pi < s_begin {
702 continue;
703 }
704 assert!(A_leftover > 0);
705 A_leftover -= 1;
706 let ix = row_global_to_local[i].zx();
707 let iy = j - s_begin;
708 s_L[(ix, iy)] = s_L[(ix, iy)] + *val;
709 }
710 }
711
712 noinline(LPanel, || {
713 for d in &supernode_postorder[s_postordered - desc_count..s_postordered] {
714 let d = d.zx();
715 if left_contrib[d].0.is_empty() {
716 continue;
717 }
718
719 let d_begin = supernode_ptr[d].zx();
720 let d_end = supernode_ptr[d + 1].zx();
721 let d_size = d_end - d_begin;
722 let d_row_idx = &left_row_idxices[lu.l_col_ptr_for_row_idx[d].zx()..lu.l_col_ptr_for_row_idx[d + 1].zx()][d_size..];
723 let d_col_ind = &lu.ut_row_idx[lu.ut_col_ptr_for_row_idx[d].zx()..lu.ut_col_ptr_for_row_idx[d + 1].zx()];
724 let d_col_start = d_col_ind.partition_point(partition_fn(s_begin));
725
726 if d_col_start < d_col_ind.len() && d_col_ind[d_col_start].zx() < s_end {
727 let d_col_mid = d_col_start + d_col_ind[d_col_start..].partition_point(partition_fn(s_end));
728
729 let mut d_LU_cols = work_to_mat_mut(&mut left_contrib[d].0, d_row_idx.len(), d_col_ind.len())
730 .subcols_mut(d_col_start, d_col_mid - d_col_start);
731 let left_contrib = &mut left_contrib[d];
732 let d_active = &mut left_contrib.1[d_col_start..];
733 let d_active_count = &mut left_contrib.2;
734 let d_active_mat = &mut left_contrib.3;
735
736 for (d_j, j) in d_col_ind[d_col_start..d_col_mid].iter().enumerate() {
737 if d_active[d_j] > I(0) {
738 let mut taken_rows = 0usize;
739 let j = j.zx();
740 let s_j = j - s_begin;
741 for (d_i, i) in d_row_idx.iter().enumerate() {
742 let i = i.zx();
743 let pi = row_perm_inv[i].zx();
744 if pi < s_begin {
745 continue;
746 }
747 let s_i = row_global_to_local[i].zx();
748
749 s_L[(s_i, s_j)] = s_L[(s_i, s_j)] - d_LU_cols[(d_i, d_j)];
750 d_LU_cols[(d_i, d_j)] = zero::<T>();
751 taken_rows += d_active_mat[(d_i, d_j + d_col_start)] as usize;
752 d_active_mat[(d_i, d_j + d_col_start)] = 0;
753 }
754 assert!(d_active[d_j] >= I(taken_rows));
755 d_active[d_j] -= I(taken_rows);
756 if d_active[d_j] == I(0) {
757 assert!(*d_active_count > 0);
758 *d_active_count -= 1;
759 }
760 }
761 }
762 if *d_active_count == 0 {
763 left_contrib.0.clear();
764 left_contrib.1 = alloc::vec::Vec::new();
765 left_contrib.2 = 0;
766 left_contrib.3 = MatU8::new();
767 }
768 }
769 }
770 });
771
772 if s_L.nrows() < s_L.ncols() {
773 return Err(LuError::SymbolicSingular {
774 index: s_begin + s_L.nrows(),
775 });
776 }
777 let transpositions = &mut transpositions[s_begin..s_end];
778 crate::linalg::lu::partial_pivoting::factor::lu_in_place_recursion(s_L.rb_mut(), 0, s_size, transpositions, par, params);
779
780 for (idx, t) in transpositions.iter().enumerate() {
781 let i_t = s_row_idxices[idx + t.zx()].zx();
782 let kk = row_perm_inv[i_t].zx();
783 row_perm.swap(s_begin + idx, row_perm_inv[i_t].zx());
784 row_perm_inv.swap(row_perm[s_begin + idx].zx(), row_perm[kk].zx());
785 s_row_idxices.swap(idx, idx + t.zx());
786 }
787 for (idx, t) in transpositions.iter().enumerate().rev() {
788 row_global_to_local.swap(s_row_idxices[idx].zx(), s_row_idxices[idx + t.zx()].zx());
789 }
790 for (idx, i) in s_row_idxices.iter().enumerate() {
791 assert!(row_global_to_local[i.zx()] == I(idx));
792 }
793
794 let s_col_indices = &mut indices[..n];
795 let mut s_col_index_count = 0usize;
796 for i in s_begin..s_end {
797 let pi = row_perm[i].zx();
798 for j in AT.row_idx_of_col(pi) {
799 let pj = col_perm_inv[j].zx();
800 if pj < s_end {
801 continue;
802 }
803 if marked[pj] < I(2 * s + 2) {
804 s_col_indices[s_col_index_count] = I(pj);
805 s_col_index_count += 1;
806 marked[pj] = I(2 * s + 2);
807 }
808 }
809 }
810
811 for d in &supernode_postorder[s_postordered - desc_count..s_postordered] {
812 let d = d.zx();
813
814 let d_begin = supernode_ptr[d].zx();
815 let d_end = supernode_ptr[d + 1].zx();
816 let d_size = d_end - d_begin;
817
818 let d_row_idx = &left_row_idxices[lu.l_col_ptr_for_row_idx[d].zx()..lu.l_col_ptr_for_row_idx[d + 1].zx()][d_size..];
819 let d_col_ind = &lu.ut_row_idx[lu.ut_col_ptr_for_row_idx[d].zx()..lu.ut_col_ptr_for_row_idx[d + 1].zx()];
820
821 let contributes_to_u = d_row_idx
822 .iter()
823 .any(|&i| row_perm_inv[i.zx()].zx() >= s_begin && row_perm_inv[i.zx()].zx() < s_end);
824
825 if contributes_to_u {
826 let d_col_start = d_col_ind.partition_point(partition_fn(s_end));
827 for j in &d_col_ind[d_col_start..] {
828 let j = j.zx();
829 if marked[j] < I(2 * s + 2) {
830 s_col_indices[s_col_index_count] = I(j);
831 s_col_index_count += 1;
832 marked[j] = I(2 * s + 2);
833 }
834 }
835 }
836 }
837
838 lu.ut_col_ptr_for_row_idx[s + 1] = I_checked(lu.ut_col_ptr_for_row_idx[s].zx() + s_col_index_count)?;
839 lu.ut_col_ptr_for_val[s + 1] = from_wide_checked(to_wide(lu.ut_col_ptr_for_val[s]) + (s_col_index_count as u128 * s_size as u128))?;
840 resize_vec(&mut lu.ut_row_idx, lu.ut_col_ptr_for_row_idx[s + 1].zx(), false, false, I(0))?;
841 resize_vec::<T>(&mut lu.ut_val, lu.ut_col_ptr_for_val[s + 1].zx(), false, false, zero::<T>())?;
842 lu.ut_row_idx[lu.ut_col_ptr_for_row_idx[s].zx()..lu.ut_col_ptr_for_row_idx[s + 1].zx()]
843 .copy_from_slice(&s_col_indices[..s_col_index_count]);
844 lu.ut_row_idx[lu.ut_col_ptr_for_row_idx[s].zx()..lu.ut_col_ptr_for_row_idx[s + 1].zx()].sort_unstable();
845
846 let s_col_indices = &lu.ut_row_idx[lu.ut_col_ptr_for_row_idx[s].zx()..lu.ut_col_ptr_for_row_idx[s + 1].zx()];
847 for (idx, j) in s_col_indices.iter().enumerate() {
848 col_global_to_local[j.zx()] = I(idx);
849 }
850
851 let s_U = &mut lu.ut_val[lu.ut_col_ptr_for_val[s].zx()..lu.ut_col_ptr_for_val[s + 1].zx()];
852 let mut s_U = MatMut::from_column_major_slice_mut(s_U, s_col_index_count, s_size).transpose_mut();
853 s_U.fill(zero());
854
855 for i in s_begin..s_end {
856 let pi = row_perm[i].zx();
857 for (j, val) in iter::zip(AT.row_idx_of_col(pi), AT.val_of_col(pi)) {
858 let pj = col_perm_inv[j].zx();
859 if pj < s_end {
860 continue;
861 }
862 assert!(A_leftover > 0);
863 A_leftover -= 1;
864 let ix = i - s_begin;
865 let iy = col_global_to_local[pj].zx();
866 s_U[(ix, iy)] = s_U[(ix, iy)] + *val;
867 }
868 }
869
870 noinline(UPanel, || {
871 for d in &supernode_postorder[s_postordered - desc_count..s_postordered] {
872 let d = d.zx();
873 if left_contrib[d].0.is_empty() {
874 continue;
875 }
876
877 let d_begin = supernode_ptr[d].zx();
878 let d_end = supernode_ptr[d + 1].zx();
879 let d_size = d_end - d_begin;
880
881 let d_row_idx = &left_row_idxices[lu.l_col_ptr_for_row_idx[d].zx()..lu.l_col_ptr_for_row_idx[d + 1].zx()][d_size..];
882 let d_col_ind = &lu.ut_row_idx[lu.ut_col_ptr_for_row_idx[d].zx()..lu.ut_col_ptr_for_row_idx[d + 1].zx()];
883
884 let contributes_to_u = d_row_idx
885 .iter()
886 .any(|&i| row_perm_inv[i.zx()].zx() >= s_begin && row_perm_inv[i.zx()].zx() < s_end);
887
888 if contributes_to_u {
889 let d_col_start = d_col_ind.partition_point(partition_fn(s_end));
890 let d_LU = work_to_mat_mut(&mut left_contrib[d].0, d_row_idx.len(), d_col_ind.len());
891 let mut d_LU = d_LU.get_mut(.., d_col_start..);
892 let left_contrib = &mut left_contrib[d];
893 let d_active = &mut left_contrib.1[d_col_start..];
894 let d_active_count = &mut left_contrib.2;
895 let d_active_mat = &mut left_contrib.3;
896
897 for (d_j, j) in d_col_ind[d_col_start..].iter().enumerate() {
898 if d_active[d_j] > I(0) {
899 let mut taken_rows = 0usize;
900 let j = j.zx();
901 let s_j = col_global_to_local[j].zx();
902 for (d_i, i) in d_row_idx.iter().enumerate() {
903 let i = i.zx();
904 let pi = row_perm_inv[i].zx();
905
906 if pi >= s_begin && pi < s_end {
907 let s_i = row_global_to_local[i].zx();
908 s_U[(s_i, s_j)] = s_U[(s_i, s_j)] - (d_LU[(d_i, d_j)]);
909 d_LU[(d_i, d_j)] = zero::<T>();
910 taken_rows += d_active_mat[(d_i, d_j + d_col_start)] as usize;
911 d_active_mat[(d_i, d_j + d_col_start)] = 0;
912 }
913 }
914 assert!(d_active[d_j] >= I(taken_rows));
915 d_active[d_j] -= I(taken_rows);
916 if d_active[d_j] == I(0) {
917 assert!(*d_active_count > 0);
918 *d_active_count -= 1;
919 }
920 }
921 }
922 if *d_active_count == 0 {
923 left_contrib.0.clear();
924 left_contrib.1 = alloc::vec::Vec::new();
925 left_contrib.2 = 0;
926 left_contrib.3 = MatU8::new();
927 }
928 }
929 }
930 });
931 linalg::triangular_solve::solve_unit_lower_triangular_in_place(s_L.rb().subrows(0, s_size), s_U.rb_mut(), par);
932
933 if s_row_idx_count > s_size && s_col_index_count > 0 {
934 resize_vec::<T>(
935 &mut right_contrib[0].0,
936 from_wide_checked(to_wide(I(s_row_idx_count - s_size)) * to_wide(I(s_col_index_count)))?.zx(),
937 false,
938 false,
939 zero::<T>(),
940 )?;
941 right_contrib[0]
942 .1
943 .try_reserve_exact(s_col_index_count)
944 .ok()
945 .ok_or(FaerError::OutOfMemory)?;
946 right_contrib[0].1.resize(s_col_index_count, I(s_row_idx_count - s_size));
947 right_contrib[0].2 = s_col_index_count;
948 right_contrib[0].3 = MatU8::with_dims(s_row_idx_count - s_size, s_col_index_count)?;
949
950 let mut s_LU = work_to_mat_mut(&mut right_contrib[0].0, s_row_idx_count - s_size, s_col_index_count);
951 linalg::matmul::matmul(s_LU.rb_mut(), Accum::Replace, s_L.rb().get(s_size.., ..), s_U.rb(), one::<T>(), par);
952
953 noinline(Front, || {
954 for d in &supernode_postorder[s_postordered - desc_count..s_postordered] {
955 let d = d.zx();
956 if left_contrib[d].0.is_empty() {
957 continue;
958 }
959
960 let d_begin = supernode_ptr[d].zx();
961 let d_end = supernode_ptr[d + 1].zx();
962 let d_size = d_end - d_begin;
963
964 let d_row_idx = &left_row_idxices[lu.l_col_ptr_for_row_idx[d].zx()..lu.l_col_ptr_for_row_idx[d + 1].zx()][d_size..];
965 let d_col_ind = &lu.ut_row_idx[lu.ut_col_ptr_for_row_idx[d].zx()..lu.ut_col_ptr_for_row_idx[d + 1].zx()];
966
967 let contributes_to_front = d_row_idx.iter().any(|&i| row_perm_inv[i.zx()].zx() >= s_end);
968
969 if contributes_to_front {
970 let d_col_start = d_col_ind.partition_point(partition_fn(s_end));
971 let d_LU = work_to_mat_mut(&mut left_contrib[d].0, d_row_idx.len(), d_col_ind.len());
972 let mut d_LU = d_LU.get_mut(.., d_col_start..);
973 let left_contrib = &mut left_contrib[d];
974 let d_active = &mut left_contrib.1[d_col_start..];
975 let d_active_count = &mut left_contrib.2;
976 let d_active_mat = &mut left_contrib.3;
977
978 let mut d_active_row_count = 0usize;
979 let mut first_iter = true;
980
981 for (d_j, j) in d_col_ind[d_col_start..].iter().enumerate() {
982 if d_active[d_j] > I(0) {
983 if first_iter {
984 first_iter = false;
985 for (d_i, i) in d_row_idx.iter().enumerate() {
986 let i = i.zx();
987 let pi = row_perm_inv[i].zx();
988 if (pi < s_end) || (row_global_to_local[i] == I(NONE)) {
989 continue;
990 }
991
992 d_active_rows[d_active_row_count] = I(d_i);
993 d_active_row_count += 1;
994 }
995 }
996
997 let j = j.zx();
998 let mut taken_rows = 0usize;
999
1000 let s_j = col_global_to_local[j];
1001 if s_j == I(NONE) {
1002 continue;
1003 }
1004 let s_j = s_j.zx();
1005 let mut dst = s_LU.rb_mut().col_mut(s_j);
1006 let mut src = d_LU.rb_mut().col_mut(d_j);
1007 assert!(dst.row_stride() == 1);
1008 assert!(src.row_stride() == 1);
1009
1010 for d_i in &d_active_rows[..d_active_row_count] {
1011 let d_i = d_i.zx();
1012 let i = d_row_idx[d_i].zx();
1013 let d_active_mat = &mut d_active_mat[(d_i, d_j + d_col_start)];
1014 if *d_active_mat == 0 {
1015 continue;
1016 }
1017 let s_i = row_global_to_local[i].zx() - s_size;
1018
1019 dst[s_i] = dst[s_i] + (src[d_i]);
1020 src[d_i] = zero::<T>();
1021
1022 taken_rows += 1;
1023 *d_active_mat = 0;
1024 }
1025
1026 d_active[d_j] -= I(taken_rows);
1027 if d_active[d_j] == I(0) {
1028 *d_active_count -= 1;
1029 }
1030 }
1031 }
1032 if *d_active_count == 0 {
1033 left_contrib.0.clear();
1034 left_contrib.1 = alloc::vec::Vec::new();
1035 left_contrib.2 = 0;
1036 left_contrib.3 = MatU8::new();
1037 }
1038 }
1039 }
1040 })
1041 }
1042
1043 for i in s_row_idxices.iter() {
1044 row_global_to_local[i.zx()] = I(NONE);
1045 }
1046 for j in s_col_indices.iter() {
1047 col_global_to_local[j.zx()] = I(NONE);
1048 }
1049 }
1050 assert!(A_leftover == 0);
1051
1052 for idx in &mut lu.l_row_idx[..lu.l_col_ptr_for_row_idx[n_supernodes].zx()] {
1053 *idx = row_perm_inv[idx.zx()];
1054 }
1055
1056 lu.nrows = m;
1057 lu.ncols = n;
1058 lu.nsupernodes = n_supernodes;
1059 lu.supernode_ptr.clone_from(supernode_ptr);
1060
1061 Ok(())
1062 }
1063}
1064
1065pub mod simplicial {
1071 use super::*;
1072 use crate::assert;
1073
1074 #[derive(Debug, Clone)]
1076 pub struct SimplicialLu<I, T> {
1077 nrows: usize,
1078 ncols: usize,
1079
1080 l_col_ptr: alloc::vec::Vec<I>,
1081 l_row_idx: alloc::vec::Vec<I>,
1082 l_val: alloc::vec::Vec<T>,
1083
1084 u_col_ptr: alloc::vec::Vec<I>,
1085 u_row_idx: alloc::vec::Vec<I>,
1086 u_val: alloc::vec::Vec<T>,
1087 }
1088
1089 impl<I: Index, T> Default for SimplicialLu<I, T> {
1090 fn default() -> Self {
1091 Self::new()
1092 }
1093 }
1094
1095 impl<I: Index, T> SimplicialLu<I, T> {
1096 #[inline]
1098 pub fn new() -> Self {
1099 Self {
1100 nrows: 0,
1101 ncols: 0,
1102
1103 l_col_ptr: alloc::vec::Vec::new(),
1104 u_col_ptr: alloc::vec::Vec::new(),
1105
1106 l_row_idx: alloc::vec::Vec::new(),
1107 u_row_idx: alloc::vec::Vec::new(),
1108
1109 l_val: alloc::vec::Vec::new(),
1110 u_val: alloc::vec::Vec::new(),
1111 }
1112 }
1113
1114 #[inline]
1116 pub fn nrows(&self) -> usize {
1117 self.nrows
1118 }
1119
1120 #[inline]
1122 pub fn ncols(&self) -> usize {
1123 self.ncols
1124 }
1125
1126 #[inline]
1128 pub fn l_factor_unsorted(&self) -> SparseColMatRef<'_, I, T> {
1129 SparseColMatRef::<'_, I, T>::new(
1130 unsafe { SymbolicSparseColMatRef::new_unchecked(self.nrows(), self.ncols(), &self.l_col_ptr, None, &self.l_row_idx) },
1131 &self.l_val,
1132 )
1133 }
1134
1135 #[inline]
1137 pub fn u_factor_unsorted(&self) -> SparseColMatRef<'_, I, T> {
1138 SparseColMatRef::<'_, I, T>::new(
1139 unsafe { SymbolicSparseColMatRef::new_unchecked(self.ncols(), self.ncols(), &self.u_col_ptr, None, &self.u_row_idx) },
1140 &self.u_val,
1141 )
1142 }
1143
1144 #[track_caller]
1151 pub fn solve_in_place_with_conj(
1152 &self,
1153 row_perm: PermRef<'_, I>,
1154 col_perm: PermRef<'_, I>,
1155 conj_lhs: Conj,
1156 rhs: MatMut<'_, T>,
1157 par: Par,
1158 work: MatMut<'_, T>,
1159 ) where
1160 T: ComplexField,
1161 {
1162 assert!(self.nrows() == self.ncols());
1163 assert!(self.nrows() == rhs.nrows());
1164 let mut X = rhs;
1165 let mut temp = work;
1166
1167 let l = self.l_factor_unsorted();
1168 let u = self.u_factor_unsorted();
1169
1170 crate::perm::permute_rows(temp.rb_mut(), X.rb(), row_perm);
1171 linalg_sp::triangular_solve::solve_unit_lower_triangular_in_place(l, conj_lhs, temp.rb_mut(), par);
1172 linalg_sp::triangular_solve::solve_upper_triangular_in_place(u, conj_lhs, temp.rb_mut(), par);
1173 crate::perm::permute_rows(X.rb_mut(), temp.rb(), col_perm.inverse());
1174 }
1175
1176 #[track_caller]
1183 pub fn solve_transpose_in_place_with_conj(
1184 &self,
1185 row_perm: PermRef<'_, I>,
1186 col_perm: PermRef<'_, I>,
1187 conj_lhs: Conj,
1188 rhs: MatMut<'_, T>,
1189 par: Par,
1190 work: MatMut<'_, T>,
1191 ) where
1192 T: ComplexField,
1193 {
1194 assert!(all(self.nrows() == self.ncols(), self.nrows() == rhs.nrows()));
1195 let mut X = rhs;
1196 let mut temp = work;
1197
1198 let l = self.l_factor_unsorted();
1199 let u = self.u_factor_unsorted();
1200
1201 crate::perm::permute_rows(temp.rb_mut(), X.rb(), col_perm);
1202 linalg_sp::triangular_solve::solve_upper_triangular_transpose_in_place(u, conj_lhs, temp.rb_mut(), par);
1203 linalg_sp::triangular_solve::solve_unit_lower_triangular_transpose_in_place(l, conj_lhs, temp.rb_mut(), par);
1204 crate::perm::permute_rows(X.rb_mut(), temp.rb(), row_perm.inverse());
1205 }
1206 }
1207
1208 fn depth_first_search<I: Index>(
1209 marked: &mut [I],
1210 mark: I,
1211
1212 xi: &mut [I],
1213 l: SymbolicSparseColMatRef<'_, I>,
1214 row_perm_inv: &[I],
1215 b: usize,
1216 stack: &mut [I],
1217 ) -> usize {
1218 let I = I::truncate;
1219
1220 let mut tail_start = xi.len();
1221 let mut head_len = 1usize;
1222 xi[0] = I(b);
1223
1224 let li = l.row_idx();
1225
1226 'dfs_loop: while head_len > 0 {
1227 let b = xi[head_len - 1].zx().zx();
1228 let pb = row_perm_inv[b].zx();
1229
1230 let range = if pb < l.ncols() { l.col_range(pb) } else { 0..0 };
1231 if marked[b] < mark {
1232 marked[b] = mark;
1233 stack[head_len - 1] = I(range.start);
1234 }
1235
1236 let start = stack[head_len - 1].zx();
1237 let end = range.end;
1238 for ptr in start..end {
1239 let i = li[ptr].zx();
1240 if marked[i] == mark {
1241 continue;
1242 }
1243 stack[head_len - 1] = I(ptr);
1244 xi[head_len] = I(i);
1245 head_len += 1;
1246 continue 'dfs_loop;
1247 }
1248
1249 head_len -= 1;
1250 tail_start -= 1;
1251 xi[tail_start] = I(b);
1252 }
1253
1254 tail_start
1255 }
1256
1257 fn reach<I: Index>(
1258 marked: &mut [I],
1259 mark: I,
1260
1261 xi: &mut [I],
1262 l: SymbolicSparseColMatRef<'_, I>,
1263 row_perm_inv: &[I],
1264 bi: &[I],
1265 stack: &mut [I],
1266 ) -> usize {
1267 let n = l.nrows();
1268 let mut tail_start = n;
1269
1270 for b in bi {
1271 let b = b.zx();
1272 if marked[b] < mark {
1273 tail_start = depth_first_search(marked, mark, &mut xi[..tail_start], l, row_perm_inv, b, stack);
1274 }
1275 }
1276
1277 tail_start
1278 }
1279
1280 #[math]
1281 fn l_incomplete_solve_sparse<I: Index, T: ComplexField>(
1282 marked: &mut [I],
1283 mark: I,
1284
1285 xi: &mut [I],
1286 x: &mut [T],
1287 l: SparseColMatRef<'_, I, T>,
1288 row_perm_inv: &[I],
1289 bi: &[I],
1290 bx: &[T],
1291 stack: &mut [I],
1292 ) -> usize {
1293 let tail_start = reach(marked, mark, xi, l.symbolic(), row_perm_inv, bi, stack);
1294
1295 let xi = &xi[tail_start..];
1296 for (i, b) in iter::zip(bi, bx) {
1297 let i = i.zx();
1298 x[i] = x[i] + *b;
1299 }
1300
1301 for i in xi {
1302 let i = i.zx();
1303 let pi = row_perm_inv[i].zx();
1304 if pi >= l.ncols() {
1305 continue;
1306 }
1307
1308 let li = l.row_idx_of_col_raw(pi);
1309 let lx = l.val_of_col(pi);
1310 let len = li.len();
1311
1312 let xi = copy(x[i]);
1313 for (li, lx) in iter::zip(&li[1..], &lx[1..len]) {
1314 let li = li.zx();
1315 x[li] = x[li] - *lx * xi;
1316 }
1317 }
1318
1319 tail_start
1320 }
1321
1322 pub fn factorize_simplicial_numeric_lu_scratch<I: Index, T: ComplexField>(nrows: usize, ncols: usize) -> StackReq {
1325 let idx = StackReq::new::<I>(nrows);
1326 let val = temp_mat_scratch::<T>(nrows, 1);
1327 let _ = ncols;
1328 StackReq::all_of(&[val, idx, idx, idx])
1329 }
1330
1331 #[math]
1334 pub fn factorize_simplicial_numeric_lu<I: Index, T: ComplexField>(
1335 row_perm: &mut [I],
1336 row_perm_inv: &mut [I],
1337 lu: &mut SimplicialLu<I, T>,
1338
1339 A: SparseColMatRef<'_, I, T>,
1340 col_perm: PermRef<'_, I>,
1341 stack: &mut MemStack,
1342 ) -> Result<(), LuError> {
1343 let I = I::truncate;
1344
1345 assert!(all(
1346 A.nrows() == row_perm.len(),
1347 A.nrows() == row_perm_inv.len(),
1348 A.ncols() == col_perm.len(),
1349 A.nrows() == A.ncols()
1350 ));
1351
1352 lu.nrows = 0;
1353 lu.ncols = 0;
1354
1355 let m = A.nrows();
1356 let n = A.ncols();
1357
1358 resize_vec(&mut lu.l_col_ptr, n + 1, true, false, I(0))?;
1359 resize_vec(&mut lu.u_col_ptr, n + 1, true, false, I(0))?;
1360
1361 let (mut x, stack) = temp_mat_zeroed::<T, _, _>(m, 1, stack);
1362 let x = x.as_mat_mut().col_mut(0).try_as_col_major_mut().unwrap().as_slice_mut();
1363
1364 let (marked, stack) = unsafe { stack.make_raw::<I>(m) };
1365 let (xj, stack) = unsafe { stack.make_raw::<I>(m) };
1366 let (stack, _) = unsafe { stack.make_raw::<I>(m) };
1367
1368 marked.fill(I(0));
1369 row_perm_inv.fill(I(n));
1370
1371 let mut l_pos = 0usize;
1372 let mut u_pos = 0usize;
1373 lu.l_col_ptr[0] = I(0);
1374 lu.u_col_ptr[0] = I(0);
1375 for j in 0..n {
1376 let l = SparseColMatRef::<'_, I, T>::new(
1377 unsafe { SymbolicSparseColMatRef::new_unchecked(m, j, &lu.l_col_ptr[..j + 1], None, &lu.l_row_idx) },
1378 &lu.l_val,
1379 );
1380
1381 let pj = col_perm.arrays().0[j].zx();
1382 let tail_start = l_incomplete_solve_sparse(
1383 marked,
1384 I(j + 1),
1385 xj,
1386 x,
1387 l,
1388 row_perm_inv,
1389 A.row_idx_of_col_raw(pj),
1390 A.val_of_col(pj),
1391 stack,
1392 );
1393 let xj = &xj[tail_start..];
1394
1395 resize_vec::<T>(&mut lu.l_val, l_pos + xj.len() + 1, false, false, zero::<T>())?;
1396 resize_vec(&mut lu.l_row_idx, l_pos + xj.len() + 1, false, false, I(0))?;
1397 resize_vec::<T>(&mut lu.u_val, u_pos + xj.len() + 1, false, false, zero::<T>())?;
1398 resize_vec(&mut lu.u_row_idx, u_pos + xj.len() + 1, false, false, I(0))?;
1399
1400 let l_val = &mut *lu.l_val;
1401 let u_val = &mut *lu.u_val;
1402
1403 let mut pivot_idx = n;
1404 let mut pivot_val = -one::<T::Real>();
1405 for i in xj {
1406 let i = i.zx();
1407 let xi = copy(x[i]);
1408 if row_perm_inv[i] == I(n) {
1409 let val = abs(xi);
1410 if matches!(val.partial_cmp(&pivot_val), None | Some(core::cmp::Ordering::Greater)) {
1411 pivot_idx = i;
1412 pivot_val = val;
1413 }
1414 } else {
1415 lu.u_row_idx[u_pos] = row_perm_inv[i];
1416 u_val[u_pos] = xi;
1417 u_pos += 1;
1418 }
1419 }
1420 if pivot_idx == n {
1421 return Err(LuError::SymbolicSingular { index: j });
1422 }
1423
1424 let x_piv = copy(x[pivot_idx]);
1425 if x_piv == zero::<T>() {
1426 panic!();
1427 }
1428 let x_piv_inv = recip(x_piv);
1429
1430 row_perm_inv[pivot_idx] = I(j);
1431
1432 lu.u_row_idx[u_pos] = I(j);
1433 u_val[u_pos] = x_piv;
1434 u_pos += 1;
1435 lu.u_col_ptr[j + 1] = I(u_pos);
1436
1437 lu.l_row_idx[l_pos] = I(pivot_idx);
1438 l_val[l_pos] = one::<T>();
1439 l_pos += 1;
1440
1441 for i in xj {
1442 let i = i.zx();
1443 let xi = copy(x[i]);
1444 if row_perm_inv[i] == I(n) {
1445 lu.l_row_idx[l_pos] = I(i);
1446 l_val[l_pos] = xi * x_piv_inv;
1447 l_pos += 1;
1448 }
1449 x[i] = zero::<T>();
1450 }
1451 lu.l_col_ptr[j + 1] = I(l_pos);
1452 }
1453
1454 for i in &mut lu.l_row_idx[..l_pos] {
1455 *i = row_perm_inv[(*i).zx()];
1456 }
1457
1458 for (idx, p) in row_perm_inv.iter().enumerate() {
1459 row_perm[p.zx()] = I(idx);
1460 }
1461
1462 lu.nrows = m;
1463 lu.ncols = n;
1464
1465 Ok(())
1466 }
1467}
1468
1469#[derive(Copy, Clone, Debug, Default)]
1471pub struct LuSymbolicParams<'a> {
1472 pub colamd_params: colamd::Control,
1474 pub supernodal_flop_ratio_threshold: SupernodalThreshold,
1476 pub supernodal_params: SymbolicSupernodalParams<'a>,
1478}
1479
1480#[derive(Debug, Clone)]
1482pub enum SymbolicLuRaw<I> {
1483 Simplicial {
1485 nrows: usize,
1487 ncols: usize,
1489 },
1490 Supernodal(supernodal::SymbolicSupernodalLu<I>),
1492}
1493
1494#[derive(Debug, Clone)]
1496pub struct SymbolicLu<I> {
1497 raw: SymbolicLuRaw<I>,
1498 col_perm_fwd: alloc::vec::Vec<I>,
1499 col_perm_inv: alloc::vec::Vec<I>,
1500 A_nnz: usize,
1501}
1502
1503#[derive(Debug, Clone)]
1504enum NumericLuRaw<I, T> {
1505 None,
1506 Supernodal(supernodal::SupernodalLu<I, T>),
1507 Simplicial(simplicial::SimplicialLu<I, T>),
1508}
1509
1510#[derive(Debug, Clone)]
1513pub struct NumericLu<I, T> {
1514 raw: NumericLuRaw<I, T>,
1515 row_perm_fwd: alloc::vec::Vec<I>,
1516 row_perm_inv: alloc::vec::Vec<I>,
1517}
1518
1519impl<I: Index, T> Default for NumericLu<I, T> {
1520 fn default() -> Self {
1521 Self::new()
1522 }
1523}
1524
1525impl<I: Index, T> NumericLu<I, T> {
1526 #[inline]
1528 pub fn new() -> Self {
1529 Self {
1530 raw: NumericLuRaw::None,
1531 row_perm_fwd: alloc::vec::Vec::new(),
1532 row_perm_inv: alloc::vec::Vec::new(),
1533 }
1534 }
1535}
1536
1537#[derive(Debug)]
1539pub struct LuRef<'a, I: Index, T> {
1540 symbolic: &'a SymbolicLu<I>,
1541 numeric: &'a NumericLu<I, T>,
1542}
1543impl<I: Index, T> Copy for LuRef<'_, I, T> {}
1544impl<I: Index, T> Clone for LuRef<'_, I, T> {
1545 fn clone(&self) -> Self {
1546 *self
1547 }
1548}
1549
1550impl<'a, I: Index, T> LuRef<'a, I, T> {
1551 #[inline]
1557 pub unsafe fn new_unchecked(symbolic: &'a SymbolicLu<I>, numeric: &'a NumericLu<I, T>) -> Self {
1558 match (&symbolic.raw, &numeric.raw) {
1559 (SymbolicLuRaw::Simplicial { .. }, NumericLuRaw::Simplicial(_)) => {},
1560 (SymbolicLuRaw::Supernodal { .. }, NumericLuRaw::Supernodal(_)) => {},
1561 _ => panic!("incompatible symbolic and numeric variants"),
1562 }
1563 Self { symbolic, numeric }
1564 }
1565
1566 #[inline]
1568 pub fn symbolic(self) -> &'a SymbolicLu<I> {
1569 self.symbolic
1570 }
1571
1572 #[inline]
1574 pub fn row_perm(self) -> PermRef<'a, I> {
1575 unsafe { PermRef::new_unchecked(&self.numeric.row_perm_fwd, &self.numeric.row_perm_inv, self.symbolic.nrows()) }
1576 }
1577
1578 #[inline]
1580 pub fn col_perm(self) -> PermRef<'a, I> {
1581 self.symbolic.col_perm()
1582 }
1583
1584 #[track_caller]
1591 pub fn solve_in_place_with_conj(self, conj: Conj, rhs: MatMut<'_, T>, par: Par, stack: &mut MemStack)
1592 where
1593 T: ComplexField,
1594 {
1595 let (mut work, _) = unsafe { temp_mat_uninit(rhs.nrows(), rhs.ncols(), stack) };
1596 let work = work.as_mat_mut();
1597 match (&self.symbolic.raw, &self.numeric.raw) {
1598 (SymbolicLuRaw::Simplicial { .. }, NumericLuRaw::Simplicial(numeric)) => {
1599 numeric.solve_in_place_with_conj(self.row_perm(), self.col_perm(), conj, rhs, par, work)
1600 },
1601 (SymbolicLuRaw::Supernodal(_), NumericLuRaw::Supernodal(numeric)) => {
1602 numeric.solve_in_place_with_conj(self.row_perm(), self.col_perm(), conj, rhs, par, work)
1603 },
1604 _ => unreachable!(),
1605 }
1606 }
1607
1608 #[track_caller]
1615 pub fn solve_transpose_in_place_with_conj(self, conj: Conj, rhs: MatMut<'_, T>, par: Par, stack: &mut MemStack)
1616 where
1617 T: ComplexField,
1618 {
1619 let (mut work, _) = unsafe { temp_mat_uninit(rhs.nrows(), rhs.ncols(), stack) };
1620 let work = work.as_mat_mut();
1621 match (&self.symbolic.raw, &self.numeric.raw) {
1622 (SymbolicLuRaw::Simplicial { .. }, NumericLuRaw::Simplicial(numeric)) => {
1623 numeric.solve_transpose_in_place_with_conj(self.row_perm(), self.col_perm(), conj, rhs, par, work)
1624 },
1625 (SymbolicLuRaw::Supernodal(_), NumericLuRaw::Supernodal(numeric)) => {
1626 numeric.solve_transpose_in_place_with_conj(self.row_perm(), self.col_perm(), conj, rhs, par, work)
1627 },
1628 _ => unreachable!(),
1629 }
1630 }
1631}
1632
1633impl<I: Index> SymbolicLu<I> {
1634 #[inline]
1636 pub fn nrows(&self) -> usize {
1637 match &self.raw {
1638 SymbolicLuRaw::Simplicial { nrows, .. } => *nrows,
1639 SymbolicLuRaw::Supernodal(this) => this.nrows,
1640 }
1641 }
1642
1643 #[inline]
1645 pub fn ncols(&self) -> usize {
1646 match &self.raw {
1647 SymbolicLuRaw::Simplicial { ncols, .. } => *ncols,
1648 SymbolicLuRaw::Supernodal(this) => this.ncols,
1649 }
1650 }
1651
1652 #[inline]
1654 pub fn col_perm(&self) -> PermRef<'_, I> {
1655 unsafe { PermRef::new_unchecked(&self.col_perm_fwd, &self.col_perm_inv, self.ncols()) }
1656 }
1657
1658 pub fn factorize_numeric_lu_scratch<T>(&self, par: Par, params: Spec<PartialPivLuParams, T>) -> StackReq
1661 where
1662 T: ComplexField,
1663 {
1664 match &self.raw {
1665 SymbolicLuRaw::Simplicial { nrows, ncols } => simplicial::factorize_simplicial_numeric_lu_scratch::<I, T>(*nrows, *ncols),
1666 SymbolicLuRaw::Supernodal(symbolic) => {
1667 let _ = par;
1668 let m = symbolic.nrows;
1669
1670 let A_nnz = self.A_nnz;
1671 let AT_scratch = StackReq::all_of(&[temp_mat_scratch::<T>(A_nnz, 1), StackReq::new::<I>(m + 1), StackReq::new::<I>(A_nnz)]);
1672 StackReq::and(AT_scratch, supernodal::factorize_supernodal_numeric_lu_scratch::<I, T>(symbolic, params))
1673 },
1674 }
1675 }
1676
1677 pub fn solve_in_place_scratch<T>(&self, rhs_ncols: usize, par: Par) -> StackReq
1679 where
1680 T: ComplexField,
1681 {
1682 let _ = par;
1683 temp_mat_scratch::<T>(self.nrows(), rhs_ncols)
1684 }
1685
1686 pub fn solve_transpose_in_place_scratch<T>(&self, rhs_ncols: usize, par: Par) -> StackReq
1689 where
1690 T: ComplexField,
1691 {
1692 let _ = par;
1693 temp_mat_scratch::<T>(self.nrows(), rhs_ncols)
1694 }
1695
1696 #[track_caller]
1698 pub fn factorize_numeric_lu<'out, T: ComplexField>(
1699 &'out self,
1700 numeric: &'out mut NumericLu<I, T>,
1701 A: SparseColMatRef<'_, I, T>,
1702 par: Par,
1703 stack: &mut MemStack,
1704 params: Spec<PartialPivLuParams, T>,
1705 ) -> Result<LuRef<'out, I, T>, LuError> {
1706 if matches!(self.raw, SymbolicLuRaw::Simplicial { .. }) && !matches!(numeric.raw, NumericLuRaw::Simplicial(_)) {
1707 numeric.raw = NumericLuRaw::Simplicial(simplicial::SimplicialLu::new());
1708 }
1709 if matches!(self.raw, SymbolicLuRaw::Supernodal(_)) && !matches!(numeric.raw, NumericLuRaw::Supernodal(_)) {
1710 numeric.raw = NumericLuRaw::Supernodal(supernodal::SupernodalLu::new());
1711 }
1712
1713 let nrows = self.nrows();
1714
1715 numeric
1716 .row_perm_fwd
1717 .try_reserve_exact(nrows.saturating_sub(numeric.row_perm_fwd.len()))
1718 .ok()
1719 .ok_or(FaerError::OutOfMemory)?;
1720 numeric
1721 .row_perm_inv
1722 .try_reserve_exact(nrows.saturating_sub(numeric.row_perm_inv.len()))
1723 .ok()
1724 .ok_or(FaerError::OutOfMemory)?;
1725 numeric.row_perm_fwd.resize(nrows, I::truncate(0));
1726 numeric.row_perm_inv.resize(nrows, I::truncate(0));
1727
1728 match (&self.raw, &mut numeric.raw) {
1729 (SymbolicLuRaw::Simplicial { nrows, ncols }, NumericLuRaw::Simplicial(lu)) => {
1730 assert!(all(A.nrows() == *nrows, A.ncols() == *ncols));
1731
1732 simplicial::factorize_simplicial_numeric_lu(&mut numeric.row_perm_fwd, &mut numeric.row_perm_inv, lu, A, self.col_perm(), stack)?;
1733 },
1734 (SymbolicLuRaw::Supernodal(symbolic), NumericLuRaw::Supernodal(lu)) => {
1735 let m = symbolic.nrows;
1736 let (new_col_ptr, stack) = unsafe { stack.make_raw::<I>(m + 1) };
1737 let (new_row_idx, stack) = unsafe { stack.make_raw::<I>(self.A_nnz) };
1738 let (mut new_values, stack) = unsafe { temp_mat_uninit::<T, _, _>(self.A_nnz, 1, stack) };
1739 let new_values = new_values.as_mat_mut().col_mut(0).try_as_col_major_mut().unwrap().as_slice_mut();
1740 let AT = utils::transpose(new_values, new_col_ptr, new_row_idx, A, stack).into_const();
1741
1742 supernodal::factorize_supernodal_numeric_lu(
1743 &mut numeric.row_perm_fwd,
1744 &mut numeric.row_perm_inv,
1745 lu,
1746 A,
1747 AT,
1748 self.col_perm(),
1749 symbolic,
1750 par,
1751 stack,
1752 params,
1753 )?;
1754 },
1755 _ => unreachable!(),
1756 }
1757
1758 Ok(unsafe { LuRef::new_unchecked(self, numeric) })
1759 }
1760}
1761
1762#[track_caller]
1765pub fn factorize_symbolic_lu<I: Index>(A: SymbolicSparseColMatRef<'_, I>, params: LuSymbolicParams<'_>) -> Result<SymbolicLu<I>, FaerError> {
1766 assert!(A.nrows() == A.ncols());
1767 let m = A.nrows();
1768 let n = A.ncols();
1769 let A_nnz = A.compute_nnz();
1770
1771 with_dim!(M, m);
1772 with_dim!(N, n);
1773
1774 let A = A.as_shape(M, N);
1775
1776 let req = {
1777 let n_scratch = StackReq::new::<I>(n);
1778 let m_scratch = StackReq::new::<I>(m);
1779 let AT_scratch = StackReq::and(
1780 StackReq::new::<I>(m + 1),
1782 StackReq::new::<I>(A_nnz),
1784 );
1785
1786 StackReq::or(
1787 linalg_sp::colamd::order_scratch::<I>(m, n, A_nnz),
1788 StackReq::all_of(&[
1789 n_scratch,
1790 n_scratch,
1791 n_scratch,
1792 n_scratch,
1793 AT_scratch,
1794 StackReq::any_of(&[
1795 StackReq::and(n_scratch, m_scratch),
1796 StackReq::all_of(&[n_scratch; 3]),
1797 StackReq::all_of(&[n_scratch, n_scratch, n_scratch, n_scratch, n_scratch, m_scratch]),
1798 supernodal::factorize_supernodal_symbolic_lu_scratch::<I>(m, n),
1799 ]),
1800 ]),
1801 )
1802 };
1803
1804 let mut mem = dyn_stack::MemBuffer::try_new(req).ok().ok_or(FaerError::OutOfMemory)?;
1805 let stack = MemStack::new(&mut mem);
1806
1807 let mut col_perm_fwd = try_zeroed::<I>(n)?;
1808 let mut col_perm_inv = try_zeroed::<I>(n)?;
1809 let mut min_row = try_zeroed::<I>(m)?;
1810
1811 linalg_sp::colamd::order(&mut col_perm_fwd, &mut col_perm_inv, A.as_dyn(), params.colamd_params, stack)?;
1812
1813 let col_perm = PermRef::new_checked(&col_perm_fwd, &col_perm_inv, n).as_shape(N);
1814
1815 let (new_col_ptr, stack) = unsafe { stack.make_raw::<I>(m + 1) };
1816 let (new_row_idx, stack) = unsafe { stack.make_raw::<I>(A_nnz) };
1817 let AT = utils::adjoint(
1818 Symbolic::materialize(new_row_idx.len()),
1819 new_col_ptr,
1820 new_row_idx,
1821 SparseColMatRef::new(A, Symbolic::materialize(A.row_idx().len())),
1822 stack,
1823 )
1824 .symbolic();
1825
1826 let (etree, stack) = unsafe { stack.make_raw::<I::Signed>(n) };
1827 let (post, stack) = unsafe { stack.make_raw::<I>(n) };
1828 let (col_counts, stack) = unsafe { stack.make_raw::<I>(n) };
1829 let (h_col_counts, stack) = unsafe { stack.make_raw::<I>(n) };
1830
1831 linalg_sp::qr::ghost_col_etree(A, Some(col_perm), Array::from_mut(etree, N), stack);
1832 let etree_ = Array::from_ref(MaybeIdx::<'_, I>::from_slice_ref_checked(etree, N), N);
1833 linalg_sp::cholesky::ghost_postorder(Array::from_mut(post, N), etree_, stack);
1834
1835 linalg_sp::qr::ghost_column_counts_aat(
1836 Array::from_mut(col_counts, N),
1837 Array::from_mut(bytemuck::cast_slice_mut(&mut min_row), M),
1838 AT,
1839 Some(col_perm),
1840 etree_,
1841 Array::from_ref(Idx::from_slice_ref_checked(post, N), N),
1842 stack,
1843 );
1844 let min_col = min_row;
1845
1846 let mut threshold = params.supernodal_flop_ratio_threshold;
1847 if threshold != SupernodalThreshold::FORCE_SIMPLICIAL && threshold != SupernodalThreshold::FORCE_SUPERNODAL {
1848 h_col_counts.fill(I::truncate(0));
1849 for i in 0..m {
1850 let min_col = min_col[i];
1851 if min_col.to_signed() < I::Signed::truncate(0) {
1852 continue;
1853 }
1854 h_col_counts[min_col.zx()] += I::truncate(1);
1855 }
1856 for j in 0..n {
1857 let parent = etree[j];
1858 if parent < I::Signed::truncate(0) {
1859 continue;
1860 }
1861 h_col_counts[parent.zx()] += h_col_counts[j] - I::truncate(1);
1862 }
1863
1864 let mut nnz = 0.0f64;
1865 let mut flops = 0.0f64;
1866 for j in 0..n {
1867 let hj = h_col_counts[j].zx() as f64;
1868 let rj = col_counts[j].zx() as f64;
1869 flops += hj + hj * rj;
1870 nnz += hj + rj;
1871 }
1872
1873 if flops / nnz > threshold.0 * linalg_sp::LU_SUPERNODAL_RATIO_FACTOR {
1874 threshold = SupernodalThreshold::FORCE_SUPERNODAL;
1875 } else {
1876 threshold = SupernodalThreshold::FORCE_SIMPLICIAL;
1877 }
1878 }
1879
1880 if threshold == SupernodalThreshold::FORCE_SUPERNODAL {
1881 let symbolic = supernodal::factorize_supernodal_symbolic_lu::<I>(
1882 A.as_dyn(),
1883 Some(col_perm.as_shape(n)),
1884 &min_col,
1885 EliminationTreeRef::<'_, I> { inner: etree },
1886 col_counts,
1887 stack,
1888 params.supernodal_params,
1889 )?;
1890 Ok(SymbolicLu {
1891 raw: SymbolicLuRaw::Supernodal(symbolic),
1892 col_perm_fwd,
1893 col_perm_inv,
1894 A_nnz,
1895 })
1896 } else {
1897 Ok(SymbolicLu {
1898 raw: SymbolicLuRaw::Simplicial { nrows: m, ncols: n },
1899 col_perm_fwd,
1900 col_perm_inv,
1901 A_nnz,
1902 })
1903 }
1904}
1905
1906#[cfg(test)]
1907mod tests {
1908 use super::*;
1909 use crate::assert;
1910 use crate::stats::prelude::*;
1911 use dyn_stack::MemBuffer;
1912 use linalg_sp::cholesky::tests::load_mtx;
1913 use matrix_market_rs::MtxData;
1914 use std::path::PathBuf;
1915
1916 #[test]
1917 fn test_numeric_lu_multifrontal() {
1918 type T = c64;
1919
1920 let (m, n, col_ptr, row_idx, val) =
1921 load_mtx::<usize>(MtxData::from_file(PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test_data/sparse_lu/YAO.mtx")).unwrap());
1922
1923 let mut rng = StdRng::seed_from_u64(0);
1924 let mut gen = || T::new(rng.gen::<f64>(), rng.gen::<f64>());
1925
1926 let val = val.iter().map(|_| gen()).collect::<alloc::vec::Vec<_>>();
1927 let A = SparseColMatRef::<'_, usize, T>::new(SymbolicSparseColMatRef::new_checked(m, n, &col_ptr, None, &row_idx), &val);
1928 let mut row_perm = vec![0usize; n];
1929 let mut row_perm_inv = vec![0usize; n];
1930 let mut col_perm = vec![0usize; n];
1931 let mut col_perm_inv = vec![0usize; n];
1932 for i in 0..n {
1933 col_perm[i] = i;
1934 col_perm_inv[i] = i;
1935 }
1936 let col_perm = PermRef::<'_, usize>::new_checked(&col_perm, &col_perm_inv, n);
1937
1938 let mut etree = vec![0usize; n];
1939 let mut min_col = vec![0usize; m];
1940 let mut col_counts = vec![0usize; n];
1941
1942 let nnz = A.compute_nnz();
1943 let mut new_col_ptr = vec![0usize; m + 1];
1944 let mut new_row_idx = vec![0usize; nnz];
1945 let mut new_values = vec![zero::<T>(); nnz];
1946 let AT = utils::transpose(
1947 &mut *new_values,
1948 &mut new_col_ptr,
1949 &mut new_row_idx,
1950 A,
1951 MemStack::new(&mut MemBuffer::new(StackReq::new::<usize>(m))),
1952 )
1953 .into_const();
1954
1955 let etree = {
1956 let mut post = vec![0usize; n];
1957
1958 let etree = linalg_sp::qr::col_etree(
1959 A.symbolic(),
1960 Some(col_perm),
1961 &mut etree,
1962 MemStack::new(&mut MemBuffer::new(StackReq::new::<usize>(m + n))),
1963 );
1964 linalg_sp::qr::postorder(&mut post, etree, MemStack::new(&mut MemBuffer::new(StackReq::new::<usize>(3 * n))));
1965 linalg_sp::qr::column_counts_ata(
1966 &mut col_counts,
1967 &mut min_col,
1968 AT.symbolic(),
1969 Some(col_perm),
1970 etree,
1971 &post,
1972 MemStack::new(&mut MemBuffer::new(StackReq::new::<usize>(5 * n + m))),
1973 );
1974 etree
1975 };
1976
1977 let symbolic = linalg_sp::lu::supernodal::factorize_supernodal_symbolic_lu::<usize>(
1978 A.symbolic(),
1979 Some(col_perm),
1980 &min_col,
1981 etree,
1982 &col_counts,
1983 MemStack::new(&mut MemBuffer::new(super::supernodal::factorize_supernodal_symbolic_lu_scratch::<usize>(
1984 m, n,
1985 ))),
1986 linalg_sp::SymbolicSupernodalParams {
1987 relax: Some(&[(4, 1.0), (16, 0.8), (48, 0.1), (usize::MAX, 0.05)]),
1988 },
1989 )
1990 .unwrap();
1991
1992 let mut lu = supernodal::SupernodalLu::<usize, T>::new();
1993 supernodal::factorize_supernodal_numeric_lu(
1994 &mut row_perm,
1995 &mut row_perm_inv,
1996 &mut lu,
1997 A,
1998 AT,
1999 col_perm,
2000 &symbolic,
2001 Par::Seq,
2002 MemStack::new(&mut MemBuffer::new(supernodal::factorize_supernodal_numeric_lu_scratch::<usize, T>(
2003 &symbolic,
2004 Default::default(),
2005 ))),
2006 Default::default(),
2007 )
2008 .unwrap();
2009
2010 let k = 2;
2011 let rhs = Mat::from_fn(n, k, |_, _| gen());
2012
2013 let mut work = rhs.clone();
2014 let A_dense = A.to_dense();
2015 let row_perm = PermRef::<'_, _>::new_checked(&row_perm, &row_perm_inv, m);
2016
2017 {
2018 let mut x = rhs.clone();
2019
2020 lu.solve_in_place_with_conj(row_perm, col_perm, Conj::No, x.as_mut(), Par::Seq, work.as_mut());
2021 assert!((&A_dense * &x - &rhs).norm_max() < 1e-10);
2022 }
2023 {
2024 let mut x = rhs.clone();
2025
2026 lu.solve_in_place_with_conj(row_perm, col_perm, Conj::Yes, x.as_mut(), Par::Seq, work.as_mut());
2027 assert!((A_dense.conjugate() * &x - &rhs).norm_max() < 1e-10);
2028 }
2029 {
2030 let mut x = rhs.clone();
2031
2032 lu.solve_transpose_in_place_with_conj(row_perm, col_perm, Conj::No, x.as_mut(), Par::Seq, work.as_mut());
2033 assert!((A_dense.transpose() * &x - &rhs).norm_max() < 1e-10);
2034 }
2035 {
2036 let mut x = rhs.clone();
2037
2038 lu.solve_transpose_in_place_with_conj(row_perm, col_perm, Conj::Yes, x.as_mut(), Par::Seq, work.as_mut());
2039 assert!((A_dense.adjoint() * &x - &rhs).norm_max() < 1e-10);
2040 }
2041 }
2042
2043 #[test]
2044 fn test_numeric_lu_simplicial() {
2045 type T = c64;
2046
2047 let (m, n, col_ptr, row_idx, val) =
2048 load_mtx::<usize>(MtxData::from_file(PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test_data/sparse_lu/YAO.mtx")).unwrap());
2049
2050 let mut rng = StdRng::seed_from_u64(0);
2051 let mut gen = || T::new(rng.gen::<f64>(), rng.gen::<f64>());
2052
2053 let val = val.iter().map(|_| gen()).collect::<alloc::vec::Vec<_>>();
2054 let A = SparseColMatRef::<'_, usize, T>::new(SymbolicSparseColMatRef::new_checked(m, n, &col_ptr, None, &row_idx), &val);
2055 let mut row_perm = vec![0usize; n];
2056 let mut row_perm_inv = vec![0usize; n];
2057 let mut col_perm = vec![0usize; n];
2058 let mut col_perm_inv = vec![0usize; n];
2059 for i in 0..n {
2060 col_perm[i] = i;
2061 col_perm_inv[i] = i;
2062 }
2063 let col_perm = PermRef::<'_, usize>::new_checked(&col_perm, &col_perm_inv, n);
2064
2065 let mut lu = simplicial::SimplicialLu::<usize, T>::new();
2066 simplicial::factorize_simplicial_numeric_lu(
2067 &mut row_perm,
2068 &mut row_perm_inv,
2069 &mut lu,
2070 A,
2071 col_perm,
2072 MemStack::new(&mut MemBuffer::new(simplicial::factorize_simplicial_numeric_lu_scratch::<usize, T>(m, n))),
2073 )
2074 .unwrap();
2075
2076 let k = 1;
2077 let rhs = Mat::from_fn(n, k, |_, _| gen());
2078
2079 let mut work = rhs.clone();
2080 let A_dense = A.to_dense();
2081 let row_perm = PermRef::<'_, _>::new_checked(&row_perm, &row_perm_inv, m);
2082
2083 {
2084 let mut x = rhs.clone();
2085
2086 lu.solve_in_place_with_conj(row_perm, col_perm, Conj::No, x.as_mut(), Par::Seq, work.as_mut());
2087 assert!((&A_dense * &x - &rhs).norm_max() < 1e-10);
2088 }
2089 {
2090 let mut x = rhs.clone();
2091
2092 lu.solve_in_place_with_conj(row_perm, col_perm, Conj::Yes, x.as_mut(), Par::Seq, work.as_mut());
2093 assert!((A_dense.conjugate() * &x - &rhs).norm_max() < 1e-10);
2094 }
2095
2096 {
2097 let mut x = rhs.clone();
2098
2099 lu.solve_transpose_in_place_with_conj(row_perm, col_perm, Conj::No, x.as_mut(), Par::Seq, work.as_mut());
2100 assert!((A_dense.transpose() * &x - &rhs).norm_max() < 1e-10);
2101 }
2102 {
2103 let mut x = rhs.clone();
2104
2105 lu.solve_transpose_in_place_with_conj(row_perm, col_perm, Conj::Yes, x.as_mut(), Par::Seq, work.as_mut());
2106 assert!((A_dense.adjoint() * &x - &rhs).norm_max() < 1e-10);
2107 }
2108 }
2109
2110 #[test]
2111 fn test_solver_lu_simplicial() {
2112 type T = c64;
2113
2114 let (m, n, col_ptr, row_idx, val) =
2115 load_mtx::<usize>(MtxData::from_file(PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test_data/sparse_lu/YAO.mtx")).unwrap());
2116
2117 let mut rng = StdRng::seed_from_u64(0);
2118 let mut gen = || T::new(rng.gen::<f64>(), rng.gen::<f64>());
2119
2120 let val = val.iter().map(|_| gen()).collect::<alloc::vec::Vec<_>>();
2121 let A = SparseColMatRef::<'_, usize, T>::new(SymbolicSparseColMatRef::new_checked(m, n, &col_ptr, None, &row_idx), &val);
2122
2123 let rhs = Mat::<T>::from_fn(m, 6, |_, _| gen());
2124
2125 for supernodal_flop_ratio_threshold in [
2126 SupernodalThreshold::AUTO,
2127 SupernodalThreshold::FORCE_SUPERNODAL,
2128 SupernodalThreshold::FORCE_SIMPLICIAL,
2129 ] {
2130 let symbolic = factorize_symbolic_lu(
2131 A.symbolic(),
2132 LuSymbolicParams {
2133 supernodal_flop_ratio_threshold,
2134 ..Default::default()
2135 },
2136 )
2137 .unwrap();
2138 let mut numeric = NumericLu::<usize, T>::new();
2139 let lu = symbolic
2140 .factorize_numeric_lu(
2141 &mut numeric,
2142 A,
2143 Par::Seq,
2144 MemStack::new(&mut MemBuffer::new(
2145 symbolic.factorize_numeric_lu_scratch::<T>(Par::Seq, Default::default()),
2146 )),
2147 Default::default(),
2148 )
2149 .unwrap();
2150
2151 {
2152 let mut x = rhs.clone();
2153 lu.solve_in_place_with_conj(
2154 crate::Conj::No,
2155 x.as_mut(),
2156 Par::Seq,
2157 MemStack::new(&mut MemBuffer::new(symbolic.solve_in_place_scratch::<T>(rhs.ncols(), Par::Seq))),
2158 );
2159
2160 let linsolve_diff = A * &x - &rhs;
2161 assert!(linsolve_diff.norm_max() <= 1e-10);
2162 }
2163 {
2164 let mut x = rhs.clone();
2165 lu.solve_in_place_with_conj(
2166 crate::Conj::Yes,
2167 x.as_mut(),
2168 Par::Seq,
2169 MemStack::new(&mut MemBuffer::new(symbolic.solve_in_place_scratch::<T>(rhs.ncols(), Par::Seq))),
2170 );
2171
2172 let linsolve_diff = A.conjugate() * &x - &rhs;
2173 assert!(linsolve_diff.norm_max() <= 1e-10);
2174 }
2175
2176 {
2177 let mut x = rhs.clone();
2178 lu.solve_transpose_in_place_with_conj(
2179 crate::Conj::No,
2180 x.as_mut(),
2181 Par::Seq,
2182 MemStack::new(&mut MemBuffer::new(symbolic.solve_transpose_in_place_scratch::<T>(rhs.ncols(), Par::Seq))),
2183 );
2184
2185 let linsolve_diff = A.transpose() * &x - &rhs;
2186 assert!(linsolve_diff.norm_max() <= 1e-10);
2187 }
2188 {
2189 let mut x = rhs.clone();
2190 lu.solve_transpose_in_place_with_conj(
2191 crate::Conj::Yes,
2192 x.as_mut(),
2193 Par::Seq,
2194 MemStack::new(&mut MemBuffer::new(symbolic.solve_transpose_in_place_scratch::<T>(rhs.ncols(), Par::Seq))),
2195 );
2196
2197 let linsolve_diff = A.adjoint() * &x - &rhs;
2198 assert!(linsolve_diff.norm_max() <= 1e-10);
2199 }
2200 }
2201 }
2202}