faer_sparse/
cholesky.rs

1//! Computes the Cholesky decomposition (either LLT, LDLT, or Bunch-Kaufman) of a given sparse
2//! matrix. See [`faer_cholesky`] for more info.
3//!
4//! The entry point in this module is [`SymbolicCholesky`] and [`factorize_symbolic_cholesky`].
5//!
6//! # Note
7//! The functions in this module accept unsorted input, producing a sorted decomposition factor
8//! (simplicial).
9
10// implementation inspired by https://gitlab.com/hodge_star/catamari
11
12use crate::{
13    amd::{self, Control},
14    ghost::{self, Array, Idx, MaybeIdx},
15    ghost_permute_hermitian_unsorted, ghost_permute_hermitian_unsorted_symbolic, make_raw_req, mem,
16    mem::NONE,
17    nomem, triangular_solve, try_collect, try_zeroed, windows2, FaerError, Index, PermutationRef,
18    Side, SliceGroup, SliceGroupMut, SparseColMatRef, SupernodalThreshold, SymbolicSparseColMatRef,
19    SymbolicSupernodalParams,
20};
21use core::{cell::Cell, iter::zip};
22use dyn_stack::{PodStack, SizeOverflow, StackReq};
23pub use faer_cholesky::{
24    bunch_kaufman::compute::BunchKaufmanRegularization,
25    ldlt_diagonal::compute::LdltRegularization,
26    llt::{compute::LltRegularization, CholeskyError},
27};
28use faer_core::{
29    assert, permutation::SignedIndex, temp_mat_req, temp_mat_uninit, unzipped, zipped,
30    ComplexField, Conj, Entity, MatMut, MatRef, Parallelism,
31};
32use faer_entity::{GroupFor, Symbolic};
33use reborrow::*;
34
35#[derive(Copy, Clone)]
36#[allow(dead_code)]
37enum Ordering<'a, I> {
38    Identity,
39    Custom(&'a [I]),
40    Algorithm(
41        &'a dyn Fn(
42            &mut [I],                       // perm
43            &mut [I],                       // perm_inv
44            SymbolicSparseColMatRef<'_, I>, // A
45            PodStack<'_>,
46        ) -> Result<(), FaerError>,
47    ),
48}
49
50pub mod simplicial {
51    use super::*;
52    use faer_core::assert;
53
54    /// Computes the elimination tree and column counts of the Cholesky factorization of the matrix
55    /// `A`.
56    pub fn prefactorize_symbolic_cholesky<'out, I: Index>(
57        etree: &'out mut [I::Signed],
58        col_counts: &mut [I],
59        A: SymbolicSparseColMatRef<'_, I>,
60        stack: PodStack<'_>,
61    ) -> EliminationTreeRef<'out, I> {
62        let n = A.nrows();
63        assert!(A.nrows() == A.ncols());
64        assert!(etree.len() == n);
65        assert!(col_counts.len() == n);
66
67        ghost::with_size(n, |N| {
68            ghost_prefactorize_symbolic_cholesky(
69                Array::from_mut(etree, N),
70                Array::from_mut(col_counts, N),
71                ghost::SymbolicSparseColMatRef::new(A, N, N),
72                stack,
73            );
74        });
75
76        simplicial::EliminationTreeRef { inner: etree }
77    }
78
79    fn ereach<'n, 'a, I: Index>(
80        stack: &'a mut Array<'n, I>,
81        A: ghost::SymbolicSparseColMatRef<'n, 'n, '_, I>,
82        etree: &Array<'n, MaybeIdx<'n, I>>,
83        k: Idx<'n, usize>,
84        visited: &mut Array<'n, I::Signed>,
85    ) -> &'a [Idx<'n, I>] {
86        let N = A.ncols();
87
88        // invariant: stack[top..] elements are less than or equal to k
89        let mut top = *N;
90        let k_: I = *k.truncate();
91        visited[k] = k_.to_signed();
92        for mut i in A.row_indices_of_col(k) {
93            // (1): after this, we know i < k
94            if i >= k {
95                continue;
96            }
97            // invariant: stack[..len] elements are less than or equal to k
98            let mut len = 0usize;
99            loop {
100                if visited[i] == k_.to_signed() {
101                    break;
102                }
103
104                // inserted element is i < k, see (1)
105                let pushed: Idx<'n, I> = i.truncate::<I>();
106                stack[N.check(len)] = *pushed;
107                // len is incremented, maintaining the invariant
108                len += 1;
109
110                visited[i] = k_.to_signed();
111                i = N.check(etree[i].into_inner().zx());
112            }
113
114            // because stack[..len] elements are less than or equal to k
115            // stack[top - len..] elements are now less than or equal to k
116            stack.as_mut().copy_within(..len, top - len);
117            // top is decremented by len, maintaining the invariant
118            top -= len;
119        }
120
121        let stack = &stack.as_ref()[top..];
122
123        // SAFETY: stack[top..] elements are < k < N
124        unsafe { Idx::from_slice_ref_unchecked(stack) }
125    }
126
127    pub fn factorize_simplicial_symbolic_req<I: Index>(n: usize) -> Result<StackReq, SizeOverflow> {
128        let n_req = StackReq::try_new::<I>(n)?;
129        StackReq::try_all_of([n_req, n_req, n_req])
130    }
131
132    pub fn factorize_simplicial_symbolic_cholesky<I: Index>(
133        A: SymbolicSparseColMatRef<'_, I>,
134        etree: EliminationTreeRef<'_, I>,
135        col_counts: &[I],
136        stack: PodStack<'_>,
137    ) -> Result<SymbolicSimplicialCholesky<I>, FaerError> {
138        let n = A.nrows();
139        assert!(A.nrows() == A.ncols());
140        assert!(etree.inner.len() == n);
141        assert!(col_counts.len() == n);
142
143        ghost::with_size(n, |N| {
144            ghost_factorize_simplicial_symbolic_cholesky(
145                ghost::SymbolicSparseColMatRef::new(A, N, N),
146                etree.ghost_inner(N),
147                Array::from_ref(col_counts, N),
148                stack,
149            )
150        })
151    }
152
153    pub(crate) fn ghost_factorize_simplicial_symbolic_cholesky<'n, I: Index>(
154        A: ghost::SymbolicSparseColMatRef<'n, 'n, '_, I>,
155        etree: &Array<'n, MaybeIdx<'n, I>>,
156        col_counts: &Array<'n, I>,
157        stack: PodStack<'_>,
158    ) -> Result<SymbolicSimplicialCholesky<I>, FaerError> {
159        let N = A.ncols();
160        let n = *N;
161
162        let mut L_col_ptrs = try_zeroed::<I>(n + 1)?;
163        for (&count, [p, p_next]) in zip(
164            col_counts.as_ref(),
165            windows2(Cell::as_slice_of_cells(Cell::from_mut(&mut L_col_ptrs))),
166        ) {
167            p_next.set(p.get() + count);
168        }
169        let l_nnz = L_col_ptrs[n].zx();
170        let mut L_row_ind = try_zeroed::<I>(l_nnz)?;
171
172        ghost::with_size(
173            l_nnz,
174            #[inline(always)]
175            move |L_NNZ| {
176                let (current_row_index, stack) = stack.make_raw::<I>(n);
177                let (ereach_stack, stack) = stack.make_raw::<I>(n);
178                let (visited, _) = stack.make_raw::<I::Signed>(n);
179
180                let ereach_stack = Array::from_mut(ereach_stack, N);
181                let visited = Array::from_mut(visited, N);
182
183                mem::fill_none(visited.as_mut());
184                let L_row_indices = Array::from_mut(&mut L_row_ind, L_NNZ);
185                let L_col_ptrs_start =
186                    Array::from_ref(Idx::from_slice_ref_checked(&L_col_ptrs[..n], L_NNZ), N);
187                let current_row_index = Array::from_mut(
188                    ghost::copy_slice(current_row_index, L_col_ptrs_start.as_ref()),
189                    N,
190                );
191
192                for k in N.indices() {
193                    let reach = ereach(ereach_stack, A, etree, k, visited);
194                    for &j in reach {
195                        let j = j.zx();
196                        let cj = &mut current_row_index[j];
197                        let row_idx = L_NNZ.check(*cj.zx() + 1);
198                        *cj = row_idx.truncate();
199                        L_row_indices[row_idx] = *k.truncate();
200                    }
201                    let k_start = L_col_ptrs_start[k].zx();
202                    L_row_indices[k_start] = *k.truncate();
203                }
204
205                let etree = try_collect(
206                    bytemuck::cast_slice::<I::Signed, I>(MaybeIdx::as_slice_ref(etree.as_ref()))
207                        .iter()
208                        .copied(),
209                )?;
210
211                let _ = SymbolicSparseColMatRef::new_unsorted_checked(
212                    n,
213                    n,
214                    &L_col_ptrs,
215                    None,
216                    &L_row_ind,
217                );
218
219                Ok(SymbolicSimplicialCholesky {
220                    dimension: n,
221                    col_ptrs: L_col_ptrs,
222                    row_indices: L_row_ind,
223                    etree,
224                })
225            },
226        )
227    }
228
229    #[derive(Copy, Clone, Debug)]
230    enum FactorizationKind {
231        Llt,
232        Ldlt,
233    }
234
235    fn factorize_simplicial_numeric_with_row_indices<I: Index, E: ComplexField>(
236        L_values: GroupFor<E, &mut [E::Unit]>,
237        L_row_indices: &mut [I],
238        L_col_ptrs: &[I],
239        kind: FactorizationKind,
240
241        etree: EliminationTreeRef<'_, I>,
242        A: SparseColMatRef<'_, I, E>,
243        regularization: LdltRegularization<'_, E>,
244
245        stack: PodStack<'_>,
246    ) -> Result<usize, CholeskyError> {
247        let n = A.ncols();
248        {
249            let L_values = SliceGroup::<'_, E>::new(E::faer_rb(E::faer_as_ref(&L_values)));
250            assert!(L_values.rb().len() == L_row_indices.len());
251        }
252        assert!(L_col_ptrs.len() == n + 1);
253        assert!(etree.into_inner().len() == n);
254        let l_nnz = L_col_ptrs[n].zx();
255
256        ghost::with_size(
257            n,
258            #[inline(always)]
259            |N| {
260                let etree = etree.ghost_inner(N);
261                let A = ghost::SparseColMatRef::new(A, N, N);
262
263                ghost::with_size(
264                    l_nnz,
265                    #[inline(always)]
266                    move |L_NNZ| {
267                        let eps = regularization.dynamic_regularization_epsilon.faer_abs();
268                        let delta = regularization.dynamic_regularization_delta.faer_abs();
269                        let has_delta = delta > E::Real::faer_zero();
270                        let mut dynamic_regularization_count = 0usize;
271
272                        let (x, stack) = crate::make_raw::<E>(n, stack);
273                        let (current_row_index, stack) = stack.make_raw::<I>(n);
274                        let (ereach_stack, stack) = stack.make_raw::<I>(n);
275                        let (visited, _) = stack.make_raw::<I::Signed>(n);
276
277                        let ereach_stack = Array::from_mut(ereach_stack, N);
278                        let visited = Array::from_mut(visited, N);
279                        let mut x = ghost::ArrayGroupMut::<'_, '_, E>::new(x.into_inner(), N);
280
281                        SliceGroupMut::<'_, E>::new(x.rb_mut().into_slice()).fill_zero();
282                        mem::fill_none(visited.as_mut());
283
284                        let mut L_values = ghost::ArrayGroupMut::<'_, '_, E>::new(L_values, L_NNZ);
285                        let L_row_indices = Array::from_mut(L_row_indices, L_NNZ);
286
287                        let L_col_ptrs_start = Array::from_ref(
288                            Idx::from_slice_ref_checked(&L_col_ptrs[..n], L_NNZ),
289                            N,
290                        );
291
292                        let current_row_index = Array::from_mut(
293                            ghost::copy_slice(current_row_index, L_col_ptrs_start.as_ref()),
294                            N,
295                        );
296
297                        for k in N.indices() {
298                            let reach = ereach(ereach_stack, *A, etree, k, visited);
299
300                            for (i, aik) in zip(
301                                A.row_indices_of_col(k),
302                                SliceGroup::<'_, E>::new(A.values_of_col(k)).into_ref_iter(),
303                            ) {
304                                x.write(i, x.read(i).faer_add(aik.read().faer_conj()));
305                            }
306
307                            let mut d = x.read(k).faer_real();
308                            x.write(k, E::faer_zero());
309
310                            for &j in reach {
311                                let j = j.zx();
312
313                                let j_start = L_col_ptrs_start[j].zx();
314                                let cj = &mut current_row_index[j];
315                                let row_idx = L_NNZ.check(*cj.zx() + 1);
316                                *cj = row_idx.truncate();
317
318                                let mut xj = x.read(j);
319                                x.write(j, E::faer_zero());
320
321                                let dj = L_values.read(j_start).faer_real();
322                                let lkj = xj.faer_scale_real(dj.faer_inv());
323                                if matches!(kind, FactorizationKind::Llt) {
324                                    xj = lkj;
325                                }
326
327                                let range = j_start.next()..row_idx.to_inclusive();
328                                for (i, lij) in zip(
329                                    &L_row_indices[range.clone()],
330                                    SliceGroup::<'_, E>::new(L_values.rb().subslice(range))
331                                        .into_ref_iter(),
332                                ) {
333                                    let i = N.check(i.zx());
334                                    let mut xi = x.read(i);
335                                    let prod = lij.read().faer_conj().faer_mul(xj);
336                                    xi = xi.faer_sub(prod);
337                                    x.write(i, xi);
338                                }
339
340                                d = d.faer_sub(lkj.faer_mul(xj.faer_conj()).faer_real());
341
342                                L_row_indices[row_idx] = *k.truncate();
343                                L_values.write(row_idx, lkj);
344                            }
345
346                            let k_start = L_col_ptrs_start[k].zx();
347                            L_row_indices[k_start] = *k.truncate();
348
349                            if has_delta {
350                                match kind {
351                                    FactorizationKind::Llt => {
352                                        if d <= eps {
353                                            d = delta;
354                                            dynamic_regularization_count += 1;
355                                        }
356                                    }
357                                    FactorizationKind::Ldlt => {
358                                        if let Some(signs) =
359                                            regularization.dynamic_regularization_signs
360                                        {
361                                            if signs[*k] > 0 && d <= eps {
362                                                d = delta;
363                                                dynamic_regularization_count += 1;
364                                            } else if signs[*k] < 0 && d >= eps.faer_neg() {
365                                                d = delta.faer_neg();
366                                                dynamic_regularization_count += 1;
367                                            }
368                                        } else if d.faer_abs() <= eps {
369                                            if d < E::Real::faer_zero() {
370                                                d = delta.faer_neg();
371                                                dynamic_regularization_count += 1;
372                                            } else {
373                                                d = delta;
374                                                dynamic_regularization_count += 1;
375                                            }
376                                        }
377                                    }
378                                }
379                            }
380
381                            match kind {
382                                FactorizationKind::Llt => {
383                                    if d <= E::Real::faer_zero() {
384                                        return Err(CholeskyError {
385                                            non_positive_definite_minor: *k + 1,
386                                        });
387                                    }
388                                    L_values.write(k_start, E::faer_from_real(d.faer_sqrt()));
389                                }
390                                FactorizationKind::Ldlt => {
391                                    L_values.write(k_start, E::faer_from_real(d));
392                                }
393                            }
394                        }
395                        Ok(dynamic_regularization_count)
396                    },
397                )
398            },
399        )
400    }
401
402    fn factorize_simplicial_numeric<I: Index, E: ComplexField>(
403        L_values: GroupFor<E, &mut [E::Unit]>,
404        kind: FactorizationKind,
405        A: SparseColMatRef<'_, I, E>,
406        regularization: LdltRegularization<'_, E>,
407        symbolic: &SymbolicSimplicialCholesky<I>,
408        stack: PodStack<'_>,
409    ) -> Result<usize, CholeskyError> {
410        let n = A.ncols();
411        let L_row_indices = &*symbolic.row_indices;
412        let L_col_ptrs = &*symbolic.col_ptrs;
413        let etree = &*symbolic.etree;
414
415        {
416            let L_values = SliceGroup::<'_, E>::new(E::faer_rb(E::faer_as_ref(&L_values)));
417            assert!(L_values.rb().len() == L_row_indices.len());
418        }
419        assert!(L_col_ptrs.len() == n + 1);
420        let l_nnz = L_col_ptrs[n].zx();
421
422        ghost::with_size(
423            n,
424            #[inline(always)]
425            |N| {
426                ghost::with_size(
427                    l_nnz,
428                    #[inline(always)]
429                    move |L_NNZ| {
430                        let etree = Array::from_ref(
431                            MaybeIdx::from_slice_ref_checked(
432                                bytemuck::cast_slice::<I, I::Signed>(etree),
433                                N,
434                            ),
435                            N,
436                        );
437                        let A = ghost::SparseColMatRef::new(A, N, N);
438
439                        let eps = regularization.dynamic_regularization_epsilon.faer_abs();
440                        let delta = regularization.dynamic_regularization_delta.faer_abs();
441                        let has_delta = delta > E::Real::faer_zero();
442                        let mut dynamic_regularization_count = 0usize;
443
444                        let (x, stack) = crate::make_raw::<E>(n, stack);
445                        let (current_row_index, stack) = stack.make_raw::<I>(n);
446                        let (ereach_stack, stack) = stack.make_raw::<I>(n);
447                        let (visited, _) = stack.make_raw::<I::Signed>(n);
448
449                        let ereach_stack = Array::from_mut(ereach_stack, N);
450                        let visited = Array::from_mut(visited, N);
451                        let mut x = ghost::ArrayGroupMut::<'_, '_, E>::new(x.into_inner(), N);
452
453                        SliceGroupMut::<'_, E>::new(x.rb_mut().into_slice()).fill_zero();
454                        mem::fill_none(visited.as_mut());
455
456                        let mut L_values = ghost::ArrayGroupMut::<'_, '_, E>::new(L_values, L_NNZ);
457                        let L_row_indices = Array::from_ref(L_row_indices, L_NNZ);
458
459                        let L_col_ptrs_start = Array::from_ref(
460                            Idx::from_slice_ref_checked(&L_col_ptrs[..n], L_NNZ),
461                            N,
462                        );
463
464                        let current_row_index = Array::from_mut(
465                            ghost::copy_slice(current_row_index, L_col_ptrs_start.as_ref()),
466                            N,
467                        );
468
469                        for k in N.indices() {
470                            let reach = ereach(ereach_stack, *A, etree, k, visited);
471
472                            for (i, aik) in zip(
473                                A.row_indices_of_col(k),
474                                SliceGroup::<'_, E>::new(A.values_of_col(k)).into_ref_iter(),
475                            ) {
476                                x.write(i, x.read(i).faer_add(aik.read().faer_conj()));
477                            }
478
479                            let mut d = x.read(k).faer_real();
480                            x.write(k, E::faer_zero());
481
482                            for &j in reach {
483                                let j = j.zx();
484
485                                let j_start = L_col_ptrs_start[j].zx();
486                                let cj = &mut current_row_index[j];
487                                let row_idx = L_NNZ.check(*cj.zx() + 1);
488                                *cj = row_idx.truncate();
489
490                                let mut xj = x.read(j);
491                                x.write(j, E::faer_zero());
492
493                                let dj = L_values.read(j_start).faer_real();
494                                let lkj = xj.faer_scale_real(dj.faer_inv());
495                                if matches!(kind, FactorizationKind::Llt) {
496                                    xj = lkj;
497                                }
498
499                                let range = j_start.next()..row_idx.to_inclusive();
500                                for (i, lij) in zip(
501                                    &L_row_indices[range.clone()],
502                                    SliceGroup::<'_, E>::new(L_values.rb().subslice(range))
503                                        .into_ref_iter(),
504                                ) {
505                                    let i = N.check(i.zx());
506                                    let mut xi = x.read(i);
507                                    let prod = lij.read().faer_conj().faer_mul(xj);
508                                    xi = xi.faer_sub(prod);
509                                    x.write(i, xi);
510                                }
511
512                                d = d.faer_sub(lkj.faer_mul(xj.faer_conj()).faer_real());
513
514                                L_values.write(row_idx, lkj);
515                            }
516
517                            let k_start = L_col_ptrs_start[k].zx();
518
519                            if has_delta {
520                                match kind {
521                                    FactorizationKind::Llt => {
522                                        if d <= eps {
523                                            d = delta;
524                                            dynamic_regularization_count += 1;
525                                        }
526                                    }
527                                    FactorizationKind::Ldlt => {
528                                        if let Some(signs) =
529                                            regularization.dynamic_regularization_signs
530                                        {
531                                            if signs[*k] > 0 && d <= eps {
532                                                d = delta;
533                                                dynamic_regularization_count += 1;
534                                            } else if signs[*k] < 0 && d >= eps.faer_neg() {
535                                                d = delta.faer_neg();
536                                                dynamic_regularization_count += 1;
537                                            }
538                                        } else if d.faer_abs() <= eps {
539                                            if d < E::Real::faer_zero() {
540                                                d = delta.faer_neg();
541                                                dynamic_regularization_count += 1;
542                                            } else {
543                                                d = delta;
544                                                dynamic_regularization_count += 1;
545                                            }
546                                        }
547                                    }
548                                }
549                            }
550
551                            match kind {
552                                FactorizationKind::Llt => {
553                                    if d <= E::Real::faer_zero() {
554                                        return Err(CholeskyError {
555                                            non_positive_definite_minor: *k + 1,
556                                        });
557                                    }
558                                    L_values.write(k_start, E::faer_from_real(d.faer_sqrt()));
559                                }
560                                FactorizationKind::Ldlt => {
561                                    L_values.write(k_start, E::faer_from_real(d));
562                                }
563                            }
564                        }
565                        Ok(dynamic_regularization_count)
566                    },
567                )
568            },
569        )
570    }
571
572    pub fn factorize_simplicial_numeric_llt<I: Index, E: ComplexField>(
573        L_values: GroupFor<E, &mut [E::Unit]>,
574        A: SparseColMatRef<'_, I, E>,
575        regularization: LltRegularization<E>,
576        symbolic: &SymbolicSimplicialCholesky<I>,
577        stack: PodStack<'_>,
578    ) -> Result<usize, CholeskyError> {
579        factorize_simplicial_numeric(
580            L_values,
581            FactorizationKind::Llt,
582            A,
583            LdltRegularization {
584                dynamic_regularization_signs: None,
585                dynamic_regularization_delta: regularization.dynamic_regularization_delta,
586                dynamic_regularization_epsilon: regularization.dynamic_regularization_epsilon,
587            },
588            symbolic,
589            stack,
590        )
591    }
592
593    pub fn factorize_simplicial_numeric_llt_with_row_indices<I: Index, E: ComplexField>(
594        L_values: GroupFor<E, &mut [E::Unit]>,
595        L_row_indices: &mut [I],
596        L_col_ptrs: &[I],
597
598        etree: EliminationTreeRef<'_, I>,
599        A: SparseColMatRef<'_, I, E>,
600        regularization: LltRegularization<E>,
601
602        stack: PodStack<'_>,
603    ) -> Result<usize, CholeskyError> {
604        factorize_simplicial_numeric_with_row_indices(
605            L_values,
606            L_row_indices,
607            L_col_ptrs,
608            FactorizationKind::Ldlt,
609            etree,
610            A,
611            LdltRegularization {
612                dynamic_regularization_signs: None,
613                dynamic_regularization_delta: regularization.dynamic_regularization_delta,
614                dynamic_regularization_epsilon: regularization.dynamic_regularization_epsilon,
615            },
616            stack,
617        )
618    }
619
620    pub fn factorize_simplicial_numeric_ldlt<I: Index, E: ComplexField>(
621        L_values: GroupFor<E, &mut [E::Unit]>,
622        A: SparseColMatRef<'_, I, E>,
623        regularization: LdltRegularization<'_, E>,
624        symbolic: &SymbolicSimplicialCholesky<I>,
625        stack: PodStack<'_>,
626    ) -> usize {
627        factorize_simplicial_numeric(
628            L_values,
629            FactorizationKind::Ldlt,
630            A,
631            regularization,
632            symbolic,
633            stack,
634        )
635        .unwrap()
636    }
637
638    pub fn factorize_simplicial_numeric_ldlt_with_row_indices<I: Index, E: ComplexField>(
639        L_values: GroupFor<E, &mut [E::Unit]>,
640        L_row_indices: &mut [I],
641        L_col_ptrs: &[I],
642
643        etree: EliminationTreeRef<'_, I>,
644        A: SparseColMatRef<'_, I, E>,
645        regularization: LdltRegularization<'_, E>,
646
647        stack: PodStack<'_>,
648    ) -> usize {
649        factorize_simplicial_numeric_with_row_indices(
650            L_values,
651            L_row_indices,
652            L_col_ptrs,
653            FactorizationKind::Ldlt,
654            etree,
655            A,
656            regularization,
657            stack,
658        )
659        .unwrap()
660    }
661
662    impl<'a, I: Index, E: Entity> SimplicialLltRef<'a, I, E> {
663        #[inline]
664        pub fn new(
665            symbolic: &'a SymbolicSimplicialCholesky<I>,
666            values: GroupFor<E, &'a [E::Unit]>,
667        ) -> Self {
668            let values = SliceGroup::new(values);
669            assert!(values.len() == symbolic.len_values());
670            Self { symbolic, values }
671        }
672
673        #[inline]
674        pub fn symbolic(self) -> &'a SymbolicSimplicialCholesky<I> {
675            self.symbolic
676        }
677
678        #[inline]
679        pub fn values(self) -> GroupFor<E, &'a [E::Unit]> {
680            self.values.into_inner()
681        }
682
683        pub fn solve_in_place_with_conj(
684            &self,
685            conj: Conj,
686            rhs: MatMut<'_, E>,
687            parallelism: Parallelism,
688            stack: PodStack<'_>,
689        ) where
690            E: ComplexField,
691        {
692            let _ = parallelism;
693            let _ = stack;
694            let n = self.symbolic().nrows();
695            assert!(rhs.nrows() == n);
696            let l = SparseColMatRef::<'_, I, E>::new(self.symbolic().ld_factors(), self.values());
697
698            let mut rhs = rhs;
699            triangular_solve::solve_lower_triangular_in_place(l, conj, rhs.rb_mut(), parallelism);
700            triangular_solve::solve_lower_triangular_transpose_in_place(
701                l,
702                conj.compose(Conj::Yes),
703                rhs.rb_mut(),
704                parallelism,
705            );
706        }
707    }
708
709    impl<'a, I: Index, E: Entity> SimplicialLdltRef<'a, I, E> {
710        #[inline]
711        pub fn new(
712            symbolic: &'a SymbolicSimplicialCholesky<I>,
713            values: GroupFor<E, &'a [E::Unit]>,
714        ) -> Self {
715            let values = SliceGroup::new(values);
716            assert!(values.len() == symbolic.len_values());
717            Self { symbolic, values }
718        }
719
720        #[inline]
721        pub fn symbolic(self) -> &'a SymbolicSimplicialCholesky<I> {
722            self.symbolic
723        }
724
725        #[inline]
726        pub fn values(self) -> GroupFor<E, &'a [E::Unit]> {
727            self.values.into_inner()
728        }
729
730        pub fn solve_in_place_with_conj(
731            &self,
732            conj: Conj,
733            rhs: MatMut<'_, E>,
734            parallelism: Parallelism,
735            stack: PodStack<'_>,
736        ) where
737            E: ComplexField,
738        {
739            let _ = parallelism;
740            let _ = stack;
741            let n = self.symbolic().nrows();
742            let ld = SparseColMatRef::<'_, I, E>::new(self.symbolic().ld_factors(), self.values());
743            assert!(rhs.nrows() == n);
744
745            let slice_group = SliceGroup::<'_, E>::new;
746
747            let mut x = rhs;
748            triangular_solve::solve_unit_lower_triangular_in_place(
749                ld,
750                conj,
751                x.rb_mut(),
752                parallelism,
753            );
754            for mut x in x.rb_mut().col_chunks_mut(1) {
755                for j in 0..n {
756                    let d_inv = slice_group(ld.values_of_col(j))
757                        .read(0)
758                        .faer_real()
759                        .faer_inv();
760                    x.write(j, 0, x.read(j, 0).faer_scale_real(d_inv));
761                }
762            }
763            triangular_solve::solve_unit_lower_triangular_transpose_in_place(
764                ld,
765                conj.compose(Conj::Yes),
766                x.rb_mut(),
767                parallelism,
768            );
769        }
770    }
771
772    impl<I: Index> SymbolicSimplicialCholesky<I> {
773        #[inline]
774        pub fn nrows(&self) -> usize {
775            self.dimension
776        }
777        #[inline]
778        pub fn ncols(&self) -> usize {
779            self.nrows()
780        }
781
782        #[inline]
783        pub fn len_values(&self) -> usize {
784            self.row_indices.len()
785        }
786
787        #[inline]
788        pub fn col_ptrs(&self) -> &[I] {
789            &self.col_ptrs
790        }
791
792        #[inline]
793        pub fn row_indices(&self) -> &[I] {
794            &self.row_indices
795        }
796
797        #[inline]
798        pub fn ld_factors(&self) -> SymbolicSparseColMatRef<'_, I> {
799            unsafe {
800                SymbolicSparseColMatRef::new_unchecked(
801                    self.dimension,
802                    self.dimension,
803                    &self.col_ptrs,
804                    None,
805                    &self.row_indices,
806                )
807            }
808        }
809
810        pub fn solve_in_place_req<E: Entity>(
811            &self,
812            rhs_ncols: usize,
813        ) -> Result<StackReq, SizeOverflow> {
814            let _ = rhs_ncols;
815            Ok(StackReq::empty())
816        }
817    }
818
819    pub fn factorize_simplicial_numeric_ldlt_req<I: Index, E: Entity>(
820        n: usize,
821    ) -> Result<StackReq, SizeOverflow> {
822        let n_req = StackReq::try_new::<I>(n)?;
823        StackReq::try_all_of([make_raw_req::<E>(n)?, n_req, n_req, n_req])
824    }
825
826    pub fn factorize_simplicial_numeric_llt_req<I: Index, E: Entity>(
827        n: usize,
828    ) -> Result<StackReq, SizeOverflow> {
829        factorize_simplicial_numeric_ldlt_req::<I, E>(n)
830    }
831
832    #[derive(Debug)]
833    pub struct SimplicialLltRef<'a, I: Index, E: Entity> {
834        symbolic: &'a SymbolicSimplicialCholesky<I>,
835        values: SliceGroup<'a, E>,
836    }
837
838    #[derive(Debug)]
839    pub struct SimplicialLdltRef<'a, I: Index, E: Entity> {
840        symbolic: &'a SymbolicSimplicialCholesky<I>,
841        values: SliceGroup<'a, E>,
842    }
843
844    #[derive(Debug)]
845    pub struct SymbolicSimplicialCholesky<I> {
846        dimension: usize,
847        col_ptrs: alloc::vec::Vec<I>,
848        row_indices: alloc::vec::Vec<I>,
849        etree: alloc::vec::Vec<I>,
850    }
851
852    #[derive(Copy, Clone, Debug)]
853    pub struct EliminationTreeRef<'a, I: Index> {
854        pub(crate) inner: &'a [I::Signed],
855    }
856
857    impl<'a, I: Index> EliminationTreeRef<'a, I> {
858        #[inline]
859        pub fn into_inner(self) -> &'a [I::Signed] {
860            self.inner
861        }
862
863        #[inline]
864        #[track_caller]
865        pub(crate) fn ghost_inner<'n>(self, N: ghost::Size<'n>) -> &'a Array<'n, MaybeIdx<'n, I>> {
866            assert!(self.inner.len() == *N);
867            unsafe { Array::from_ref(MaybeIdx::from_slice_ref_unchecked(self.inner), N) }
868        }
869    }
870}
871
872pub mod supernodal {
873    use super::*;
874    use faer_core::{assert, debug_assert};
875
876    fn ereach_super<'n, 'nsuper, I: Index>(
877        A: ghost::SymbolicSparseColMatRef<'n, 'n, '_, I>,
878        super_etree: &Array<'nsuper, MaybeIdx<'nsuper, I>>,
879        index_to_super: &Array<'n, Idx<'nsuper, I>>,
880        current_row_positions: &mut Array<'nsuper, I>,
881        row_indices: &mut [Idx<'n, I>],
882        k: Idx<'n, usize>,
883        visited: &mut Array<'nsuper, I::Signed>,
884    ) {
885        let k_: I = *k.truncate();
886        visited[index_to_super[k].zx()] = k_.to_signed();
887        for i in A.row_indices_of_col(k) {
888            if i >= k {
889                continue;
890            }
891            let mut supernode_i = index_to_super[i].zx();
892            loop {
893                if visited[supernode_i] == k_.to_signed() {
894                    break;
895                }
896
897                row_indices[current_row_positions[supernode_i].zx()] = k.truncate();
898                current_row_positions[supernode_i] += I::truncate(1);
899
900                visited[supernode_i] = k_.to_signed();
901                supernode_i = super_etree[supernode_i].sx().idx().unwrap();
902            }
903        }
904    }
905
906    fn ereach_super_ata<'m, 'n, 'nsuper, I: Index>(
907        A: ghost::SymbolicSparseColMatRef<'m, 'n, '_, I>,
908        perm: Option<ghost::PermutationRef<'n, '_, I, Symbolic>>,
909        min_col: &Array<'m, MaybeIdx<'n, I>>,
910        super_etree: &Array<'nsuper, MaybeIdx<'nsuper, I>>,
911        index_to_super: &Array<'n, Idx<'nsuper, I>>,
912        current_row_positions: &mut Array<'nsuper, I>,
913        row_indices: &mut [Idx<'n, I>],
914        k: Idx<'n, usize>,
915        visited: &mut Array<'nsuper, I::Signed>,
916    ) {
917        let k_: I = *k.truncate();
918        visited[index_to_super[k].zx()] = k_.to_signed();
919
920        let fwd = perm.map(|perm| perm.into_arrays().0);
921        let fwd = |i: Idx<'n, usize>| fwd.map(|fwd| fwd[k].zx()).unwrap_or(i);
922        for i in A.row_indices_of_col(fwd(k)) {
923            let Some(i) = min_col[i].idx() else { continue };
924            let i = i.zx();
925
926            if i >= k {
927                continue;
928            }
929            let mut supernode_i = index_to_super[i].zx();
930            loop {
931                if visited[supernode_i] == k_.to_signed() {
932                    break;
933                }
934
935                row_indices[current_row_positions[supernode_i].zx()] = k.truncate();
936                current_row_positions[supernode_i] += I::truncate(1);
937
938                visited[supernode_i] = k_.to_signed();
939                supernode_i = super_etree[supernode_i].sx().idx().unwrap();
940            }
941        }
942    }
943
944    #[derive(Debug)]
945    pub struct SymbolicSupernodeRef<'a, I> {
946        start: usize,
947        pattern: &'a [I],
948    }
949
950    impl<'a, I: Index> SymbolicSupernodeRef<'a, I> {
951        #[inline]
952        pub fn start(self) -> usize {
953            self.start
954        }
955
956        pub fn pattern(self) -> &'a [I] {
957            self.pattern
958        }
959    }
960
961    impl<'a, I: Index, E: Entity> SupernodeRef<'a, I, E> {
962        #[inline]
963        pub fn start(self) -> usize {
964            self.symbolic.start
965        }
966
967        pub fn pattern(self) -> &'a [I] {
968            self.symbolic.pattern
969        }
970
971        pub fn matrix(self) -> MatRef<'a, E> {
972            self.matrix
973        }
974    }
975
976    #[derive(Debug)]
977    pub struct SupernodeRef<'a, I, E: Entity> {
978        matrix: MatRef<'a, E>,
979        symbolic: SymbolicSupernodeRef<'a, I>,
980    }
981
982    impl<'a, I: Index, E: Entity> SupernodalIntranodeBunchKaufmanRef<'a, I, E> {
983        #[inline]
984        pub fn new(
985            symbolic: &'a SymbolicSupernodalCholesky<I>,
986            values: GroupFor<E, &'a [E::Unit]>,
987            subdiag: GroupFor<E, &'a [E::Unit]>,
988            perm: PermutationRef<'a, I, E>,
989        ) -> Self {
990            let values = SliceGroup::<'_, E>::new(values);
991            let subdiag = SliceGroup::<'_, E>::new(subdiag);
992            assert!(values.len() == symbolic.len_values());
993            Self {
994                symbolic,
995                values,
996                subdiag,
997                perm,
998            }
999        }
1000
1001        #[inline]
1002        pub fn symbolic(self) -> &'a SymbolicSupernodalCholesky<I> {
1003            self.symbolic
1004        }
1005
1006        #[inline]
1007        pub fn values(self) -> SliceGroup<'a, E> {
1008            self.values
1009        }
1010
1011        #[inline]
1012        pub fn supernode(self, s: usize) -> SupernodeRef<'a, I, E> {
1013            let symbolic = self.symbolic();
1014            let L_values = self.values();
1015            let s_start = symbolic.supernode_begin[s].zx();
1016            let s_end = symbolic.supernode_begin[s + 1].zx();
1017
1018            let s_pattern = &symbolic.row_indices()[symbolic.col_ptrs_for_row_indices()[s].zx()
1019                ..symbolic.col_ptrs_for_row_indices()[s + 1].zx()];
1020            let s_ncols = s_end - s_start;
1021            let s_nrows = s_pattern.len() + s_ncols;
1022
1023            let Ls = faer_core::mat::from_column_major_slice::<'_, E>(
1024                L_values
1025                    .subslice(
1026                        symbolic.col_ptrs_for_values()[s].zx()
1027                            ..symbolic.col_ptrs_for_values()[s + 1].zx(),
1028                    )
1029                    .into_inner(),
1030                s_nrows,
1031                s_ncols,
1032            );
1033
1034            SupernodeRef {
1035                matrix: Ls,
1036                symbolic: SymbolicSupernodeRef {
1037                    start: s_start,
1038                    pattern: s_pattern,
1039                },
1040            }
1041        }
1042
1043        pub fn solve_in_place_no_numeric_permute_with_conj(
1044            self,
1045            conj: Conj,
1046            rhs: MatMut<'_, E>,
1047            parallelism: Parallelism,
1048            stack: PodStack<'_>,
1049        ) where
1050            E: ComplexField,
1051        {
1052            let symbolic = self.symbolic();
1053            let n = symbolic.nrows();
1054            assert!(rhs.nrows() == n);
1055            let mut stack = stack;
1056
1057            let mut x = rhs;
1058
1059            let k = x.ncols();
1060            for s in 0..symbolic.n_supernodes() {
1061                let s = self.supernode(s);
1062                let size = s.matrix.ncols();
1063                let Ls = s.matrix;
1064                let (Ls_top, Ls_bot) = Ls.split_at_row(size);
1065                let mut x_top = x.rb_mut().subrows_mut(s.start(), size);
1066                faer_core::solve::solve_unit_lower_triangular_in_place_with_conj(
1067                    Ls_top,
1068                    conj,
1069                    x_top.rb_mut(),
1070                    parallelism,
1071                );
1072
1073                let (mut tmp, _) = temp_mat_uninit::<E>(s.pattern().len(), k, stack.rb_mut());
1074                faer_core::mul::matmul_with_conj(
1075                    tmp.rb_mut(),
1076                    Ls_bot,
1077                    conj,
1078                    x_top.rb(),
1079                    Conj::No,
1080                    None,
1081                    E::faer_one(),
1082                    parallelism,
1083                );
1084
1085                let inv = self.perm.into_arrays().1;
1086                for j in 0..k {
1087                    for (idx, i) in s.pattern().iter().enumerate() {
1088                        let i = i.zx();
1089                        let i = inv[i].zx();
1090                        x.write(i, j, x.read(i, j).faer_sub(tmp.read(idx, j)))
1091                    }
1092                }
1093            }
1094            for s in 0..symbolic.n_supernodes() {
1095                let s = self.supernode(s);
1096                let size = s.matrix.ncols();
1097                let Bs = s.matrix();
1098                let subdiag = self.subdiag.subslice(s.start()..s.start() + size);
1099
1100                let mut idx = 0;
1101                while idx < size {
1102                    let subdiag = subdiag.read(idx);
1103                    let i = idx + s.start();
1104                    if subdiag == E::faer_zero() {
1105                        let d = Bs.read(idx, idx).faer_real();
1106                        for j in 0..k {
1107                            x.write(i, j, x.read(i, j).faer_scale_real(d))
1108                        }
1109                        idx += 1;
1110                    } else {
1111                        let d11 = Bs.read(idx, idx).faer_real();
1112                        let d22 = Bs.read(idx + 1, idx + 1).faer_real();
1113                        let d21 = subdiag;
1114
1115                        if conj == Conj::Yes {
1116                            for j in 0..k {
1117                                let xi = x.read(i, j);
1118                                let xip1 = x.read(i + 1, j);
1119
1120                                x.write(i, j, xi.faer_scale_real(d11).faer_add(xip1.faer_mul(d21)));
1121                                x.write(
1122                                    i + 1,
1123                                    j,
1124                                    xip1.faer_scale_real(d22)
1125                                        .faer_add(xi.faer_mul(d21.faer_conj())),
1126                                );
1127                            }
1128                        } else {
1129                            for j in 0..k {
1130                                let xi = x.read(i, j);
1131                                let xip1 = x.read(i + 1, j);
1132
1133                                x.write(
1134                                    i,
1135                                    j,
1136                                    xi.faer_scale_real(d11)
1137                                        .faer_add(xip1.faer_mul(d21.faer_conj())),
1138                                );
1139                                x.write(
1140                                    i + 1,
1141                                    j,
1142                                    xip1.faer_scale_real(d22).faer_add(xi.faer_mul(d21)),
1143                                );
1144                            }
1145                        }
1146                        idx += 2;
1147                    }
1148                }
1149            }
1150            for s in (0..symbolic.n_supernodes()).rev() {
1151                let s = self.supernode(s);
1152                let size = s.matrix.ncols();
1153                let Ls = s.matrix;
1154                let (Ls_top, Ls_bot) = Ls.split_at_row(size);
1155
1156                let (mut tmp, _) = temp_mat_uninit::<E>(s.pattern().len(), k, stack.rb_mut());
1157                let inv = self.perm.into_arrays().1;
1158                for j in 0..k {
1159                    for (idx, i) in s.pattern().iter().enumerate() {
1160                        let i = i.zx();
1161                        let i = inv[i].zx();
1162                        tmp.write(idx, j, x.read(i, j));
1163                    }
1164                }
1165
1166                let mut x_top = x.rb_mut().subrows_mut(s.start(), size);
1167                faer_core::mul::matmul_with_conj(
1168                    x_top.rb_mut(),
1169                    Ls_bot.transpose(),
1170                    conj.compose(Conj::Yes),
1171                    tmp.rb(),
1172                    Conj::No,
1173                    Some(E::faer_one()),
1174                    E::faer_one().faer_neg(),
1175                    parallelism,
1176                );
1177                faer_core::solve::solve_unit_upper_triangular_in_place_with_conj(
1178                    Ls_top.transpose(),
1179                    conj.compose(Conj::Yes),
1180                    x_top.rb_mut(),
1181                    parallelism,
1182                );
1183            }
1184        }
1185    }
1186
1187    impl<'a, I: Index, E: Entity> SupernodalLdltRef<'a, I, E> {
1188        #[inline]
1189        pub fn new(
1190            symbolic: &'a SymbolicSupernodalCholesky<I>,
1191            values: GroupFor<E, &'a [E::Unit]>,
1192        ) -> Self {
1193            let values = SliceGroup::new(values);
1194            assert!(values.len() == symbolic.len_values());
1195            Self { symbolic, values }
1196        }
1197
1198        #[inline]
1199        pub fn symbolic(self) -> &'a SymbolicSupernodalCholesky<I> {
1200            self.symbolic
1201        }
1202
1203        #[inline]
1204        pub fn values(self) -> SliceGroup<'a, E> {
1205            self.values
1206        }
1207
1208        #[inline]
1209        pub fn supernode(self, s: usize) -> SupernodeRef<'a, I, E> {
1210            let symbolic = self.symbolic();
1211            let L_values = self.values();
1212            let s_start = symbolic.supernode_begin[s].zx();
1213            let s_end = symbolic.supernode_begin[s + 1].zx();
1214
1215            let s_pattern = &symbolic.row_indices()[symbolic.col_ptrs_for_row_indices()[s].zx()
1216                ..symbolic.col_ptrs_for_row_indices()[s + 1].zx()];
1217            let s_ncols = s_end - s_start;
1218            let s_nrows = s_pattern.len() + s_ncols;
1219
1220            let Ls = faer_core::mat::from_column_major_slice::<'_, E>(
1221                L_values
1222                    .subslice(
1223                        symbolic.col_ptrs_for_values()[s].zx()
1224                            ..symbolic.col_ptrs_for_values()[s + 1].zx(),
1225                    )
1226                    .into_inner(),
1227                s_nrows,
1228                s_ncols,
1229            );
1230
1231            SupernodeRef {
1232                matrix: Ls,
1233                symbolic: SymbolicSupernodeRef {
1234                    start: s_start,
1235                    pattern: s_pattern,
1236                },
1237            }
1238        }
1239
1240        pub fn solve_in_place_with_conj(
1241            &self,
1242            conj: Conj,
1243            rhs: MatMut<'_, E>,
1244            parallelism: Parallelism,
1245            stack: PodStack<'_>,
1246        ) where
1247            E: ComplexField,
1248        {
1249            let symbolic = self.symbolic();
1250            let n = symbolic.nrows();
1251            assert!(rhs.nrows() == n);
1252
1253            let mut x = rhs;
1254            let mut stack = stack;
1255            let k = x.ncols();
1256            for s in 0..symbolic.n_supernodes() {
1257                let s = self.supernode(s);
1258                let size = s.matrix.ncols();
1259                let Ls = s.matrix;
1260                let (Ls_top, Ls_bot) = Ls.split_at_row(size);
1261                let mut x_top = x.rb_mut().subrows_mut(s.start(), size);
1262                faer_core::solve::solve_unit_lower_triangular_in_place_with_conj(
1263                    Ls_top,
1264                    conj,
1265                    x_top.rb_mut(),
1266                    parallelism,
1267                );
1268
1269                let (mut tmp, _) = temp_mat_uninit::<E>(s.pattern().len(), k, stack.rb_mut());
1270                faer_core::mul::matmul_with_conj(
1271                    tmp.rb_mut(),
1272                    Ls_bot,
1273                    conj,
1274                    x_top.rb(),
1275                    Conj::No,
1276                    None,
1277                    E::faer_one(),
1278                    parallelism,
1279                );
1280
1281                for j in 0..k {
1282                    for (idx, i) in s.pattern().iter().enumerate() {
1283                        let i = i.zx();
1284                        x.write(i, j, x.read(i, j).faer_sub(tmp.read(idx, j)))
1285                    }
1286                }
1287            }
1288            for s in 0..symbolic.n_supernodes() {
1289                let s = self.supernode(s);
1290                let size = s.matrix.ncols();
1291                let Ds = s.matrix.diagonal().column_vector();
1292                for j in 0..k {
1293                    for idx in 0..size {
1294                        let d_inv = Ds.read(idx).faer_real();
1295                        let i = idx + s.start();
1296                        x.write(i, j, x.read(i, j).faer_scale_real(d_inv))
1297                    }
1298                }
1299            }
1300            for s in (0..symbolic.n_supernodes()).rev() {
1301                let s = self.supernode(s);
1302                let size = s.matrix.ncols();
1303                let Ls = s.matrix;
1304                let (Ls_top, Ls_bot) = Ls.split_at_row(size);
1305
1306                let (mut tmp, _) = temp_mat_uninit::<E>(s.pattern().len(), k, stack.rb_mut());
1307                for j in 0..k {
1308                    for (idx, i) in s.pattern().iter().enumerate() {
1309                        let i = i.zx();
1310                        tmp.write(idx, j, x.read(i, j));
1311                    }
1312                }
1313
1314                let mut x_top = x.rb_mut().subrows_mut(s.start(), size);
1315                faer_core::mul::matmul_with_conj(
1316                    x_top.rb_mut(),
1317                    Ls_bot.transpose(),
1318                    conj.compose(Conj::Yes),
1319                    tmp.rb(),
1320                    Conj::No,
1321                    Some(E::faer_one()),
1322                    E::faer_one().faer_neg(),
1323                    parallelism,
1324                );
1325                faer_core::solve::solve_unit_upper_triangular_in_place_with_conj(
1326                    Ls_top.transpose(),
1327                    conj.compose(Conj::Yes),
1328                    x_top.rb_mut(),
1329                    parallelism,
1330                );
1331            }
1332        }
1333    }
1334
1335    impl<'a, I: Index, E: Entity> SupernodalLltRef<'a, I, E> {
1336        #[inline]
1337        pub fn new(
1338            symbolic: &'a SymbolicSupernodalCholesky<I>,
1339            values: GroupFor<E, &'a [E::Unit]>,
1340        ) -> Self {
1341            let values = SliceGroup::new(values);
1342            assert!(values.len() == symbolic.len_values());
1343            Self { symbolic, values }
1344        }
1345
1346        #[inline]
1347        pub fn symbolic(self) -> &'a SymbolicSupernodalCholesky<I> {
1348            self.symbolic
1349        }
1350
1351        #[inline]
1352        pub fn values(self) -> SliceGroup<'a, E> {
1353            self.values
1354        }
1355
1356        #[inline]
1357        pub fn supernode(self, s: usize) -> SupernodeRef<'a, I, E> {
1358            let symbolic = self.symbolic();
1359            let L_values = self.values();
1360            let s_start = symbolic.supernode_begin[s].zx();
1361            let s_end = symbolic.supernode_begin[s + 1].zx();
1362
1363            let s_pattern = &symbolic.row_indices()[symbolic.col_ptrs_for_row_indices()[s].zx()
1364                ..symbolic.col_ptrs_for_row_indices()[s + 1].zx()];
1365            let s_ncols = s_end - s_start;
1366            let s_nrows = s_pattern.len() + s_ncols;
1367
1368            let Ls = faer_core::mat::from_column_major_slice::<'_, E>(
1369                L_values
1370                    .subslice(
1371                        symbolic.col_ptrs_for_values()[s].zx()
1372                            ..symbolic.col_ptrs_for_values()[s + 1].zx(),
1373                    )
1374                    .into_inner(),
1375                s_nrows,
1376                s_ncols,
1377            );
1378
1379            SupernodeRef {
1380                matrix: Ls,
1381                symbolic: SymbolicSupernodeRef {
1382                    start: s_start,
1383                    pattern: s_pattern,
1384                },
1385            }
1386        }
1387
1388        pub fn solve_in_place_with_conj(
1389            &self,
1390            conj: Conj,
1391            rhs: MatMut<'_, E>,
1392            parallelism: Parallelism,
1393            stack: PodStack<'_>,
1394        ) where
1395            E: ComplexField,
1396        {
1397            let symbolic = self.symbolic();
1398            let n = symbolic.nrows();
1399            assert!(rhs.nrows() == n);
1400
1401            let mut x = rhs;
1402            let mut stack = stack;
1403            let k = x.ncols();
1404            for s in 0..symbolic.n_supernodes() {
1405                let s = self.supernode(s);
1406                let size = s.matrix.ncols();
1407                let Ls = s.matrix;
1408                let (Ls_top, Ls_bot) = Ls.split_at_row(size);
1409                let mut x_top = x.rb_mut().subrows_mut(s.start(), size);
1410                faer_core::solve::solve_lower_triangular_in_place_with_conj(
1411                    Ls_top,
1412                    conj,
1413                    x_top.rb_mut(),
1414                    parallelism,
1415                );
1416
1417                let (mut tmp, _) = temp_mat_uninit::<E>(s.pattern().len(), k, stack.rb_mut());
1418                faer_core::mul::matmul_with_conj(
1419                    tmp.rb_mut(),
1420                    Ls_bot,
1421                    conj,
1422                    x_top.rb(),
1423                    Conj::No,
1424                    None,
1425                    E::faer_one(),
1426                    parallelism,
1427                );
1428
1429                for j in 0..k {
1430                    for (idx, i) in s.pattern().iter().enumerate() {
1431                        let i = i.zx();
1432                        x.write(i, j, x.read(i, j).faer_sub(tmp.read(idx, j)))
1433                    }
1434                }
1435            }
1436            for s in (0..symbolic.n_supernodes()).rev() {
1437                let s = self.supernode(s);
1438                let size = s.matrix.ncols();
1439                let Ls = s.matrix;
1440                let (Ls_top, Ls_bot) = Ls.split_at_row(size);
1441
1442                let (mut tmp, _) = temp_mat_uninit::<E>(s.pattern().len(), k, stack.rb_mut());
1443                for j in 0..k {
1444                    for (idx, i) in s.pattern().iter().enumerate() {
1445                        let i = i.zx();
1446                        tmp.write(idx, j, x.read(i, j));
1447                    }
1448                }
1449
1450                let mut x_top = x.rb_mut().subrows_mut(s.start(), size);
1451                faer_core::mul::matmul_with_conj(
1452                    x_top.rb_mut(),
1453                    Ls_bot.transpose(),
1454                    conj.compose(Conj::Yes),
1455                    tmp.rb(),
1456                    Conj::No,
1457                    Some(E::faer_one()),
1458                    E::faer_one().faer_neg(),
1459                    parallelism,
1460                );
1461                faer_core::solve::solve_upper_triangular_in_place_with_conj(
1462                    Ls_top.transpose(),
1463                    conj.compose(Conj::Yes),
1464                    x_top.rb_mut(),
1465                    parallelism,
1466                );
1467            }
1468        }
1469    }
1470
1471    impl<I: Index> SymbolicSupernodalCholesky<I> {
1472        #[inline]
1473        pub fn n_supernodes(&self) -> usize {
1474            self.supernode_postorder.len()
1475        }
1476
1477        #[inline]
1478        pub fn nrows(&self) -> usize {
1479            self.dimension
1480        }
1481        #[inline]
1482        pub fn ncols(&self) -> usize {
1483            self.nrows()
1484        }
1485
1486        #[inline]
1487        pub fn len_values(&self) -> usize {
1488            self.col_ptrs_for_values()[self.n_supernodes()].zx()
1489        }
1490
1491        #[inline]
1492        pub fn supernode_begin(&self) -> &[I] {
1493            &self.supernode_begin[..self.n_supernodes()]
1494        }
1495
1496        #[inline]
1497        pub fn supernode_end(&self) -> &[I] {
1498            &self.supernode_begin[1..]
1499        }
1500
1501        #[inline]
1502        pub fn col_ptrs_for_row_indices(&self) -> &[I] {
1503            &self.col_ptrs_for_row_indices
1504        }
1505
1506        #[inline]
1507        pub fn col_ptrs_for_values(&self) -> &[I] {
1508            &self.col_ptrs_for_values
1509        }
1510
1511        #[inline]
1512        pub fn row_indices(&self) -> &[I] {
1513            &self.row_indices
1514        }
1515
1516        #[inline]
1517        pub fn supernode(&self, s: usize) -> supernodal::SymbolicSupernodeRef<'_, I> {
1518            let symbolic = self;
1519            let start = symbolic.supernode_begin[s].zx();
1520            let pattern = &symbolic.row_indices()[symbolic.col_ptrs_for_row_indices()[s].zx()
1521                ..symbolic.col_ptrs_for_row_indices()[s + 1].zx()];
1522            supernodal::SymbolicSupernodeRef { start, pattern }
1523        }
1524
1525        pub fn solve_in_place_req<E: Entity>(
1526            &self,
1527            rhs_ncols: usize,
1528        ) -> Result<StackReq, SizeOverflow> {
1529            let mut req = StackReq::empty();
1530            let symbolic = self;
1531            for s in 0..symbolic.n_supernodes() {
1532                let s = self.supernode(s);
1533                req = req.try_or(temp_mat_req::<E>(s.pattern.len(), rhs_ncols)?)?;
1534            }
1535            Ok(req)
1536        }
1537    }
1538
1539    pub fn factorize_supernodal_symbolic_cholesky_req<I: Index>(
1540        n: usize,
1541    ) -> Result<StackReq, SizeOverflow> {
1542        let n_req = StackReq::try_new::<I>(n)?;
1543        StackReq::try_all_of([n_req, n_req, n_req, n_req])
1544    }
1545
1546    pub fn factorize_supernodal_symbolic<I: Index>(
1547        A: SymbolicSparseColMatRef<'_, I>,
1548        etree: simplicial::EliminationTreeRef<'_, I>,
1549        col_counts: &[I],
1550        stack: PodStack<'_>,
1551        params: SymbolicSupernodalParams<'_>,
1552    ) -> Result<SymbolicSupernodalCholesky<I>, FaerError> {
1553        let n = A.nrows();
1554        assert!(A.nrows() == A.ncols());
1555        assert!(etree.into_inner().len() == n);
1556        assert!(col_counts.len() == n);
1557        ghost::with_size(n, |N| {
1558            ghost_factorize_supernodal_symbolic(
1559                ghost::SymbolicSparseColMatRef::new(A, N, N),
1560                None,
1561                None,
1562                CholeskyInput::A,
1563                etree.ghost_inner(N),
1564                Array::from_ref(col_counts, N),
1565                stack,
1566                params,
1567            )
1568        })
1569    }
1570
1571    pub(crate) enum CholeskyInput {
1572        A,
1573        ATA,
1574    }
1575
1576    pub(crate) fn ghost_factorize_supernodal_symbolic<'m, 'n, I: Index>(
1577        A: ghost::SymbolicSparseColMatRef<'m, 'n, '_, I>,
1578        col_perm: Option<ghost::PermutationRef<'n, '_, I, Symbolic>>,
1579        min_col: Option<&Array<'m, MaybeIdx<'n, I>>>,
1580        input: CholeskyInput,
1581        etree: &Array<'n, MaybeIdx<'n, I>>,
1582        col_counts: &Array<'n, I>,
1583        stack: PodStack<'_>,
1584        params: SymbolicSupernodalParams<'_>,
1585    ) -> Result<SymbolicSupernodalCholesky<I>, FaerError> {
1586        let to_wide = |i: I| i.zx() as u128;
1587        let from_wide = |i: u128| I::truncate(i as usize);
1588        let from_wide_checked = |i: u128| -> Option<I> {
1589            (i <= to_wide(I::from_signed(I::Signed::MAX))).then_some(I::truncate(i as usize))
1590        };
1591
1592        let N = A.ncols();
1593        let n = *N;
1594
1595        let zero = I::truncate(0);
1596        let one = I::truncate(1);
1597        let none = I::Signed::truncate(NONE);
1598
1599        if n == 0 {
1600            // would be funny if this allocation failed
1601            return Ok(SymbolicSupernodalCholesky {
1602                dimension: n,
1603                supernode_postorder: alloc::vec::Vec::new(),
1604                supernode_postorder_inv: alloc::vec::Vec::new(),
1605                descendant_count: alloc::vec::Vec::new(),
1606
1607                supernode_begin: try_collect([zero])?,
1608                col_ptrs_for_row_indices: try_collect([zero])?,
1609                col_ptrs_for_values: try_collect([zero])?,
1610                row_indices: alloc::vec::Vec::new(),
1611            });
1612        }
1613        let mut original_stack = stack;
1614
1615        let (index_to_super__, stack) = original_stack.rb_mut().make_raw::<I>(n);
1616        let (super_etree__, stack) = stack.make_raw::<I::Signed>(n);
1617        let (supernode_sizes__, stack) = stack.make_raw::<I>(n);
1618        let (child_count__, _) = stack.make_raw::<I>(n);
1619
1620        let child_count = Array::from_mut(child_count__, N);
1621        let index_to_super = Array::from_mut(index_to_super__, N);
1622
1623        mem::fill_zero(child_count.as_mut());
1624        for j in N.indices() {
1625            if let Some(parent) = etree[j].idx() {
1626                child_count[parent.zx()] += one;
1627            }
1628        }
1629
1630        mem::fill_zero(supernode_sizes__);
1631        let mut current_supernode = 0usize;
1632        supernode_sizes__[0] = one;
1633        for (j_prev, j) in zip(N.indices().take(n - 1), N.indices().skip(1)) {
1634            let is_parent_of_prev = (*etree[j_prev]).sx() == *j;
1635            let is_parent_of_only_prev = child_count[j] == one;
1636            let same_pattern_as_prev = col_counts[j_prev] == col_counts[j] + one;
1637
1638            if !(is_parent_of_prev && is_parent_of_only_prev && same_pattern_as_prev) {
1639                current_supernode += 1;
1640            }
1641            supernode_sizes__[current_supernode] += one;
1642        }
1643        let n_fundamental_supernodes = current_supernode + 1;
1644
1645        // last n elements contain supernode degrees
1646        let supernode_begin__ = ghost::with_size(
1647            n_fundamental_supernodes,
1648            |N_FUNDAMENTAL_SUPERNODES| -> Result<alloc::vec::Vec<I>, FaerError> {
1649                let supernode_sizes = Array::from_mut(
1650                    &mut supernode_sizes__[..n_fundamental_supernodes],
1651                    N_FUNDAMENTAL_SUPERNODES,
1652                );
1653                let super_etree = Array::from_mut(
1654                    &mut super_etree__[..n_fundamental_supernodes],
1655                    N_FUNDAMENTAL_SUPERNODES,
1656                );
1657
1658                let mut supernode_begin = 0usize;
1659                for s in N_FUNDAMENTAL_SUPERNODES.indices() {
1660                    let size = supernode_sizes[s].zx();
1661                    index_to_super.as_mut()[supernode_begin..][..size].fill(*s.truncate::<I>());
1662                    supernode_begin += size;
1663                }
1664
1665                let index_to_super = Array::from_mut(
1666                    Idx::from_slice_mut_checked(index_to_super.as_mut(), N_FUNDAMENTAL_SUPERNODES),
1667                    N,
1668                );
1669
1670                let mut supernode_begin = 0usize;
1671                for s in N_FUNDAMENTAL_SUPERNODES.indices() {
1672                    let size = supernode_sizes[s].zx();
1673                    let last = supernode_begin + size - 1;
1674                    let last = N.check(last);
1675                    if let Some(parent) = etree[last].idx() {
1676                        super_etree[s] = index_to_super[parent.zx()].to_signed();
1677                    } else {
1678                        super_etree[s] = none;
1679                    }
1680                    supernode_begin += size;
1681                }
1682
1683                let super_etree = Array::from_mut(
1684                    MaybeIdx::<'_, I>::from_slice_mut_checked(
1685                        super_etree.as_mut(),
1686                        N_FUNDAMENTAL_SUPERNODES,
1687                    ),
1688                    N_FUNDAMENTAL_SUPERNODES,
1689                );
1690
1691                if let Some(relax) = params.relax {
1692                    let req = || -> Result<StackReq, SizeOverflow> {
1693                        let req = StackReq::try_new::<I>(n_fundamental_supernodes)?;
1694                        StackReq::try_all_of([req; 5])
1695                    };
1696                    let mut mem = dyn_stack::GlobalPodBuffer::try_new(req().map_err(nomem)?)
1697                        .map_err(nomem)?;
1698                    let stack = PodStack::new(&mut mem);
1699
1700                    let child_lists = bytemuck::cast_slice_mut(
1701                        &mut child_count.as_mut()[..n_fundamental_supernodes],
1702                    );
1703                    let (child_list_heads, stack) =
1704                        stack.make_raw::<I::Signed>(n_fundamental_supernodes);
1705                    let (last_merged_children, stack) =
1706                        stack.make_raw::<I::Signed>(n_fundamental_supernodes);
1707                    let (merge_parents, stack) =
1708                        stack.make_raw::<I::Signed>(n_fundamental_supernodes);
1709                    let (fundamental_supernode_degrees, stack) =
1710                        stack.make_raw::<I>(n_fundamental_supernodes);
1711                    let (num_zeros, _) = stack.make_raw::<I>(n_fundamental_supernodes);
1712
1713                    let child_lists = Array::from_mut(
1714                        ghost::fill_none::<I>(child_lists, N_FUNDAMENTAL_SUPERNODES),
1715                        N_FUNDAMENTAL_SUPERNODES,
1716                    );
1717                    let child_list_heads = Array::from_mut(
1718                        ghost::fill_none::<I>(child_list_heads, N_FUNDAMENTAL_SUPERNODES),
1719                        N_FUNDAMENTAL_SUPERNODES,
1720                    );
1721                    let last_merged_children = Array::from_mut(
1722                        ghost::fill_none::<I>(last_merged_children, N_FUNDAMENTAL_SUPERNODES),
1723                        N_FUNDAMENTAL_SUPERNODES,
1724                    );
1725                    let merge_parents = Array::from_mut(
1726                        ghost::fill_none::<I>(merge_parents, N_FUNDAMENTAL_SUPERNODES),
1727                        N_FUNDAMENTAL_SUPERNODES,
1728                    );
1729                    let fundamental_supernode_degrees =
1730                        Array::from_mut(fundamental_supernode_degrees, N_FUNDAMENTAL_SUPERNODES);
1731                    let num_zeros = Array::from_mut(num_zeros, N_FUNDAMENTAL_SUPERNODES);
1732
1733                    let mut supernode_begin = 0usize;
1734                    for s in N_FUNDAMENTAL_SUPERNODES.indices() {
1735                        let size = supernode_sizes[s].zx();
1736                        fundamental_supernode_degrees[s] =
1737                            col_counts[N.check(supernode_begin + size - 1)] - one;
1738                        supernode_begin += size;
1739                    }
1740
1741                    for s in N_FUNDAMENTAL_SUPERNODES.indices() {
1742                        if let Some(parent) = super_etree[s].idx() {
1743                            let parent = parent.zx();
1744                            child_lists[s] = child_list_heads[parent];
1745                            child_list_heads[parent] = MaybeIdx::from_index(s.truncate());
1746                        }
1747                    }
1748
1749                    mem::fill_zero(num_zeros.as_mut());
1750                    for parent in N_FUNDAMENTAL_SUPERNODES.indices() {
1751                        loop {
1752                            let mut merging_child = MaybeIdx::none();
1753                            let mut num_new_zeros = 0usize;
1754                            let mut num_merged_zeros = 0usize;
1755                            let mut largest_mergable_size = 0usize;
1756
1757                            let mut child_ = child_list_heads[parent];
1758                            while let Some(child) = child_.idx() {
1759                                let child = child.zx();
1760                                if *child + 1 != *parent {
1761                                    child_ = child_lists[child];
1762                                    continue;
1763                                }
1764
1765                                if merge_parents[child].idx().is_some() {
1766                                    child_ = child_lists[child];
1767                                    continue;
1768                                }
1769
1770                                let parent_size = supernode_sizes[parent].zx();
1771                                let child_size = supernode_sizes[child].zx();
1772                                if child_size < largest_mergable_size {
1773                                    child_ = child_lists[child];
1774                                    continue;
1775                                }
1776
1777                                let parent_degree = fundamental_supernode_degrees[parent].zx();
1778                                let child_degree = fundamental_supernode_degrees[child].zx();
1779
1780                                let num_parent_zeros = num_zeros[parent].zx();
1781                                let num_child_zeros = num_zeros[child].zx();
1782
1783                                let status_num_merged_zeros = {
1784                                    let num_new_zeros =
1785                                        (parent_size + parent_degree - child_degree) * child_size;
1786
1787                                    if num_new_zeros == 0 {
1788                                        num_parent_zeros + num_child_zeros
1789                                    } else {
1790                                        let num_old_zeros = num_child_zeros + num_parent_zeros;
1791                                        let num_zeros = num_new_zeros + num_old_zeros;
1792
1793                                        let combined_size = child_size + parent_size;
1794                                        let num_expanded_entries =
1795                                            (combined_size * (combined_size + 1)) / 2
1796                                                + parent_degree * combined_size;
1797
1798                                        let f = || {
1799                                            for cutoff in relax {
1800                                                let num_zeros_cutoff =
1801                                                    num_expanded_entries as f64 * cutoff.1;
1802                                                if cutoff.0 >= combined_size
1803                                                    && num_zeros_cutoff >= num_zeros as f64
1804                                                {
1805                                                    return num_zeros;
1806                                                }
1807                                            }
1808                                            NONE
1809                                        };
1810                                        f()
1811                                    }
1812                                };
1813                                if status_num_merged_zeros == NONE {
1814                                    child_ = child_lists[child];
1815                                    continue;
1816                                }
1817
1818                                let num_proposed_new_zeros =
1819                                    status_num_merged_zeros - (num_child_zeros + num_parent_zeros);
1820                                if child_size > largest_mergable_size
1821                                    || num_proposed_new_zeros < num_new_zeros
1822                                {
1823                                    merging_child = MaybeIdx::from_index(child);
1824                                    num_new_zeros = num_proposed_new_zeros;
1825                                    num_merged_zeros = status_num_merged_zeros;
1826                                    largest_mergable_size = child_size;
1827                                }
1828
1829                                child_ = child_lists[child];
1830                            }
1831
1832                            if let Some(merging_child) = merging_child.idx() {
1833                                supernode_sizes[parent] =
1834                                    supernode_sizes[parent] + supernode_sizes[merging_child];
1835                                supernode_sizes[merging_child] = zero;
1836                                num_zeros[parent] = I::truncate(num_merged_zeros);
1837
1838                                merge_parents[merging_child] =
1839                                    if let Some(child) = last_merged_children[parent].idx() {
1840                                        MaybeIdx::from_index(child)
1841                                    } else {
1842                                        MaybeIdx::from_index(parent.truncate())
1843                                    };
1844
1845                                last_merged_children[parent] = if let Some(child) =
1846                                    last_merged_children[merging_child].idx()
1847                                {
1848                                    MaybeIdx::from_index(child)
1849                                } else {
1850                                    MaybeIdx::from_index(merging_child.truncate())
1851                                };
1852                            } else {
1853                                break;
1854                            }
1855                        }
1856                    }
1857
1858                    let original_to_relaxed = last_merged_children;
1859                    original_to_relaxed.as_mut().fill(MaybeIdx::none());
1860
1861                    let mut pos = 0usize;
1862                    for s in N_FUNDAMENTAL_SUPERNODES.indices() {
1863                        let idx = N_FUNDAMENTAL_SUPERNODES.check(pos);
1864                        let size = supernode_sizes[s];
1865                        let degree = fundamental_supernode_degrees[s];
1866                        if size > zero {
1867                            supernode_sizes[idx] = size;
1868                            fundamental_supernode_degrees[idx] = degree;
1869                            original_to_relaxed[s] = MaybeIdx::from_index(idx.truncate());
1870
1871                            pos += 1;
1872                        }
1873                    }
1874                    let n_relaxed_supernodes = pos;
1875
1876                    let mut supernode_begin__ = try_zeroed(n_relaxed_supernodes + 1)?;
1877                    supernode_begin__[1..].copy_from_slice(
1878                        &fundamental_supernode_degrees.as_ref()[..n_relaxed_supernodes],
1879                    );
1880
1881                    Ok(supernode_begin__)
1882                } else {
1883                    let mut supernode_begin__ = try_zeroed(n_fundamental_supernodes + 1)?;
1884
1885                    let mut supernode_begin = 0usize;
1886                    for s in N_FUNDAMENTAL_SUPERNODES.indices() {
1887                        let size = supernode_sizes[s].zx();
1888                        supernode_begin__[*s + 1] =
1889                            col_counts[N.check(supernode_begin + size - 1)] - one;
1890                        supernode_begin += size;
1891                    }
1892
1893                    Ok(supernode_begin__)
1894                }
1895            },
1896        )?;
1897
1898        let n_supernodes = supernode_begin__.len() - 1;
1899
1900        let (supernode_begin__, col_ptrs_for_row_indices__, col_ptrs_for_values__, row_indices__) =
1901            ghost::with_size(
1902                n_supernodes,
1903                |N_SUPERNODES| -> Result<
1904                    (
1905                        alloc::vec::Vec<I>,
1906                        alloc::vec::Vec<I>,
1907                        alloc::vec::Vec<I>,
1908                        alloc::vec::Vec<I>,
1909                    ),
1910                    FaerError,
1911                > {
1912                    let supernode_sizes =
1913                        Array::from_mut(&mut supernode_sizes__[..n_supernodes], N_SUPERNODES);
1914
1915                    if n_supernodes != n_fundamental_supernodes {
1916                        let mut supernode_begin = 0usize;
1917                        for s in N_SUPERNODES.indices() {
1918                            let size = supernode_sizes[s].zx();
1919                            index_to_super.as_mut()[supernode_begin..][..size]
1920                                .fill(*s.truncate::<I>());
1921                            supernode_begin += size;
1922                        }
1923
1924                        let index_to_super = Array::from_mut(
1925                            Idx::<'_, I>::from_slice_mut_checked(
1926                                index_to_super.as_mut(),
1927                                N_SUPERNODES,
1928                            ),
1929                            N,
1930                        );
1931                        let super_etree =
1932                            Array::from_mut(&mut super_etree__[..n_supernodes], N_SUPERNODES);
1933
1934                        let mut supernode_begin = 0usize;
1935                        for s in N_SUPERNODES.indices() {
1936                            let size = supernode_sizes[s].zx();
1937                            let last = supernode_begin + size - 1;
1938                            if let Some(parent) = etree[N.check(last)].idx() {
1939                                super_etree[s] = index_to_super[parent.zx()].to_signed();
1940                            } else {
1941                                super_etree[s] = none;
1942                            }
1943                            supernode_begin += size;
1944                        }
1945                    }
1946
1947                    let index_to_super = Array::from_mut(
1948                        Idx::from_slice_mut_checked(index_to_super.as_mut(), N_SUPERNODES),
1949                        N,
1950                    );
1951
1952                    let mut supernode_begin__ = supernode_begin__;
1953                    let mut col_ptrs_for_row_indices__ = try_zeroed::<I>(n_supernodes + 1)?;
1954                    let mut col_ptrs_for_values__ = try_zeroed::<I>(n_supernodes + 1)?;
1955
1956                    let mut row_ptr = zero;
1957                    let mut val_ptr = zero;
1958
1959                    supernode_begin__[0] = zero;
1960
1961                    let mut row_indices__ = {
1962                        let mut wide_val_count = 0u128;
1963                        for (s, [current, next]) in zip(
1964                            N_SUPERNODES.indices(),
1965                            windows2(Cell::as_slice_of_cells(Cell::from_mut(
1966                                &mut *supernode_begin__,
1967                            ))),
1968                        ) {
1969                            let degree = next.get();
1970                            let ncols = supernode_sizes[s];
1971                            let nrows = degree + ncols;
1972                            supernode_sizes[s] = row_ptr;
1973                            next.set(current.get() + ncols);
1974
1975                            col_ptrs_for_row_indices__[*s] = row_ptr;
1976                            col_ptrs_for_values__[*s] = val_ptr;
1977
1978                            let wide_matrix_size = to_wide(nrows) * to_wide(ncols);
1979                            wide_val_count += wide_matrix_size;
1980
1981                            row_ptr += degree;
1982                            val_ptr = from_wide(to_wide(val_ptr) + wide_matrix_size);
1983                        }
1984                        col_ptrs_for_row_indices__[n_supernodes] = row_ptr;
1985                        col_ptrs_for_values__[n_supernodes] = val_ptr;
1986                        from_wide_checked(wide_val_count).ok_or(FaerError::IndexOverflow)?;
1987
1988                        try_zeroed::<I>(row_ptr.zx())?
1989                    };
1990
1991                    let super_etree = Array::from_ref(
1992                        MaybeIdx::from_slice_ref_checked(
1993                            &super_etree__[..n_supernodes],
1994                            N_SUPERNODES,
1995                        ),
1996                        N_SUPERNODES,
1997                    );
1998
1999                    let current_row_positions = supernode_sizes;
2000
2001                    let row_indices = Idx::from_slice_mut_checked(&mut row_indices__, N);
2002                    let visited = Array::from_mut(
2003                        bytemuck::cast_slice_mut(&mut child_count.as_mut()[..n_supernodes]),
2004                        N_SUPERNODES,
2005                    );
2006
2007                    mem::fill_none::<I::Signed>(visited.as_mut());
2008                    if matches!(input, CholeskyInput::A) {
2009                        let A = ghost::SymbolicSparseColMatRef::new(A.into_inner(), N, N);
2010                        for s in N_SUPERNODES.indices() {
2011                            let k1 =
2012                                ghost::IdxInclusive::new_checked(supernode_begin__[*s].zx(), N);
2013                            let k2 =
2014                                ghost::IdxInclusive::new_checked(supernode_begin__[*s + 1].zx(), N);
2015
2016                            for k in k1.range_to(k2) {
2017                                ereach_super(
2018                                    A,
2019                                    super_etree,
2020                                    index_to_super,
2021                                    current_row_positions,
2022                                    row_indices,
2023                                    k,
2024                                    visited,
2025                                );
2026                            }
2027                        }
2028                    } else {
2029                        let min_col = min_col.unwrap();
2030                        for s in N_SUPERNODES.indices() {
2031                            let k1 =
2032                                ghost::IdxInclusive::new_checked(supernode_begin__[*s].zx(), N);
2033                            let k2 =
2034                                ghost::IdxInclusive::new_checked(supernode_begin__[*s + 1].zx(), N);
2035
2036                            for k in k1.range_to(k2) {
2037                                ereach_super_ata(
2038                                    A,
2039                                    col_perm,
2040                                    min_col,
2041                                    super_etree,
2042                                    index_to_super,
2043                                    current_row_positions,
2044                                    row_indices,
2045                                    k,
2046                                    visited,
2047                                );
2048                            }
2049                        }
2050                    }
2051
2052                    debug_assert!(
2053                        current_row_positions.as_ref() == &col_ptrs_for_row_indices__[1..]
2054                    );
2055
2056                    Ok((
2057                        supernode_begin__,
2058                        col_ptrs_for_row_indices__,
2059                        col_ptrs_for_values__,
2060                        row_indices__,
2061                    ))
2062                },
2063            )?;
2064
2065        let mut supernode_etree__: alloc::vec::Vec<I> = try_collect(
2066            bytemuck::cast_slice(&super_etree__[..n_supernodes])
2067                .iter()
2068                .copied(),
2069        )?;
2070        let mut supernode_postorder__ = try_zeroed::<I>(n_supernodes)?;
2071
2072        let mut descendent_count__ = try_zeroed::<I>(n_supernodes)?;
2073
2074        ghost::with_size(n_supernodes, |N_SUPERNODES| {
2075            let post = Array::from_mut(&mut supernode_postorder__, N_SUPERNODES);
2076            let desc_count = Array::from_mut(&mut descendent_count__, N_SUPERNODES);
2077            let etree: &Array<'_, MaybeIdx<'_, I>> = Array::from_ref(
2078                MaybeIdx::from_slice_ref_checked(
2079                    bytemuck::cast_slice(&supernode_etree__),
2080                    N_SUPERNODES,
2081                ),
2082                N_SUPERNODES,
2083            );
2084
2085            for s in N_SUPERNODES.indices() {
2086                if let Some(parent) = etree[s].idx() {
2087                    let parent = parent.zx();
2088                    desc_count[parent] = desc_count[parent] + desc_count[s] + one;
2089                }
2090            }
2091
2092            ghost_postorder(post, etree, original_stack);
2093            let post_inv = Array::from_mut(
2094                bytemuck::cast_slice_mut(&mut supernode_etree__),
2095                N_SUPERNODES,
2096            );
2097            for i in N_SUPERNODES.indices() {
2098                post_inv[N_SUPERNODES.check(post[i].zx())] = I::truncate(*i);
2099            }
2100        });
2101
2102        Ok(SymbolicSupernodalCholesky {
2103            dimension: n,
2104            supernode_postorder: supernode_postorder__,
2105            supernode_postorder_inv: supernode_etree__,
2106            descendant_count: descendent_count__,
2107            supernode_begin: supernode_begin__,
2108            col_ptrs_for_row_indices: col_ptrs_for_row_indices__,
2109            col_ptrs_for_values: col_ptrs_for_values__,
2110            row_indices: row_indices__,
2111        })
2112    }
2113
2114    #[inline]
2115    pub(crate) fn partition_fn<I: Index>(idx: usize) -> impl Fn(&I) -> bool {
2116        let idx = I::truncate(idx);
2117        move |&i| i < idx
2118    }
2119
2120    pub fn factorize_supernodal_numeric_llt_req<I: Index, E: Entity>(
2121        symbolic: &SymbolicSupernodalCholesky<I>,
2122        parallelism: Parallelism,
2123    ) -> Result<StackReq, SizeOverflow> {
2124        let n_supernodes = symbolic.n_supernodes();
2125        let n = symbolic.nrows();
2126        let post = &*symbolic.supernode_postorder;
2127        let post_inv = &*symbolic.supernode_postorder_inv;
2128
2129        let desc_count = &*symbolic.descendant_count;
2130
2131        let col_ptr_row = &*symbolic.col_ptrs_for_row_indices;
2132        let row_ind = &*symbolic.row_indices;
2133
2134        let mut req = StackReq::empty();
2135        for s in 0..n_supernodes {
2136            let s_start = symbolic.supernode_begin[s].zx();
2137            let s_end = symbolic.supernode_begin[s + 1].zx();
2138
2139            let s_ncols = s_end - s_start;
2140
2141            let s_postordered = post_inv[s].zx();
2142            let desc_count = desc_count[s].zx();
2143            for d in &post[s_postordered - desc_count..s_postordered] {
2144                let mut d_req = StackReq::empty();
2145                let d = d.zx();
2146
2147                let d_pattern = &row_ind[col_ptr_row[d].zx()..col_ptr_row[d + 1].zx()];
2148                let d_pattern_start = d_pattern.partition_point(partition_fn(s_start));
2149                let d_pattern_mid_len =
2150                    d_pattern[d_pattern_start..].partition_point(partition_fn(s_end));
2151
2152                d_req = d_req.try_and(temp_mat_req::<E>(
2153                    d_pattern.len() - d_pattern_start,
2154                    d_pattern_mid_len,
2155                )?)?;
2156                req = req.try_or(d_req)?;
2157            }
2158            req = req.try_or(
2159                faer_cholesky::ldlt_diagonal::compute::raw_cholesky_in_place_req::<E>(
2160                    s_ncols,
2161                    parallelism,
2162                    Default::default(),
2163                )?,
2164            )?;
2165        }
2166        req.try_and(StackReq::try_new::<I>(n)?)
2167    }
2168
2169    pub fn factorize_supernodal_numeric_ldlt_req<I: Index, E: Entity>(
2170        symbolic: &SymbolicSupernodalCholesky<I>,
2171        parallelism: Parallelism,
2172    ) -> Result<StackReq, SizeOverflow> {
2173        let n_supernodes = symbolic.n_supernodes();
2174        let n = symbolic.nrows();
2175        let post = &*symbolic.supernode_postorder;
2176        let post_inv = &*symbolic.supernode_postorder_inv;
2177
2178        let desc_count = &*symbolic.descendant_count;
2179
2180        let col_ptr_row = &*symbolic.col_ptrs_for_row_indices;
2181        let row_ind = &*symbolic.row_indices;
2182
2183        let mut req = StackReq::empty();
2184        for s in 0..n_supernodes {
2185            let s_start = symbolic.supernode_begin[s].zx();
2186            let s_end = symbolic.supernode_begin[s + 1].zx();
2187
2188            let s_ncols = s_end - s_start;
2189
2190            let s_postordered = post_inv[s].zx();
2191            let desc_count = desc_count[s].zx();
2192            for d in &post[s_postordered - desc_count..s_postordered] {
2193                let mut d_req = StackReq::empty();
2194
2195                let d = d.zx();
2196                let d_start = symbolic.supernode_begin[d].zx();
2197                let d_end = symbolic.supernode_begin[d + 1].zx();
2198
2199                let d_pattern = &row_ind[col_ptr_row[d].zx()..col_ptr_row[d + 1].zx()];
2200
2201                let d_ncols = d_end - d_start;
2202
2203                let d_pattern_start = d_pattern.partition_point(partition_fn(s_start));
2204                let d_pattern_mid_len =
2205                    d_pattern[d_pattern_start..].partition_point(partition_fn(s_end));
2206
2207                d_req = d_req.try_and(temp_mat_req::<E>(
2208                    d_pattern.len() - d_pattern_start,
2209                    d_pattern_mid_len,
2210                )?)?;
2211                d_req = d_req.try_and(temp_mat_req::<E>(d_ncols, d_pattern_mid_len)?)?;
2212                req = req.try_or(d_req)?;
2213            }
2214            req = req.try_or(
2215                faer_cholesky::ldlt_diagonal::compute::raw_cholesky_in_place_req::<E>(
2216                    s_ncols,
2217                    parallelism,
2218                    Default::default(),
2219                )?,
2220            )?;
2221        }
2222        req.try_and(StackReq::try_new::<I>(n)?)
2223    }
2224
2225    pub fn factorize_supernodal_numeric_intranode_bunch_kaufman_req<I: Index, E: Entity>(
2226        symbolic: &SymbolicSupernodalCholesky<I>,
2227        parallelism: Parallelism,
2228    ) -> Result<StackReq, SizeOverflow> {
2229        let n_supernodes = symbolic.n_supernodes();
2230        let n = symbolic.nrows();
2231        let post = &*symbolic.supernode_postorder;
2232        let post_inv = &*symbolic.supernode_postorder_inv;
2233
2234        let desc_count = &*symbolic.descendant_count;
2235
2236        let col_ptr_row = &*symbolic.col_ptrs_for_row_indices;
2237        let row_ind = &*symbolic.row_indices;
2238
2239        let mut req = StackReq::empty();
2240        for s in 0..n_supernodes {
2241            let s_start = symbolic.supernode_begin[s].zx();
2242            let s_end = symbolic.supernode_begin[s + 1].zx();
2243
2244            let s_ncols = s_end - s_start;
2245            let s_pattern = &row_ind[col_ptr_row[s].zx()..col_ptr_row[s + 1].zx()];
2246
2247            let s_postordered = post_inv[s].zx();
2248            let desc_count = desc_count[s].zx();
2249            for d in &post[s_postordered - desc_count..s_postordered] {
2250                let mut d_req = StackReq::empty();
2251
2252                let d = d.zx();
2253                let d_start = symbolic.supernode_begin[d].zx();
2254                let d_end = symbolic.supernode_begin[d + 1].zx();
2255
2256                let d_pattern = &row_ind[col_ptr_row[d].zx()..col_ptr_row[d + 1].zx()];
2257
2258                let d_ncols = d_end - d_start;
2259
2260                let d_pattern_start = d_pattern.partition_point(partition_fn(s_start));
2261                let d_pattern_mid_len =
2262                    d_pattern[d_pattern_start..].partition_point(partition_fn(s_end));
2263
2264                d_req = d_req.try_and(temp_mat_req::<E>(
2265                    d_pattern.len() - d_pattern_start,
2266                    d_pattern_mid_len,
2267                )?)?;
2268                d_req = d_req.try_and(temp_mat_req::<E>(d_ncols, d_pattern_mid_len)?)?;
2269                req = req.try_or(d_req)?;
2270            }
2271            req = StackReq::try_any_of([
2272                req,
2273                faer_cholesky::bunch_kaufman::compute::cholesky_in_place_req::<I, E>(
2274                    s_ncols,
2275                    parallelism,
2276                    Default::default(),
2277                )?,
2278                faer_core::permutation::permute_cols_in_place_req::<I, E>(
2279                    s_pattern.len(),
2280                    s_ncols,
2281                )?,
2282            ])?;
2283        }
2284        req.try_and(StackReq::try_new::<I>(n)?)
2285    }
2286
2287    pub fn factorize_supernodal_numeric_llt<I: Index, E: ComplexField>(
2288        L_values: GroupFor<E, &mut [E::Unit]>,
2289        A_lower: SparseColMatRef<'_, I, E>,
2290        regularization: LltRegularization<E>,
2291        symbolic: &SymbolicSupernodalCholesky<I>,
2292        parallelism: Parallelism,
2293        stack: PodStack<'_>,
2294    ) -> Result<usize, CholeskyError> {
2295        let n_supernodes = symbolic.n_supernodes();
2296        let n = symbolic.nrows();
2297        let mut dynamic_regularization_count = 0usize;
2298        let mut L_values = SliceGroupMut::<'_, E>::new(L_values);
2299        L_values.fill_zero();
2300
2301        assert!(A_lower.nrows() == n);
2302        assert!(A_lower.ncols() == n);
2303        assert!(L_values.len() == symbolic.len_values());
2304        let slice_group = SliceGroup::<'_, E>::new;
2305
2306        let none = I::Signed::truncate(NONE);
2307
2308        let post = &*symbolic.supernode_postorder;
2309        let post_inv = &*symbolic.supernode_postorder_inv;
2310
2311        let desc_count = &*symbolic.descendant_count;
2312
2313        let col_ptr_row = &*symbolic.col_ptrs_for_row_indices;
2314        let col_ptr_val = &*symbolic.col_ptrs_for_values;
2315        let row_ind = &*symbolic.row_indices;
2316
2317        // mapping from global indices to local
2318        let (global_to_local, mut stack) = stack.make_raw::<I::Signed>(n);
2319        mem::fill_none(global_to_local.as_mut());
2320
2321        for s in 0..n_supernodes {
2322            let s_start = symbolic.supernode_begin[s].zx();
2323            let s_end = symbolic.supernode_begin[s + 1].zx();
2324
2325            let s_pattern = &row_ind[col_ptr_row[s].zx()..col_ptr_row[s + 1].zx()];
2326            let s_ncols = s_end - s_start;
2327            let s_nrows = s_pattern.len() + s_ncols;
2328
2329            for (i, &row) in s_pattern.iter().enumerate() {
2330                global_to_local[row.zx()] = I::Signed::truncate(i + s_ncols);
2331            }
2332
2333            let (head, tail) = L_values.rb_mut().split_at(col_ptr_val[s].zx());
2334            let head = head.rb();
2335            let mut Ls = faer_core::mat::from_column_major_slice_mut::<'_, E>(
2336                tail.subslice(0..(col_ptr_val[s + 1] - col_ptr_val[s]).zx())
2337                    .into_inner(),
2338                s_nrows,
2339                s_ncols,
2340            );
2341
2342            for j in s_start..s_end {
2343                let j_shifted = j - s_start;
2344                for (i, val) in zip(
2345                    A_lower.row_indices_of_col(j),
2346                    slice_group(A_lower.values_of_col(j)).into_ref_iter(),
2347                ) {
2348                    let val = val.read();
2349                    let (ix, iy) = if i >= s_end {
2350                        (global_to_local[i].sx(), j_shifted)
2351                    } else {
2352                        (i - s_start, j_shifted)
2353                    };
2354                    Ls.write(ix, iy, Ls.read(ix, iy).faer_add(val));
2355                }
2356            }
2357
2358            let s_postordered = post_inv[s].zx();
2359            let desc_count = desc_count[s].zx();
2360            for d in &post[s_postordered - desc_count..s_postordered] {
2361                let d = d.zx();
2362                let d_start = symbolic.supernode_begin[d].zx();
2363                let d_end = symbolic.supernode_begin[d + 1].zx();
2364
2365                let d_pattern = &row_ind[col_ptr_row[d].zx()..col_ptr_row[d + 1].zx()];
2366                let d_ncols = d_end - d_start;
2367                let d_nrows = d_pattern.len() + d_ncols;
2368
2369                let Ld = faer_core::mat::from_column_major_slice::<'_, E>(
2370                    head.subslice(col_ptr_val[d].zx()..col_ptr_val[d + 1].zx())
2371                        .into_inner(),
2372                    d_nrows,
2373                    d_ncols,
2374                );
2375
2376                let d_pattern_start = d_pattern.partition_point(partition_fn(s_start));
2377                let d_pattern_mid_len =
2378                    d_pattern[d_pattern_start..].partition_point(partition_fn(s_end));
2379                let d_pattern_mid = d_pattern_start + d_pattern_mid_len;
2380
2381                let (_, Ld_mid_bot) = Ld.split_at_row(d_ncols);
2382                let (_, Ld_mid_bot) = Ld_mid_bot.split_at_row(d_pattern_start);
2383                let (Ld_mid, Ld_bot) = Ld_mid_bot.split_at_row(d_pattern_mid_len);
2384
2385                let stack = stack.rb_mut();
2386
2387                let (tmp, _) = temp_mat_uninit::<E>(Ld_mid_bot.nrows(), d_pattern_mid_len, stack);
2388
2389                let (mut tmp_top, mut tmp_bot) = tmp.split_at_row_mut(d_pattern_mid_len);
2390
2391                use faer_core::{mul, mul::triangular};
2392                triangular::matmul(
2393                    tmp_top.rb_mut(),
2394                    triangular::BlockStructure::TriangularLower,
2395                    Ld_mid,
2396                    triangular::BlockStructure::Rectangular,
2397                    Ld_mid.rb().adjoint(),
2398                    triangular::BlockStructure::Rectangular,
2399                    None,
2400                    E::faer_one(),
2401                    parallelism,
2402                );
2403                mul::matmul(
2404                    tmp_bot.rb_mut(),
2405                    Ld_bot,
2406                    Ld_mid.rb().adjoint(),
2407                    None,
2408                    E::faer_one(),
2409                    parallelism,
2410                );
2411                for (j_idx, j) in d_pattern[d_pattern_start..d_pattern_mid].iter().enumerate() {
2412                    let j = j.zx();
2413                    let j_s = j - s_start;
2414                    for (i_idx, i) in d_pattern[d_pattern_start..d_pattern_mid][j_idx..]
2415                        .iter()
2416                        .enumerate()
2417                    {
2418                        let i_idx = i_idx + j_idx;
2419
2420                        let i = i.zx();
2421                        let i_s = i - s_start;
2422
2423                        debug_assert!(i_s >= j_s);
2424
2425                        unsafe {
2426                            Ls.write_unchecked(
2427                                i_s,
2428                                j_s,
2429                                Ls.read_unchecked(i_s, j_s)
2430                                    .faer_sub(tmp_top.read_unchecked(i_idx, j_idx)),
2431                            )
2432                        };
2433                    }
2434                }
2435
2436                for (j_idx, j) in d_pattern[d_pattern_start..d_pattern_mid].iter().enumerate() {
2437                    let j = j.zx();
2438                    let j_s = j - s_start;
2439                    for (i_idx, i) in d_pattern[d_pattern_mid..].iter().enumerate() {
2440                        let i = i.zx();
2441                        let i_s = global_to_local[i].zx();
2442                        unsafe {
2443                            Ls.write_unchecked(
2444                                i_s,
2445                                j_s,
2446                                Ls.read_unchecked(i_s, j_s)
2447                                    .faer_sub(tmp_bot.read_unchecked(i_idx, j_idx)),
2448                            )
2449                        };
2450                    }
2451                }
2452            }
2453
2454            let (mut Ls_top, mut Ls_bot) = Ls.rb_mut().split_at_row_mut(s_ncols);
2455
2456            let params = Default::default();
2457            dynamic_regularization_count += match faer_cholesky::llt::compute::cholesky_in_place(
2458                Ls_top.rb_mut(),
2459                regularization,
2460                parallelism,
2461                stack.rb_mut(),
2462                params,
2463            ) {
2464                Ok(count) => count,
2465                Err(err) => {
2466                    return Err(CholeskyError {
2467                        non_positive_definite_minor: err.non_positive_definite_minor + s_start,
2468                    })
2469                }
2470            }
2471            .dynamic_regularization_count;
2472            faer_core::solve::solve_lower_triangular_in_place(
2473                Ls_top.rb().conjugate(),
2474                Ls_bot.rb_mut().transpose_mut(),
2475                parallelism,
2476            );
2477
2478            for &row in s_pattern {
2479                global_to_local[row.zx()] = none;
2480            }
2481        }
2482        Ok(dynamic_regularization_count)
2483    }
2484
2485    pub fn factorize_supernodal_numeric_ldlt<I: Index, E: ComplexField>(
2486        L_values: GroupFor<E, &mut [E::Unit]>,
2487        A_lower: SparseColMatRef<'_, I, E>,
2488        regularization: LdltRegularization<'_, E>,
2489        symbolic: &SymbolicSupernodalCholesky<I>,
2490        parallelism: Parallelism,
2491        stack: PodStack<'_>,
2492    ) -> usize {
2493        let n_supernodes = symbolic.n_supernodes();
2494        let n = symbolic.nrows();
2495        let mut dynamic_regularization_count = 0usize;
2496        let mut L_values = SliceGroupMut::<'_, E>::new(L_values);
2497        L_values.fill_zero();
2498
2499        assert!(A_lower.nrows() == n);
2500        assert!(A_lower.ncols() == n);
2501        assert!(L_values.len() == symbolic.len_values());
2502        let slice_group = SliceGroup::<'_, E>::new;
2503
2504        let none = I::Signed::truncate(NONE);
2505
2506        let post = &*symbolic.supernode_postorder;
2507        let post_inv = &*symbolic.supernode_postorder_inv;
2508
2509        let desc_count = &*symbolic.descendant_count;
2510
2511        let col_ptr_row = &*symbolic.col_ptrs_for_row_indices;
2512        let col_ptr_val = &*symbolic.col_ptrs_for_values;
2513        let row_ind = &*symbolic.row_indices;
2514
2515        // mapping from global indices to local
2516        let (global_to_local, mut stack) = stack.make_raw::<I::Signed>(n);
2517        mem::fill_none(global_to_local.as_mut());
2518
2519        for s in 0..n_supernodes {
2520            let s_start = symbolic.supernode_begin[s].zx();
2521            let s_end = symbolic.supernode_begin[s + 1].zx();
2522
2523            let s_pattern = &row_ind[col_ptr_row[s].zx()..col_ptr_row[s + 1].zx()];
2524            let s_ncols = s_end - s_start;
2525            let s_nrows = s_pattern.len() + s_ncols;
2526
2527            for (i, &row) in s_pattern.iter().enumerate() {
2528                global_to_local[row.zx()] = I::Signed::truncate(i + s_ncols);
2529            }
2530
2531            let (head, tail) = L_values.rb_mut().split_at(col_ptr_val[s].zx());
2532            let head = head.rb();
2533            let mut Ls = faer_core::mat::from_column_major_slice_mut::<'_, E>(
2534                tail.subslice(0..(col_ptr_val[s + 1] - col_ptr_val[s]).zx())
2535                    .into_inner(),
2536                s_nrows,
2537                s_ncols,
2538            );
2539
2540            for j in s_start..s_end {
2541                let j_shifted = j - s_start;
2542                for (i, val) in zip(
2543                    A_lower.row_indices_of_col(j),
2544                    slice_group(A_lower.values_of_col(j)).into_ref_iter(),
2545                ) {
2546                    let val = val.read();
2547                    let (ix, iy) = if i >= s_end {
2548                        (global_to_local[i].sx(), j_shifted)
2549                    } else {
2550                        (i - s_start, j_shifted)
2551                    };
2552                    Ls.write(ix, iy, Ls.read(ix, iy).faer_add(val));
2553                }
2554            }
2555
2556            let s_postordered = post_inv[s].zx();
2557            let desc_count = desc_count[s].zx();
2558            for d in &post[s_postordered - desc_count..s_postordered] {
2559                let d = d.zx();
2560                let d_start = symbolic.supernode_begin[d].zx();
2561                let d_end = symbolic.supernode_begin[d + 1].zx();
2562
2563                let d_pattern = &row_ind[col_ptr_row[d].zx()..col_ptr_row[d + 1].zx()];
2564                let d_ncols = d_end - d_start;
2565                let d_nrows = d_pattern.len() + d_ncols;
2566
2567                let Ld = faer_core::mat::from_column_major_slice::<'_, E>(
2568                    head.subslice(col_ptr_val[d].zx()..col_ptr_val[d + 1].zx())
2569                        .into_inner(),
2570                    d_nrows,
2571                    d_ncols,
2572                );
2573
2574                let d_pattern_start = d_pattern.partition_point(partition_fn(s_start));
2575                let d_pattern_mid_len =
2576                    d_pattern[d_pattern_start..].partition_point(partition_fn(s_end));
2577                let d_pattern_mid = d_pattern_start + d_pattern_mid_len;
2578
2579                let (Ld_top, Ld_mid_bot) = Ld.split_at_row(d_ncols);
2580                let (_, Ld_mid_bot) = Ld_mid_bot.split_at_row(d_pattern_start);
2581                let (Ld_mid, Ld_bot) = Ld_mid_bot.split_at_row(d_pattern_mid_len);
2582                let D = Ld_top.diagonal().column_vector();
2583
2584                let stack = stack.rb_mut();
2585
2586                let (tmp, stack) =
2587                    temp_mat_uninit::<E>(Ld_mid_bot.nrows(), d_pattern_mid_len, stack);
2588                let (tmp2, _) = temp_mat_uninit::<E>(Ld_mid.ncols(), Ld_mid.nrows(), stack);
2589                let mut Ld_mid_x_D = tmp2.transpose_mut();
2590
2591                for i in 0..d_pattern_mid_len {
2592                    for j in 0..d_ncols {
2593                        Ld_mid_x_D.write(
2594                            i,
2595                            j,
2596                            Ld_mid
2597                                .read(i, j)
2598                                .faer_scale_real(D.read(j).faer_real().faer_inv()),
2599                        );
2600                    }
2601                }
2602
2603                let (mut tmp_top, mut tmp_bot) = tmp.split_at_row_mut(d_pattern_mid_len);
2604
2605                use faer_core::{mul, mul::triangular};
2606                triangular::matmul(
2607                    tmp_top.rb_mut(),
2608                    triangular::BlockStructure::TriangularLower,
2609                    Ld_mid,
2610                    triangular::BlockStructure::Rectangular,
2611                    Ld_mid_x_D.rb().adjoint(),
2612                    triangular::BlockStructure::Rectangular,
2613                    None,
2614                    E::faer_one(),
2615                    parallelism,
2616                );
2617                mul::matmul(
2618                    tmp_bot.rb_mut(),
2619                    Ld_bot,
2620                    Ld_mid_x_D.rb().adjoint(),
2621                    None,
2622                    E::faer_one(),
2623                    parallelism,
2624                );
2625                for (j_idx, j) in d_pattern[d_pattern_start..d_pattern_mid].iter().enumerate() {
2626                    let j = j.zx();
2627                    let j_s = j - s_start;
2628                    for (i_idx, i) in d_pattern[d_pattern_start..d_pattern_mid][j_idx..]
2629                        .iter()
2630                        .enumerate()
2631                    {
2632                        let i_idx = i_idx + j_idx;
2633
2634                        let i = i.zx();
2635                        let i_s = i - s_start;
2636
2637                        debug_assert!(i_s >= j_s);
2638
2639                        unsafe {
2640                            Ls.write_unchecked(
2641                                i_s,
2642                                j_s,
2643                                Ls.read_unchecked(i_s, j_s)
2644                                    .faer_sub(tmp_top.read_unchecked(i_idx, j_idx)),
2645                            )
2646                        };
2647                    }
2648                }
2649
2650                for (j_idx, j) in d_pattern[d_pattern_start..d_pattern_mid].iter().enumerate() {
2651                    let j = j.zx();
2652                    let j_s = j - s_start;
2653                    for (i_idx, i) in d_pattern[d_pattern_mid..].iter().enumerate() {
2654                        let i = i.zx();
2655                        let i_s = global_to_local[i].zx();
2656                        unsafe {
2657                            Ls.write_unchecked(
2658                                i_s,
2659                                j_s,
2660                                Ls.read_unchecked(i_s, j_s)
2661                                    .faer_sub(tmp_bot.read_unchecked(i_idx, j_idx)),
2662                            )
2663                        };
2664                    }
2665                }
2666            }
2667
2668            let (mut Ls_top, mut Ls_bot) = Ls.rb_mut().split_at_row_mut(s_ncols);
2669
2670            let params = Default::default();
2671            dynamic_regularization_count +=
2672                faer_cholesky::ldlt_diagonal::compute::raw_cholesky_in_place(
2673                    Ls_top.rb_mut(),
2674                    LdltRegularization {
2675                        dynamic_regularization_signs: regularization
2676                            .dynamic_regularization_signs
2677                            .map(|signs| &signs[s_start..s_end]),
2678                        ..regularization
2679                    },
2680                    parallelism,
2681                    stack.rb_mut(),
2682                    params,
2683                )
2684                .dynamic_regularization_count;
2685            zipped!(Ls_top.rb_mut())
2686                .for_each_triangular_upper(faer_core::zip::Diag::Skip, |unzipped!(mut x)| {
2687                    x.write(E::faer_zero())
2688                });
2689            faer_core::solve::solve_unit_lower_triangular_in_place(
2690                Ls_top.rb().conjugate(),
2691                Ls_bot.rb_mut().transpose_mut(),
2692                parallelism,
2693            );
2694            for j in 0..s_ncols {
2695                let d = Ls_top.read(j, j).faer_real();
2696                for i in 0..s_pattern.len() {
2697                    Ls_bot.write(i, j, Ls_bot.read(i, j).faer_scale_real(d));
2698                }
2699            }
2700
2701            for &row in s_pattern {
2702                global_to_local[row.zx()] = none;
2703            }
2704        }
2705        dynamic_regularization_count
2706    }
2707
2708    pub fn factorize_supernodal_numeric_intranode_bunch_kaufman<I: Index, E: ComplexField>(
2709        L_values: GroupFor<E, &mut [E::Unit]>,
2710        subdiag: GroupFor<E, &mut [E::Unit]>,
2711        perm_forward: &mut [I],
2712        perm_inverse: &mut [I],
2713        A_lower: SparseColMatRef<'_, I, E>,
2714        regularization: BunchKaufmanRegularization<'_, E>,
2715        symbolic: &SymbolicSupernodalCholesky<I>,
2716        parallelism: Parallelism,
2717        stack: PodStack<'_>,
2718    ) -> usize {
2719        let mut regularization = regularization;
2720        let n_supernodes = symbolic.n_supernodes();
2721        let n = symbolic.nrows();
2722        let mut dynamic_regularization_count = 0usize;
2723        let mut L_values = SliceGroupMut::<'_, E>::new(L_values);
2724        let mut subdiag = SliceGroupMut::<'_, E>::new(subdiag);
2725        L_values.fill_zero();
2726
2727        assert!(A_lower.nrows() == n);
2728        assert!(A_lower.ncols() == n);
2729        assert!(perm_forward.len() == n);
2730        assert!(perm_inverse.len() == n);
2731        assert!(subdiag.len() == n);
2732        assert!(L_values.len() == symbolic.len_values());
2733        let slice_group = SliceGroup::<'_, E>::new;
2734
2735        let none = I::Signed::truncate(NONE);
2736
2737        let post = &*symbolic.supernode_postorder;
2738        let post_inv = &*symbolic.supernode_postorder_inv;
2739
2740        let desc_count = &*symbolic.descendant_count;
2741
2742        let col_ptr_row = &*symbolic.col_ptrs_for_row_indices;
2743        let col_ptr_val = &*symbolic.col_ptrs_for_values;
2744        let row_ind = &*symbolic.row_indices;
2745
2746        // mapping from global indices to local
2747        let (global_to_local, mut stack) = stack.make_raw::<I::Signed>(n);
2748        mem::fill_none(global_to_local.as_mut());
2749
2750        for s in 0..n_supernodes {
2751            let s_start = symbolic.supernode_begin[s].zx();
2752            let s_end = symbolic.supernode_begin[s + 1].zx();
2753
2754            let s_pattern = &row_ind[col_ptr_row[s].zx()..col_ptr_row[s + 1].zx()];
2755            let s_ncols = s_end - s_start;
2756            let s_nrows = s_pattern.len() + s_ncols;
2757
2758            for (i, &row) in s_pattern.iter().enumerate() {
2759                global_to_local[row.zx()] = I::Signed::truncate(i + s_ncols);
2760            }
2761
2762            let (head, tail) = L_values.rb_mut().split_at(col_ptr_val[s].zx());
2763            let head = head.rb();
2764            let mut Ls = faer_core::mat::from_column_major_slice_mut::<'_, E>(
2765                tail.subslice(0..(col_ptr_val[s + 1] - col_ptr_val[s]).zx())
2766                    .into_inner(),
2767                s_nrows,
2768                s_ncols,
2769            );
2770
2771            for j in s_start..s_end {
2772                let j_shifted = j - s_start;
2773                for (i, val) in zip(
2774                    A_lower.row_indices_of_col(j),
2775                    slice_group(A_lower.values_of_col(j)).into_ref_iter(),
2776                ) {
2777                    let val = val.read();
2778                    let (ix, iy) = if i >= s_end {
2779                        (global_to_local[i].sx(), j_shifted)
2780                    } else {
2781                        (i - s_start, j_shifted)
2782                    };
2783                    Ls.write(ix, iy, Ls.read(ix, iy).faer_add(val));
2784                }
2785            }
2786
2787            let s_postordered = post_inv[s].zx();
2788            let desc_count = desc_count[s].zx();
2789            for d in &post[s_postordered - desc_count..s_postordered] {
2790                let d = d.zx();
2791                let d_start = symbolic.supernode_begin[d].zx();
2792                let d_end = symbolic.supernode_begin[d + 1].zx();
2793
2794                let d_pattern = &row_ind[col_ptr_row[d].zx()..col_ptr_row[d + 1].zx()];
2795                let d_ncols = d_end - d_start;
2796                let d_nrows = d_pattern.len() + d_ncols;
2797
2798                let Ld = faer_core::mat::from_column_major_slice::<'_, E>(
2799                    head.subslice(col_ptr_val[d].zx()..col_ptr_val[d + 1].zx())
2800                        .into_inner(),
2801                    d_nrows,
2802                    d_ncols,
2803                );
2804
2805                let d_pattern_start = d_pattern.partition_point(partition_fn(s_start));
2806                let d_pattern_mid_len =
2807                    d_pattern[d_pattern_start..].partition_point(partition_fn(s_end));
2808                let d_pattern_mid = d_pattern_start + d_pattern_mid_len;
2809
2810                let (Ld_top, Ld_mid_bot) = Ld.split_at_row(d_ncols);
2811                let (_, Ld_mid_bot) = Ld_mid_bot.split_at_row(d_pattern_start);
2812                let (Ld_mid, Ld_bot) = Ld_mid_bot.split_at_row(d_pattern_mid_len);
2813                let d_subdiag = subdiag.rb().subslice(d_start..d_start + d_ncols);
2814
2815                let stack = stack.rb_mut();
2816
2817                let (tmp, stack) =
2818                    temp_mat_uninit::<E>(Ld_mid_bot.nrows(), d_pattern_mid_len, stack);
2819                let (tmp2, _) = temp_mat_uninit::<E>(Ld_mid.ncols(), Ld_mid.nrows(), stack);
2820                let mut Ld_mid_x_D = tmp2.transpose_mut();
2821
2822                let mut j = 0;
2823                while j < d_ncols {
2824                    let subdiag = d_subdiag.read(j);
2825                    if subdiag == E::faer_zero() {
2826                        let d = Ld_top.read(j, j).faer_real().faer_inv();
2827                        for i in 0..d_pattern_mid_len {
2828                            Ld_mid_x_D.write(i, j, Ld_mid.read(i, j).faer_scale_real(d));
2829                        }
2830                        j += 1;
2831                    } else {
2832                        // 1/d21
2833                        let akp1k = subdiag.faer_inv();
2834                        // d11/d21
2835                        let ak = akp1k.faer_scale_real(Ld_top.read(j, j).faer_real());
2836                        // d22/conj(d21)
2837                        let akp1 = akp1k
2838                            .faer_conj()
2839                            .faer_scale_real(Ld_top.read(j + 1, j + 1).faer_real());
2840
2841                        // (d11 * d21 / |d21|^2  -  1)^-1
2842                        // = |d21|^2 / ( d11 * d21 - |d21|^2 )
2843                        let denom = ak
2844                            .faer_mul(akp1)
2845                            .faer_real()
2846                            .faer_sub(E::Real::faer_one())
2847                            .faer_inv();
2848
2849                        for i in 0..d_pattern_mid_len {
2850                            // x1 / d21
2851                            let xk = Ld_mid.read(i, j).faer_mul(akp1k);
2852                            // x2 / conj(d21)
2853                            let xkp1 = Ld_mid.read(i, j + 1).faer_mul(akp1k.faer_conj());
2854
2855                            // d22/conj(d21) * x1/d21 * |d21|^2 / (d11 * d21 - |d21|^2)
2856                            // - x2/conj(d21) * |d21|^2 / (d11 * d21 - |d21|^2)
2857                            //
2858                            // =  x1 * d22/det - x2 * d21/det
2859                            Ld_mid_x_D.write(
2860                                i,
2861                                j,
2862                                (akp1.faer_mul(xk).faer_sub(xkp1)).faer_scale_real(denom),
2863                            );
2864                            Ld_mid_x_D.write(
2865                                i,
2866                                j + 1,
2867                                (ak.faer_mul(xkp1).faer_sub(xk)).faer_scale_real(denom),
2868                            );
2869                        }
2870                        j += 2;
2871                    }
2872                }
2873
2874                let (mut tmp_top, mut tmp_bot) = tmp.split_at_row_mut(d_pattern_mid_len);
2875
2876                use faer_core::{mul, mul::triangular};
2877                triangular::matmul(
2878                    tmp_top.rb_mut(),
2879                    triangular::BlockStructure::TriangularLower,
2880                    Ld_mid,
2881                    triangular::BlockStructure::Rectangular,
2882                    Ld_mid_x_D.rb().adjoint(),
2883                    triangular::BlockStructure::Rectangular,
2884                    None,
2885                    E::faer_one(),
2886                    parallelism,
2887                );
2888                mul::matmul(
2889                    tmp_bot.rb_mut(),
2890                    Ld_bot,
2891                    Ld_mid_x_D.rb().adjoint(),
2892                    None,
2893                    E::faer_one(),
2894                    parallelism,
2895                );
2896
2897                for (j_idx, j) in d_pattern[d_pattern_start..d_pattern_mid].iter().enumerate() {
2898                    let j = j.zx();
2899                    let j_s = j - s_start;
2900                    for (i_idx, i) in d_pattern[d_pattern_start..d_pattern_mid][j_idx..]
2901                        .iter()
2902                        .enumerate()
2903                    {
2904                        let i_idx = i_idx + j_idx;
2905
2906                        let i = i.zx();
2907                        let i_s = i - s_start;
2908
2909                        debug_assert!(i_s >= j_s);
2910                        unsafe {
2911                            Ls.write_unchecked(
2912                                i_s,
2913                                j_s,
2914                                Ls.read_unchecked(i_s, j_s)
2915                                    .faer_sub(tmp_top.read_unchecked(i_idx, j_idx)),
2916                            )
2917                        };
2918                    }
2919                }
2920
2921                for (j_idx, j) in d_pattern[d_pattern_start..d_pattern_mid].iter().enumerate() {
2922                    let j = j.zx();
2923                    let j_s = j - s_start;
2924                    for (i_idx, i) in d_pattern[d_pattern_mid..].iter().enumerate() {
2925                        let i = i.zx();
2926                        let i_s = global_to_local[i].zx();
2927                        unsafe {
2928                            Ls.write_unchecked(
2929                                i_s,
2930                                j_s,
2931                                Ls.read_unchecked(i_s, j_s)
2932                                    .faer_sub(tmp_bot.read_unchecked(i_idx, j_idx)),
2933                            )
2934                        };
2935                    }
2936                }
2937            }
2938
2939            let (mut Ls_top, mut Ls_bot) = Ls.rb_mut().split_at_row_mut(s_ncols);
2940            let mut s_subdiag = subdiag.rb_mut().subslice(s_start..s_end);
2941
2942            let params = Default::default();
2943            let (info, perm) = faer_cholesky::bunch_kaufman::compute::cholesky_in_place(
2944                Ls_top.rb_mut(),
2945                faer_core::mat::from_column_major_slice_mut::<'_, E>(
2946                    s_subdiag.rb_mut().into_inner(),
2947                    s_ncols,
2948                    1,
2949                ),
2950                BunchKaufmanRegularization {
2951                    dynamic_regularization_signs: regularization
2952                        .dynamic_regularization_signs
2953                        .rb_mut()
2954                        .map(|signs| &mut signs[s_start..s_end]),
2955                    ..regularization
2956                },
2957                &mut perm_forward[s_start..s_end],
2958                &mut perm_inverse[s_start..s_end],
2959                parallelism,
2960                stack.rb_mut(),
2961                params,
2962            );
2963            dynamic_regularization_count += info.dynamic_regularization_count;
2964            zipped!(Ls_top.rb_mut())
2965                .for_each_triangular_upper(faer_core::zip::Diag::Skip, |unzipped!(mut x)| {
2966                    x.write(E::faer_zero())
2967                });
2968
2969            faer_core::permutation::permute_cols_in_place(
2970                Ls_bot.rb_mut(),
2971                perm.rb(),
2972                stack.rb_mut(),
2973            );
2974
2975            for p in &mut perm_forward[s_start..s_end] {
2976                *p += I::truncate(s_start);
2977            }
2978            for p in &mut perm_inverse[s_start..s_end] {
2979                *p += I::truncate(s_start);
2980            }
2981
2982            faer_core::solve::solve_unit_lower_triangular_in_place(
2983                Ls_top.rb().conjugate(),
2984                Ls_bot.rb_mut().transpose_mut(),
2985                parallelism,
2986            );
2987
2988            let mut j = 0;
2989            while j < s_ncols {
2990                if s_subdiag.read(j) == E::faer_zero() {
2991                    let d = Ls_top.read(j, j).faer_real();
2992                    for i in 0..s_pattern.len() {
2993                        Ls_bot.write(i, j, Ls_bot.read(i, j).faer_scale_real(d));
2994                    }
2995                    j += 1;
2996                } else {
2997                    let akp1k = s_subdiag.read(j);
2998                    let ak = Ls_top.read(j, j).faer_real();
2999                    let akp1 = Ls_top.read(j + 1, j + 1).faer_real();
3000
3001                    for i in 0..s_pattern.len() {
3002                        let xk = Ls_bot.read(i, j);
3003                        let xkp1 = Ls_bot.read(i, j + 1);
3004
3005                        Ls_bot.write(i, j, xk.faer_scale_real(ak).faer_add(xkp1.faer_mul(akp1k)));
3006                        Ls_bot.write(
3007                            i,
3008                            j + 1,
3009                            xkp1.faer_scale_real(akp1)
3010                                .faer_add(xk.faer_mul(akp1k.faer_conj())),
3011                        );
3012                    }
3013                    j += 2;
3014                }
3015            }
3016
3017            for &row in s_pattern {
3018                global_to_local[row.zx()] = none;
3019            }
3020        }
3021        dynamic_regularization_count
3022    }
3023
3024    #[derive(Debug)]
3025    pub struct SupernodalLltRef<'a, I, E: Entity> {
3026        symbolic: &'a SymbolicSupernodalCholesky<I>,
3027        values: SliceGroup<'a, E>,
3028    }
3029
3030    #[derive(Debug)]
3031    pub struct SupernodalLdltRef<'a, I, E: Entity> {
3032        symbolic: &'a SymbolicSupernodalCholesky<I>,
3033        values: SliceGroup<'a, E>,
3034    }
3035
3036    #[derive(Debug)]
3037    pub struct SupernodalIntranodeBunchKaufmanRef<'a, I, E: Entity> {
3038        symbolic: &'a SymbolicSupernodalCholesky<I>,
3039        values: SliceGroup<'a, E>,
3040        subdiag: SliceGroup<'a, E>,
3041        pub(super) perm: PermutationRef<'a, I, E>,
3042    }
3043
3044    #[derive(Debug)]
3045    pub struct SymbolicSupernodalCholesky<I> {
3046        pub(crate) dimension: usize,
3047        pub(crate) supernode_postorder: alloc::vec::Vec<I>,
3048        pub(crate) supernode_postorder_inv: alloc::vec::Vec<I>,
3049        pub(crate) descendant_count: alloc::vec::Vec<I>,
3050
3051        pub(crate) supernode_begin: alloc::vec::Vec<I>,
3052        pub(crate) col_ptrs_for_row_indices: alloc::vec::Vec<I>,
3053        pub(crate) col_ptrs_for_values: alloc::vec::Vec<I>,
3054        pub(crate) row_indices: alloc::vec::Vec<I>,
3055    }
3056}
3057
3058// workspace: I×(n)
3059fn ghost_prefactorize_symbolic_cholesky<'n, 'out, I: Index>(
3060    etree: &'out mut Array<'n, I::Signed>,
3061    col_counts: &mut Array<'n, I>,
3062    A: ghost::SymbolicSparseColMatRef<'n, 'n, '_, I>,
3063    stack: PodStack<'_>,
3064) -> &'out mut Array<'n, MaybeIdx<'n, I>> {
3065    let N = A.ncols();
3066    let (visited, _) = stack.make_raw::<I>(*N);
3067    let etree = Array::from_mut(ghost::fill_none::<I>(etree.as_mut(), N), N);
3068    let visited = Array::from_mut(visited, N);
3069
3070    for j in N.indices() {
3071        let j_ = j.truncate::<I>();
3072        visited[j] = *j_;
3073        col_counts[j] = I::truncate(1);
3074
3075        for mut i in A.row_indices_of_col(j) {
3076            if i < j {
3077                loop {
3078                    if visited[i] == *j_ {
3079                        break;
3080                    }
3081
3082                    let next_i = if let Some(parent) = etree[i].idx() {
3083                        parent.zx()
3084                    } else {
3085                        etree[i] = MaybeIdx::from_index(j_);
3086                        j
3087                    };
3088
3089                    col_counts[i] += I::truncate(1);
3090                    visited[i] = *j_;
3091                    i = next_i;
3092                }
3093            }
3094        }
3095    }
3096
3097    etree
3098}
3099
3100#[derive(Debug, Copy, Clone)]
3101#[doc(hidden)]
3102pub struct ComputationModel {
3103    pub ldl: [f64; 4],
3104    pub triangular_solve: [f64; 6],
3105    pub matmul: [f64; 6],
3106    pub assembly: [f64; 4],
3107}
3108
3109impl ComputationModel {
3110    #[allow(clippy::excessive_precision)]
3111    pub const OPENBLAS_I7_1185G7: Self = ComputationModel {
3112        ldl: [
3113            3.527141723946874224e-07,
3114            -5.382557351808083451e-08,
3115            4.677984682984275924e-09,
3116            7.384424667338682676e-12,
3117        ],
3118        triangular_solve: [
3119            1.101115592925888909e-06,
3120            6.936563076265144074e-07,
3121            -1.827661167503034051e-09,
3122            1.959826916788009885e-09,
3123            1.079857543323972179e-09,
3124            2.963338652996178598e-11,
3125        ],
3126        matmul: [
3127            6.14190596709488416e-07,
3128            -4.489948374364910256e-09,
3129            5.943145978912038475e-10,
3130            -1.201283634136652872e-08,
3131            1.266858215451465993e-09,
3132            2.624001993284897048e-11,
3133        ],
3134        assembly: [
3135            3.069607518266660019e-07,
3136            3.763778311956422235e-08,
3137            1.991443920635728855e-07,
3138            3.788938150548870089e-09,
3139        ],
3140    };
3141
3142    #[inline]
3143    pub fn ldl_estimate(&self, n: f64) -> f64 {
3144        let p = self.ldl;
3145        p[0] + n * (p[1] + n * (p[2] + n * p[3]))
3146    }
3147
3148    #[inline]
3149    pub fn triangular_solve_estimate(&self, n: f64, k: f64) -> f64 {
3150        let p = self.triangular_solve;
3151        p[0] + n * (p[1] + n * p[2]) + k * (p[3] + n * (p[4] + n * p[5]))
3152    }
3153
3154    #[inline]
3155    pub fn matmul_estimate(&self, m: f64, n: f64, k: f64) -> f64 {
3156        let p = self.matmul;
3157        p[0] + (m + n) * p[1] + (m * n) * p[2] + k * (p[3] + (m + n) * p[4] + (m * n) * p[5])
3158    }
3159
3160    #[inline]
3161    pub fn assembly_estimate(&self, br: f64, bc: f64) -> f64 {
3162        let p = self.assembly;
3163        p[0] + br * p[1] + bc * p[2] + br * bc * p[3]
3164    }
3165}
3166
3167/// The inner factorization used for the symbolic Cholesky, either simplicial or symbolic.
3168#[derive(Debug)]
3169pub enum SymbolicCholeskyRaw<I> {
3170    Simplicial(simplicial::SymbolicSimplicialCholesky<I>),
3171    Supernodal(supernodal::SymbolicSupernodalCholesky<I>),
3172}
3173
3174/// The symbolic structure of a sparse Cholesky decomposition.
3175#[derive(Debug)]
3176pub struct SymbolicCholesky<I> {
3177    raw: SymbolicCholeskyRaw<I>,
3178    perm_fwd: alloc::vec::Vec<I>,
3179    perm_inv: alloc::vec::Vec<I>,
3180    A_nnz: usize,
3181}
3182
3183impl<I: Index> SymbolicCholesky<I> {
3184    /// Returns the number of rows of the matrix.
3185    #[inline]
3186    pub fn nrows(&self) -> usize {
3187        match &self.raw {
3188            SymbolicCholeskyRaw::Simplicial(this) => this.nrows(),
3189            SymbolicCholeskyRaw::Supernodal(this) => this.nrows(),
3190        }
3191    }
3192
3193    /// Returns the number of columns of the matrix.
3194    #[inline]
3195    pub fn ncols(&self) -> usize {
3196        self.nrows()
3197    }
3198
3199    /// Returns the inner type of the factorization, either simplicial or symbolic.
3200    #[inline]
3201    pub fn raw(&self) -> &SymbolicCholeskyRaw<I> {
3202        &self.raw
3203    }
3204
3205    /// Returns the permutation that was computed during symbolic analysis.
3206    #[inline]
3207    pub fn perm(&self) -> PermutationRef<'_, I, Symbolic> {
3208        unsafe { PermutationRef::new_unchecked(&self.perm_fwd, &self.perm_inv) }
3209    }
3210
3211    /// Returns the length of the slice needed to store the numerical values of the Cholesky
3212    /// decomposition.
3213    #[inline]
3214    pub fn len_values(&self) -> usize {
3215        match &self.raw {
3216            SymbolicCholeskyRaw::Simplicial(this) => this.len_values(),
3217            SymbolicCholeskyRaw::Supernodal(this) => this.len_values(),
3218        }
3219    }
3220
3221    /// Computes the required workspace size and alignment for a numerical LLT factorization.
3222    #[inline]
3223    pub fn factorize_numeric_llt_req<E: Entity>(
3224        &self,
3225        parallelism: Parallelism,
3226    ) -> Result<StackReq, SizeOverflow> {
3227        let n = self.nrows();
3228        let A_nnz = self.A_nnz;
3229
3230        let n_req = StackReq::try_new::<I>(n)?;
3231        let A_req = StackReq::try_all_of([
3232            make_raw_req::<E>(A_nnz)?,
3233            StackReq::try_new::<I>(n + 1)?,
3234            StackReq::try_new::<I>(A_nnz)?,
3235        ])?;
3236        let permute_req = n_req;
3237
3238        let factor_req = match &self.raw {
3239            SymbolicCholeskyRaw::Simplicial(_) => {
3240                simplicial::factorize_simplicial_numeric_llt_req::<I, E>(n)?
3241            }
3242            SymbolicCholeskyRaw::Supernodal(this) => {
3243                supernodal::factorize_supernodal_numeric_llt_req::<I, E>(this, parallelism)?
3244            }
3245        };
3246
3247        StackReq::try_all_of([A_req, StackReq::try_or(permute_req, factor_req)?])
3248    }
3249
3250    /// Computes the required workspace size and alignment for a numerical LDLT factorization.
3251    #[inline]
3252    pub fn factorize_numeric_ldlt_req<E: Entity>(
3253        &self,
3254        with_regularization_signs: bool,
3255        parallelism: Parallelism,
3256    ) -> Result<StackReq, SizeOverflow> {
3257        let n = self.nrows();
3258        let A_nnz = self.A_nnz;
3259
3260        let regularization_signs = if with_regularization_signs {
3261            StackReq::try_new::<i8>(n)?
3262        } else {
3263            StackReq::empty()
3264        };
3265
3266        let n_req = StackReq::try_new::<I>(n)?;
3267        let A_req = StackReq::try_all_of([
3268            make_raw_req::<E>(A_nnz)?,
3269            StackReq::try_new::<I>(n + 1)?,
3270            StackReq::try_new::<I>(A_nnz)?,
3271        ])?;
3272        let permute_req = n_req;
3273
3274        let factor_req = match &self.raw {
3275            SymbolicCholeskyRaw::Simplicial(_) => {
3276                simplicial::factorize_simplicial_numeric_ldlt_req::<I, E>(n)?
3277            }
3278            SymbolicCholeskyRaw::Supernodal(this) => {
3279                supernodal::factorize_supernodal_numeric_ldlt_req::<I, E>(this, parallelism)?
3280            }
3281        };
3282
3283        StackReq::try_all_of([
3284            regularization_signs,
3285            A_req,
3286            StackReq::try_or(permute_req, factor_req)?,
3287        ])
3288    }
3289
3290    /// Computes the required workspace size and alignment for a numerical intranodal Bunch-Kaufman
3291    /// factorization.
3292    #[inline]
3293    pub fn factorize_numeric_intranode_bunch_kaufman_req<E: Entity>(
3294        &self,
3295        with_regularization_signs: bool,
3296        parallelism: Parallelism,
3297    ) -> Result<StackReq, SizeOverflow> {
3298        let n = self.nrows();
3299        let A_nnz = self.A_nnz;
3300
3301        let regularization_signs = if with_regularization_signs {
3302            StackReq::try_new::<i8>(n)?
3303        } else {
3304            StackReq::empty()
3305        };
3306
3307        let n_req = StackReq::try_new::<I>(n)?;
3308        let A_req = StackReq::try_all_of([
3309            make_raw_req::<E>(A_nnz)?,
3310            StackReq::try_new::<I>(n + 1)?,
3311            StackReq::try_new::<I>(A_nnz)?,
3312        ])?;
3313        let permute_req = n_req;
3314
3315        let factor_req = match &self.raw {
3316            SymbolicCholeskyRaw::Simplicial(_) => {
3317                simplicial::factorize_simplicial_numeric_ldlt_req::<I, E>(n)?
3318            }
3319            SymbolicCholeskyRaw::Supernodal(this) => {
3320                supernodal::factorize_supernodal_numeric_intranode_bunch_kaufman_req::<I, E>(
3321                    this,
3322                    parallelism,
3323                )?
3324            }
3325        };
3326
3327        StackReq::try_all_of([
3328            regularization_signs,
3329            A_req,
3330            StackReq::try_or(permute_req, factor_req)?,
3331        ])
3332    }
3333
3334    /// Computes a numerical LLT factorization of A, or returns a [`CholeskyError`] if the matrix
3335    /// is not numerically positive definite.
3336    #[track_caller]
3337    pub fn factorize_numeric_llt<'out, E: ComplexField>(
3338        &'out self,
3339        L_values: GroupFor<E, &'out mut [E::Unit]>,
3340        A: SparseColMatRef<'_, I, E>,
3341        side: Side,
3342        regularization: LltRegularization<E>,
3343        parallelism: Parallelism,
3344        stack: PodStack<'_>,
3345    ) -> Result<LltRef<'out, I, E>, CholeskyError> {
3346        assert!(A.nrows() == A.ncols());
3347        let n = A.nrows();
3348        let mut L_values = L_values;
3349
3350        ghost::with_size(n, |N| {
3351            let A_nnz = self.A_nnz;
3352            let A = ghost::SparseColMatRef::new(A, N, N);
3353
3354            let perm = ghost::PermutationRef::new(self.perm(), N);
3355
3356            let (mut new_values, stack) = crate::make_raw::<E>(A_nnz, stack);
3357            let (new_col_ptr, stack) = stack.make_raw::<I>(n + 1);
3358            let (new_row_ind, mut stack) = stack.make_raw::<I>(A_nnz);
3359
3360            let out_side = match &self.raw {
3361                SymbolicCholeskyRaw::Simplicial(_) => Side::Upper,
3362                SymbolicCholeskyRaw::Supernodal(_) => Side::Lower,
3363            };
3364
3365            let A = unsafe {
3366                ghost_permute_hermitian_unsorted(
3367                    new_values.rb_mut(),
3368                    new_col_ptr,
3369                    new_row_ind,
3370                    A,
3371                    perm.cast(),
3372                    side,
3373                    out_side,
3374                    false,
3375                    stack.rb_mut(),
3376                )
3377            };
3378
3379            match &self.raw {
3380                SymbolicCholeskyRaw::Simplicial(this) => {
3381                    simplicial::factorize_simplicial_numeric_llt(
3382                        E::faer_rb_mut(E::faer_as_mut(&mut L_values)),
3383                        A.into_inner().into_const(),
3384                        regularization,
3385                        this,
3386                        stack,
3387                    )?;
3388                }
3389                SymbolicCholeskyRaw::Supernodal(this) => {
3390                    supernodal::factorize_supernodal_numeric_llt(
3391                        E::faer_rb_mut(E::faer_as_mut(&mut L_values)),
3392                        A.into_inner().into_const(),
3393                        regularization,
3394                        this,
3395                        parallelism,
3396                        stack,
3397                    )?;
3398                }
3399            }
3400
3401            Ok(LltRef::<'out, I, E>::new(
3402                self,
3403                E::faer_into_const(L_values),
3404            ))
3405        })
3406    }
3407
3408    /// Computes a numerical LDLT factorization of A.
3409    #[inline]
3410    pub fn factorize_numeric_ldlt<'out, E: ComplexField>(
3411        &'out self,
3412        L_values: GroupFor<E, &'out mut [E::Unit]>,
3413        A: SparseColMatRef<'_, I, E>,
3414        side: Side,
3415        regularization: LdltRegularization<'_, E>,
3416        parallelism: Parallelism,
3417        stack: PodStack<'_>,
3418    ) -> LdltRef<'out, I, E> {
3419        assert!(A.nrows() == A.ncols());
3420        let n = A.nrows();
3421        let mut L_values = L_values;
3422
3423        ghost::with_size(n, |N| {
3424            let A_nnz = self.A_nnz;
3425            let A = ghost::SparseColMatRef::new(A, N, N);
3426
3427            let (new_signs, stack) =
3428                stack.make_raw::<i8>(if regularization.dynamic_regularization_signs.is_some() {
3429                    n
3430                } else {
3431                    0
3432                });
3433
3434            let perm = ghost::PermutationRef::new(self.perm(), N);
3435            let fwd = perm.into_arrays().0;
3436            let signs = regularization.dynamic_regularization_signs.map(|signs| {
3437                {
3438                    let new_signs = Array::from_mut(new_signs, N);
3439                    let signs = Array::from_ref(signs, N);
3440                    for i in N.indices() {
3441                        new_signs[i] = signs[fwd[i].zx()];
3442                    }
3443                }
3444                &*new_signs
3445            });
3446            let regularization = LdltRegularization {
3447                dynamic_regularization_signs: signs,
3448                ..regularization
3449            };
3450
3451            let (mut new_values, stack) = crate::make_raw::<E>(A_nnz, stack);
3452            let (new_col_ptr, stack) = stack.make_raw::<I>(n + 1);
3453            let (new_row_ind, mut stack) = stack.make_raw::<I>(A_nnz);
3454
3455            let out_side = match &self.raw {
3456                SymbolicCholeskyRaw::Simplicial(_) => Side::Upper,
3457                SymbolicCholeskyRaw::Supernodal(_) => Side::Lower,
3458            };
3459
3460            let A = unsafe {
3461                ghost_permute_hermitian_unsorted(
3462                    new_values.rb_mut(),
3463                    new_col_ptr,
3464                    new_row_ind,
3465                    A,
3466                    perm.cast(),
3467                    side,
3468                    out_side,
3469                    false,
3470                    stack.rb_mut(),
3471                )
3472            };
3473
3474            match &self.raw {
3475                SymbolicCholeskyRaw::Simplicial(this) => {
3476                    simplicial::factorize_simplicial_numeric_ldlt(
3477                        E::faer_rb_mut(E::faer_as_mut(&mut L_values)),
3478                        A.into_inner().into_const(),
3479                        regularization,
3480                        this,
3481                        stack,
3482                    );
3483                }
3484                SymbolicCholeskyRaw::Supernodal(this) => {
3485                    supernodal::factorize_supernodal_numeric_ldlt(
3486                        E::faer_rb_mut(E::faer_as_mut(&mut L_values)),
3487                        A.into_inner().into_const(),
3488                        regularization,
3489                        this,
3490                        parallelism,
3491                        stack,
3492                    );
3493                }
3494            }
3495
3496            LdltRef::<'out, I, E>::new(self, E::faer_into_const(L_values))
3497        })
3498    }
3499
3500    /// Computes a numerical intranodal Bunch-Kaufman factorization of A.
3501    #[inline]
3502    pub fn factorize_numeric_intranode_bunch_kaufman<'out, E: ComplexField>(
3503        &'out self,
3504        L_values: GroupFor<E, &'out mut [E::Unit]>,
3505        subdiag: GroupFor<E, &'out mut [E::Unit]>,
3506        perm_forward: &'out mut [I],
3507        perm_inverse: &'out mut [I],
3508        A: SparseColMatRef<'_, I, E>,
3509        side: Side,
3510        regularization: LdltRegularization<'_, E>,
3511        parallelism: Parallelism,
3512        stack: PodStack<'_>,
3513    ) -> IntranodeBunchKaufmanRef<'out, I, E> {
3514        assert!(A.nrows() == A.ncols());
3515        let n = A.nrows();
3516        let mut L_values = L_values;
3517        let mut subdiag = subdiag;
3518
3519        ghost::with_size(n, move |N| {
3520            let A_nnz = self.A_nnz;
3521            let A = ghost::SparseColMatRef::new(A, N, N);
3522
3523            let (new_signs, stack) =
3524                stack.make_raw::<i8>(if regularization.dynamic_regularization_signs.is_some() {
3525                    n
3526                } else {
3527                    0
3528                });
3529
3530            let static_perm = ghost::PermutationRef::new(self.perm(), N);
3531            let signs = regularization.dynamic_regularization_signs.map(|signs| {
3532                {
3533                    let fwd = static_perm.into_arrays().0;
3534                    let new_signs = Array::from_mut(new_signs, N);
3535                    let signs = Array::from_ref(signs, N);
3536                    for i in N.indices() {
3537                        new_signs[i] = signs[fwd[i].zx()];
3538                    }
3539                }
3540                &mut *new_signs
3541            });
3542
3543            let (mut new_values, stack) = crate::make_raw::<E>(A_nnz, stack);
3544            let (new_col_ptr, stack) = stack.make_raw::<I>(n + 1);
3545            let (new_row_ind, mut stack) = stack.make_raw::<I>(A_nnz);
3546
3547            let out_side = match &self.raw {
3548                SymbolicCholeskyRaw::Simplicial(_) => Side::Upper,
3549                SymbolicCholeskyRaw::Supernodal(_) => Side::Lower,
3550            };
3551
3552            let A = unsafe {
3553                ghost_permute_hermitian_unsorted(
3554                    new_values.rb_mut(),
3555                    new_col_ptr,
3556                    new_row_ind,
3557                    A,
3558                    static_perm.cast(),
3559                    side,
3560                    out_side,
3561                    false,
3562                    stack.rb_mut(),
3563                )
3564            };
3565
3566            match &self.raw {
3567                SymbolicCholeskyRaw::Simplicial(this) => {
3568                    let regularization = LdltRegularization {
3569                        dynamic_regularization_signs: signs.rb(),
3570                        dynamic_regularization_delta: regularization.dynamic_regularization_delta,
3571                        dynamic_regularization_epsilon: regularization
3572                            .dynamic_regularization_epsilon,
3573                    };
3574                    for (i, p) in perm_forward.iter_mut().enumerate() {
3575                        *p = I::truncate(i);
3576                    }
3577                    for (i, p) in perm_inverse.iter_mut().enumerate() {
3578                        *p = I::truncate(i);
3579                    }
3580                    simplicial::factorize_simplicial_numeric_ldlt(
3581                        E::faer_rb_mut(E::faer_as_mut(&mut L_values)),
3582                        A.into_inner().into_const(),
3583                        regularization,
3584                        this,
3585                        stack,
3586                    );
3587                }
3588                SymbolicCholeskyRaw::Supernodal(this) => {
3589                    let regularization = BunchKaufmanRegularization {
3590                        dynamic_regularization_signs: signs,
3591                        dynamic_regularization_delta: regularization.dynamic_regularization_delta,
3592                        dynamic_regularization_epsilon: regularization
3593                            .dynamic_regularization_epsilon,
3594                    };
3595
3596                    supernodal::factorize_supernodal_numeric_intranode_bunch_kaufman(
3597                        E::faer_rb_mut(E::faer_as_mut(&mut L_values)),
3598                        E::faer_rb_mut(E::faer_as_mut(&mut subdiag)),
3599                        perm_forward,
3600                        perm_inverse,
3601                        A.into_inner().into_const(),
3602                        regularization,
3603                        this,
3604                        parallelism,
3605                        stack,
3606                    );
3607                }
3608            }
3609
3610            IntranodeBunchKaufmanRef::<'out, I, E>::new(
3611                self,
3612                E::faer_into_const(L_values),
3613                E::faer_into_const(subdiag),
3614                unsafe { PermutationRef::<'out, I, E>::new_unchecked(perm_forward, perm_inverse) },
3615            )
3616        })
3617    }
3618
3619    /// Computes the required workspace size and alignment for a dense solve in place using an LLT,
3620    /// LDLT or intranodal Bunch-Kaufman factorization.
3621    pub fn solve_in_place_req<E: Entity>(
3622        &self,
3623        rhs_ncols: usize,
3624    ) -> Result<StackReq, SizeOverflow> {
3625        temp_mat_req::<E>(self.nrows(), rhs_ncols)?.try_and(match self.raw() {
3626            SymbolicCholeskyRaw::Simplicial(this) => this.solve_in_place_req::<E>(rhs_ncols)?,
3627            SymbolicCholeskyRaw::Supernodal(this) => this.solve_in_place_req::<E>(rhs_ncols)?,
3628        })
3629    }
3630}
3631
3632/// Sparse LLT factorization wrapper.
3633#[derive(Debug)]
3634pub struct LltRef<'a, I: Index, E: Entity> {
3635    symbolic: &'a SymbolicCholesky<I>,
3636    values: SliceGroup<'a, E>,
3637}
3638
3639/// Sparse LDLT factorization wrapper.
3640#[derive(Debug)]
3641pub struct LdltRef<'a, I: Index, E: Entity> {
3642    symbolic: &'a SymbolicCholesky<I>,
3643    values: SliceGroup<'a, E>,
3644}
3645
3646/// Sparse intranodal Bunch-Kaufman factorization wrapper.
3647#[derive(Debug)]
3648pub struct IntranodeBunchKaufmanRef<'a, I: Index, E: Entity> {
3649    symbolic: &'a SymbolicCholesky<I>,
3650    values: SliceGroup<'a, E>,
3651    subdiag: SliceGroup<'a, E>,
3652    perm: PermutationRef<'a, I, E>,
3653}
3654
3655impl<'a, I: Index, E: Entity> core::ops::Deref for LltRef<'a, I, E> {
3656    type Target = SymbolicCholesky<I>;
3657    #[inline]
3658    fn deref(&self) -> &Self::Target {
3659        &self.symbolic
3660    }
3661}
3662impl<'a, I: Index, E: Entity> core::ops::Deref for LdltRef<'a, I, E> {
3663    type Target = SymbolicCholesky<I>;
3664    #[inline]
3665    fn deref(&self) -> &Self::Target {
3666        &self.symbolic
3667    }
3668}
3669impl<'a, I: Index, E: Entity> core::ops::Deref for IntranodeBunchKaufmanRef<'a, I, E> {
3670    type Target = SymbolicCholesky<I>;
3671    #[inline]
3672    fn deref(&self) -> &Self::Target {
3673        &self.symbolic
3674    }
3675}
3676
3677impl_copy!(<'a><I><supernodal::SymbolicSupernodeRef<'a, I>>);
3678impl_copy!(<'a><I, E: Entity><supernodal::SupernodeRef<'a, I, E>>);
3679
3680impl_copy!(<'a><I:Index, E: Entity><simplicial::SimplicialLdltRef<'a, I, E>>);
3681impl_copy!(<'a><I:Index, E: Entity><simplicial::SimplicialLltRef<'a, I, E>>);
3682
3683impl_copy!(<'a><I, E: Entity><supernodal::SupernodalLltRef<'a, I, E>>);
3684impl_copy!(<'a><I, E: Entity><supernodal::SupernodalLdltRef<'a, I, E>>);
3685impl_copy!(<'a><I, E: Entity><supernodal::SupernodalIntranodeBunchKaufmanRef<'a, I, E>>);
3686
3687impl_copy!(<'a><I: Index, E: Entity><IntranodeBunchKaufmanRef<'a, I, E>>);
3688impl_copy!(<'a><I: Index, E: Entity><LdltRef<'a, I, E>>);
3689impl_copy!(<'a><I: Index, E: Entity><LltRef<'a, I, E>>);
3690
3691impl<'a, I: Index, E: Entity> IntranodeBunchKaufmanRef<'a, I, E> {
3692    #[inline]
3693    pub fn new(
3694        symbolic: &'a SymbolicCholesky<I>,
3695        values: GroupFor<E, &'a [E::Unit]>,
3696        subdiag: GroupFor<E, &'a [E::Unit]>,
3697        perm: PermutationRef<'a, I, E>,
3698    ) -> Self {
3699        let values = SliceGroup::<'_, E>::new(values);
3700        let subdiag = SliceGroup::<'_, E>::new(subdiag);
3701        assert!(symbolic.len_values() == values.len());
3702        Self {
3703            symbolic,
3704            values,
3705            subdiag,
3706            perm,
3707        }
3708    }
3709
3710    #[inline]
3711    pub fn symbolic(self) -> &'a SymbolicCholesky<I> {
3712        self.symbolic
3713    }
3714
3715    pub fn solve_in_place_with_conj(
3716        &self,
3717        conj: Conj,
3718        rhs: MatMut<'_, E>,
3719        parallelism: Parallelism,
3720        stack: PodStack<'_>,
3721    ) where
3722        E: ComplexField,
3723    {
3724        let k = rhs.ncols();
3725        let n = self.symbolic.nrows();
3726
3727        let mut rhs = rhs;
3728
3729        let (mut x, stack) = temp_mat_uninit::<E>(n, k, stack);
3730        let (fwd, inv) = self.symbolic.perm().into_arrays();
3731
3732        match self.symbolic.raw() {
3733            SymbolicCholeskyRaw::Simplicial(symbolic) => {
3734                let this = simplicial::SimplicialLdltRef::new(symbolic, self.values.into_inner());
3735
3736                for j in 0..k {
3737                    for (i, fwd) in fwd.iter().enumerate() {
3738                        x.write(i, j, rhs.read(fwd.zx().zx(), j));
3739                    }
3740                }
3741                this.solve_in_place_with_conj(conj, x.rb_mut(), parallelism, stack);
3742                for j in 0..k {
3743                    for (i, inv) in inv.iter().enumerate() {
3744                        rhs.write(i, j, x.read(inv.zx().zx(), j));
3745                    }
3746                }
3747            }
3748            SymbolicCholeskyRaw::Supernodal(symbolic) => {
3749                let (dyn_fwd, dyn_inv) = self.perm.into_arrays();
3750                for j in 0..k {
3751                    for (i, dyn_fwd) in dyn_fwd.iter().enumerate() {
3752                        x.write(i, j, rhs.read(fwd[dyn_fwd.zx()].zx(), j));
3753                    }
3754                }
3755
3756                let this = supernodal::SupernodalIntranodeBunchKaufmanRef::new(
3757                    symbolic,
3758                    self.values.into_inner(),
3759                    self.subdiag.into_inner(),
3760                    self.perm,
3761                );
3762                this.solve_in_place_no_numeric_permute_with_conj(
3763                    conj,
3764                    x.rb_mut(),
3765                    parallelism,
3766                    stack,
3767                );
3768
3769                for j in 0..k {
3770                    for (i, inv) in inv.iter().enumerate() {
3771                        rhs.write(i, j, x.read(dyn_inv[inv.zx()].zx(), j));
3772                    }
3773                }
3774            }
3775        }
3776    }
3777}
3778
3779impl<'a, I: Index, E: Entity> LltRef<'a, I, E> {
3780    #[inline]
3781    pub fn new(symbolic: &'a SymbolicCholesky<I>, values: GroupFor<E, &'a [E::Unit]>) -> Self {
3782        let values = SliceGroup::<'_, E>::new(values);
3783        assert!(symbolic.len_values() == values.len());
3784        Self { symbolic, values }
3785    }
3786
3787    #[inline]
3788    pub fn symbolic(self) -> &'a SymbolicCholesky<I> {
3789        self.symbolic
3790    }
3791
3792    pub fn solve_in_place_with_conj(
3793        &self,
3794        conj: Conj,
3795        rhs: MatMut<'_, E>,
3796        parallelism: Parallelism,
3797        stack: PodStack<'_>,
3798    ) where
3799        E: ComplexField,
3800    {
3801        let k = rhs.ncols();
3802        let n = self.symbolic.nrows();
3803
3804        let mut rhs = rhs;
3805
3806        let (mut x, stack) = temp_mat_uninit::<E>(n, k, stack);
3807
3808        let (fwd, inv) = self.symbolic.perm().into_arrays();
3809        for j in 0..k {
3810            for (i, fwd) in fwd.iter().enumerate() {
3811                x.write(i, j, rhs.read(fwd.zx(), j));
3812            }
3813        }
3814
3815        match self.symbolic.raw() {
3816            SymbolicCholeskyRaw::Simplicial(symbolic) => {
3817                let this = simplicial::SimplicialLltRef::new(symbolic, self.values.into_inner());
3818                this.solve_in_place_with_conj(conj, x.rb_mut(), parallelism, stack);
3819            }
3820            SymbolicCholeskyRaw::Supernodal(symbolic) => {
3821                let this = supernodal::SupernodalLltRef::new(symbolic, self.values.into_inner());
3822                this.solve_in_place_with_conj(conj, x.rb_mut(), parallelism, stack);
3823            }
3824        }
3825
3826        for j in 0..k {
3827            for (i, inv) in inv.iter().enumerate() {
3828                rhs.write(i, j, x.read(inv.zx(), j));
3829            }
3830        }
3831    }
3832}
3833
3834impl<'a, I: Index, E: Entity> LdltRef<'a, I, E> {
3835    #[inline]
3836    pub fn new(symbolic: &'a SymbolicCholesky<I>, values: GroupFor<E, &'a [E::Unit]>) -> Self {
3837        let values = SliceGroup::<'_, E>::new(values);
3838        assert!(symbolic.len_values() == values.len());
3839        Self { symbolic, values }
3840    }
3841
3842    #[inline]
3843    pub fn symbolic(self) -> &'a SymbolicCholesky<I> {
3844        self.symbolic
3845    }
3846
3847    pub fn solve_in_place_with_conj(
3848        &self,
3849        conj: Conj,
3850        rhs: MatMut<'_, E>,
3851        parallelism: Parallelism,
3852        stack: PodStack<'_>,
3853    ) where
3854        E: ComplexField,
3855    {
3856        let k = rhs.ncols();
3857        let n = self.symbolic.nrows();
3858
3859        let mut rhs = rhs;
3860
3861        let (mut x, stack) = temp_mat_uninit::<E>(n, k, stack);
3862
3863        let (fwd, inv) = self.symbolic.perm().into_arrays();
3864        for j in 0..k {
3865            for (i, fwd) in fwd.iter().enumerate() {
3866                x.write(i, j, rhs.read(fwd.zx(), j));
3867            }
3868        }
3869
3870        match self.symbolic.raw() {
3871            SymbolicCholeskyRaw::Simplicial(symbolic) => {
3872                let this = simplicial::SimplicialLdltRef::new(symbolic, self.values.into_inner());
3873                this.solve_in_place_with_conj(conj, x.rb_mut(), parallelism, stack);
3874            }
3875            SymbolicCholeskyRaw::Supernodal(symbolic) => {
3876                let this = supernodal::SupernodalLdltRef::new(symbolic, self.values.into_inner());
3877                this.solve_in_place_with_conj(conj, x.rb_mut(), parallelism, stack);
3878            }
3879        }
3880
3881        for j in 0..k {
3882            for (i, inv) in inv.iter().enumerate() {
3883                rhs.write(i, j, x.read(inv.zx(), j));
3884            }
3885        }
3886    }
3887}
3888
3889fn postorder_depth_first_search<'n, I: Index>(
3890    post: &mut Array<'n, I>,
3891    root: usize,
3892    mut start_index: usize,
3893    stack: &mut Array<'n, I>,
3894    first_child: &mut Array<'n, MaybeIdx<'n, I>>,
3895    next_child: &Array<'n, I::Signed>,
3896) -> usize {
3897    let mut top = 1usize;
3898    let N = post.len();
3899
3900    stack[N.check(0)] = I::truncate(root);
3901    while top != 0 {
3902        let current_node = stack[N.check(top - 1)].zx();
3903        let first_child = &mut first_child[N.check(current_node)];
3904        let current_child = first_child.sx();
3905
3906        if let Some(current_child) = current_child.idx() {
3907            stack[N.check(top)] = *current_child.truncate::<I>();
3908            top += 1;
3909            *first_child = MaybeIdx::new_checked(next_child[current_child], N);
3910        } else {
3911            post[N.check(start_index)] = I::truncate(current_node);
3912            start_index += 1;
3913            top -= 1;
3914        }
3915    }
3916    start_index
3917}
3918
3919pub(crate) fn ghost_postorder<'n, I: Index>(
3920    post: &mut Array<'n, I>,
3921    etree: &Array<'n, MaybeIdx<'n, I>>,
3922    stack: PodStack<'_>,
3923) {
3924    let N = post.len();
3925    let n = *N;
3926
3927    if n == 0 {
3928        return;
3929    }
3930
3931    let (stack_, stack) = stack.make_raw::<I>(n);
3932    let (first_child, stack) = stack.make_raw::<I::Signed>(n);
3933    let (next_child, _) = stack.make_raw::<I::Signed>(n);
3934
3935    let stack = Array::from_mut(stack_, N);
3936    let next_child = Array::from_mut(next_child, N);
3937    let first_child = Array::from_mut(ghost::fill_none::<I>(first_child, N), N);
3938
3939    for j in N.indices().rev() {
3940        let parent = etree[j];
3941        if let Some(parent) = parent.idx() {
3942            let first = &mut first_child[parent.zx()];
3943            next_child[j] = **first;
3944            *first = MaybeIdx::from_index(j.truncate::<I>());
3945        }
3946    }
3947
3948    let mut start_index = 0usize;
3949    for (root, &parent) in etree.as_ref().iter().enumerate() {
3950        if parent.idx().is_none() {
3951            start_index = postorder_depth_first_search(
3952                post,
3953                root,
3954                start_index,
3955                stack,
3956                first_child,
3957                next_child,
3958            );
3959        }
3960    }
3961}
3962
3963#[derive(Copy, Clone, Debug, Default)]
3964pub struct CholeskySymbolicParams<'a> {
3965    pub amd_params: Control,
3966    pub supernodal_flop_ratio_threshold: SupernodalThreshold,
3967    pub supernodal_params: SymbolicSupernodalParams<'a>,
3968}
3969
3970/// Computes the symbolic Cholesky factorization of the matrix `A`, or returns an error if the
3971/// operation could not be completed.
3972pub fn factorize_symbolic_cholesky<I: Index>(
3973    A: SymbolicSparseColMatRef<'_, I>,
3974    side: Side,
3975    params: CholeskySymbolicParams<'_>,
3976) -> Result<SymbolicCholesky<I>, FaerError> {
3977    let n = A.nrows();
3978    let A_nnz = A.compute_nnz();
3979
3980    assert!(A.nrows() == A.ncols());
3981    ghost::with_size(n, |N| {
3982        let A = ghost::SymbolicSparseColMatRef::new(A, N, N);
3983
3984        let req = || -> Result<StackReq, SizeOverflow> {
3985            let n_req = StackReq::try_new::<I>(n)?;
3986            let A_req = StackReq::try_and(
3987                // new_col_ptr
3988                StackReq::try_new::<I>(n + 1)?,
3989                // new_row_ind
3990                StackReq::try_new::<I>(A_nnz)?,
3991            )?;
3992
3993            StackReq::try_or(
3994                amd::order_maybe_unsorted_req::<I>(n, A_nnz)?,
3995                StackReq::try_all_of([
3996                    A_req,
3997                    // permute_symmetric | etree
3998                    n_req,
3999                    // col_counts
4000                    n_req,
4001                    // ghost_prefactorize_symbolic
4002                    n_req,
4003                    // ghost_factorize_*_symbolic
4004                    StackReq::try_or(
4005                        supernodal::factorize_supernodal_symbolic_cholesky_req::<I>(n)?,
4006                        simplicial::factorize_simplicial_symbolic_req::<I>(n)?,
4007                    )?,
4008                ])?,
4009            )
4010        };
4011
4012        let req = req().map_err(nomem)?;
4013        let mut mem = dyn_stack::GlobalPodBuffer::try_new(req).map_err(nomem)?;
4014        let mut stack = PodStack::new(&mut mem);
4015
4016        let mut perm_fwd = try_zeroed(n)?;
4017        let mut perm_inv = try_zeroed(n)?;
4018        let flops = amd::order_maybe_unsorted(
4019            &mut perm_fwd,
4020            &mut perm_inv,
4021            A.into_inner(),
4022            params.amd_params,
4023            stack.rb_mut(),
4024        )?;
4025        let flops = flops.n_div + flops.n_mult_subs_ldl;
4026        let perm_ =
4027            ghost::PermutationRef::new(PermutationRef::new_checked(&perm_fwd, &perm_inv), N);
4028
4029        let (new_col_ptr, stack) = stack.make_raw::<I>(n + 1);
4030        let (new_row_ind, mut stack) = stack.make_raw::<I>(A_nnz);
4031        let A = unsafe {
4032            ghost_permute_hermitian_unsorted_symbolic(
4033                new_col_ptr,
4034                new_row_ind,
4035                A,
4036                perm_,
4037                side,
4038                Side::Upper,
4039                stack.rb_mut(),
4040            )
4041        };
4042
4043        let (etree, stack) = stack.make_raw::<I::Signed>(n);
4044        let (col_counts, mut stack) = stack.make_raw::<I>(n);
4045        let etree = Array::from_mut(etree, N);
4046        let col_counts = Array::from_mut(col_counts, N);
4047        let etree =
4048            &*ghost_prefactorize_symbolic_cholesky::<I>(etree, col_counts, A, stack.rb_mut());
4049        let L_nnz = I::sum_nonnegative(col_counts.as_ref()).ok_or(FaerError::IndexOverflow)?;
4050
4051        let raw = if (flops / L_nnz.zx() as f64)
4052            > params.supernodal_flop_ratio_threshold.0 * crate::CHOLESKY_SUPERNODAL_RATIO_FACTOR
4053        {
4054            SymbolicCholeskyRaw::Supernodal(supernodal::ghost_factorize_supernodal_symbolic(
4055                A,
4056                None,
4057                None,
4058                supernodal::CholeskyInput::A,
4059                etree,
4060                col_counts,
4061                stack.rb_mut(),
4062                params.supernodal_params,
4063            )?)
4064        } else {
4065            SymbolicCholeskyRaw::Simplicial(
4066                simplicial::ghost_factorize_simplicial_symbolic_cholesky(
4067                    A,
4068                    etree,
4069                    col_counts,
4070                    stack.rb_mut(),
4071                )?,
4072            )
4073        };
4074
4075        Ok(SymbolicCholesky {
4076            raw,
4077            perm_fwd,
4078            perm_inv,
4079            A_nnz,
4080        })
4081    })
4082}
4083
4084#[cfg(test)]
4085pub(crate) mod tests {
4086    use super::{supernodal::SupernodalLdltRef, *};
4087    use crate::{
4088        cholesky::supernodal::{CholeskyInput, SupernodalIntranodeBunchKaufmanRef},
4089        qd::Double,
4090    };
4091    use dyn_stack::GlobalPodBuffer;
4092    use faer_core::{assert, Mat};
4093    use num_complex::Complex;
4094    use rand::{Rng, SeedableRng};
4095
4096    fn test_counts<I: Index>() {
4097        let truncate = I::truncate;
4098
4099        let n = 11;
4100        let col_ptr = &[0, 3, 6, 10, 13, 16, 21, 24, 29, 31, 37, 43].map(truncate);
4101        let row_ind = &[
4102            0, 5, 6, // 0
4103            1, 2, 7, // 1
4104            1, 2, 9, 10, // 2
4105            3, 5, 9, // 3
4106            4, 7, 10, // 4
4107            0, 3, 5, 8, 9, // 5
4108            0, 6, 10, // 6
4109            1, 4, 7, 9, 10, // 7
4110            5, 8, // 8
4111            2, 3, 5, 7, 9, 10, // 9
4112            2, 4, 6, 7, 9, 10, // 10
4113        ]
4114        .map(truncate);
4115
4116        let A = SymbolicSparseColMatRef::new_unsorted_checked(n, n, col_ptr, None, row_ind);
4117        let zero = truncate(0);
4118        let mut etree = vec![zero.to_signed(); n];
4119        let mut col_count = vec![zero; n];
4120        ghost::with_size(n, |N| {
4121            let A = ghost::SymbolicSparseColMatRef::new(A, N, N);
4122            let etree = ghost_prefactorize_symbolic_cholesky(
4123                Array::from_mut(&mut etree, N),
4124                Array::from_mut(&mut col_count, N),
4125                A,
4126                PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::<I>(5 * n))),
4127            );
4128
4129            supernodal::ghost_factorize_supernodal_symbolic(
4130                A,
4131                None,
4132                None,
4133                CholeskyInput::A,
4134                etree,
4135                Array::from_ref(&col_count, N),
4136                PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::<I>(20 * n))),
4137                Default::default(),
4138            )
4139            .unwrap();
4140        });
4141        assert_eq!(
4142            etree,
4143            [5, 2, 7, 5, 7, 6, 8, 9, 9, 10, NONE].map(I::Signed::truncate)
4144        );
4145        assert_eq!(col_count, [3, 3, 4, 3, 3, 4, 4, 3, 3, 2, 1].map(truncate));
4146    }
4147
4148    include!("../data.rs");
4149
4150    fn test_amd<I: Index>() {
4151        for &(_, (_, col_ptr, row_ind, _)) in ALL {
4152            let I = I::truncate;
4153            let n = col_ptr.len() - 1;
4154
4155            let (amd_perm, amd_perm_inv, _) =
4156                ::amd::order(n, col_ptr, row_ind, &Default::default()).unwrap();
4157            let col_ptr = &*col_ptr.iter().copied().map(I).collect::<Vec<_>>();
4158            let row_ind = &*row_ind.iter().copied().map(I).collect::<Vec<_>>();
4159            let amd_perm = &*amd_perm.iter().copied().map(I).collect::<Vec<_>>();
4160            let amd_perm_inv = &*amd_perm_inv.iter().copied().map(I).collect::<Vec<_>>();
4161            let A = SymbolicSparseColMatRef::new_unsorted_checked(n, n, col_ptr, None, row_ind);
4162
4163            let perm = &mut vec![I(0); n];
4164            let perm_inv = &mut vec![I(0); n];
4165
4166            crate::amd::order_maybe_unsorted(
4167                perm,
4168                perm_inv,
4169                A,
4170                Default::default(),
4171                PodStack::new(&mut GlobalPodBuffer::new(
4172                    crate::amd::order_maybe_unsorted_req::<I>(n, row_ind.len()).unwrap(),
4173                )),
4174            )
4175            .unwrap();
4176
4177            assert!(perm == amd_perm);
4178            assert!(perm_inv == amd_perm_inv);
4179        }
4180    }
4181
4182    fn sparse_to_dense<I: Index, E: ComplexField>(sparse: SparseColMatRef<'_, I, E>) -> Mat<E> {
4183        let m = sparse.nrows();
4184        let n = sparse.ncols();
4185
4186        let mut dense = Mat::<E>::zeros(m, n);
4187        let slice_group = SliceGroup::<'_, E>::new;
4188
4189        for j in 0..n {
4190            for (i, val) in zip(
4191                sparse.row_indices_of_col(j),
4192                slice_group(sparse.values_of_col(j)).into_ref_iter(),
4193            ) {
4194                dense.write(i, j, val.read());
4195            }
4196        }
4197
4198        dense
4199    }
4200
4201    fn reconstruct_from_supernodal_llt<I: Index, E: ComplexField>(
4202        symbolic: &supernodal::SymbolicSupernodalCholesky<I>,
4203        L_values: GroupFor<E, &[E::Unit]>,
4204    ) -> Mat<E> {
4205        let L_values = SliceGroup::<'_, E>::new(L_values);
4206        let ldlt = SupernodalLdltRef::new(symbolic, L_values.into_inner());
4207        let n_supernodes = ldlt.symbolic().n_supernodes();
4208        let n = ldlt.symbolic().nrows();
4209
4210        let mut dense = Mat::<E>::zeros(n, n);
4211
4212        for s in 0..n_supernodes {
4213            let s = ldlt.supernode(s);
4214            let size = s.matrix().ncols();
4215
4216            let (Ls_top, Ls_bot) = s.matrix().split_at_row(size);
4217            dense
4218                .as_mut()
4219                .submatrix_mut(s.start(), s.start(), size, size)
4220                .copy_from(Ls_top);
4221
4222            for col in 0..size {
4223                for (i, row) in s.pattern().iter().enumerate() {
4224                    dense.write(row.zx(), s.start() + col, Ls_bot.read(i, col));
4225                }
4226            }
4227        }
4228
4229        &dense * dense.adjoint()
4230    }
4231
4232    fn reconstruct_from_supernodal_ldlt<I: Index, E: ComplexField>(
4233        symbolic: &supernodal::SymbolicSupernodalCholesky<I>,
4234        L_values: GroupFor<E, &[E::Unit]>,
4235    ) -> Mat<E> {
4236        let L_values = SliceGroup::<'_, E>::new(L_values);
4237        let ldlt = SupernodalLdltRef::new(symbolic, L_values.into_inner());
4238        let n_supernodes = ldlt.symbolic().n_supernodes();
4239        let n = ldlt.symbolic().nrows();
4240
4241        let mut dense = Mat::<E>::zeros(n, n);
4242
4243        for s in 0..n_supernodes {
4244            let s = ldlt.supernode(s);
4245            let size = s.matrix().ncols();
4246
4247            let (Ls_top, Ls_bot) = s.matrix().split_at_row(size);
4248            dense
4249                .as_mut()
4250                .submatrix_mut(s.start(), s.start(), size, size)
4251                .copy_from(Ls_top);
4252
4253            for col in 0..size {
4254                for (i, row) in s.pattern().iter().enumerate() {
4255                    dense.write(row.zx(), s.start() + col, Ls_bot.read(i, col));
4256                }
4257            }
4258        }
4259
4260        let mut D = Mat::<E>::zeros(n, n);
4261        zipped!(
4262            D.as_mut().diagonal_mut().column_vector_mut().as_2d_mut(),
4263            dense.as_ref().diagonal().column_vector().as_2d()
4264        )
4265        .for_each(|unzipped!(mut dst, src)| dst.write(src.read().faer_inv()));
4266        dense
4267            .as_mut()
4268            .diagonal_mut()
4269            .column_vector_mut()
4270            .fill(E::faer_one());
4271        &dense * D * dense.adjoint()
4272    }
4273
4274    fn reconstruct_from_simplicial_llt<'a, I: Index, E: ComplexField>(
4275        symbolic: &'a simplicial::SymbolicSimplicialCholesky<I>,
4276        L_values: GroupFor<E, &'a [E::Unit]>,
4277    ) -> Mat<E> {
4278        let slice_group = SliceGroup::<'_, E>::new;
4279        let L_values = slice_group(L_values);
4280        let n = symbolic.nrows();
4281        let mut dense = Mat::<E>::zeros(n, n);
4282
4283        let L = SparseColMatRef::<'_, I, E>::new(
4284            SymbolicSparseColMatRef::new_unsorted_checked(
4285                n,
4286                n,
4287                symbolic.col_ptrs(),
4288                None,
4289                symbolic.row_indices(),
4290            ),
4291            L_values.into_inner(),
4292        );
4293
4294        for j in 0..n {
4295            for (i, val) in zip(
4296                L.row_indices_of_col(j),
4297                slice_group(L.values_of_col(j)).into_ref_iter(),
4298            ) {
4299                dense.write(i, j, val.read());
4300            }
4301        }
4302
4303        &dense * dense.adjoint()
4304    }
4305
4306    fn reconstruct_from_simplicial_ldlt<'a, I: Index, E: ComplexField>(
4307        symbolic: &'a simplicial::SymbolicSimplicialCholesky<I>,
4308        L_values: GroupFor<E, &'a [E::Unit]>,
4309    ) -> Mat<E> {
4310        let slice_group = SliceGroup::<'_, E>::new;
4311        let L_values = slice_group(L_values);
4312        let n = symbolic.nrows();
4313        let mut dense = Mat::<E>::zeros(n, n);
4314
4315        let L = SparseColMatRef::<'_, I, E>::new(
4316            SymbolicSparseColMatRef::new_unsorted_checked(
4317                n,
4318                n,
4319                symbolic.col_ptrs(),
4320                None,
4321                symbolic.row_indices(),
4322            ),
4323            L_values.into_inner(),
4324        );
4325
4326        for j in 0..n {
4327            for (i, val) in zip(
4328                L.row_indices_of_col(j),
4329                slice_group(L.values_of_col(j)).into_ref_iter(),
4330            ) {
4331                dense.write(i, j, val.read());
4332            }
4333        }
4334
4335        let mut D = Mat::<E>::zeros(n, n);
4336        D.as_mut()
4337            .diagonal_mut()
4338            .column_vector_mut()
4339            .copy_from(dense.as_ref().diagonal().column_vector());
4340        dense
4341            .as_mut()
4342            .diagonal_mut()
4343            .column_vector_mut()
4344            .fill(E::faer_one());
4345
4346        &dense * D * dense.adjoint()
4347    }
4348
4349    fn test_supernodal<I: Index>() {
4350        type E = Complex<Double<f64>>;
4351        let truncate = I::truncate;
4352
4353        let (_, col_ptr, row_ind, values) = MEDIUM;
4354
4355        let mut gen = rand::rngs::StdRng::seed_from_u64(0);
4356        let mut complexify = |e: E| {
4357            let i = E::faer_one().faer_neg().faer_sqrt();
4358            if e == E::faer_from_f64(1.0) {
4359                e.faer_add(i.faer_mul(E::faer_from_f64(gen.gen())))
4360            } else {
4361                e
4362            }
4363        };
4364
4365        let n = col_ptr.len() - 1;
4366        let nnz = values.len();
4367        let col_ptr = &*col_ptr.iter().copied().map(truncate).collect::<Vec<_>>();
4368        let row_ind = &*row_ind.iter().copied().map(truncate).collect::<Vec<_>>();
4369        let values_mat =
4370            faer_core::Mat::<E>::from_fn(nnz, 1, |i, _| complexify(E::faer_from_f64(values[i])));
4371        let values = values_mat.col_as_slice(0);
4372
4373        let A = SparseColMatRef::<'_, I, E>::new(
4374            SymbolicSparseColMatRef::new_unsorted_checked(n, n, col_ptr, None, row_ind),
4375            values,
4376        );
4377        let zero = truncate(0);
4378        let mut etree = vec![zero.to_signed(); n];
4379        let mut col_count = vec![zero; n];
4380        ghost::with_size(n, |N| {
4381            let A = ghost::SparseColMatRef::new(A, N, N);
4382            let etree = ghost_prefactorize_symbolic_cholesky(
4383                Array::from_mut(&mut etree, N),
4384                Array::from_mut(&mut col_count, N),
4385                *A,
4386                PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::<I>(5 * n))),
4387            );
4388
4389            let symbolic = supernodal::ghost_factorize_supernodal_symbolic(
4390                *A,
4391                None,
4392                None,
4393                CholeskyInput::A,
4394                etree,
4395                Array::from_ref(&col_count, N),
4396                PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::<I>(20 * n))),
4397                Default::default(),
4398            )
4399            .unwrap();
4400
4401            let mut A_lower_col_ptr = col_ptr.to_vec();
4402            let mut A_lower_values = values_mat.clone();
4403            let mut A_lower_row_ind = row_ind.to_vec();
4404            let A_lower_values = SliceGroupMut::new(A_lower_values.col_as_slice_mut(0));
4405            let A_lower = faer_core::sparse::util::ghost_adjoint(
4406                &mut A_lower_col_ptr,
4407                &mut A_lower_row_ind,
4408                A_lower_values,
4409                A,
4410                PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::<I>(20 * n))),
4411            );
4412            let mut values = faer_core::Mat::<E>::zeros(symbolic.len_values(), 1);
4413
4414            supernodal::factorize_supernodal_numeric_ldlt(
4415                values.col_as_slice_mut(0),
4416                A_lower.into_inner().into_const(),
4417                Default::default(),
4418                &symbolic,
4419                Parallelism::None,
4420                PodStack::new(&mut GlobalPodBuffer::new(
4421                    supernodal::factorize_supernodal_numeric_ldlt_req::<I, E>(
4422                        &symbolic,
4423                        Parallelism::None,
4424                    )
4425                    .unwrap(),
4426                )),
4427            );
4428            let mut A = sparse_to_dense(A.into_inner());
4429            for j in 0..n {
4430                for i in j + 1..n {
4431                    A.write(i, j, A.read(j, i).faer_conj());
4432                }
4433            }
4434
4435            let err =
4436                reconstruct_from_supernodal_ldlt::<I, E>(&symbolic, values.col_as_slice(0)) - A;
4437            let mut max = <E as ComplexField>::Real::faer_zero();
4438            for j in 0..n {
4439                for i in 0..n {
4440                    let x = err.read(i, j).faer_abs();
4441                    max = if max > x { max } else { x }
4442                }
4443            }
4444            assert!(max < <E as ComplexField>::Real::faer_from_f64(1e-25));
4445        });
4446    }
4447
4448    fn test_supernodal_ldlt<I: Index>() {
4449        type E = Complex<Double<f64>>;
4450        let truncate = I::truncate;
4451
4452        let (_, col_ptr, row_ind, values) = MEDIUM;
4453
4454        let mut gen = rand::rngs::StdRng::seed_from_u64(0);
4455        let i = E::faer_one().faer_neg().faer_sqrt();
4456        let mut complexify = |e: E| {
4457            if e == E::faer_from_f64(1.0) {
4458                e.faer_add(i.faer_mul(E::faer_from_f64(2000.0 * gen.gen::<f64>())))
4459                    .faer_add(E::faer_from_f64(2000.0 * gen.gen::<f64>()))
4460            } else {
4461                e.faer_add(E::faer_from_f64(100.0 * gen.gen::<f64>()))
4462            }
4463        };
4464
4465        let n = col_ptr.len() - 1;
4466        let nnz = values.len();
4467        let col_ptr = &*col_ptr.iter().copied().map(truncate).collect::<Vec<_>>();
4468        let row_ind = &*row_ind.iter().copied().map(truncate).collect::<Vec<_>>();
4469        let values_mat =
4470            faer_core::Mat::<E>::from_fn(nnz, 1, |i, _| complexify(E::faer_from_f64(values[i])));
4471        let values = values_mat.col_as_slice(0);
4472
4473        let A = SparseColMatRef::<'_, I, E>::new(
4474            SymbolicSparseColMatRef::new_unsorted_checked(n, n, col_ptr, None, row_ind),
4475            values,
4476        );
4477        let mut A_dense = sparse_to_dense(A);
4478        for j in 0..n {
4479            for i in j + 1..n {
4480                A_dense.write(i, j, A_dense.read(j, i).faer_conj());
4481            }
4482        }
4483
4484        let zero = truncate(0);
4485        let mut etree = vec![zero.to_signed(); n];
4486        let mut col_count = vec![zero; n];
4487        ghost::with_size(n, |N| {
4488            let A = ghost::SparseColMatRef::new(A, N, N);
4489            let etree = ghost_prefactorize_symbolic_cholesky(
4490                Array::from_mut(&mut etree, N),
4491                Array::from_mut(&mut col_count, N),
4492                *A,
4493                PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::<I>(5 * n))),
4494            );
4495
4496            let symbolic = supernodal::ghost_factorize_supernodal_symbolic(
4497                *A,
4498                None,
4499                None,
4500                CholeskyInput::A,
4501                etree,
4502                Array::from_ref(&col_count, N),
4503                PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::<I>(20 * n))),
4504                Default::default(),
4505            )
4506            .unwrap();
4507
4508            let mut A_lower_col_ptr = col_ptr.to_vec();
4509            let mut A_lower_values = values_mat.clone();
4510            let mut A_lower_row_ind = row_ind.to_vec();
4511            let A_lower_values = SliceGroupMut::new(A_lower_values.col_as_slice_mut(0));
4512            let A_lower = faer_core::sparse::util::ghost_adjoint(
4513                &mut A_lower_col_ptr,
4514                &mut A_lower_row_ind,
4515                A_lower_values,
4516                A,
4517                PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::<I>(20 * n))),
4518            );
4519            let mut values = faer_core::Mat::<E>::zeros(symbolic.len_values(), 1);
4520
4521            supernodal::factorize_supernodal_numeric_ldlt(
4522                values.col_as_slice_mut(0),
4523                A_lower.into_inner().into_const(),
4524                Default::default(),
4525                &symbolic,
4526                Parallelism::None,
4527                PodStack::new(&mut GlobalPodBuffer::new(
4528                    supernodal::factorize_supernodal_numeric_ldlt_req::<I, E>(
4529                        &symbolic,
4530                        Parallelism::None,
4531                    )
4532                    .unwrap(),
4533                )),
4534            );
4535            let k = 2;
4536
4537            let rhs = Mat::<E>::from_fn(n, k, |_, _| {
4538                E::faer_from_f64(gen.gen()).faer_add(i.faer_mul(E::faer_from_f64(gen.gen())))
4539            });
4540            for conj in [Conj::Yes, Conj::No] {
4541                let mut x = rhs.clone();
4542                let ldlt = SupernodalLdltRef::new(&symbolic, values.col_as_slice(0));
4543                ldlt.solve_in_place_with_conj(
4544                    conj,
4545                    x.as_mut(),
4546                    Parallelism::None,
4547                    PodStack::new(&mut GlobalPodBuffer::new(
4548                        symbolic.solve_in_place_req::<E>(k).unwrap(),
4549                    )),
4550                );
4551
4552                let rhs_reconstructed = if conj == Conj::No {
4553                    &A_dense * &x
4554                } else {
4555                    A_dense.conjugate() * &x
4556                };
4557                let mut max = <E as ComplexField>::Real::faer_zero();
4558                for j in 0..k {
4559                    for i in 0..n {
4560                        let x = rhs_reconstructed
4561                            .read(i, j)
4562                            .faer_sub(rhs.read(i, j))
4563                            .faer_abs();
4564                        max = if max > x { max } else { x }
4565                    }
4566                }
4567                assert!(max < <E as ComplexField>::Real::faer_from_f64(1e-25));
4568            }
4569        });
4570    }
4571
4572    fn test_supernodal_intranode_bk_1<I: Index>() {
4573        type E = Complex<f64>;
4574        let truncate = I::truncate;
4575
4576        let (_, col_ptr, row_ind, values) = MEDIUM;
4577
4578        let mut gen = rand::rngs::StdRng::seed_from_u64(0);
4579        let i = E::faer_one().faer_neg().faer_sqrt();
4580
4581        let n = col_ptr.len() - 1;
4582        let nnz = values.len();
4583        let col_ptr = &*col_ptr.iter().copied().map(truncate).collect::<Vec<_>>();
4584        let row_ind = &*row_ind.iter().copied().map(truncate).collect::<Vec<_>>();
4585
4586        let mut complexify = |e: E| {
4587            let i = E::faer_one().faer_neg().faer_sqrt();
4588            if e == E::faer_from_f64(1.0) {
4589                e.faer_add(i.faer_mul(E::faer_from_f64(1000.0 * gen.gen::<f64>())))
4590            } else {
4591                e.faer_add(E::faer_from_f64(1000.0 * gen.gen::<f64>()))
4592            }
4593        };
4594        let values_mat =
4595            faer_core::Mat::<E>::from_fn(nnz, 1, |i, _| complexify(E::faer_from_f64(values[i])));
4596        let values = values_mat.col_as_slice(0);
4597
4598        let A = SparseColMatRef::<'_, I, E>::new(
4599            SymbolicSparseColMatRef::new_unsorted_checked(n, n, col_ptr, None, row_ind),
4600            values,
4601        );
4602        let mut A_dense = sparse_to_dense(A);
4603        for j in 0..n {
4604            for i in j + 1..n {
4605                A_dense.write(i, j, A_dense.read(j, i).faer_conj());
4606            }
4607        }
4608
4609        let zero = truncate(0);
4610        let mut etree = vec![zero.to_signed(); n];
4611        let mut col_count = vec![zero; n];
4612        ghost::with_size(n, |N| {
4613            let A = ghost::SparseColMatRef::new(A, N, N);
4614            let etree = ghost_prefactorize_symbolic_cholesky(
4615                Array::from_mut(&mut etree, N),
4616                Array::from_mut(&mut col_count, N),
4617                *A,
4618                PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::<I>(5 * n))),
4619            );
4620
4621            let symbolic = supernodal::ghost_factorize_supernodal_symbolic(
4622                *A,
4623                None,
4624                None,
4625                CholeskyInput::A,
4626                etree,
4627                Array::from_ref(&col_count, N),
4628                PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::<I>(20 * n))),
4629                Default::default(),
4630            )
4631            .unwrap();
4632
4633            let mut A_lower_col_ptr = col_ptr.to_vec();
4634            let mut A_lower_values = values_mat.clone();
4635            let mut A_lower_row_ind = row_ind.to_vec();
4636            let A_lower_values = SliceGroupMut::new(A_lower_values.col_as_slice_mut(0));
4637            let A_lower = faer_core::sparse::util::ghost_adjoint(
4638                &mut A_lower_col_ptr,
4639                &mut A_lower_row_ind,
4640                A_lower_values,
4641                A,
4642                PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::<I>(20 * n))),
4643            );
4644            let mut values = faer_core::Mat::<E>::zeros(symbolic.len_values(), 1);
4645
4646            let mut fwd = vec![zero; n];
4647            let mut inv = vec![zero; n];
4648            let mut subdiag = Mat::<E>::zeros(n, 1);
4649
4650            supernodal::factorize_supernodal_numeric_intranode_bunch_kaufman(
4651                values.col_as_slice_mut(0),
4652                subdiag.col_as_slice_mut(0),
4653                &mut fwd,
4654                &mut inv,
4655                A_lower.into_inner().into_const(),
4656                Default::default(),
4657                &symbolic,
4658                Parallelism::None,
4659                PodStack::new(&mut GlobalPodBuffer::new(
4660                    supernodal::factorize_supernodal_numeric_intranode_bunch_kaufman_req::<I, E>(
4661                        &symbolic,
4662                        Parallelism::None,
4663                    )
4664                    .unwrap(),
4665                )),
4666            );
4667            let k = 2;
4668
4669            let rhs = Mat::<E>::from_fn(n, k, |_, _| {
4670                E::faer_from_f64(gen.gen()).faer_add(i.faer_mul(E::faer_from_f64(gen.gen())))
4671            });
4672            for conj in [Conj::Yes, Conj::No] {
4673                let mut x = rhs.clone();
4674                let lblt = SupernodalIntranodeBunchKaufmanRef::new(
4675                    &symbolic,
4676                    values.col_as_slice(0),
4677                    subdiag.col_as_slice(0),
4678                    PermutationRef::new_checked(&fwd, &inv),
4679                );
4680                faer_core::permutation::permute_rows_in_place(
4681                    x.as_mut(),
4682                    lblt.perm,
4683                    PodStack::new(&mut GlobalPodBuffer::new(
4684                        faer_core::permutation::permute_rows_in_place_req::<I, E>(n, k).unwrap(),
4685                    )),
4686                );
4687                lblt.solve_in_place_no_numeric_permute_with_conj(
4688                    conj,
4689                    x.as_mut(),
4690                    Parallelism::None,
4691                    PodStack::new(&mut GlobalPodBuffer::new(
4692                        symbolic.solve_in_place_req::<E>(k).unwrap(),
4693                    )),
4694                );
4695                faer_core::permutation::permute_rows_in_place(
4696                    x.as_mut(),
4697                    lblt.perm.inverse(),
4698                    PodStack::new(&mut GlobalPodBuffer::new(
4699                        faer_core::permutation::permute_rows_in_place_req::<I, E>(n, k).unwrap(),
4700                    )),
4701                );
4702
4703                let rhs_reconstructed = if conj == Conj::No {
4704                    &A_dense * &x
4705                } else {
4706                    A_dense.conjugate() * &x
4707                };
4708                let mut max = <E as ComplexField>::Real::faer_zero();
4709                for j in 0..k {
4710                    for i in 0..n {
4711                        let x = rhs_reconstructed
4712                            .read(i, j)
4713                            .faer_sub(rhs.read(i, j))
4714                            .faer_abs();
4715                        max = if max > x { max } else { x }
4716                    }
4717                }
4718                assert!(max < <E as ComplexField>::Real::faer_from_f64(1e-10));
4719            }
4720        });
4721    }
4722
4723    fn test_supernodal_intranode_bk_2<I: Index>() {
4724        type E = Complex<f64>;
4725        let truncate = I::truncate;
4726
4727        let (_, col_ptr, row_ind, values) = MEDIUM_P;
4728
4729        let mut gen = rand::rngs::StdRng::seed_from_u64(0);
4730        let i = E::faer_one().faer_neg().faer_sqrt();
4731
4732        let n = col_ptr.len() - 1;
4733        let nnz = values.len();
4734        let col_ptr = &*col_ptr.iter().copied().map(truncate).collect::<Vec<_>>();
4735        let row_ind = &*row_ind.iter().copied().map(truncate).collect::<Vec<_>>();
4736        let values_mat = faer_core::Mat::<E>::from_fn(nnz, 1, |i, _| values[i]);
4737        let values = values_mat.col_as_slice(0);
4738
4739        let A = SparseColMatRef::<'_, I, E>::new(
4740            SymbolicSparseColMatRef::new_unsorted_checked(n, n, col_ptr, None, row_ind),
4741            values,
4742        );
4743        let mut A_dense = sparse_to_dense(A);
4744        for j in 0..n {
4745            for i in j + 1..n {
4746                A_dense.write(i, j, A_dense.read(j, i).faer_conj());
4747            }
4748        }
4749
4750        let zero = truncate(0);
4751        let mut etree = vec![zero.to_signed(); n];
4752        let mut col_count = vec![zero; n];
4753        ghost::with_size(n, |N| {
4754            let A = ghost::SparseColMatRef::new(A, N, N);
4755            let etree = ghost_prefactorize_symbolic_cholesky(
4756                Array::from_mut(&mut etree, N),
4757                Array::from_mut(&mut col_count, N),
4758                *A,
4759                PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::<I>(5 * n))),
4760            );
4761
4762            let symbolic = supernodal::ghost_factorize_supernodal_symbolic(
4763                *A,
4764                None,
4765                None,
4766                CholeskyInput::A,
4767                etree,
4768                Array::from_ref(&col_count, N),
4769                PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::<I>(20 * n))),
4770                Default::default(),
4771            )
4772            .unwrap();
4773
4774            let mut A_lower_col_ptr = col_ptr.to_vec();
4775            let mut A_lower_values = values_mat.clone();
4776            let mut A_lower_row_ind = row_ind.to_vec();
4777            let A_lower_values = SliceGroupMut::new(A_lower_values.col_as_slice_mut(0));
4778            let A_lower = faer_core::sparse::util::ghost_adjoint(
4779                &mut A_lower_col_ptr,
4780                &mut A_lower_row_ind,
4781                A_lower_values,
4782                A,
4783                PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::<I>(20 * n))),
4784            );
4785            let mut values = faer_core::Mat::<E>::zeros(symbolic.len_values(), 1);
4786
4787            let mut fwd = vec![zero; n];
4788            let mut inv = vec![zero; n];
4789            let mut subdiag = Mat::<E>::zeros(n, 1);
4790
4791            supernodal::factorize_supernodal_numeric_intranode_bunch_kaufman(
4792                values.col_as_slice_mut(0),
4793                subdiag.col_as_slice_mut(0),
4794                &mut fwd,
4795                &mut inv,
4796                A_lower.into_inner().into_const(),
4797                Default::default(),
4798                &symbolic,
4799                Parallelism::None,
4800                PodStack::new(&mut GlobalPodBuffer::new(
4801                    supernodal::factorize_supernodal_numeric_intranode_bunch_kaufman_req::<I, E>(
4802                        &symbolic,
4803                        Parallelism::None,
4804                    )
4805                    .unwrap(),
4806                )),
4807            );
4808            let k = 2;
4809
4810            let rhs = Mat::<E>::from_fn(n, k, |_, _| {
4811                E::faer_from_f64(gen.gen()).faer_add(i.faer_mul(E::faer_from_f64(gen.gen())))
4812            });
4813            for conj in [Conj::Yes, Conj::No] {
4814                let mut x = rhs.clone();
4815                let lblt = SupernodalIntranodeBunchKaufmanRef::new(
4816                    &symbolic,
4817                    values.col_as_slice(0),
4818                    subdiag.col_as_slice(0),
4819                    PermutationRef::new_checked(&fwd, &inv),
4820                );
4821                faer_core::permutation::permute_rows_in_place(
4822                    x.as_mut(),
4823                    lblt.perm,
4824                    PodStack::new(&mut GlobalPodBuffer::new(
4825                        faer_core::permutation::permute_rows_in_place_req::<I, E>(n, k).unwrap(),
4826                    )),
4827                );
4828                lblt.solve_in_place_no_numeric_permute_with_conj(
4829                    conj,
4830                    x.as_mut(),
4831                    Parallelism::None,
4832                    PodStack::new(&mut GlobalPodBuffer::new(
4833                        symbolic.solve_in_place_req::<E>(k).unwrap(),
4834                    )),
4835                );
4836                faer_core::permutation::permute_rows_in_place(
4837                    x.as_mut(),
4838                    lblt.perm.inverse(),
4839                    PodStack::new(&mut GlobalPodBuffer::new(
4840                        faer_core::permutation::permute_rows_in_place_req::<I, E>(n, k).unwrap(),
4841                    )),
4842                );
4843
4844                let rhs_reconstructed = if conj == Conj::No {
4845                    &A_dense * &x
4846                } else {
4847                    A_dense.conjugate() * &x
4848                };
4849                let mut max = <E as ComplexField>::Real::faer_zero();
4850                for j in 0..k {
4851                    for i in 0..n {
4852                        let x = rhs_reconstructed
4853                            .read(i, j)
4854                            .faer_sub(rhs.read(i, j))
4855                            .faer_abs();
4856                        max = if max > x { max } else { x }
4857                    }
4858                }
4859                assert!(max < <E as ComplexField>::Real::faer_from_f64(1e-10));
4860            }
4861        });
4862    }
4863
4864    fn test_simplicial<I: Index>() {
4865        type E = Complex<Double<f64>>;
4866        let truncate = I::truncate;
4867
4868        let (_, col_ptr, row_ind, values) = SMALL;
4869
4870        let mut gen = rand::rngs::StdRng::seed_from_u64(0);
4871        let mut complexify = |e: E| {
4872            let i = E::faer_one().faer_neg().faer_sqrt();
4873            if e == E::faer_from_f64(1.0) {
4874                e.faer_add(i.faer_mul(E::faer_from_f64(gen.gen())))
4875            } else {
4876                e
4877            }
4878        };
4879
4880        let n = col_ptr.len() - 1;
4881        let nnz = values.len();
4882        let col_ptr = &*col_ptr.iter().copied().map(truncate).collect::<Vec<_>>();
4883        let row_ind = &*row_ind.iter().copied().map(truncate).collect::<Vec<_>>();
4884        let values_mat =
4885            faer_core::Mat::<E>::from_fn(nnz, 1, |i, _| complexify(E::faer_from_f64(values[i])));
4886        let values = values_mat.col_as_slice(0);
4887
4888        let A = SparseColMatRef::<'_, I, E>::new(
4889            SymbolicSparseColMatRef::new_unsorted_checked(n, n, col_ptr, None, row_ind),
4890            values,
4891        );
4892        let zero = truncate(0);
4893        let mut etree = vec![zero.to_signed(); n];
4894        let mut col_count = vec![zero; n];
4895        ghost::with_size(n, |N| {
4896            let A = ghost::SparseColMatRef::new(A, N, N);
4897            let etree = ghost_prefactorize_symbolic_cholesky(
4898                Array::from_mut(&mut etree, N),
4899                Array::from_mut(&mut col_count, N),
4900                *A,
4901                PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::<I>(5 * n))),
4902            );
4903
4904            let symbolic = simplicial::ghost_factorize_simplicial_symbolic_cholesky(
4905                *A,
4906                etree,
4907                Array::from_ref(&col_count, N),
4908                PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::<I>(20 * n))),
4909            )
4910            .unwrap();
4911
4912            let mut values = faer_core::Mat::<E>::zeros(symbolic.len_values(), 1);
4913
4914            simplicial::factorize_simplicial_numeric_ldlt::<I, E>(
4915                values.col_as_slice_mut(0),
4916                A.into_inner(),
4917                Default::default(),
4918                &symbolic,
4919                PodStack::new(&mut GlobalPodBuffer::new(
4920                    simplicial::factorize_simplicial_numeric_ldlt_req::<I, E>(n).unwrap(),
4921                )),
4922            );
4923            let mut A = sparse_to_dense(A.into_inner());
4924            for j in 0..n {
4925                for i in j + 1..n {
4926                    A.write(i, j, A.read(j, i).faer_conj());
4927                }
4928            }
4929
4930            let err =
4931                reconstruct_from_simplicial_ldlt::<I, E>(&symbolic, values.col_as_slice(0)) - &A;
4932
4933            let mut max = <E as ComplexField>::Real::faer_zero();
4934            for j in 0..n {
4935                for i in 0..n {
4936                    let x = err.read(i, j).faer_abs();
4937                    max = if max > x { max } else { x }
4938                }
4939            }
4940            assert!(max < <E as ComplexField>::Real::faer_from_f64(1e-25));
4941        });
4942    }
4943
4944    fn test_solver_llt<I: Index>() {
4945        type E = Complex<Double<f64>>;
4946        let truncate = I::truncate;
4947
4948        for (_, col_ptr, row_ind, values) in [SMALL, MEDIUM] {
4949            let mut gen = rand::rngs::StdRng::seed_from_u64(0);
4950            let i = E::faer_one().faer_neg().faer_sqrt();
4951            let mut complexify = |e: E| {
4952                if e == E::faer_from_f64(1.0) {
4953                    e.faer_add(i.faer_mul(E::faer_from_f64(gen.gen())))
4954                } else {
4955                    e
4956                }
4957            };
4958
4959            let n = col_ptr.len() - 1;
4960            let nnz = values.len();
4961            let col_ptr = &*col_ptr.iter().copied().map(truncate).collect::<Vec<_>>();
4962            let row_ind = &*row_ind.iter().copied().map(truncate).collect::<Vec<_>>();
4963            let values_mat = faer_core::Mat::<E>::from_fn(nnz, 1, |i, _| {
4964                complexify(E::faer_from_f64(values[i]))
4965            });
4966            let values = values_mat.col_as_slice(0);
4967
4968            let A_upper = SparseColMatRef::<'_, I, E>::new(
4969                SymbolicSparseColMatRef::new_unsorted_checked(n, n, col_ptr, None, row_ind),
4970                values,
4971            );
4972
4973            let mut A_lower_col_ptr = col_ptr.to_vec();
4974            let mut A_lower_values = values_mat.clone();
4975            let mut A_lower_row_ind = row_ind.to_vec();
4976            let A_lower = faer_core::sparse::util::adjoint(
4977                &mut A_lower_col_ptr,
4978                &mut A_lower_row_ind,
4979                A_lower_values.col_as_slice_mut(0),
4980                A_upper,
4981                PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::<I>(20 * n))),
4982            )
4983            .into_const();
4984
4985            let mut A_dense = sparse_to_dense(A_upper);
4986            for j in 0..n {
4987                for i in j + 1..n {
4988                    A_dense.write(i, j, A_dense.read(j, i).faer_conj());
4989                }
4990            }
4991
4992            for (A, side, supernodal_flop_ratio_threshold, parallelism) in [
4993                (
4994                    A_upper,
4995                    Side::Upper,
4996                    SupernodalThreshold::FORCE_SIMPLICIAL,
4997                    Parallelism::None,
4998                ),
4999                (
5000                    A_upper,
5001                    Side::Upper,
5002                    SupernodalThreshold::FORCE_SUPERNODAL,
5003                    Parallelism::None,
5004                ),
5005                (
5006                    A_lower,
5007                    Side::Lower,
5008                    SupernodalThreshold::FORCE_SIMPLICIAL,
5009                    Parallelism::None,
5010                ),
5011                (
5012                    A_lower,
5013                    Side::Lower,
5014                    SupernodalThreshold::FORCE_SUPERNODAL,
5015                    Parallelism::None,
5016                ),
5017            ] {
5018                let symbolic = factorize_symbolic_cholesky(
5019                    A.symbolic(),
5020                    side,
5021                    CholeskySymbolicParams {
5022                        supernodal_flop_ratio_threshold,
5023                        ..Default::default()
5024                    },
5025                )
5026                .unwrap();
5027                let mut mem = GlobalPodBuffer::new(
5028                    symbolic
5029                        .factorize_numeric_ldlt_req::<E>(false, parallelism)
5030                        .unwrap(),
5031                );
5032                let mut L_values = Mat::<E>::zeros(symbolic.len_values(), 1);
5033
5034                symbolic
5035                    .factorize_numeric_llt::<E>(
5036                        L_values.col_as_slice_mut(0),
5037                        A,
5038                        side,
5039                        Default::default(),
5040                        parallelism,
5041                        PodStack::new(&mut mem),
5042                    )
5043                    .unwrap();
5044                let L_values = L_values.col_as_slice(0);
5045
5046                let A_reconstructed = match symbolic.raw() {
5047                    SymbolicCholeskyRaw::Simplicial(symbolic) => {
5048                        reconstruct_from_simplicial_llt::<I, E>(symbolic, L_values)
5049                    }
5050                    SymbolicCholeskyRaw::Supernodal(symbolic) => {
5051                        reconstruct_from_supernodal_llt::<I, E>(symbolic, L_values)
5052                    }
5053                };
5054
5055                let (perm_fwd, _) = symbolic.perm().into_arrays();
5056
5057                let mut max = <E as ComplexField>::Real::faer_zero();
5058                for j in 0..n {
5059                    for i in 0..n {
5060                        let x = (A_reconstructed
5061                            .read(i, j)
5062                            .faer_sub(A_dense.read(perm_fwd[i].zx(), perm_fwd[j].zx())))
5063                        .faer_abs();
5064                        max = if max > x { max } else { x }
5065                    }
5066                }
5067                assert!(max < <E as ComplexField>::Real::faer_from_f64(1e-25));
5068
5069                for k in (1..16).chain(128..132) {
5070                    let rhs = Mat::<E>::from_fn(n, k, |_, _| {
5071                        E::faer_from_f64(gen.gen())
5072                            .faer_add(i.faer_mul(E::faer_from_f64(gen.gen())))
5073                    });
5074                    for conj in [Conj::Yes, Conj::No] {
5075                        let mut x = rhs.clone();
5076                        let llt = LltRef::new(&symbolic, L_values);
5077                        llt.solve_in_place_with_conj(
5078                            conj,
5079                            x.as_mut(),
5080                            parallelism,
5081                            PodStack::new(&mut GlobalPodBuffer::new(
5082                                symbolic.solve_in_place_req::<E>(k).unwrap(),
5083                            )),
5084                        );
5085
5086                        let rhs_reconstructed = if conj == Conj::No {
5087                            &A_dense * &x
5088                        } else {
5089                            A_dense.conjugate() * &x
5090                        };
5091                        let mut max = <E as ComplexField>::Real::faer_zero();
5092                        for j in 0..k {
5093                            for i in 0..n {
5094                                let x = rhs_reconstructed
5095                                    .read(i, j)
5096                                    .faer_sub(rhs.read(i, j))
5097                                    .faer_abs();
5098                                max = if max > x { max } else { x }
5099                            }
5100                        }
5101                        assert!(max < <E as ComplexField>::Real::faer_from_f64(1e-25));
5102                    }
5103                }
5104            }
5105        }
5106    }
5107
5108    fn test_solver_ldlt<I: Index>() {
5109        type E = Complex<Double<f64>>;
5110        let truncate = I::truncate;
5111
5112        for (_, col_ptr, row_ind, values) in [SMALL, MEDIUM] {
5113            let mut gen = rand::rngs::StdRng::seed_from_u64(0);
5114            let i = E::faer_one().faer_neg().faer_sqrt();
5115            let mut complexify = |e: E| {
5116                if e == E::faer_from_f64(1.0) {
5117                    e.faer_add(i.faer_mul(E::faer_from_f64(gen.gen())))
5118                } else {
5119                    e
5120                }
5121            };
5122
5123            let n = col_ptr.len() - 1;
5124            let nnz = values.len();
5125            let col_ptr = &*col_ptr.iter().copied().map(truncate).collect::<Vec<_>>();
5126            let row_ind = &*row_ind.iter().copied().map(truncate).collect::<Vec<_>>();
5127            let values_mat = faer_core::Mat::<E>::from_fn(nnz, 1, |i, _| {
5128                complexify(E::faer_from_f64(values[i]))
5129            });
5130            let values = values_mat.col_as_slice(0);
5131
5132            let A_upper = SparseColMatRef::<'_, I, E>::new(
5133                SymbolicSparseColMatRef::new_unsorted_checked(n, n, col_ptr, None, row_ind),
5134                values,
5135            );
5136
5137            let mut A_lower_col_ptr = col_ptr.to_vec();
5138            let mut A_lower_values = values_mat.clone();
5139            let mut A_lower_row_ind = row_ind.to_vec();
5140            let A_lower = faer_core::sparse::util::adjoint(
5141                &mut A_lower_col_ptr,
5142                &mut A_lower_row_ind,
5143                A_lower_values.col_as_slice_mut(0),
5144                A_upper,
5145                PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::<I>(20 * n))),
5146            )
5147            .into_const();
5148
5149            let mut A_dense = sparse_to_dense(A_upper);
5150            for j in 0..n {
5151                for i in j + 1..n {
5152                    A_dense.write(i, j, A_dense.read(j, i).faer_conj());
5153                }
5154            }
5155
5156            for (A, side, supernodal_flop_ratio_threshold, parallelism) in [
5157                (
5158                    A_upper,
5159                    Side::Upper,
5160                    SupernodalThreshold::FORCE_SIMPLICIAL,
5161                    Parallelism::None,
5162                ),
5163                (
5164                    A_upper,
5165                    Side::Upper,
5166                    SupernodalThreshold::FORCE_SUPERNODAL,
5167                    Parallelism::None,
5168                ),
5169                (
5170                    A_lower,
5171                    Side::Lower,
5172                    SupernodalThreshold::FORCE_SIMPLICIAL,
5173                    Parallelism::None,
5174                ),
5175                (
5176                    A_lower,
5177                    Side::Lower,
5178                    SupernodalThreshold::FORCE_SUPERNODAL,
5179                    Parallelism::None,
5180                ),
5181            ] {
5182                let symbolic = factorize_symbolic_cholesky(
5183                    A.symbolic(),
5184                    side,
5185                    CholeskySymbolicParams {
5186                        supernodal_flop_ratio_threshold,
5187                        ..Default::default()
5188                    },
5189                )
5190                .unwrap();
5191                let mut mem = GlobalPodBuffer::new(
5192                    symbolic
5193                        .factorize_numeric_ldlt_req::<E>(false, parallelism)
5194                        .unwrap(),
5195                );
5196                let mut L_values = Mat::<E>::zeros(symbolic.len_values(), 1);
5197
5198                symbolic.factorize_numeric_ldlt::<E>(
5199                    L_values.col_as_slice_mut(0),
5200                    A,
5201                    side,
5202                    Default::default(),
5203                    parallelism,
5204                    PodStack::new(&mut mem),
5205                );
5206                let L_values = L_values.col_as_slice(0);
5207                let A_reconstructed = match symbolic.raw() {
5208                    SymbolicCholeskyRaw::Simplicial(symbolic) => {
5209                        reconstruct_from_simplicial_ldlt::<I, E>(symbolic, L_values)
5210                    }
5211                    SymbolicCholeskyRaw::Supernodal(symbolic) => {
5212                        reconstruct_from_supernodal_ldlt::<I, E>(symbolic, L_values)
5213                    }
5214                };
5215
5216                let (perm_fwd, _) = symbolic.perm().into_arrays();
5217
5218                let mut max = <E as ComplexField>::Real::faer_zero();
5219                for j in 0..n {
5220                    for i in 0..n {
5221                        let x = (A_reconstructed
5222                            .read(i, j)
5223                            .faer_sub(A_dense.read(perm_fwd[i].zx(), perm_fwd[j].zx())))
5224                        .faer_abs();
5225                        max = if max > x { max } else { x }
5226                    }
5227                }
5228                assert!(max < <E as ComplexField>::Real::faer_from_f64(1e-25));
5229
5230                for k in (0..16).chain(128..132) {
5231                    let rhs = Mat::<E>::from_fn(n, k, |_, _| {
5232                        E::faer_from_f64(gen.gen())
5233                            .faer_add(i.faer_mul(E::faer_from_f64(gen.gen())))
5234                    });
5235                    for conj in [Conj::Yes, Conj::No] {
5236                        let mut x = rhs.clone();
5237                        let ldlt = LdltRef::new(&symbolic, L_values);
5238                        ldlt.solve_in_place_with_conj(
5239                            conj,
5240                            x.as_mut(),
5241                            parallelism,
5242                            PodStack::new(&mut GlobalPodBuffer::new(
5243                                symbolic.solve_in_place_req::<E>(k).unwrap(),
5244                            )),
5245                        );
5246
5247                        let rhs_reconstructed = if conj == Conj::No {
5248                            &A_dense * &x
5249                        } else {
5250                            A_dense.conjugate() * &x
5251                        };
5252                        let mut max = <E as ComplexField>::Real::faer_zero();
5253                        for j in 0..k {
5254                            for i in 0..n {
5255                                let x = rhs_reconstructed
5256                                    .read(i, j)
5257                                    .faer_sub(rhs.read(i, j))
5258                                    .faer_abs();
5259                                max = if max > x { max } else { x }
5260                            }
5261                        }
5262                        assert!(max < <E as ComplexField>::Real::faer_from_f64(1e-25));
5263                    }
5264                }
5265            }
5266        }
5267    }
5268
5269    fn test_solver_intranode_bk<I: Index>() {
5270        type E = Complex<Double<f64>>;
5271        let truncate = I::truncate;
5272
5273        for (_, col_ptr, row_ind, values) in [MEDIUM, SMALL] {
5274            let mut gen = rand::rngs::StdRng::seed_from_u64(0);
5275            let i = E::faer_one().faer_neg().faer_sqrt();
5276            let mut complexify = |e: E| {
5277                if e == E::faer_from_f64(1.0) {
5278                    e.faer_add(i.faer_mul(E::faer_from_f64(2000.0 * gen.gen::<f64>())))
5279                        .faer_add(E::faer_from_f64(2000.0 * gen.gen::<f64>()))
5280                } else {
5281                    e.faer_add(E::faer_from_f64(100.0 * gen.gen::<f64>()))
5282                }
5283            };
5284
5285            let n = col_ptr.len() - 1;
5286            let nnz = values.len();
5287            let col_ptr = &*col_ptr.iter().copied().map(truncate).collect::<Vec<_>>();
5288            let row_ind = &*row_ind.iter().copied().map(truncate).collect::<Vec<_>>();
5289            let values_mat = faer_core::Mat::<E>::from_fn(nnz, 1, |i, _| {
5290                complexify(E::faer_from_f64(values[i]))
5291            });
5292            let values = values_mat.col_as_slice(0);
5293
5294            let A_upper = SparseColMatRef::<'_, I, E>::new(
5295                SymbolicSparseColMatRef::new_unsorted_checked(n, n, col_ptr, None, row_ind),
5296                values,
5297            );
5298
5299            let mut A_lower_col_ptr = col_ptr.to_vec();
5300            let mut A_lower_values = values_mat.clone();
5301            let mut A_lower_row_ind = row_ind.to_vec();
5302            let A_lower = faer_core::sparse::util::adjoint(
5303                &mut A_lower_col_ptr,
5304                &mut A_lower_row_ind,
5305                A_lower_values.col_as_slice_mut(0),
5306                A_upper,
5307                PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::<I>(20 * n))),
5308            )
5309            .into_const();
5310
5311            let mut A_dense = sparse_to_dense(A_upper);
5312            for j in 0..n {
5313                for i in j + 1..n {
5314                    A_dense.write(i, j, A_dense.read(j, i).faer_conj());
5315                }
5316            }
5317
5318            for (A, side, supernodal_flop_ratio_threshold, parallelism) in [
5319                (
5320                    A_upper,
5321                    Side::Upper,
5322                    SupernodalThreshold::FORCE_SIMPLICIAL,
5323                    Parallelism::None,
5324                ),
5325                (
5326                    A_upper,
5327                    Side::Upper,
5328                    SupernodalThreshold::FORCE_SUPERNODAL,
5329                    Parallelism::None,
5330                ),
5331                (
5332                    A_lower,
5333                    Side::Lower,
5334                    SupernodalThreshold::FORCE_SIMPLICIAL,
5335                    Parallelism::None,
5336                ),
5337                (
5338                    A_lower,
5339                    Side::Lower,
5340                    SupernodalThreshold::FORCE_SUPERNODAL,
5341                    Parallelism::None,
5342                ),
5343            ] {
5344                let symbolic = factorize_symbolic_cholesky(
5345                    A.symbolic(),
5346                    side,
5347                    CholeskySymbolicParams {
5348                        supernodal_flop_ratio_threshold,
5349                        ..Default::default()
5350                    },
5351                )
5352                .unwrap();
5353                let mut mem = GlobalPodBuffer::new(
5354                    symbolic
5355                        .factorize_numeric_intranode_bunch_kaufman_req::<E>(false, parallelism)
5356                        .unwrap(),
5357                );
5358                let mut L_values = Mat::<E>::zeros(symbolic.len_values(), 1);
5359                let mut subdiag = Mat::<E>::zeros(n, 1);
5360                let mut fwd = vec![I::truncate(0); n];
5361                let mut inv = vec![I::truncate(0); n];
5362
5363                let lblt = symbolic.factorize_numeric_intranode_bunch_kaufman::<E>(
5364                    L_values.col_as_slice_mut(0),
5365                    subdiag.col_as_slice_mut(0),
5366                    &mut fwd,
5367                    &mut inv,
5368                    A,
5369                    side,
5370                    Default::default(),
5371                    parallelism,
5372                    PodStack::new(&mut mem),
5373                );
5374
5375                for k in (1..16).chain(128..132) {
5376                    let rhs = Mat::<E>::from_fn(n, k, |_, _| {
5377                        E::faer_from_f64(gen.gen())
5378                            .faer_add(i.faer_mul(E::faer_from_f64(gen.gen())))
5379                    });
5380                    for conj in [Conj::No, Conj::Yes] {
5381                        let mut x = rhs.clone();
5382                        lblt.solve_in_place_with_conj(
5383                            conj,
5384                            x.as_mut(),
5385                            parallelism,
5386                            PodStack::new(&mut GlobalPodBuffer::new(
5387                                symbolic.solve_in_place_req::<E>(k).unwrap(),
5388                            )),
5389                        );
5390
5391                        let rhs_reconstructed = if conj == Conj::No {
5392                            &A_dense * &x
5393                        } else {
5394                            A_dense.conjugate() * &x
5395                        };
5396                        let mut max = <E as ComplexField>::Real::faer_zero();
5397                        for j in 0..k {
5398                            for i in 0..n {
5399                                let x = rhs_reconstructed
5400                                    .read(i, j)
5401                                    .faer_sub(rhs.read(i, j))
5402                                    .faer_abs();
5403                                max = if max > x { max } else { x }
5404                            }
5405                        }
5406                        assert!(max < <E as ComplexField>::Real::faer_from_f64(1e-25));
5407                    }
5408                }
5409            }
5410        }
5411    }
5412
5413    fn test_solver_regularization<I: Index>() {
5414        type E = f64;
5415        let I = I::truncate;
5416
5417        for (_, col_ptr, row_ind, values) in [SMALL, MEDIUM] {
5418            let n = col_ptr.len() - 1;
5419            let nnz = values.len();
5420            let col_ptr = &*col_ptr.iter().copied().map(I).collect::<Vec<_>>();
5421            let row_ind = &*row_ind.iter().copied().map(I).collect::<Vec<_>>();
5422            // artificial zeros
5423            let values_mat = faer_core::Mat::<E>::from_fn(nnz, 1, |_, _| 0.0);
5424            let dynamic_regularization_epsilon = 1e-6;
5425            let dynamic_regularization_delta = 1e-2;
5426
5427            let values = values_mat.col_as_slice(0);
5428            let mut signs = vec![-1i8; n];
5429            signs[..8].fill(1);
5430
5431            let A_upper = SparseColMatRef::<'_, I, E>::new(
5432                SymbolicSparseColMatRef::new_unsorted_checked(n, n, col_ptr, None, row_ind),
5433                values,
5434            );
5435
5436            let mut A_lower_col_ptr = col_ptr.to_vec();
5437            let mut A_lower_values = values_mat.clone();
5438            let mut A_lower_row_ind = row_ind.to_vec();
5439            let A_lower = faer_core::sparse::util::adjoint(
5440                &mut A_lower_col_ptr,
5441                &mut A_lower_row_ind,
5442                A_lower_values.col_as_slice_mut(0),
5443                A_upper,
5444                PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::<I>(20 * n))),
5445            )
5446            .into_const();
5447
5448            let mut A_dense = sparse_to_dense(A_upper);
5449            for (j, &sign) in signs.iter().enumerate() {
5450                A_dense.write(j, j, sign as f64 * dynamic_regularization_delta);
5451                for i in j + 1..n {
5452                    A_dense.write(i, j, A_dense.read(j, i).faer_conj());
5453                }
5454            }
5455
5456            for (A, side, supernodal_flop_ratio_threshold, parallelism) in [
5457                (
5458                    A_upper,
5459                    Side::Upper,
5460                    SupernodalThreshold::FORCE_SIMPLICIAL,
5461                    Parallelism::None,
5462                ),
5463                (
5464                    A_upper,
5465                    Side::Upper,
5466                    SupernodalThreshold::FORCE_SUPERNODAL,
5467                    Parallelism::None,
5468                ),
5469                (
5470                    A_lower,
5471                    Side::Lower,
5472                    SupernodalThreshold::FORCE_SIMPLICIAL,
5473                    Parallelism::None,
5474                ),
5475                (
5476                    A_lower,
5477                    Side::Lower,
5478                    SupernodalThreshold::FORCE_SUPERNODAL,
5479                    Parallelism::None,
5480                ),
5481            ] {
5482                let symbolic = factorize_symbolic_cholesky(
5483                    A.symbolic(),
5484                    side,
5485                    CholeskySymbolicParams {
5486                        supernodal_flop_ratio_threshold,
5487                        ..Default::default()
5488                    },
5489                )
5490                .unwrap();
5491                let mut mem = GlobalPodBuffer::new(
5492                    symbolic
5493                        .factorize_numeric_ldlt_req::<E>(true, parallelism)
5494                        .unwrap(),
5495                );
5496                let mut L_values = Mat::<E>::zeros(symbolic.len_values(), 1);
5497                let mut L_values = L_values.col_as_slice_mut(0);
5498
5499                symbolic.factorize_numeric_ldlt(
5500                    L_values.rb_mut(),
5501                    A,
5502                    side,
5503                    LdltRegularization {
5504                        dynamic_regularization_signs: Some(&signs),
5505                        dynamic_regularization_delta,
5506                        dynamic_regularization_epsilon,
5507                    },
5508                    parallelism,
5509                    PodStack::new(&mut mem),
5510                );
5511                let L_values = L_values.rb();
5512
5513                let A_reconstructed = match symbolic.raw() {
5514                    SymbolicCholeskyRaw::Simplicial(symbolic) => {
5515                        reconstruct_from_simplicial_ldlt::<I, E>(symbolic, L_values)
5516                    }
5517                    SymbolicCholeskyRaw::Supernodal(symbolic) => {
5518                        reconstruct_from_supernodal_ldlt::<I, E>(symbolic, L_values)
5519                    }
5520                };
5521
5522                let (perm_fwd, _) = symbolic.perm().into_arrays();
5523                let mut max = <E as ComplexField>::Real::faer_zero();
5524                for j in 0..n {
5525                    for i in 0..n {
5526                        let x = (A_reconstructed
5527                            .read(i, j)
5528                            .faer_sub(A_dense.read(perm_fwd[i].zx(), perm_fwd[j].zx())))
5529                        .abs();
5530                        max = if max > x { max } else { x }
5531                    }
5532                }
5533                assert!(max == 0.0);
5534            }
5535        }
5536    }
5537
5538    monomorphize_test!(test_amd);
5539    monomorphize_test!(test_counts);
5540    monomorphize_test!(test_supernodal, u32);
5541    monomorphize_test!(test_supernodal_ldlt, u32);
5542    monomorphize_test!(test_supernodal_intranode_bk_1, u32);
5543    monomorphize_test!(test_supernodal_intranode_bk_2, u32);
5544    monomorphize_test!(test_simplicial, u32);
5545    monomorphize_test!(test_solver_llt, u32);
5546    monomorphize_test!(test_solver_ldlt, u32);
5547    monomorphize_test!(test_solver_intranode_bk, u32);
5548    monomorphize_test!(test_solver_regularization, u32);
5549}