Skip to main content

gam_linalg/
faer_ndarray.rs

1use dyn_stack::{MemBuffer, MemStack};
2use faer::diag::{Diag, DiagRef};
3use faer::linalg::solvers::{self, Solve};
4pub use faer::linalg::solvers::{
5    Lblt as FaerLblt, Ldlt as FaerLdlt, Llt as FaerLlt, Solve as FaerSolve,
6};
7use faer::linalg::svd::{self, ComputeSvdVectors};
8use faer::prelude::ReborrowMut;
9use faer::{Conj, Mat, MatMut, MatRef, Par, Side, Unbind, get_global_parallelism};
10use ndarray::{Array1, Array2, ArrayBase, ArrayViewMut1, Data, Ix1, Ix2};
11use std::marker::PhantomData;
12use std::panic::{AssertUnwindSafe, catch_unwind};
13use thiserror::Error;
14
15const RRQR_RANK_ALPHA: f64 = 100.0;
16
17thread_local! {
18    static NESTED_PARALLEL_DEPTH: std::cell::Cell<usize> = const { std::cell::Cell::new(0) };
19}
20
21struct NestedParallelGuard;
22
23impl NestedParallelGuard {
24    #[inline]
25    fn enter() -> Self {
26        NESTED_PARALLEL_DEPTH.with(|depth| depth.set(depth.get().saturating_add(1)));
27        Self
28    }
29}
30
31impl Drop for NestedParallelGuard {
32    #[inline]
33    fn drop(&mut self) {
34        NESTED_PARALLEL_DEPTH.with(|depth| depth.set(depth.get().saturating_sub(1)));
35    }
36}
37
38/// Run `body` with the current thread marked as inside a data-parallel row
39/// region, so any faer GEMM it issues (directly or transitively) pins to
40/// `Par::Seq` via [`effective_global_parallelism`] instead of re-fanning the
41/// global Rayon pool. The guard is held for exactly the duration of `body` and
42/// dropped on return — including early `?` returns from inside `body`, since the
43/// guard lives in this function's frame.
44///
45/// Call this from the per-chunk/per-row closure of an `into_par_iter` whose body
46/// performs GEMM, to prevent the Rayon-pool × faer-pool oversubscription.
47#[inline]
48pub fn with_nested_parallel<T>(body: impl FnOnce() -> T) -> T {
49    let guard = NestedParallelGuard::enter();
50    let out = body();
51    drop(guard);
52    out
53}
54
55/// `true` when the current thread is inside at least one [`NestedParallelGuard`]
56/// scope, i.e. a parallel row reduction is already in flight on this thread.
57#[inline]
58pub fn in_nested_parallel_region() -> bool {
59    NESTED_PARALLEL_DEPTH.with(|depth| depth.get() > 0)
60}
61
62/// faer parallelism policy that respects nested data-parallel regions: returns
63/// faer's global policy at the top level, but `Par::Seq` once a
64/// [`NestedParallelGuard`] is active so a GEMM issued from inside a parallel row
65/// fan-out does not multiply the live thread count against the outer pool.
66///
67/// Use this in place of `faer::get_global_parallelism()` for any matmul that can
68/// be reached from inside a row-parallel closure.
69#[inline]
70pub fn effective_global_parallelism() -> Par {
71    if in_nested_parallel_region() {
72        Par::Seq
73    } else {
74        get_global_parallelism()
75    }
76}
77
78#[derive(Debug, Error)]
79pub enum FaerLinalgError {
80    #[error("Factorization failed in {context}")]
81    FactorizationFailed { context: &'static str },
82    #[error("SVD failed to converge in {context}")]
83    SvdNoConvergence { context: &'static str },
84    #[error("Self-adjoint eigendecomposition input contains non-finite values in {context}")]
85    SelfAdjointEigenNonFiniteInput { context: &'static str },
86    #[error("Self-adjoint eigendecomposition failed: {0:?}")]
87    SelfAdjointEigen(solvers::EvdError),
88    #[error("Cholesky factorization failed: {0:?}")]
89    Cholesky(solvers::LltError),
90    #[error("LDLT factorization failed: {0:?}")]
91    Ldlt(solvers::LdltError),
92}
93
94pub enum FaerSymmetricFactor {
95    Llt(FaerLlt<f64>),
96    Ldlt(FaerLdlt<f64>),
97    Lblt(FaerLblt<f64>),
98}
99
100#[inline]
101pub fn cholesky_factor_logdet(factor: MatRef<'_, f64>) -> f64 {
102    2.0 * diagonal_log_sum(factor.diagonal())
103}
104
105#[inline]
106fn diagonal_log_sum(diagonal: DiagRef<'_, f64>) -> f64 {
107    diagonal
108        .column_vector()
109        .iter()
110        .map(|&x| x.ln())
111        .sum::<f64>()
112}
113
114impl FaerSymmetricFactor {
115    /// Returns the dimension of the factorized square matrix.
116    #[inline]
117    pub fn n(&self) -> usize {
118        use faer::linalg::solvers::ShapeCore;
119        match self {
120            FaerSymmetricFactor::Llt(f) => f.nrows(),
121            FaerSymmetricFactor::Ldlt(f) => f.nrows(),
122            FaerSymmetricFactor::Lblt(f) => f.nrows(),
123        }
124    }
125
126    #[inline]
127    pub fn solve(&self, rhs: MatRef<'_, f64>) -> Mat<f64> {
128        match self {
129            FaerSymmetricFactor::Llt(f) => f.solve(rhs),
130            FaerSymmetricFactor::Ldlt(f) => f.solve(rhs),
131            FaerSymmetricFactor::Lblt(f) => f.solve(rhs),
132        }
133    }
134
135    #[inline]
136    pub fn solve_in_place(&self, rhs: MatMut<'_, f64>) {
137        match self {
138            FaerSymmetricFactor::Llt(f) => f.solve_in_place(rhs),
139            FaerSymmetricFactor::Ldlt(f) => f.solve_in_place(rhs),
140            FaerSymmetricFactor::Lblt(f) => f.solve_in_place(rhs),
141        }
142    }
143}
144
145impl crate::matrix::FactorizedSystem for FaerSymmetricFactor {
146    fn solve(&self, rhs: &Array1<f64>) -> Result<Array1<f64>, String> {
147        let mut out = rhs.clone();
148        let mut out_mat = array1_to_col_matmut(&mut out);
149        self.solve_in_place(out_mat.as_mut());
150        if !out.iter().all(|v| v.is_finite()) {
151            return Err("symmetric factor solve produced non-finite values".to_string());
152        }
153        Ok(out)
154    }
155
156    fn solvemulti(&self, rhs: &Array2<f64>) -> Result<Array2<f64>, String> {
157        let mut out = Array2::<f64>::zeros(rhs.raw_dim());
158        for j in 0..rhs.ncols() {
159            for i in 0..rhs.nrows() {
160                out[[i, j]] = rhs[[i, j]];
161            }
162        }
163        let mut out_mat = array2_to_matmut(&mut out);
164        self.solve_in_place(out_mat.as_mut());
165        if !out.iter().all(|v| v.is_finite()) {
166            return Err("symmetric factor multi-solve produced non-finite values".to_string());
167        }
168        Ok(out)
169    }
170
171    fn logdet(&self) -> f64 {
172        match self {
173            FaerSymmetricFactor::Llt(f) => cholesky_factor_logdet(f.L()),
174            FaerSymmetricFactor::Ldlt(f) => diagonal_log_sum(f.D()),
175            FaerSymmetricFactor::Lblt(..) => {
176                // lblt doesn't easily expose diagonal determinant. Fallback to sparse or other representations if needed, but typically Lblt is indefinite!
177                // Actually faer doesn't easily expose lblt logdet since it has 2x2 blocks.
178                // For our ML systems, if we dropped to LBLT, the matrix was indefinite and logdet is ill-defined (or complex).
179                f64::NAN
180            }
181        }
182    }
183}
184
185/// Factorize a symmetric system with LLT -> LDLT -> LBLT fallback.
186#[inline]
187pub fn factorize_symmetricwith_fallback(
188    matrix: MatRef<'_, f64>,
189    side: Side,
190) -> Result<FaerSymmetricFactor, FaerLinalgError> {
191    if let Ok(llt) = FaerLlt::new(matrix, side) {
192        return Ok(FaerSymmetricFactor::Llt(llt));
193    }
194    let ldlt_err = match FaerLdlt::new(matrix, side) {
195        Ok(ldlt) => return Ok(FaerSymmetricFactor::Ldlt(ldlt)),
196        Err(err) => err,
197    };
198    let lblt = catch_unwind(AssertUnwindSafe(|| FaerLblt::new(matrix, side)))
199        .map_err(|_| FaerLinalgError::Ldlt(ldlt_err))?;
200    Ok(FaerSymmetricFactor::Lblt(lblt))
201}
202
203#[inline]
204const fn should_use_faer_matmul(m: usize, n: usize, k: usize) -> bool {
205    // Small, centralized dispatch policy:
206    // - stay on ndarray for tiny products to avoid setup overhead,
207    // - switch to faer GEMM/GEMV for moderate+ sizes.
208    const MIN_DIM: usize = 32;
209    const MIN_FLOP_SCALE: usize = 64 * 64;
210    (m >= MIN_DIM || n >= MIN_DIM || k >= MIN_DIM)
211        && m.saturating_mul(n).saturating_mul(k) >= MIN_FLOP_SCALE
212}
213
214#[inline]
215pub fn matmul_parallelism(m: usize, n: usize, k: usize) -> Par {
216    // Prefer a work-based policy over per-dimension thresholds.
217    // Tall/skinny products (e.g. N x p with large N, modest p) should still
218    // parallelize when total work is high.
219    const PAR_MIN_FLOP_SCALE: usize = 2_000_000;
220    const PAR_MIN_LONG_DIM: usize = 256;
221    let flop_scale = m.saturating_mul(n).saturating_mul(k);
222    let long_dim = m.max(n).max(k);
223    if flop_scale >= PAR_MIN_FLOP_SCALE && long_dim >= PAR_MIN_LONG_DIM {
224        // `effective_global_parallelism` collapses to `Par::Seq` when this GEMM
225        // is reached from inside a `NestedParallelGuard` row region, preventing
226        // the Rayon-pool × faer-pool multiplicative oversubscription.
227        effective_global_parallelism()
228    } else {
229        Par::Seq
230    }
231}
232
233#[inline]
234pub fn array2_to_matmut(array: &mut Array2<f64>) -> MatMut<'_, f64> {
235    let (rows, cols) = array.dim();
236    let strides = array.strides();
237
238    // Check if we can get a pointer.
239    // If the array is contiguous (either C or F order), or simply sliced with strides,
240    // faer can handle it as long as we pass the pointer and strides.
241    // However, as_mut_ptr() requires a mutable reference.
242    // ndarray's as_ptr/as_mut_ptr works for both layouts.
243
244    let s0 = strides[0];
245    let s1 = strides[1];
246
247    // SAFETY: array.as_mut_ptr() is ndarray's logical (0, 0) pointer, and
248    // ndarray's dimensions plus signed element strides describe every initialized
249    // element of this uniquely borrowed Array2 for the returned MatMut lifetime.
250    unsafe { MatMut::from_raw_parts_mut(array.as_mut_ptr(), rows, cols, s0, s1) }
251}
252
253#[inline]
254pub fn array1_to_col_matmut(array: &mut Array1<f64>) -> MatMut<'_, f64> {
255    let len = array.len();
256    let stride = array.strides()[0];
257    // SAFETY: array.as_mut_ptr() is ndarray's logical first-element pointer, and
258    // len plus the signed element stride describe every initialized element of
259    // this uniquely borrowed Array1 for the returned len×1 MatMut lifetime.
260    unsafe {
261        MatMut::from_raw_parts_mut(
262            array.as_mut_ptr(),
263            len,
264            1,
265            stride,
266            0, // col stride irrelevant for 1 column
267        )
268    }
269}
270
271/// Compute A^T * A using faer's SIMD-optimized GEMM.
272/// This is MUCH faster than ndarray's .t().dot() for matrices where n > ~100.
273///
274/// For a matrix A of shape (n, p), this computes the (p, p) result.
275/// Uses a zero-copy view for positive-stride layouts and copies only layouts
276/// with non-positive strides.
277#[inline]
278pub fn fast_ata<S: Data<Elem = f64>>(a: &ArrayBase<S, Ix2>) -> Array2<f64> {
279    let p = a.ncols();
280    let mut out = Array2::<f64>::zeros((p, p));
281    fast_ata_into(a, &mut out);
282    out
283}
284
285/// Compute A^T * A into a pre-allocated output buffer.
286/// `out` must be shaped (p, p) where A is (n, p).
287#[inline]
288pub fn fast_ata_into<S: Data<Elem = f64>>(a: &ArrayBase<S, Ix2>, out: &mut Array2<f64>) {
289    use faer::Accum;
290    use faer::linalg::matmul::triangular::{BlockStructure, matmul as tri_matmul};
291
292    let (n, p) = a.dim();
293    assert_eq!(out.nrows(), p, "output rows must match p");
294    assert_eq!(out.ncols(), p, "output cols must match p");
295
296    if !should_use_faer_matmul(p, p, n) {
297        out.assign(&a.t().dot(a));
298        return;
299    }
300
301    let mut outview = array2_to_matmut(out);
302
303    let aview = FaerArrayView::new(a);
304    let a_ref = aview.as_ref();
305    let a_t = a_ref.transpose();
306    let par = matmul_parallelism(p, p, n);
307    tri_matmul(
308        outview.as_mut(),
309        BlockStructure::TriangularLower,
310        Accum::Replace,
311        a_t,
312        BlockStructure::Rectangular,
313        a_ref,
314        BlockStructure::Rectangular,
315        1.0,
316        par,
317    );
318    // Mirror lower triangle to upper to populate the full symmetric output.
319    for i in 0..p {
320        for j in (i + 1)..p {
321            out[[i, j]] = out[[j, i]];
322        }
323    }
324}
325
326/// Compute A^T * B using faer's SIMD-optimized GEMM.
327/// For A of shape (n, p) and B of shape (n, q), this computes the (p, q) result.
328/// Uses zero-copy views when possible.
329#[inline]
330pub fn fast_atb<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
331    a: &ArrayBase<S1, Ix2>,
332    b: &ArrayBase<S2, Ix2>,
333) -> Array2<f64> {
334    if let Some(out) =
335        crate::gpu_hook::gpu_dispatch().and_then(|d| d.try_fast_atb(a.view(), b.view()))
336    {
337        return out;
338    }
339    let (n_a, p) = a.dim();
340    let q = b.ncols();
341    fast_atb_with_parallelism(a, b, matmul_parallelism(p, q, n_a))
342}
343
344/// Compute A^T * B with an explicit faer parallelism policy for callers that
345/// are already running independent products in an outer Rayon task.
346#[inline]
347pub fn fast_atb_with_parallelism<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
348    a: &ArrayBase<S1, Ix2>,
349    b: &ArrayBase<S2, Ix2>,
350    par: Par,
351) -> Array2<f64> {
352    use faer::linalg::matmul::matmul;
353    use faer::{Accum, Mat};
354
355    let (n_a, p) = a.dim();
356    let (n_b, q) = b.dim();
357    assert_eq!(n_a, n_b, "A and B must have same number of rows");
358
359    // For very small matrices, ndarray might be faster due to less overhead
360    if !should_use_faer_matmul(p, q, n_a) {
361        return a.t().dot(b);
362    }
363
364    let mut result = Mat::<f64>::zeros(p, q);
365
366    let aview = FaerArrayView::new(a);
367    let bview = FaerArrayView::new(b);
368    let a_ref = aview.as_ref();
369    let b_ref = bview.as_ref();
370
371    // dst = A^T * B
372    matmul(
373        result.as_mut(),
374        Accum::Replace,
375        a_ref.transpose(),
376        b_ref,
377        1.0,
378        par,
379    );
380
381    mat_to_array(result.as_ref())
382}
383
384/// Compute A * B^T using faer's SIMD-optimized GEMM.
385/// For A of shape (m, k) and B of shape (n, k), this computes the (m, n) result.
386#[inline]
387pub fn fast_abt<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
388    a: &ArrayBase<S1, Ix2>,
389    b: &ArrayBase<S2, Ix2>,
390) -> Array2<f64> {
391    use faer::linalg::matmul::matmul;
392    use faer::{Accum, Mat};
393
394    let (m, k_a) = a.dim();
395    let (n, k_b) = b.dim();
396    assert_eq!(
397        k_a, k_b,
398        "A and B must have same number of columns for A·Bᵀ"
399    );
400
401    if !should_use_faer_matmul(m, n, k_a) {
402        return a.dot(&b.t());
403    }
404
405    let mut result = Mat::<f64>::zeros(m, n);
406    let aview = FaerArrayView::new(a);
407    let bview = FaerArrayView::new(b);
408    let par = matmul_parallelism(m, n, k_a);
409    matmul(
410        result.as_mut(),
411        Accum::Replace,
412        aview.as_ref(),
413        bview.as_ref().transpose(),
414        1.0,
415        par,
416    );
417    mat_to_array(result.as_ref())
418}
419
420/// Compute A * B using faer's SIMD-optimized GEMM.
421/// For A of shape (n, p) and B of shape (p, q), this computes the (n, q) result.
422/// Uses zero-copy views when possible.
423#[inline]
424pub fn fast_ab<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
425    a: &ArrayBase<S1, Ix2>,
426    b: &ArrayBase<S2, Ix2>,
427) -> Array2<f64> {
428    if let Some(out) =
429        crate::gpu_hook::gpu_dispatch().and_then(|d| d.try_fast_ab(a.view(), b.view()))
430    {
431        return out;
432    }
433    let n = a.nrows();
434    let q = b.ncols();
435    let mut out = Array2::<f64>::zeros((n, q));
436    fast_ab_into(a, b, &mut out);
437    out
438}
439
440/// Compute A * v using faer's SIMD-optimized GEMV.
441/// For A of shape (n, p) and v of shape (p,), this computes the (n,) result.
442#[inline]
443pub fn fast_av<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
444    a: &ArrayBase<S1, Ix2>,
445    v: &ArrayBase<S2, Ix1>,
446) -> Array1<f64> {
447    if let Some(out) =
448        crate::gpu_hook::gpu_dispatch().and_then(|d| d.try_fast_av(a.view(), v.view()))
449    {
450        return out;
451    }
452    fast_av_impl(a, v)
453}
454
455#[inline]
456fn fast_av_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
457    a: &ArrayBase<S1, Ix2>,
458    v: &ArrayBase<S2, Ix1>,
459) -> Array1<f64> {
460    use faer::linalg::matmul::matmul;
461    use faer::{Accum, Mat};
462
463    let (n, p) = a.dim();
464    assert_eq!(p, v.len(), "A cols must match v length");
465
466    if !should_use_faer_matmul(n, 1, p) {
467        return a.dot(v);
468    }
469
470    let mut result = Mat::<f64>::zeros(n, 1);
471
472    let aview = FaerArrayView::new(a);
473    let vview = FaerColView::new(v);
474    let a_ref = aview.as_ref();
475    let v_ref = vview.as_ref();
476
477    let par = matmul_parallelism(n, 1, p);
478    matmul(result.as_mut(), Accum::Replace, a_ref, v_ref, 1.0, par);
479
480    let mut out = Array1::<f64>::zeros(n);
481    for i in 0..n {
482        out[i] = result[(i, 0)];
483    }
484    out
485}
486
487/// Compute A * v into a pre-allocated output buffer.
488/// `out` must be length n where A is (n, p) and v is length p.
489#[inline]
490pub fn fast_av_into<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
491    a: &ArrayBase<S1, Ix2>,
492    v: &ArrayBase<S2, Ix1>,
493    out: &mut Array1<f64>,
494) {
495    fast_av_into_impl(a, v, out);
496}
497
498#[inline]
499fn fast_av_into_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
500    a: &ArrayBase<S1, Ix2>,
501    v: &ArrayBase<S2, Ix1>,
502    out: &mut Array1<f64>,
503) {
504    use faer::Accum;
505    use faer::linalg::matmul::matmul;
506
507    let (n, p) = a.dim();
508    assert_eq!(v.len(), p, "vector length must match A cols");
509    assert_eq!(out.len(), n, "output length must match A rows");
510
511    if !should_use_faer_matmul(n, 1, p) {
512        out.assign(&a.dot(v));
513        return;
514    }
515
516    let mut outview = array1_to_col_matmut(out);
517
518    let aview = FaerArrayView::new(a);
519    let vview = FaerColView::new(v);
520    let a_ref = aview.as_ref();
521    let v_ref = vview.as_ref();
522    let par = matmul_parallelism(n, 1, p);
523    matmul(outview.as_mut(), Accum::Replace, a_ref, v_ref, 1.0, par);
524}
525
526/// Compute A * v into a pre-allocated `ArrayViewMut1` slice. Like
527/// [`fast_av_into`] but accepts a writable slice rather than `&mut Array1`,
528/// so callers can write directly into a sub-range of a larger buffer
529/// without intermediate allocation.
530///
531/// `out` must have length n where A is (n, p) and v is length p.
532#[inline]
533pub fn fast_av_view_into<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
534    a: &ArrayBase<S1, Ix2>,
535    v: &ArrayBase<S2, Ix1>,
536    out: ArrayViewMut1<'_, f64>,
537) {
538    fast_av_view_into_impl(a, v, out);
539}
540
541#[inline]
542fn fast_av_view_into_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
543    a: &ArrayBase<S1, Ix2>,
544    v: &ArrayBase<S2, Ix1>,
545    mut out: ArrayViewMut1<'_, f64>,
546) {
547    use faer::Accum;
548    use faer::linalg::matmul::matmul;
549
550    let (n, p) = a.dim();
551    assert_eq!(v.len(), p, "vector length must match A cols");
552    assert_eq!(out.len(), n, "output length must match A rows");
553
554    if !should_use_faer_matmul(n, 1, p) {
555        let prod = a.dot(v);
556        out.assign(&prod);
557        return;
558    }
559
560    let len = out.len();
561    let stride = out.strides()[0];
562    // SAFETY: out.as_mut_ptr() is ndarray's logical first-element pointer, and
563    // len plus the signed element stride describe every initialized element of
564    // this uniquely borrowed view for the returned len×1 MatMut lifetime.
565    let outview = unsafe {
566        MatMut::from_raw_parts_mut(
567            out.as_mut_ptr(),
568            len,
569            1,
570            stride,
571            0, // col stride irrelevant for 1 column
572        )
573    };
574
575    let aview = FaerArrayView::new(a);
576    let vview = FaerColView::new(v);
577    let a_ref = aview.as_ref();
578    let v_ref = vview.as_ref();
579    let par = matmul_parallelism(n, 1, p);
580    matmul(outview, Accum::Replace, a_ref, v_ref, 1.0, par);
581}
582
583/// Compute A^T * v using faer's SIMD-optimized GEMV.
584/// For A of shape (n, p) and v of shape (n,), this computes the (p,) result.
585#[inline]
586pub fn fast_atv<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
587    a: &ArrayBase<S1, Ix2>,
588    v: &ArrayBase<S2, Ix1>,
589) -> Array1<f64> {
590    if let Some(out) =
591        crate::gpu_hook::gpu_dispatch().and_then(|d| d.try_fast_atv(a.view(), v.view()))
592    {
593        return out;
594    }
595    fast_atv_impl(a, v)
596}
597
598#[inline]
599fn fast_atv_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
600    a: &ArrayBase<S1, Ix2>,
601    v: &ArrayBase<S2, Ix1>,
602) -> Array1<f64> {
603    use faer::Accum;
604    use faer::linalg::matmul::matmul;
605
606    let (n, p) = a.dim();
607    assert_eq!(n, v.len(), "A rows must match v length");
608
609    // For very small arrays, ndarray might be faster
610    if !should_use_faer_matmul(p, 1, n) {
611        return a.t().dot(v);
612    }
613
614    let mut out = Array1::<f64>::zeros(p);
615    let mut outview = array1_to_col_matmut(&mut out);
616
617    let aview = FaerArrayView::new(a);
618    let vview = FaerColView::new(v);
619    let a_ref = aview.as_ref();
620    let v_ref = vview.as_ref();
621
622    // dst = A^T * v (treating v as n×1 matrix)
623    let par = matmul_parallelism(p, 1, n);
624    matmul(
625        outview.as_mut(),
626        Accum::Replace,
627        a_ref.transpose(),
628        v_ref,
629        1.0,
630        par,
631    );
632
633    out
634}
635
636/// Compute A^T * v into a pre-allocated output buffer.
637/// `out` must be length p where A is (n, p) and v is length n.
638#[inline]
639pub fn fast_atv_into<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
640    a: &ArrayBase<S1, Ix2>,
641    v: &ArrayBase<S2, Ix1>,
642    out: &mut Array1<f64>,
643) {
644    fast_atv_into_impl(a, v, out);
645}
646
647#[inline]
648fn fast_atv_into_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
649    a: &ArrayBase<S1, Ix2>,
650    v: &ArrayBase<S2, Ix1>,
651    out: &mut Array1<f64>,
652) {
653    use faer::Accum;
654    use faer::linalg::matmul::matmul;
655
656    let (n, p) = a.dim();
657    assert_eq!(v.len(), n, "vector length must match A rows");
658    assert_eq!(out.len(), p, "output length must match A cols");
659
660    if !should_use_faer_matmul(p, 1, n) {
661        out.assign(&a.t().dot(v));
662        return;
663    }
664
665    let mut outview = array1_to_col_matmut(out);
666
667    let aview = FaerArrayView::new(a);
668    let vview = FaerColView::new(v);
669    let a_ref = aview.as_ref();
670    let v_ref = vview.as_ref();
671    let par = matmul_parallelism(p, 1, n);
672    matmul(
673        outview.as_mut(),
674        Accum::Replace,
675        a_ref.transpose(),
676        v_ref,
677        1.0,
678        par,
679    );
680}
681
682/// Compute A^T * diag(W) * A using streaming chunks to avoid O(n*p) allocation.
683#[inline]
684pub fn fast_xt_diag_x<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
685    x: &ArrayBase<S1, Ix2>,
686    w: &ArrayBase<S2, Ix1>,
687) -> Array2<f64> {
688    assert_eq!(
689        x.nrows(),
690        w.len(),
691        "fast_xt_diag_x row/weight length mismatch"
692    );
693    if let Some(out) =
694        crate::gpu_hook::gpu_dispatch().and_then(|d| d.try_fast_xt_diag_x(x.view(), w.view()))
695    {
696        return out;
697    }
698    let p = x.ncols();
699    fast_xt_diag_x_with_parallelism(x, w, matmul_parallelism(p, p, x.nrows()))
700}
701
702/// Compute A^T * diag(W) * A with an explicit faer parallelism policy for
703/// callers that parallelize multiple independent Hessian blocks externally.
704#[inline]
705pub fn fast_xt_diag_x_with_parallelism<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
706    x: &ArrayBase<S1, Ix2>,
707    w: &ArrayBase<S2, Ix1>,
708    par: Par,
709) -> Array2<f64> {
710    assert_eq!(
711        x.nrows(),
712        w.len(),
713        "fast_xt_diag_x_with_parallelism row/weight length mismatch"
714    );
715    fast_xt_diag_x_with_parallelism_impl(x, w, par)
716}
717
718#[inline]
719fn fast_xt_diag_x_with_parallelism_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
720    x: &ArrayBase<S1, Ix2>,
721    w: &ArrayBase<S2, Ix1>,
722    par: Par,
723) -> Array2<f64> {
724    use ndarray::ShapeBuilder;
725
726    let p = x.ncols();
727    // F-order result so the symmetric lower-triangle accumulation writes
728    // column-contiguously; the kernel mirrors to a full symmetric matrix.
729    let mut result = Array2::<f64>::zeros((p, p).f());
730    stream_weighted_crossprod_into(
731        x,
732        w,
733        &mut result,
734        CrossprodStructure::SymmetricLower,
735        CrossprodAccum::Replace,
736        par,
737    );
738    result
739}
740
741/// Output packaging for [`stream_weighted_crossprod_into`].
742#[derive(Clone, Copy, PartialEq, Eq, Debug)]
743pub enum CrossprodStructure {
744    /// Compute every entry of the (symmetric) Gram via full GEMM.
745    Full,
746    /// Accumulate only the lower triangle via triangular matmul (~50% fewer
747    /// FLOPs), then mirror once into the upper triangle for a full symmetric
748    /// result. Mathematically identical output to [`Full`](Self::Full).
749    SymmetricLower,
750}
751
752/// Accumulation policy for [`stream_weighted_crossprod_into`].
753#[derive(Clone, Copy, PartialEq, Eq, Debug)]
754pub enum CrossprodAccum {
755    /// Overwrite `out` with `Xᵀ·diag(W)·X`, ignoring prior contents.
756    Replace,
757    /// Add `Xᵀ·diag(W)·X` into the existing contents of `out`.
758    Add,
759}
760
761/// Shared dense weighted-Gram kernel: accumulate `Xᵀ·diag(W)·X` into `out`.
762///
763/// This is the single tuned implementation of the chunked row-scaling +
764/// matmul strategy; the matrix-returning (`fast_xt_diag_x*`) entry points and
765/// stream-in callers share it so that performance tuning, negative-weight
766/// handling, chunk sizing, and layout fixes land in exactly one place.
767///
768/// Computes the product as `Xᵀ·(W·X)` to preserve the sign of `W`: the prior
769/// `sqrt(max(0, w))`-then-Gram form clipped negative weights to zero, which
770/// corrupted observed-Hessian assembly when any block carried heavy residuals
771/// (e.g. under the logb σ link).
772///
773/// Peak working-set allocation is `chunk_rows × p × 8` bytes (~8 MB) rather
774/// than `n × p × 8` bytes for a materialized `W·X`.
775///
776/// `out` must be `p × p`. With [`CrossprodStructure::SymmetricLower`] the
777/// lower triangle is accumulated and then mirrored, so on return `out` holds
778/// the full symmetric matrix regardless of `structure`.
779pub fn stream_weighted_crossprod_into<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
780    x: &ArrayBase<S1, Ix2>,
781    w: &ArrayBase<S2, Ix1>,
782    out: &mut Array2<f64>,
783    structure: CrossprodStructure,
784    accum: CrossprodAccum,
785    par: Par,
786) {
787    use faer::Accum;
788    use faer::linalg::matmul::matmul;
789    use faer::linalg::matmul::triangular::{BlockStructure, matmul as tri_matmul};
790    use ndarray::s;
791
792    let (n, p) = x.dim();
793    assert_eq!(n, w.len(), "X rows must match W length");
794    assert_eq!(out.nrows(), p, "output rows must match X cols");
795    assert_eq!(out.ncols(), p, "output cols must match X cols");
796    if p == 0 {
797        return;
798    }
799    if n == 0 {
800        if accum == CrossprodAccum::Replace {
801            out.fill(0.0);
802        }
803        return;
804    }
805
806    if !should_use_faer_matmul(p, p, n) {
807        // Tiny products: ndarray's own GEMM avoids faer setup overhead.
808        let w_x = Array2::from_shape_fn((n, p), |(i, j)| w[i] * x[[i, j]]);
809        let gram = x.t().dot(&w_x);
810        match accum {
811            CrossprodAccum::Replace => out.assign(&gram),
812            CrossprodAccum::Add => *out += &gram,
813        }
814        return;
815    }
816
817    // Streaming chunked: peak allocation is chunk_rows × p instead of n × p.
818    const TARGET_BYTES: usize = 8 * 1024 * 1024;
819    const MIN_ROWS: usize = 512;
820    const MAX_ROWS: usize = 131_072;
821    let chunk_rows = (TARGET_BYTES / (p.max(1) * 8))
822        .clamp(MIN_ROWS, MAX_ROWS)
823        .min(n);
824
825    // Triangular accumulation requires a zero baseline in the lower triangle
826    // because each chunk's `Accum::Add` lands there; for a Replace request we
827    // zero up front and add every chunk, for an Add request the caller's
828    // contents are preserved and every chunk adds on top.
829    if accum == CrossprodAccum::Replace {
830        out.fill(0.0);
831    }
832
833    // Row-major wx_chunk so the per-row scaling loop has stride-1 writes
834    // alongside stride-1 reads from a row-major X. An F-order wx_chunk would
835    // force strided writes by `chunk_rows`, breaking vectorization and cache
836    // locality on the per-PIRLS-iter Hessian assembly. faer's matmul handles
837    // either layout via FaerArrayView.
838    let mut wx_chunk = Array2::<f64>::zeros((chunk_rows, p));
839
840    let x_is_row_major = x.is_standard_layout();
841    let w_slice_opt = w.as_slice();
842
843    // Scope the faer mutable view so its borrow on `out` ends before the
844    // symmetric mirror step.
845    {
846        let mut out_view = array2_to_matmut(out);
847        for start in (0..n).step_by(chunk_rows) {
848            let rows = (n - start).min(chunk_rows);
849            {
850                let chunk_slice = wx_chunk
851                    .as_slice_mut()
852                    .expect("row-major chunk is contiguous");
853                if x_is_row_major && let (Some(x_all), Some(w_all)) = (x.as_slice(), w_slice_opt) {
854                    for local in 0..rows {
855                        let src = start + local;
856                        let wi = w_all[src];
857                        let src_off = src * p;
858                        let dst_off = local * p;
859                        let src_row = &x_all[src_off..src_off + p];
860                        let dst_row = &mut chunk_slice[dst_off..dst_off + p];
861                        for col in 0..p {
862                            dst_row[col] = src_row[col] * wi;
863                        }
864                    }
865                } else {
866                    let x_slice = x.slice(s![start..start + rows, ..]);
867                    for local in 0..rows {
868                        let wi = w[start + local];
869                        let xrow = x_slice.row(local);
870                        let dst_off = local * p;
871                        let dst_row = &mut chunk_slice[dst_off..dst_off + p];
872                        for (col, xij) in xrow.iter().enumerate() {
873                            dst_row[col] = xij * wi;
874                        }
875                    }
876                }
877            }
878            let x_slice = x.slice(s![start..start + rows, ..]);
879            let wx_slice = wx_chunk.slice(s![0..rows, ..]);
880            let x_view = FaerArrayView::new(&x_slice);
881            let wx_view = FaerArrayView::new(&wx_slice);
882            match structure {
883                CrossprodStructure::SymmetricLower => {
884                    // X^T diag(W) X is symmetric; accumulate the lower triangle
885                    // only, then mirror once after the chunk loop. ~50% fewer
886                    // FLOPs vs. full GEMM.
887                    tri_matmul(
888                        out_view.as_mut(),
889                        BlockStructure::TriangularLower,
890                        Accum::Add,
891                        x_view.as_ref().transpose(),
892                        BlockStructure::Rectangular,
893                        wx_view.as_ref(),
894                        BlockStructure::Rectangular,
895                        1.0,
896                        par,
897                    );
898                }
899                CrossprodStructure::Full => {
900                    matmul(
901                        out_view.as_mut(),
902                        Accum::Add,
903                        x_view.as_ref().transpose(),
904                        wx_view.as_ref(),
905                        1.0,
906                        par,
907                    );
908                }
909            }
910        }
911    }
912
913    if structure == CrossprodStructure::SymmetricLower {
914        // Mirror lower triangle to upper for a full symmetric output.
915        for i in 0..p {
916            for j in (i + 1)..p {
917                out[[i, j]] = out[[j, i]];
918            }
919        }
920    }
921}
922
923/// Compute A^T * diag(W) * B using streaming chunks.
924#[inline]
925pub fn fast_xt_diag_y<S1: Data<Elem = f64>, S2: Data<Elem = f64>, S3: Data<Elem = f64>>(
926    x: &ArrayBase<S1, Ix2>,
927    w: &ArrayBase<S2, Ix1>,
928    y: &ArrayBase<S3, Ix2>,
929) -> Array2<f64> {
930    assert_eq!(x.nrows(), y.nrows(), "fast_xt_diag_y X/Y row mismatch");
931    assert_eq!(
932        y.nrows(),
933        w.len(),
934        "fast_xt_diag_y row/weight length mismatch"
935    );
936    if let Some(out) = crate::gpu_hook::gpu_dispatch()
937        .and_then(|d| d.try_fast_xt_diag_y(x.view(), w.view(), y.view()))
938    {
939        return out;
940    }
941    fast_xt_diag_y_impl(x, w, y)
942}
943
944#[inline]
945fn fast_xt_diag_y_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>, S3: Data<Elem = f64>>(
946    x: &ArrayBase<S1, Ix2>,
947    w: &ArrayBase<S2, Ix1>,
948    y: &ArrayBase<S3, Ix2>,
949) -> Array2<f64> {
950    use faer::Accum;
951    use faer::linalg::matmul::matmul;
952    use ndarray::{ShapeBuilder, s};
953
954    let (n, q) = y.dim();
955    let px = x.ncols();
956    assert_eq!(n, w.len(), "Y rows must match W length");
957    assert_eq!(n, x.nrows(), "X rows must match Y rows");
958    if n == 0 || px == 0 || q == 0 {
959        return Array2::<f64>::zeros((px, q));
960    }
961    if !should_use_faer_matmul(px, q, n) {
962        let w_y = Array2::from_shape_fn((n, q), |(i, j)| w[i] * y[[i, j]]);
963        return x.t().dot(&w_y);
964    }
965
966    // Streaming: only allocate chunk_rows × q for the weighted Y slice.
967    const TARGET_BYTES: usize = 8 * 1024 * 1024;
968    const MIN_ROWS: usize = 512;
969    const MAX_ROWS: usize = 131_072;
970    let total_cols = px + q;
971    let chunk_rows = (TARGET_BYTES / (total_cols.max(1) * 8))
972        .clamp(MIN_ROWS, MAX_ROWS)
973        .min(n);
974
975    let mut result = Array2::<f64>::zeros((px, q).f());
976    // Row-major wy_chunk — same rationale as fast_xt_diag_x: stride-1
977    // writes alongside stride-1 reads from a row-major Y.
978    let mut wy_chunk = Array2::<f64>::zeros((chunk_rows, q));
979
980    let y_is_row_major = y.is_standard_layout();
981    let w_slice_opt = w.as_slice();
982
983    {
984        let mut out_view = array2_to_matmut(&mut result);
985
986        for start in (0..n).step_by(chunk_rows) {
987            let rows = (n - start).min(chunk_rows);
988            {
989                let chunk_slice = wy_chunk
990                    .as_slice_mut()
991                    .expect("row-major chunk is contiguous");
992                if y_is_row_major && let (Some(y_all), Some(w_all)) = (y.as_slice(), w_slice_opt) {
993                    for local in 0..rows {
994                        let src = start + local;
995                        let wi = w_all[src];
996                        let src_off = src * q;
997                        let dst_off = local * q;
998                        let src_row = &y_all[src_off..src_off + q];
999                        let dst_row = &mut chunk_slice[dst_off..dst_off + q];
1000                        for col in 0..q {
1001                            dst_row[col] = src_row[col] * wi;
1002                        }
1003                    }
1004                } else {
1005                    let y_slice = y.slice(s![start..start + rows, ..]);
1006                    for local in 0..rows {
1007                        let wi = w[start + local];
1008                        let yrow = y_slice.row(local);
1009                        let dst_off = local * q;
1010                        let dst_row = &mut chunk_slice[dst_off..dst_off + q];
1011                        for (col, yij) in yrow.iter().enumerate() {
1012                            dst_row[col] = yij * wi;
1013                        }
1014                    }
1015                }
1016            }
1017            let x_slice = x.slice(s![start..start + rows, ..]);
1018            let wy_slice = wy_chunk.slice(s![0..rows, ..]);
1019            let x_view = FaerArrayView::new(&x_slice);
1020            let wy_view = FaerArrayView::new(&wy_slice);
1021            let par = matmul_parallelism(px, q, rows);
1022            matmul(
1023                out_view.as_mut(),
1024                Accum::Add,
1025                x_view.as_ref().transpose(),
1026                wy_view.as_ref(),
1027                1.0,
1028                par,
1029            );
1030        }
1031    }
1032
1033    result
1034}
1035
1036/// Compute the 2×2 block joint Hessian in a single streaming pass:
1037///   [X_a^T diag(w_aa) X_a,   X_a^T diag(w_ab) X_b]
1038///   [X_b^T diag(w_ab) X_a,   X_b^T diag(w_bb) X_b]
1039///
1040/// This reads X_a and X_b once per chunk instead of twice (saving 50% bandwidth).
1041pub fn fast_joint_hessian_2x2<
1042    S1: Data<Elem = f64>,
1043    S2: Data<Elem = f64>,
1044    S3: Data<Elem = f64>,
1045    S4: Data<Elem = f64>,
1046    S5: Data<Elem = f64>,
1047>(
1048    x_a: &ArrayBase<S1, Ix2>,
1049    x_b: &ArrayBase<S2, Ix2>,
1050    w_aa: &ArrayBase<S3, Ix1>,
1051    w_ab: &ArrayBase<S4, Ix1>,
1052    w_bb: &ArrayBase<S5, Ix1>,
1053) -> Array2<f64> {
1054    if let Some(out) = crate::gpu_hook::gpu_dispatch().and_then(|d| {
1055        d.try_fast_joint_hessian_2x2(
1056            x_a.view(),
1057            x_b.view(),
1058            w_aa.view(),
1059            w_ab.view(),
1060            w_bb.view(),
1061        )
1062    }) {
1063        return out;
1064    }
1065    fast_joint_hessian_2x2_impl(x_a, x_b, w_aa, w_ab, w_bb)
1066}
1067
1068#[inline]
1069fn fast_joint_hessian_2x2_impl<
1070    S1: Data<Elem = f64>,
1071    S2: Data<Elem = f64>,
1072    S3: Data<Elem = f64>,
1073    S4: Data<Elem = f64>,
1074    S5: Data<Elem = f64>,
1075>(
1076    x_a: &ArrayBase<S1, Ix2>,
1077    x_b: &ArrayBase<S2, Ix2>,
1078    w_aa: &ArrayBase<S3, Ix1>,
1079    w_ab: &ArrayBase<S4, Ix1>,
1080    w_bb: &ArrayBase<S5, Ix1>,
1081) -> Array2<f64> {
1082    use faer::Accum;
1083    use faer::linalg::matmul::matmul;
1084    use ndarray::{ShapeBuilder, s};
1085
1086    let n = x_a.nrows();
1087    let pa = x_a.ncols();
1088    let pb = x_b.ncols();
1089    let total = pa + pb;
1090    assert_eq!(n, x_b.nrows());
1091    assert_eq!(n, w_aa.len());
1092    assert_eq!(n, w_ab.len());
1093    assert_eq!(n, w_bb.len());
1094
1095    if n == 0 || total == 0 {
1096        return Array2::<f64>::zeros((total, total));
1097    }
1098
1099    // For small problems, fall back to separate computations
1100    if !should_use_faer_matmul(pa.max(pb), pa.max(pb), n) {
1101        let waa_xa = Array2::from_shape_fn((n, pa), |(i, j)| w_aa[i] * x_a[[i, j]]);
1102        let wab_xb = Array2::from_shape_fn((n, pb), |(i, j)| w_ab[i] * x_b[[i, j]]);
1103        let wbb_xb = Array2::from_shape_fn((n, pb), |(i, j)| w_bb[i] * x_b[[i, j]]);
1104        let mut out = Array2::<f64>::zeros((total, total));
1105        out.slice_mut(s![..pa, ..pa]).assign(&x_a.t().dot(&waa_xa));
1106        out.slice_mut(s![..pa, pa..]).assign(&x_a.t().dot(&wab_xb));
1107        out.slice_mut(s![pa.., pa..]).assign(&x_b.t().dot(&wbb_xb));
1108        // Mirror upper to lower
1109        for i in 0..total {
1110            for j in 0..i {
1111                out[[i, j]] = out[[j, i]];
1112            }
1113        }
1114        return out;
1115    }
1116
1117    const TARGET_BYTES: usize = 8 * 1024 * 1024;
1118    const MIN_ROWS: usize = 512;
1119    const MAX_ROWS: usize = 131_072;
1120    // Need buffers for: waa_xa(chunk×pa) + wab_xb(chunk×pb) + wbb_xb(chunk×pb)
1121    let cols_needed = pa + 2 * pb;
1122    let chunk_rows = (TARGET_BYTES / (cols_needed.max(1) * 8))
1123        .clamp(MIN_ROWS, MAX_ROWS)
1124        .min(n);
1125
1126    let mut out = Array2::<f64>::zeros((total, total).f());
1127    // Row-major weighted buffers so the per-row scale loops have stride-1
1128    // writes (the previous F-order layout strided writes by chunk_rows
1129    // across `pa` / `pb`, gutting vectorization on the per-PIRLS-iter
1130    // joint Hessian assembly). faer's matmul handles either layout.
1131    let mut waa_xa_buf = Array2::<f64>::zeros((chunk_rows, pa));
1132    let mut wab_xb_buf = Array2::<f64>::zeros((chunk_rows, pb));
1133    let mut wbb_xb_buf = Array2::<f64>::zeros((chunk_rows, pb));
1134
1135    let xa_is_row_major = x_a.is_standard_layout();
1136    let xb_is_row_major = x_b.is_standard_layout();
1137    let waa_slice_opt = w_aa.as_slice();
1138    let wab_slice_opt = w_ab.as_slice();
1139    let wbb_slice_opt = w_bb.as_slice();
1140
1141    {
1142        let mut out_mat = array2_to_matmut(&mut out);
1143
1144        for start in (0..n).step_by(chunk_rows) {
1145            let rows = (n - start).min(chunk_rows);
1146            let xa_slice = x_a.slice(s![start..start + rows, ..]);
1147            let xb_slice = x_b.slice(s![start..start + rows, ..]);
1148
1149            // Weight X_a and X_b in a single pass through this chunk.
1150            {
1151                let waa_chunk = waa_xa_buf
1152                    .as_slice_mut()
1153                    .expect("row-major waa chunk is contiguous");
1154                let wab_chunk = wab_xb_buf
1155                    .as_slice_mut()
1156                    .expect("row-major wab chunk is contiguous");
1157                let wbb_chunk = wbb_xb_buf
1158                    .as_slice_mut()
1159                    .expect("row-major wbb chunk is contiguous");
1160
1161                if xa_is_row_major
1162                    && xb_is_row_major
1163                    && let (Some(xa_all), Some(xb_all)) = (x_a.as_slice(), x_b.as_slice())
1164                    && let (Some(waa_all), Some(wab_all), Some(wbb_all)) =
1165                        (waa_slice_opt, wab_slice_opt, wbb_slice_opt)
1166                {
1167                    for local in 0..rows {
1168                        let i = start + local;
1169                        let waa_i = waa_all[i];
1170                        let wab_i = wab_all[i];
1171                        let wbb_i = wbb_all[i];
1172                        let xa_off = i * pa;
1173                        let xa_row = &xa_all[xa_off..xa_off + pa];
1174                        let xb_off = i * pb;
1175                        let xb_row = &xb_all[xb_off..xb_off + pb];
1176                        let waa_off = local * pa;
1177                        let wab_off = local * pb;
1178                        let wbb_off = local * pb;
1179                        let waa_row = &mut waa_chunk[waa_off..waa_off + pa];
1180                        for col in 0..pa {
1181                            waa_row[col] = xa_row[col] * waa_i;
1182                        }
1183                        let wab_row = &mut wab_chunk[wab_off..wab_off + pb];
1184                        let wbb_row = &mut wbb_chunk[wbb_off..wbb_off + pb];
1185                        for col in 0..pb {
1186                            let xij = xb_row[col];
1187                            wab_row[col] = xij * wab_i;
1188                            wbb_row[col] = xij * wbb_i;
1189                        }
1190                    }
1191                } else {
1192                    for local in 0..rows {
1193                        let i = start + local;
1194                        let waa_i = w_aa[i];
1195                        let wab_i = w_ab[i];
1196                        let wbb_i = w_bb[i];
1197                        let waa_off = local * pa;
1198                        let wab_off = local * pb;
1199                        let wbb_off = local * pb;
1200                        let waa_row = &mut waa_chunk[waa_off..waa_off + pa];
1201                        let xa_row = xa_slice.row(local);
1202                        for (col, xij) in xa_row.iter().enumerate() {
1203                            waa_row[col] = xij * waa_i;
1204                        }
1205                        let wab_row = &mut wab_chunk[wab_off..wab_off + pb];
1206                        let wbb_row = &mut wbb_chunk[wbb_off..wbb_off + pb];
1207                        let xb_row = xb_slice.row(local);
1208                        for (col, xij) in xb_row.iter().enumerate() {
1209                            wab_row[col] = xij * wab_i;
1210                            wbb_row[col] = xij * wbb_i;
1211                        }
1212                    }
1213                }
1214            }
1215
1216            let xa_view = FaerArrayView::new(&xa_slice);
1217            let xb_view = FaerArrayView::new(&xb_slice);
1218            let waa_xa_slice = waa_xa_buf.slice(s![0..rows, ..]);
1219            let wab_xb_slice = wab_xb_buf.slice(s![0..rows, ..]);
1220            let wbb_xb_slice = wbb_xb_buf.slice(s![0..rows, ..]);
1221            let waa_xa_view = FaerArrayView::new(&waa_xa_slice);
1222            let wab_xb_view = FaerArrayView::new(&wab_xb_slice);
1223            let wbb_xb_view = FaerArrayView::new(&wbb_xb_slice);
1224
1225            // Block [0..pa, 0..pa]: X_a^T diag(w_aa) X_a
1226            matmul(
1227                out_mat.rb_mut().submatrix_mut(0, 0, pa, pa),
1228                Accum::Add,
1229                xa_view.as_ref().transpose(),
1230                waa_xa_view.as_ref(),
1231                1.0,
1232                matmul_parallelism(pa, pa, rows),
1233            );
1234            // Block [0..pa, pa..total]: X_a^T diag(w_ab) X_b
1235            matmul(
1236                out_mat.rb_mut().submatrix_mut(0, pa, pa, pb),
1237                Accum::Add,
1238                xa_view.as_ref().transpose(),
1239                wab_xb_view.as_ref(),
1240                1.0,
1241                matmul_parallelism(pa, pb, rows),
1242            );
1243            // Block [pa..total, pa..total]: X_b^T diag(w_bb) X_b
1244            matmul(
1245                out_mat.rb_mut().submatrix_mut(pa, pa, pb, pb),
1246                Accum::Add,
1247                xb_view.as_ref().transpose(),
1248                wbb_xb_view.as_ref(),
1249                1.0,
1250                matmul_parallelism(pb, pb, rows),
1251            );
1252        }
1253    } // out_mat dropped
1254    // Mirror upper triangle to lower
1255    for i in 0..total {
1256        for j in 0..i {
1257            out[[i, j]] = out[[j, i]];
1258        }
1259    }
1260    out
1261}
1262
1263fn mat_to_array(mat: MatRef<'_, f64>) -> Array2<f64> {
1264    let nrows = mat.nrows();
1265    let ncols = mat.ncols();
1266    let mut out = Array2::<f64>::zeros((nrows, ncols));
1267    if nrows == 0 || ncols == 0 {
1268        return out;
1269    }
1270    // ndarray is row-major by default. Write row-by-row for best cache behavior
1271    // on the output side.
1272    if let Some(out_slice) = out.as_slice_memory_order_mut() {
1273        // Row-major: out_slice[i * ncols + j] = mat[(i, j)]
1274        for i in 0..nrows {
1275            let row_start = i * ncols;
1276            for j in 0..ncols {
1277                out_slice[row_start + j] = mat[(i, j)];
1278            }
1279        }
1280    } else {
1281        for j in 0..ncols {
1282            for i in 0..nrows {
1283                out[[i, j]] = mat[(i, j)];
1284            }
1285        }
1286    }
1287    out
1288}
1289
1290/// Write faer matmul result A*B directly into a pre-allocated ndarray Array2.
1291/// Avoids the intermediate faer::Mat allocation and mat_to_array copy.
1292#[inline]
1293pub fn fast_ab_into<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
1294    a: &ArrayBase<S1, Ix2>,
1295    b: &ArrayBase<S2, Ix2>,
1296    out: &mut Array2<f64>,
1297) {
1298    fast_ab_into_impl(a, b, out);
1299}
1300
1301#[inline]
1302fn fast_ab_into_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
1303    a: &ArrayBase<S1, Ix2>,
1304    b: &ArrayBase<S2, Ix2>,
1305    out: &mut Array2<f64>,
1306) {
1307    use faer::Accum;
1308    use faer::linalg::matmul::matmul;
1309
1310    let (n, p) = a.dim();
1311    let (p_b, q) = b.dim();
1312    assert_eq!(p, p_b, "A and B must have compatible inner dimensions");
1313    assert_eq!(out.dim(), (n, q), "output dimensions must match A*B result");
1314
1315    if !should_use_faer_matmul(n, q, p) {
1316        out.assign(&a.dot(b));
1317        return;
1318    }
1319
1320    let aview = FaerArrayView::new(a);
1321    let bview = FaerArrayView::new(b);
1322    let a_ref = aview.as_ref();
1323    let b_ref = bview.as_ref();
1324
1325    let par = matmul_parallelism(n, q, p);
1326    let mut outview = array2_to_matmut(out);
1327    matmul(outview.as_mut(), Accum::Replace, a_ref, b_ref, 1.0, par);
1328}
1329
1330fn diag_to_array(diag: DiagRef<'_, f64>) -> Array1<f64> {
1331    let mat = diag.column_vector().as_mat();
1332    let mut out = Array1::<f64>::zeros(mat.nrows());
1333    for i in 0..mat.nrows() {
1334        out[i] = mat[(i, 0)];
1335    }
1336    out
1337}
1338
1339pub struct FaerArrayView<'a> {
1340    ptr: *const f64,
1341    rows: usize,
1342    cols: usize,
1343    row_stride: isize,
1344    col_stride: isize,
1345    owned: Option<Array2<f64>>,
1346    marker: PhantomData<&'a f64>,
1347}
1348
1349impl<'a> FaerArrayView<'a> {
1350    #[inline]
1351    pub fn new<S: Data<Elem = f64>>(array: &'a ArrayBase<S, Ix2>) -> Self {
1352        let (rows, cols) = array.dim();
1353        let strides = array.strides();
1354        // Guard against layouts that can alias or reverse memory traversal (e.g.
1355        // negative/zero strides). These can violate assumptions in faer kernels.
1356        // For such layouts we materialize a compact owned copy.
1357        if strides[0] <= 0 || strides[1] <= 0 {
1358            let owned = array.to_owned();
1359            let owned_strides = owned.strides();
1360            return Self {
1361                ptr: owned.as_ptr(),
1362                rows,
1363                cols,
1364                row_stride: owned_strides[0],
1365                col_stride: owned_strides[1],
1366                owned: Some(owned),
1367                marker: PhantomData,
1368            };
1369        }
1370
1371        Self {
1372            ptr: array.as_ptr(),
1373            rows,
1374            cols,
1375            row_stride: strides[0],
1376            col_stride: strides[1],
1377            owned: None,
1378            marker: PhantomData,
1379        }
1380    }
1381
1382    #[inline]
1383    pub fn as_ref(&self) -> MatRef<'_, f64> {
1384        let (ptr, rows, cols, row_stride, col_stride) = if let Some(owned) = &self.owned {
1385            let strides = owned.strides();
1386            (
1387                owned.as_ptr(),
1388                owned.nrows(),
1389                owned.ncols(),
1390                strides[0],
1391                strides[1],
1392            )
1393        } else {
1394            (
1395                self.ptr,
1396                self.rows,
1397                self.cols,
1398                self.row_stride,
1399                self.col_stride,
1400            )
1401        };
1402        // SAFETY: ptr/shape/strides come from either a live ndarray view
1403        // (positive strides, validated bounds/alignment) or the owned
1404        // compact copy held inside this wrapper — no mutable aliasing.
1405        unsafe { MatRef::from_raw_parts(ptr, rows, cols, row_stride, col_stride) }
1406    }
1407}
1408
1409pub struct FaerColView<'a> {
1410    ptr: *const f64,
1411    len: usize,
1412    stride: isize,
1413    owned: Option<Array1<f64>>,
1414    marker: PhantomData<&'a f64>,
1415}
1416
1417impl<'a> FaerColView<'a> {
1418    #[inline]
1419    pub fn new<S: Data<Elem = f64>>(array: &'a ArrayBase<S, Ix1>) -> Self {
1420        let len = array.len();
1421        let stride = array.strides()[0];
1422        if stride <= 0 {
1423            let owned = array.to_owned();
1424            return Self {
1425                ptr: owned.as_ptr(),
1426                len,
1427                stride: 1,
1428                owned: Some(owned),
1429                marker: PhantomData,
1430            };
1431        }
1432        Self {
1433            ptr: array.as_ptr(),
1434            len,
1435            stride,
1436            owned: None,
1437            marker: PhantomData,
1438        }
1439    }
1440
1441    #[inline]
1442    pub fn as_ref(&self) -> MatRef<'_, f64> {
1443        let (ptr, len, stride) = if let Some(owned) = &self.owned {
1444            (owned.as_ptr(), owned.len(), 1)
1445        } else {
1446            (self.ptr, self.len, self.stride)
1447        };
1448        // SAFETY: ptr/len/stride come from either a live ndarray column
1449        // (positive stride, validated bounds/alignment) or the owned
1450        // compact copy; ncols=1 so the 0 col-stride is unused.
1451        unsafe { MatRef::from_raw_parts(ptr, len, 1, stride, 0) }
1452    }
1453}
1454
1455pub trait FaerSvd {
1456    fn svd(
1457        &self,
1458        compute_u: bool,
1459        computevt: bool,
1460    ) -> Result<(Option<Array2<f64>>, Array1<f64>, Option<Array2<f64>>), FaerLinalgError>;
1461}
1462
1463impl<S: Data<Elem = f64>> FaerSvd for ArrayBase<S, Ix2> {
1464    fn svd(
1465        &self,
1466        compute_u: bool,
1467        computevt: bool,
1468    ) -> Result<(Option<Array2<f64>>, Array1<f64>, Option<Array2<f64>>), FaerLinalgError> {
1469        let faerview = FaerArrayView::new(self);
1470        let faer_mat = faerview.as_ref();
1471        if !compute_u && !computevt {
1472            let (rows, cols) = faer_mat.shape();
1473            let mut singular = Diag::<f64>::zeros(rows.min(cols));
1474            let par = get_global_parallelism();
1475            let mut mem = MemBuffer::new(svd::svd_scratch::<f64>(
1476                rows,
1477                cols,
1478                ComputeSvdVectors::No,
1479                ComputeSvdVectors::No,
1480                par,
1481                Default::default(),
1482            ));
1483            let stack = MemStack::new(&mut mem);
1484            svd::svd(
1485                faer_mat,
1486                singular.as_mut(),
1487                None,
1488                None,
1489                par,
1490                stack,
1491                Default::default(),
1492            )
1493            .map_err(|_| FaerLinalgError::SvdNoConvergence {
1494                context: "faer SVD singular values only",
1495            })?;
1496            let singularvalues = diag_to_array(singular.as_ref());
1497            return Ok((None, singularvalues, None));
1498        }
1499
1500        let (rows, cols) = faer_mat.shape();
1501        let rank = rows.min(cols);
1502        let compute_u_flag = if compute_u {
1503            ComputeSvdVectors::Thin
1504        } else {
1505            ComputeSvdVectors::No
1506        };
1507        let computev_flag = if computevt {
1508            ComputeSvdVectors::Thin
1509        } else {
1510            ComputeSvdVectors::No
1511        };
1512
1513        let mut singular = Diag::<f64>::zeros(rows.min(cols));
1514        let mut u_storage = compute_u.then(|| Mat::<f64>::zeros(rows, rank));
1515        let mut v_storage = computevt.then(|| Mat::<f64>::zeros(cols, rank));
1516
1517        let par = get_global_parallelism();
1518        let mut mem = MemBuffer::new(svd::svd_scratch::<f64>(
1519            rows,
1520            cols,
1521            compute_u_flag,
1522            computev_flag,
1523            par,
1524            Default::default(),
1525        ));
1526        let stack = MemStack::new(&mut mem);
1527
1528        svd::svd(
1529            faer_mat.as_ref(),
1530            singular.as_mut(),
1531            u_storage.as_mut().map(|mat| mat.as_mut()),
1532            v_storage.as_mut().map(|mat| mat.as_mut()),
1533            par,
1534            stack,
1535            Default::default(),
1536        )
1537        .map_err(|_| FaerLinalgError::SvdNoConvergence {
1538            context: "faer SVD with vectors",
1539        })?;
1540
1541        let singularvalues = diag_to_array(singular.as_ref());
1542        let u_opt = u_storage.map(|mat| mat_to_array(mat.as_ref()));
1543        let vt_opt = v_storage.map(|mat| {
1544            let mat_ref = mat.as_ref();
1545            let mut out = Array2::<f64>::zeros((mat_ref.ncols(), mat_ref.nrows()));
1546            for j in 0..mat_ref.nrows() {
1547                for i in 0..mat_ref.ncols() {
1548                    out[[i, j]] = mat_ref[(j, i)];
1549                }
1550            }
1551            out
1552        });
1553
1554        Ok((u_opt, singularvalues, vt_opt))
1555    }
1556}
1557
1558pub trait FaerEigh {
1559    fn eigh(&self, side: Side) -> Result<(Array1<f64>, Array2<f64>), FaerLinalgError>;
1560}
1561
1562impl<S: Data<Elem = f64>> FaerEigh for ArrayBase<S, Ix2> {
1563    fn eigh(&self, side: Side) -> Result<(Array1<f64>, Array2<f64>), FaerLinalgError> {
1564        fn try_eigh(
1565            matrix: &Array2<f64>,
1566            side: Side,
1567        ) -> Result<(Array1<f64>, Array2<f64>), FaerLinalgError> {
1568            let faerview = FaerArrayView::new(matrix);
1569            let eigen = catch_unwind(AssertUnwindSafe(|| {
1570                faerview.as_ref().self_adjoint_eigen(side)
1571            }))
1572            .map_err(|_| FaerLinalgError::FactorizationFailed {
1573                context: "self-adjoint eigendecomposition panic boundary",
1574            })?
1575            .map_err(FaerLinalgError::SelfAdjointEigen)?;
1576            let values = diag_to_array(eigen.S());
1577            let vectors = mat_to_array(eigen.U());
1578            Ok((values, vectors))
1579        }
1580
1581        let owned = self.to_owned();
1582        if owned.nrows() != owned.ncols() {
1583            return Err(FaerLinalgError::FactorizationFailed {
1584                context: "self-adjoint eigendecomposition non-square input",
1585            });
1586        }
1587        if owned.nrows() == 0 {
1588            return Ok((Array1::zeros(0), Array2::zeros((0, 0))));
1589        }
1590        if owned.iter().any(|value| !value.is_finite()) {
1591            return Err(FaerLinalgError::SelfAdjointEigenNonFiniteInput {
1592                context: "self-adjoint eigendecomposition input validation",
1593            });
1594        }
1595        if let Ok((evals, evecs)) = try_eigh(&owned, side)
1596            && evals.iter().all(|value| value.is_finite())
1597            && evecs.iter().all(|value| value.is_finite())
1598        {
1599            return Ok((evals, evecs));
1600        }
1601
1602        let mut repaired = owned.clone();
1603        crate::matrix::symmetrize_in_place(&mut repaired);
1604
1605        let scale = repaired
1606            .iter()
1607            .fold(0.0_f64, |acc, &value| acc.max(value.abs()))
1608            .max(1.0);
1609        let scaled = repaired.mapv(|value| value / scale);
1610        // Relative diagonal-jitter ladder for the eigendecomposition repair: the
1611        // matrix is pre-scaled to unit max-abs, so these are fractions of its
1612        // scale. We try the unperturbed matrix first, then escalate the ridge by
1613        // two decades per attempt until the factorization yields all-finite
1614        // eigenpairs, accepting the smallest jitter that succeeds.
1615        const JITTER_SCHEDULE: [f64; 6] = [0.0, 1e-12, 1e-10, 1e-8, 1e-6, 1e-4];
1616        let jitter_schedule = JITTER_SCHEDULE;
1617        let mut last_error = FaerLinalgError::FactorizationFailed {
1618            context: "self-adjoint eigendecomposition repair attempts",
1619        };
1620
1621        for &jitter in &jitter_schedule {
1622            let mut candidate = scaled.clone();
1623            if jitter > 0.0 {
1624                let n = candidate.nrows();
1625                for i in 0..n {
1626                    candidate[[i, i]] += jitter;
1627                }
1628            }
1629
1630            match try_eigh(&candidate, side) {
1631                Ok((mut evals, evecs))
1632                    if evals.iter().all(|value| value.is_finite())
1633                        && evecs.iter().all(|value| value.is_finite()) =>
1634                {
1635                    for value in &mut evals {
1636                        *value = (*value - jitter) * scale;
1637                    }
1638                    return Ok((evals, evecs));
1639                }
1640                Ok((_, _)) => {
1641                    last_error = FaerLinalgError::SelfAdjointEigenNonFiniteInput {
1642                        context: "self-adjoint eigendecomposition repaired output validation",
1643                    };
1644                }
1645                Err(err) => {
1646                    last_error = err;
1647                }
1648            }
1649        }
1650
1651        Err(last_error)
1652    }
1653}
1654
1655pub struct FaerCholeskyFactor {
1656    factor: solvers::Llt<f64>,
1657}
1658
1659impl FaerCholeskyFactor {
1660    pub fn solvevec(&self, rhs: &Array1<f64>) -> Array1<f64> {
1661        let mut rhs = rhs.to_owned();
1662        let mut rhsview = array1_to_col_matmut(&mut rhs);
1663        self.factor.solve_in_place(rhsview.as_mut());
1664        rhs
1665    }
1666
1667    pub fn solve_mat_in_place(&self, rhs: &mut Array2<f64>) {
1668        let mut rhsview = array2_to_matmut(rhs);
1669        self.factor.solve_in_place(rhsview.as_mut());
1670    }
1671
1672    pub fn solve_mat_into<S: Data<Elem = f64>>(
1673        &self,
1674        rhs: &ArrayBase<S, Ix2>,
1675        out: &mut Array2<f64>,
1676    ) {
1677        if out.dim() != rhs.dim() {
1678            *out = Array2::<f64>::zeros(rhs.dim());
1679        }
1680        out.assign(rhs);
1681        self.solve_mat_in_place(out);
1682    }
1683
1684    pub fn solve_mat(&self, rhs: &Array2<f64>) -> Array2<f64> {
1685        let mut out = Array2::<f64>::zeros(rhs.dim());
1686        self.solve_mat_into(rhs, &mut out);
1687        out
1688    }
1689
1690    pub fn diag(&self) -> Array1<f64> {
1691        diag_to_array(self.factor.L().diagonal())
1692    }
1693
1694    pub fn lower_triangular(&self) -> Array2<f64> {
1695        mat_to_array(self.factor.L())
1696    }
1697}
1698
1699pub trait FaerCholesky {
1700    fn cholesky(&self, side: Side) -> Result<FaerCholeskyFactor, FaerLinalgError>;
1701}
1702
1703impl<S: Data<Elem = f64>> FaerCholesky for ArrayBase<S, Ix2> {
1704    fn cholesky(&self, side: Side) -> Result<FaerCholeskyFactor, FaerLinalgError> {
1705        let faerview = FaerArrayView::new(self);
1706        let factor = faerview
1707            .as_ref()
1708            .llt(side)
1709            .map_err(FaerLinalgError::Cholesky)?;
1710        Ok(FaerCholeskyFactor { factor })
1711    }
1712}
1713
1714pub trait FaerQr {
1715    fn qr(&self) -> Result<(Array2<f64>, Array2<f64>), FaerLinalgError>;
1716}
1717
1718impl<S: Data<Elem = f64>> FaerQr for ArrayBase<S, Ix2> {
1719    fn qr(&self) -> Result<(Array2<f64>, Array2<f64>), FaerLinalgError> {
1720        let faerview = FaerArrayView::new(self);
1721        let qr = faerview.as_ref().qr();
1722        let q = qr.compute_thin_Q();
1723        let r = qr.thin_R();
1724        Ok((mat_to_array(q.as_ref()), mat_to_array(r)))
1725    }
1726}
1727
1728/// Compute an orthonormal basis for `null(a^T)` using column-pivoted QR on `a`.
1729///
1730/// This is intended for tall/skinny matrices where `a ∈ R^{m×n}` with `m >= n`.
1731/// If `A P^T = Q R`, then the trailing `m-rank(A)` columns of `Q` span
1732/// `null(A^T)`.
1733///
1734/// The trailing columns of `Q` are reconstructed by applying the stored
1735/// Householder reflector sequence to canonical basis vectors. When `A` is
1736/// numerically rank zero (e.g. an entirely unpenalized block penalty in a
1737/// parametric-only GLM), *every* reflector is degenerate — the Householder
1738/// vector of a zero column has zero norm, so faer's coefficients become
1739/// non-finite and the reconstructed basis is filled with `NaN`. Mathematically
1740/// a rank-zero `m×n` matrix has `null(A^T) = R^m`, whose canonical orthonormal
1741/// basis is the identity, so we return `I_m` directly instead of routing through
1742/// the (undefined) reflectors. This keeps every downstream consumer — REML
1743/// null-space log-determinants, identifiability audits — finite and exact for
1744/// the fully-unpenalized case. For `rank >= 1` at least one well-defined
1745/// reflector seeds the block, and the reconstruction stays finite.
1746pub fn rrqr_nullspace_basis<S: Data<Elem = f64>>(
1747    a: &ArrayBase<S, Ix2>,
1748    rank_alpha: f64,
1749) -> Result<(Array2<f64>, usize), FaerLinalgError> {
1750    let faerview = FaerArrayView::new(a);
1751    let qr = faerview.as_ref().col_piv_qr();
1752    let r = qr.thin_R();
1753    let diag_len = r.nrows().min(r.ncols());
1754    let leading_diag = if diag_len > 0 { r[(0, 0)].abs() } else { 0.0 };
1755    let tol = rank_alpha
1756        * f64::EPSILON
1757        * (a.nrows().max(a.ncols()).max(1) as f64)
1758        * leading_diag.max(1.0);
1759    let rank = (0..diag_len).filter(|&i| r[(i, i)].abs() > tol).count();
1760    let z = if rank >= a.nrows() {
1761        Array2::<f64>::zeros((a.nrows(), 0))
1762    } else if rank == 0 {
1763        // Numerically rank-zero input: the whole space is the null space.
1764        // Return the canonical orthonormal basis directly; the Householder
1765        // reflectors of a zero matrix are degenerate and would yield NaN.
1766        Array2::<f64>::eye(a.nrows())
1767    } else {
1768        let nullity = a.nrows() - rank;
1769        let mut selector = Mat::<f64>::zeros(a.nrows(), nullity);
1770        for j in 0..nullity {
1771            selector[(rank + j, j)] = 1.0;
1772        }
1773        let par = get_global_parallelism();
1774        faer::linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
1775            qr.Q_basis(),
1776            qr.Q_coeff(),
1777            Conj::No,
1778            selector.as_mut(),
1779            par,
1780            MemStack::new(&mut MemBuffer::new(
1781                faer::linalg::householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<f64>(
1782                    a.nrows(),
1783                    qr.Q_coeff().nrows(),
1784                    nullity,
1785                ),
1786            )),
1787        );
1788        mat_to_array(selector.as_ref())
1789    };
1790    Ok((z, rank))
1791}
1792
1793#[inline]
1794pub const fn default_rrqr_rank_alpha() -> f64 {
1795    RRQR_RANK_ALPHA
1796}
1797
1798/// Result of a column-pivoted QR with rank detection and column permutation.
1799///
1800/// `A · P = Q · R` where the permutation `P` is exposed as the forward index
1801/// array: column `j` of `A · P` corresponds to original column
1802/// `column_permutation[j]` of `A`. With rank `r < min(m, n)`, the trailing
1803/// `min(m, n) - r` entries of `column_permutation` name the columns that the
1804/// pivoted QR demoted past the rank threshold — i.e., the columns identified
1805/// as redundant. Identifiability auditors (`identifiability::audit`)
1806/// use that suffix to attribute `DroppedColumn` entries to specific original
1807/// columns.
1808pub struct RrqrWithPermutation {
1809    pub rank: usize,
1810    pub column_permutation: Vec<usize>,
1811    pub leading_diag_abs: f64,
1812    pub rank_tol: f64,
1813}
1814
1815/// Column-pivoted rank-revealing QR returning the rank, the column permutation,
1816/// and the rank-detection tolerance. Use this when callers need to name which
1817/// columns the pivoted QR demoted past the rank threshold.
1818///
1819/// The rank cutoff matches [`rrqr_nullspace_basis`]: a column-pivoted QR is
1820/// computed on `a`; columns with `|R[i, i]| > tol` count toward the rank,
1821/// where `tol = rank_alpha · eps · max(m, n, 1) · max(|R[0, 0]|, 1)`. Returns
1822/// `Err` when `a` has zero rows.
1823pub fn rrqr_with_permutation<S: Data<Elem = f64>>(
1824    a: &ArrayBase<S, Ix2>,
1825    rank_alpha: f64,
1826) -> Result<RrqrWithPermutation, FaerLinalgError> {
1827    if a.nrows() == 0 {
1828        return Err(FaerLinalgError::FactorizationFailed {
1829            context: "rrqr_with_permutation: input has zero rows",
1830        });
1831    }
1832    let faerview = FaerArrayView::new(a);
1833    let qr = faerview.as_ref().col_piv_qr();
1834    let r = qr.thin_R();
1835    let diag_len = r.nrows().min(r.ncols());
1836    let leading_diag = if diag_len > 0 { r[(0, 0)].abs() } else { 0.0 };
1837    let tol = rank_alpha
1838        * f64::EPSILON
1839        * (a.nrows().max(a.ncols()).max(1) as f64)
1840        * leading_diag.max(1.0);
1841    let rank = (0..diag_len).filter(|&i| r[(i, i)].abs() > tol).count();
1842    let (forward, _inverse) = qr.P().arrays();
1843    let column_permutation: Vec<usize> = forward.iter().copied().map(|idx| idx.unbound()).collect();
1844    Ok(RrqrWithPermutation {
1845        rank,
1846        column_permutation,
1847        leading_diag_abs: leading_diag,
1848        rank_tol: tol,
1849    })
1850}
1851
1852/// Result of a Gram-driven column-pivoted RRQR (see
1853/// [`rrqr_from_gram_with_permutation`]). Carries the same rank / permutation /
1854/// tolerance as [`RrqrWithPermutation`], plus a `verdict_margin` that measures
1855/// how unambiguous the rank cut is — the ratio between the smallest *kept*
1856/// pivot and the rank tolerance. A large margin means squaring the design into
1857/// a Gram could not have flipped any rank decision; a small margin means the
1858/// verdict sits near the cliff and the caller should re-confirm on the full
1859/// (un-squared) design to stay bit-exact.
1860pub struct RrqrFromGram {
1861    pub rank: usize,
1862    pub column_permutation: Vec<usize>,
1863    pub rank_tol: f64,
1864    /// Leading pivot magnitude `|R[0,0]|` of the square-root factor — equal to
1865    /// the largest column norm of the original tall design (col-piv QR pivots the
1866    /// largest-norm column first), so it matches the tall path's
1867    /// `RrqrWithPermutation::leading_diag_abs`.
1868    pub leading_diag_abs: f64,
1869    /// `min_kept_pivot / rank_tol` (∞ when full rank with no kept pivot below
1870    /// tol, i.e. every pivot is comfortably above; `0` when rank is 0).
1871    pub verdict_margin: f64,
1872}
1873
1874/// Column-pivoted rank-revealing QR computed from the design's `p × p` Gram
1875/// `G = AᵀA` (or penalty-augmented `AᵀA + SᵀS`) instead of from the tall
1876/// `m × p` design itself.
1877///
1878/// # Why this is exact (in exact arithmetic)
1879///
1880/// Column-pivoted QR selects, at each step, the not-yet-pivoted column with the
1881/// largest residual norm, where the residual is the part orthogonal to the
1882/// already-chosen columns. Those residual norms — and the resulting pivot
1883/// sequence, the diagonal magnitudes `|R[i,i]|`, and hence the rank cut — are a
1884/// function of the column *inner products* only, i.e. of the Gram `G`. Running
1885/// col-piv QR on the Cholesky factor `R₀` of `G` (`R₀ᵀR₀ = G`, `R₀` is `p × p`)
1886/// reproduces the identical pivot order and identical `|R[i,i]|` as col-piv QR
1887/// on the original `m × p` matrix, because both see the same column geometry.
1888/// This is the standard "pivoted QR depends only on the Gram" identity and lets
1889/// the joint identifiability rank verdict run in `O(p³)` instead of streaming
1890/// all `m ≈ 2·10⁵` rows again.
1891///
1892/// # Tolerance
1893///
1894/// The rank cutoff must match what the tall-matrix [`rrqr_with_permutation`]
1895/// would have used, so the caller passes `m_rows` (the row count of the
1896/// original tall design, including any appended penalty rows). The tolerance is
1897/// `rank_alpha · eps · max(m_rows, p) · max(|R[0,0]|, 1)` — bit-identical to the
1898/// tall path, since `|R[0,0]|` (the leading pivot magnitude = largest column
1899/// norm) is the same in both factorizations.
1900///
1901/// # Finite-precision guard
1902///
1903/// Forming `G = AᵀA` squares the condition number, so a rank decision that sits
1904/// right at the tolerance cliff could in principle flip. The returned
1905/// `verdict_margin` lets the caller detect that case and fall back to the exact
1906/// tall RRQR; in the overwhelmingly common well-separated case (full column
1907/// rank, smallest pivot orders of magnitude above tol) the margin is huge and
1908/// no fallback is needed.
1909pub fn rrqr_from_gram_with_permutation<S: Data<Elem = f64>>(
1910    gram: &ArrayBase<S, Ix2>,
1911    m_rows: usize,
1912    rank_alpha: f64,
1913) -> Result<RrqrFromGram, FaerLinalgError> {
1914    let p = gram.ncols();
1915    if p == 0 {
1916        return Ok(RrqrFromGram {
1917            rank: 0,
1918            column_permutation: Vec::new(),
1919            rank_tol: 0.0,
1920            leading_diag_abs: 0.0,
1921            verdict_margin: 0.0,
1922        });
1923    }
1924    if gram.nrows() != p {
1925        return Err(FaerLinalgError::FactorizationFailed {
1926            context: "rrqr_from_gram_with_permutation: Gram is not square",
1927        });
1928    }
1929    // Symmetric square-root factor F (p×p) with FᵀF = G. The Gram is PSD by
1930    // construction (AᵀA), so its eigendecomposition G = V·diag(λ)·Vᵀ gives the
1931    // factor F = diag(√λ₊)·Vᵀ (rows indexed by eigenpair, columns by original
1932    // design column). Any factor with FᵀF = G reproduces the same column
1933    // geometry, which is all col-piv QR consumes — we use the eigen square root
1934    // rather than a bare Cholesky because Cholesky fails on the numerically
1935    // semidefinite Gram that is exactly the rank-deficient case we must classify.
1936    // Tiny-negative eigenvalues from finite precision are clamped to zero.
1937    let (evals, evecs) = gram.eigh(Side::Lower)?;
1938    let mut f = Array2::<f64>::zeros((p, p));
1939    for k in 0..p {
1940        let scale = evals[k].max(0.0).sqrt();
1941        if scale == 0.0 {
1942            continue;
1943        }
1944        for i in 0..p {
1945            f[[k, i]] = scale * evecs[[i, k]];
1946        }
1947    }
1948    // Single col-piv QR on F. Its pivot order, per-pivot |R[i,i]| magnitudes,
1949    // and leading pivot equal those of col-piv QR on the original tall design
1950    // (FᵀF = G), so this reproduces the exact tall-path geometry.
1951    let faer_f = FaerArrayView::new(&f);
1952    let qr = faer_f.as_ref().col_piv_qr();
1953    let r = qr.thin_R();
1954    let diag_len = r.nrows().min(r.ncols());
1955    let pivots: Vec<f64> = (0..diag_len).map(|i| r[(i, i)].abs()).collect();
1956    let leading_diag = pivots.first().copied().unwrap_or(0.0);
1957    let (forward, _inverse) = qr.P().arrays();
1958    let column_permutation: Vec<usize> = forward.iter().copied().map(|idx| idx.unbound()).collect();
1959    // Re-scale the tolerance from F's `max(p, p)=p` row dimension to the
1960    // original tall design's `max(m_rows, p)`, keeping the rank cut bit-
1961    // identical to what the tall [`rrqr_with_permutation`] would have produced.
1962    let tol = rank_alpha * f64::EPSILON * (m_rows.max(p).max(1) as f64) * leading_diag.max(1.0);
1963    let rank = pivots.iter().filter(|&&v| v > tol).count();
1964    let min_kept = pivots[..rank].iter().copied().fold(f64::INFINITY, f64::min);
1965    let max_dropped = pivots[rank..].iter().copied().fold(0.0f64, f64::max);
1966    // Margin: how far the verdict is from the cliff. Use the smaller of
1967    // (min_kept / tol) and (tol / max_dropped) so a near-tol dropped pivot also
1968    // shrinks the margin. A margin ≫ 1 means no rank decision could flip.
1969    let kept_margin = if rank == 0 {
1970        f64::INFINITY
1971    } else {
1972        min_kept / tol
1973    };
1974    let dropped_margin = if rank == diag_len {
1975        f64::INFINITY
1976    } else {
1977        tol / max_dropped.max(f64::MIN_POSITIVE)
1978    };
1979    // Gram-squaring precision floor. Forming `G = XᵀX` collapses the bottom half
1980    // of the spectrum: a true singular value below `√ε · σ_max` is lost in the
1981    // rounding of `G` (its squared value `σ² < ε·σ_max²` underflows the Gram's
1982    // representable range), and the eigen-square-root then RESURRECTS it as a
1983    // SPURIOUS pivot of magnitude `≈ √(ε·σ_max²) = √ε · σ_max` — orders of
1984    // magnitude ABOVE the true σ and above `tol`. That artefact makes col-piv QR
1985    // on `F` KEEP a column the tall (un-squared) QR would demote: an EXACTLY
1986    // collinear alias (true σ = 0, so `σ² = 0` floored at `≈ ε·σ_max²`) shows up
1987    // as a kept pivot near `√ε · leading`, over-ranking the design and dropping
1988    // nothing (gam#933: a callback-owned column aliased with a higher-priority
1989    // anchor was never demoted, so the reduction never ran and the MAP-uniqueness
1990    // check then fired on the raw collinear joint design). `min_kept / tol` does
1991    // NOT catch this — the spurious pivot sits comfortably above `tol`, so the
1992    // existing margin reports a falsely-confident verdict. The honest test is
1993    // whether the smallest KEPT pivot is itself near the Gram precision floor
1994    // `√ε · leading`: if so, the Gram path cannot distinguish it from a true zero
1995    // and the verdict MUST be re-confirmed on the full-precision tall design.
1996    // Encode that as a third margin term `min_kept / (√ε · leading)` so a kept
1997    // pivot in the floor regime shrinks `verdict_margin` below the caller's
1998    // fallback threshold; for a genuinely full-rank design every kept pivot is
1999    // `≫ √ε · leading` and this term is large, leaving the fast path intact.
2000    let gram_precision_floor = f64::EPSILON.sqrt() * leading_diag.max(1.0);
2001    let kept_floor_margin = if rank == 0 {
2002        f64::INFINITY
2003    } else {
2004        min_kept / gram_precision_floor.max(f64::MIN_POSITIVE)
2005    };
2006    let verdict_margin = kept_margin.min(dropped_margin).min(kept_floor_margin);
2007    Ok(RrqrFromGram {
2008        rank,
2009        column_permutation,
2010        rank_tol: tol,
2011        leading_diag_abs: leading_diag,
2012        verdict_margin,
2013    })
2014}
2015
2016#[cfg(test)]
2017mod tests {
2018    use super::*;
2019    use ndarray::{array, s};
2020
2021    /// Local mirror of the audit's `JOINT_GRAM_RRQR_MIN_VERDICT_MARGIN` fallback
2022    /// threshold, used only by the regression tests below to assert the verdict
2023    /// margin lands on the correct side of the cliff. Kept in sync by value (1e3).
2024    const JOINT_GRAM_RRQR_TRUST_MARGIN_FOR_TEST: f64 = 1.0e3;
2025
2026    #[test]
2027    fn rrqr_nullspace_basis_is_orthonormal_and_annihilates_transpose() {
2028        let a = array![[1.0, 0.0], [1.0, 0.0], [0.0, 2.0], [0.0, 0.0],];
2029        let (z, rank) =
2030            rrqr_nullspace_basis(&a, default_rrqr_rank_alpha()).expect("RRQR should succeed");
2031        assert_eq!(rank, 2);
2032        assert_eq!(z.nrows(), 4);
2033        assert_eq!(z.ncols(), 2);
2034
2035        let gram = z.t().dot(&z);
2036        let ident = Array2::<f64>::eye(z.ncols());
2037        let gram_err = (&gram - &ident)
2038            .iter()
2039            .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2040        assert!(gram_err < 1e-10, "Z is not orthonormal: {gram_err:e}");
2041
2042        let residual = a.t().dot(&z);
2043        let resid_max = residual.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2044        assert!(resid_max < 1e-10, "A^T Z residual too large: {resid_max:e}");
2045    }
2046
2047    #[test]
2048    fn rrqr_with_permutation_attributes_redundant_column() {
2049        // 3 columns, column 2 is a duplicate of column 0 → rank 2, column 2
2050        // is the redundant one that the pivoted QR should demote past the
2051        // rank threshold. (Column 1 contributes a different direction.)
2052        let a = array![
2053            [1.0, 0.0, 1.0],
2054            [1.0, 0.0, 1.0],
2055            [0.0, 2.0, 0.0],
2056            [0.0, 0.0, 0.0],
2057        ];
2058        let result =
2059            rrqr_with_permutation(&a, default_rrqr_rank_alpha()).expect("RRQR should succeed");
2060        assert_eq!(result.rank, 2);
2061        assert_eq!(result.column_permutation.len(), 3);
2062        let demoted = result.column_permutation[result.rank..].to_vec();
2063        assert!(
2064            demoted.contains(&2) || demoted.contains(&0),
2065            "demoted suffix should include one of the aliased columns (0 or 2), got {demoted:?}"
2066        );
2067        let mut sorted = result.column_permutation.clone();
2068        sorted.sort();
2069        assert_eq!(
2070            sorted,
2071            vec![0, 1, 2],
2072            "permutation must be a valid bijection on 0..n"
2073        );
2074    }
2075
2076    #[test]
2077    fn rrqr_with_permutation_full_rank_returns_identity_like_order() {
2078        let a = array![[1.0, 0.0], [0.0, 2.0], [0.0, 0.0]];
2079        let result =
2080            rrqr_with_permutation(&a, default_rrqr_rank_alpha()).expect("RRQR should succeed");
2081        assert_eq!(result.rank, 2);
2082        let mut sorted = result.column_permutation.clone();
2083        sorted.sort();
2084        assert_eq!(sorted, vec![0, 1]);
2085    }
2086
2087    #[test]
2088    fn rrqr_with_permutation_rejects_zero_rows() {
2089        let a = Array2::<f64>::zeros((0, 3));
2090        assert!(rrqr_with_permutation(&a, default_rrqr_rank_alpha()).is_err());
2091    }
2092
2093    #[test]
2094    fn rrqr_nullspace_basis_square_zero_matrix_is_finite_identity() {
2095        // Square zero matrix (the parametric-only penalty case): null(A^T) is
2096        // the whole space, so the basis must be a finite orthonormal 3x3 set.
2097        let a = Array2::<f64>::zeros((3, 3));
2098        let (z, rank) =
2099            rrqr_nullspace_basis(&a, default_rrqr_rank_alpha()).expect("RRQR should succeed");
2100        assert_eq!(rank, 0);
2101        assert_eq!(z.dim(), (3, 3));
2102        assert!(
2103            z.iter().all(|v| v.is_finite()),
2104            "square zero matrix produced a non-finite null basis: {z:?}"
2105        );
2106        let gram = z.t().dot(&z);
2107        let ident = Array2::<f64>::eye(3);
2108        let gram_err = (&gram - &ident)
2109            .iter()
2110            .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2111        assert!(gram_err < 1e-10, "Z is not orthonormal: {gram_err:e}");
2112    }
2113
2114    #[test]
2115    fn rrqr_nullspace_basis_detectszero_rank_matrix() {
2116        let a = Array2::<f64>::zeros((5, 2));
2117        let (z, rank) =
2118            rrqr_nullspace_basis(&a, default_rrqr_rank_alpha()).expect("RRQR should succeed");
2119        assert_eq!(rank, 0);
2120        assert_eq!(z.dim(), (5, 5));
2121        let ident = Array2::<f64>::eye(5);
2122        let max_err = (&z.slice(s![.., ..5]).to_owned() - &ident)
2123            .iter()
2124            .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2125        assert!(max_err < 1e-10, "zero matrix should yield identity basis");
2126    }
2127
2128    //
2129    // Eigendecomposition NoConvergence on pathological matrices
2130    //
2131    // These tests lock down the hardened contract for FaerEigh::eigh:
2132    // non-finite input must be rejected explicitly, while finite symmetric
2133    // matrices still produce finite spectra.
2134    //
2135
2136    #[test]
2137    fn eigh_on_nan_matrix_rejects_non_finite_input() {
2138        let mat = array![
2139            [1.0, 0.0, 0.0, 0.0],
2140            [0.0, 2.0, 0.0, 0.0],
2141            [0.0, 0.0, 3.0, f64::NAN],
2142            [0.0, 0.0, f64::NAN, 4.0]
2143        ];
2144        let err = mat
2145            .eigh(Side::Lower)
2146            .expect_err("non-finite symmetric input must be rejected");
2147        assert!(matches!(
2148            err,
2149            FaerLinalgError::SelfAdjointEigenNonFiniteInput { .. }
2150        ));
2151    }
2152
2153    #[test]
2154    fn fast_ata_matches_full_gemm_above_threshold() {
2155        // Pick (n, p) large enough to trigger the faer triangular path
2156        // (should_use_faer_matmul threshold is MIN_DIM=32, MIN_FLOP_SCALE=64*64).
2157        let n = 200;
2158        let p = 40;
2159        let a: Array2<f64> = Array2::from_shape_fn((n, p), |(i, j)| {
2160            ((i * 7 + j * 3) as f64).sin() + 0.1 * j as f64
2161        });
2162        let expected = a.t().dot(&a);
2163        let got = fast_ata(&a);
2164        let max_err = (&got - &expected)
2165            .iter()
2166            .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2167        assert!(max_err < 1e-10, "fast_ata mismatch: {max_err:e}");
2168        // Output must be fully populated and symmetric.
2169        for i in 0..p {
2170            for j in 0..p {
2171                assert!((got[[i, j]] - got[[j, i]]).abs() < 1e-12);
2172            }
2173        }
2174    }
2175
2176    #[test]
2177    fn fast_xt_diag_x_matches_naive_above_threshold() {
2178        let n = 400;
2179        let p = 36;
2180        let x: Array2<f64> =
2181            Array2::from_shape_fn((n, p), |(i, j)| (i as f64 * 0.1).cos() + j as f64 * 0.05);
2182        let w: Array1<f64> = Array1::from_shape_fn(n, |i| (i as f64 * 0.03).sin());
2183        // Naive reference: X^T diag(w) X.
2184        let wx = Array2::from_shape_fn((n, p), |(i, j)| w[i] * x[[i, j]]);
2185        let expected = x.t().dot(&wx);
2186        let got = fast_xt_diag_x(&x, &w);
2187        let max_err = (&got - &expected)
2188            .iter()
2189            .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2190        assert!(max_err < 1e-9, "fast_xt_diag_x mismatch: {max_err:e}");
2191        for i in 0..p {
2192            for j in 0..p {
2193                assert!((got[[i, j]] - got[[j, i]]).abs() < 1e-12);
2194            }
2195        }
2196    }
2197
2198    #[test]
2199    fn stream_weighted_crossprod_full_and_triangular_parity_with_negative_weights() {
2200        // The stream-in and matrix-returning `fast_xt_diag_x*` packaging modes
2201        // share one kernel. Both packaging modes — and both accumulation
2202        // modes — must reproduce the naive `Xᵀ·diag(w)·X` reference, including signed
2203        // (negative) weights, which the pre-unification sqrt-clip form
2204        // silently corrupted.
2205        //
2206        // Exercise both the streaming faer path (n large enough to clear
2207        // `should_use_faer_matmul`) and the tiny ndarray fallback (small n,p).
2208        for &(n, p) in &[(900usize, 40usize), (8usize, 3usize)] {
2209            let x: Array2<f64> =
2210                Array2::from_shape_fn((n, p), |(i, j)| (i as f64 * 0.07).cos() + j as f64 * 0.013);
2211            // Weights span both signs and zero so negative-weight handling and
2212            // sign preservation are genuinely tested.
2213            let w: Array1<f64> =
2214                Array1::from_shape_fn(n, |i| (i as f64 * 0.11).sin() - 0.25 * (i % 3) as f64);
2215            assert!(
2216                w.iter().any(|&v| v < 0.0),
2217                "weight vector must contain negatives to test sign preservation"
2218            );
2219
2220            // Naive reference: Xᵀ diag(w) X with signed weights.
2221            let wx = Array2::from_shape_fn((n, p), |(i, j)| w[i] * x[[i, j]]);
2222            let expected = x.t().dot(&wx);
2223
2224            let par = matmul_parallelism(p, p, n);
2225
2226            // Full output, Replace.
2227            let mut full = Array2::<f64>::ones((p, p));
2228            stream_weighted_crossprod_into(
2229                &x,
2230                &w,
2231                &mut full,
2232                CrossprodStructure::Full,
2233                CrossprodAccum::Replace,
2234                par,
2235            );
2236
2237            // Triangular+mirror output, Replace. Seed with garbage to prove
2238            // Replace clears prior contents (incl. the upper triangle, which
2239            // the triangular path only reaches via the mirror).
2240            let mut tri = Array2::<f64>::from_elem((p, p), -7.0);
2241            stream_weighted_crossprod_into(
2242                &x,
2243                &w,
2244                &mut tri,
2245                CrossprodStructure::SymmetricLower,
2246                CrossprodAccum::Replace,
2247                par,
2248            );
2249
2250            let full_err = (&full - &expected)
2251                .iter()
2252                .fold(0.0_f64, |a, &v| a.max(v.abs()));
2253            let tri_err = (&tri - &expected)
2254                .iter()
2255                .fold(0.0_f64, |a, &v| a.max(v.abs()));
2256            assert!(
2257                full_err < 1e-9,
2258                "full kernel mismatch (n={n}, p={p}): {full_err:e}"
2259            );
2260            assert!(
2261                tri_err < 1e-9,
2262                "triangular kernel mismatch (n={n}, p={p}): {tri_err:e}"
2263            );
2264
2265            // Full and triangular packaging must agree elementwise, and both
2266            // must be exactly symmetric.
2267            for i in 0..p {
2268                for j in 0..p {
2269                    assert!(
2270                        (full[[i, j]] - tri[[i, j]]).abs() < 1e-12,
2271                        "full vs triangular disagree at ({i},{j})"
2272                    );
2273                    assert!(
2274                        (tri[[i, j]] - tri[[j, i]]).abs() < 1e-12,
2275                        "triangular output not symmetric at ({i},{j})"
2276                    );
2277                }
2278            }
2279
2280            // Accumulation parity: Add into a pre-filled buffer must equal the
2281            // prior contents plus the Gram, for both structures.
2282            let base = Array2::<f64>::from_elem((p, p), 1.5);
2283            let mut add_full = base.clone();
2284            stream_weighted_crossprod_into(
2285                &x,
2286                &w,
2287                &mut add_full,
2288                CrossprodStructure::Full,
2289                CrossprodAccum::Add,
2290                par,
2291            );
2292            let mut add_tri = base.clone();
2293            stream_weighted_crossprod_into(
2294                &x,
2295                &w,
2296                &mut add_tri,
2297                CrossprodStructure::SymmetricLower,
2298                CrossprodAccum::Add,
2299                par,
2300            );
2301            let expected_add = &base + &expected;
2302            let add_full_err = (&add_full - &expected_add)
2303                .iter()
2304                .fold(0.0_f64, |a, &v| a.max(v.abs()));
2305            let add_tri_err = (&add_tri - &expected_add)
2306                .iter()
2307                .fold(0.0_f64, |a, &v| a.max(v.abs()));
2308            assert!(
2309                add_full_err < 1e-9,
2310                "full Add mismatch (n={n}, p={p}): {add_full_err:e}"
2311            );
2312            assert!(
2313                add_tri_err < 1e-9,
2314                "triangular Add mismatch (n={n}, p={p}): {add_tri_err:e}"
2315            );
2316
2317            // The matrix.rs adapter (Full + Replace into a zeroed buffer) must
2318            // match the faer_ndarray return-style adapter bit-for-functionally.
2319            let returned = fast_xt_diag_x(&x, &w);
2320            let returned_err = (&returned - &full)
2321                .iter()
2322                .fold(0.0_f64, |a, &v| a.max(v.abs()));
2323            assert!(
2324                returned_err < 1e-12,
2325                "return adapter vs stream-into adapter disagree (n={n}, p={p}): {returned_err:e}"
2326            );
2327        }
2328    }
2329
2330    #[test]
2331    fn eigh_succeeds_on_same_structure_without_nan() {
2332        // Control: the same matrix with finite values produces finite eigenvalues.
2333        let mat = array![[1.0, 0.5, 0.1], [0.5, 2.0, 0.3], [0.1, 0.3, 1.5]];
2334        let (evals, _) = mat
2335            .eigh(Side::Lower)
2336            .expect("eigh should succeed on a well-conditioned finite matrix");
2337        assert!(
2338            evals.iter().all(|&v| v.is_finite()),
2339            "all eigenvalues should be finite"
2340        );
2341    }
2342
2343    /// gam#933 regression: the Gram-squared RRQR must NOT silently over-rank an
2344    /// EXACTLY collinear design. Forming `G = XᵀX` squares the spectrum, so the
2345    /// zero singular value of an exact alias underflows to `≈ ε·σ_max²` in `G` and
2346    /// the eigen-square-root resurrects it as a SPURIOUS pivot `≈ √ε·σ_max` that
2347    /// sits above `tol` — col-piv QR on the Gram factor would KEEP it and report
2348    /// full rank. The precision-floor margin term must catch this: the smallest
2349    /// kept pivot is near `√ε·leading`, so `verdict_margin` collapses below the
2350    /// caller's fallback threshold, forcing the full-precision tall path (which
2351    /// sees the true zero singular value and demotes the column).
2352    #[test]
2353    fn gram_rrqr_flags_low_margin_on_exact_collinearity_so_caller_falls_back() {
2354        // Joint design [1, x | x, x²] with x ∈ [-1, 1]: columns 1 and 2 are an
2355        // EXACT duplicate (the #933 callback-owned alias), so the true rank is 3.
2356        let n = 48usize;
2357        let x: Vec<f64> = (0..n)
2358            .map(|i| -1.0 + 2.0 * (i as f64) / (n as f64 - 1.0))
2359            .collect();
2360        let mut a = Array2::<f64>::zeros((n, 4));
2361        for i in 0..n {
2362            a[[i, 0]] = 1.0;
2363            a[[i, 1]] = x[i];
2364            a[[i, 2]] = x[i];
2365            a[[i, 3]] = x[i] * x[i];
2366        }
2367        let alpha = default_rrqr_rank_alpha();
2368
2369        // The tall (un-squared) RRQR is the full-precision reference: it must see
2370        // rank 3 and demote one of the duplicate x columns.
2371        let tall = rrqr_with_permutation(&a, alpha).expect("tall RRQR should succeed");
2372        assert_eq!(tall.rank, 3, "tall RRQR must demote the exact alias");
2373
2374        // The Gram-squared RRQR must report a SMALL verdict_margin here so the
2375        // caller re-confirms on the tall design instead of trusting a possibly
2376        // over-ranked Gram verdict. (We do not assert the Gram rank itself —
2377        // squaring may report 3 or 4 — only that the margin signals the cliff.)
2378        let unit = Array1::<f64>::ones(n);
2379        let gram = fast_xt_diag_x_with_parallelism(&a, &unit, faer::get_global_parallelism());
2380        let gram_rrqr =
2381            rrqr_from_gram_with_permutation(&gram, n, alpha).expect("Gram RRQR should succeed");
2382        assert!(
2383            gram_rrqr.verdict_margin < JOINT_GRAM_RRQR_TRUST_MARGIN_FOR_TEST,
2384            "exact-collinearity Gram verdict must report low margin to force tall \
2385             fallback; got margin={:.3e} (rank={})",
2386            gram_rrqr.verdict_margin,
2387            gram_rrqr.rank,
2388        );
2389    }
2390
2391    /// Companion to the regression above: a genuinely full-rank, moderately
2392    /// conditioned design must keep a LARGE Gram verdict margin so the fast Gram
2393    /// path is retained (the precision-floor term must not trip on real, small-
2394    /// but-nonzero singular values).
2395    #[test]
2396    fn gram_rrqr_keeps_high_margin_on_full_rank_design() {
2397        let n = 200usize;
2398        let p = 5usize;
2399        let mut a = Array2::<f64>::zeros((n, p));
2400        // Deterministic, well-separated columns (distinct low-order polynomials).
2401        for i in 0..n {
2402            let t = (i as f64) / (n as f64 - 1.0);
2403            a[[i, 0]] = 1.0;
2404            a[[i, 1]] = t;
2405            a[[i, 2]] = t * t;
2406            a[[i, 3]] = t * t * t;
2407            a[[i, 4]] = (t * 6.0).sin();
2408        }
2409        let alpha = default_rrqr_rank_alpha();
2410        let unit = Array1::<f64>::ones(n);
2411        let gram = fast_xt_diag_x_with_parallelism(&a, &unit, faer::get_global_parallelism());
2412        let gram_rrqr =
2413            rrqr_from_gram_with_permutation(&gram, n, alpha).expect("Gram RRQR should succeed");
2414        assert_eq!(gram_rrqr.rank, p, "full-rank design must keep all columns");
2415        assert!(
2416            gram_rrqr.verdict_margin >= JOINT_GRAM_RRQR_TRUST_MARGIN_FOR_TEST,
2417            "full-rank design must keep a high margin (fast Gram path); got {:.3e}",
2418            gram_rrqr.verdict_margin,
2419        );
2420    }
2421}