Skip to main content

gam_linalg/
faer_ndarray.rs

1use dyn_stack::{MemBuffer, MemStack};
2use faer::diag::{Diag, DiagRef};
3use faer::linalg::solvers::{self, Solve};
4pub use faer::linalg::solvers::{
5    Lblt as FaerLblt, Ldlt as FaerLdlt, Llt as FaerLlt, Solve as FaerSolve,
6};
7use faer::linalg::svd::{self, ComputeSvdVectors};
8use faer::prelude::ReborrowMut;
9use faer::{Conj, Mat, MatMut, MatRef, Par, Side, Unbind, get_global_parallelism};
10use ndarray::{Array1, Array2, ArrayBase, ArrayViewMut1, Data, Ix1, Ix2};
11use std::marker::PhantomData;
12use std::panic::{AssertUnwindSafe, catch_unwind};
13use thiserror::Error;
14
15const RRQR_RANK_ALPHA: f64 = 100.0;
16
17thread_local! {
18    static NESTED_PARALLEL_DEPTH: std::cell::Cell<usize> = const { std::cell::Cell::new(0) };
19}
20
21struct NestedParallelGuard;
22
23impl NestedParallelGuard {
24    #[inline]
25    fn enter() -> Self {
26        NESTED_PARALLEL_DEPTH.with(|depth| depth.set(depth.get().saturating_add(1)));
27        Self
28    }
29}
30
31impl Drop for NestedParallelGuard {
32    #[inline]
33    fn drop(&mut self) {
34        NESTED_PARALLEL_DEPTH.with(|depth| depth.set(depth.get().saturating_sub(1)));
35    }
36}
37
38/// Run `body` with the current thread marked as inside a data-parallel row
39/// region, so any faer GEMM it issues (directly or transitively) pins to
40/// `Par::Seq` via [`effective_global_parallelism`] instead of re-fanning the
41/// global Rayon pool. The guard is held for exactly the duration of `body` and
42/// dropped on return — including early `?` returns from inside `body`, since the
43/// guard lives in this function's frame.
44///
45/// Call this from the per-chunk/per-row closure of an `into_par_iter` whose body
46/// performs GEMM, to prevent the Rayon-pool × faer-pool oversubscription.
47#[inline]
48pub fn with_nested_parallel<T>(body: impl FnOnce() -> T) -> T {
49    let guard = NestedParallelGuard::enter();
50    let out = body();
51    drop(guard);
52    out
53}
54
55/// `true` when the current thread is inside at least one [`NestedParallelGuard`]
56/// scope, i.e. a parallel row reduction is already in flight on this thread.
57#[inline]
58pub fn in_nested_parallel_region() -> bool {
59    NESTED_PARALLEL_DEPTH.with(|depth| depth.get() > 0)
60}
61
62/// faer parallelism policy that respects nested data-parallel regions: returns
63/// faer's global policy at the top level, but `Par::Seq` once a
64/// [`NestedParallelGuard`] is active so a GEMM issued from inside a parallel row
65/// fan-out does not multiply the live thread count against the outer pool.
66///
67/// Use this in place of `faer::get_global_parallelism()` for any matmul that can
68/// be reached from inside a row-parallel closure.
69#[inline]
70pub fn effective_global_parallelism() -> Par {
71    if in_nested_parallel_region() {
72        Par::Seq
73    } else {
74        get_global_parallelism()
75    }
76}
77
78#[derive(Debug, Error)]
79pub enum FaerLinalgError {
80    #[error("Factorization failed in {context}")]
81    FactorizationFailed { context: &'static str },
82    #[error("SVD failed to converge in {context}")]
83    SvdNoConvergence { context: &'static str },
84    #[error("Self-adjoint eigendecomposition input contains non-finite values in {context}")]
85    SelfAdjointEigenNonFiniteInput { context: &'static str },
86    #[error("Self-adjoint eigendecomposition failed: {0:?}")]
87    SelfAdjointEigen(solvers::EvdError),
88    #[error("Cholesky factorization failed: {0:?}")]
89    Cholesky(solvers::LltError),
90    #[error("LDLT factorization failed: {0:?}")]
91    Ldlt(solvers::LdltError),
92}
93
94pub enum FaerSymmetricFactor {
95    Llt(FaerLlt<f64>),
96    Ldlt(FaerLdlt<f64>),
97    Lblt(FaerLblt<f64>),
98}
99
100#[inline]
101pub fn cholesky_factor_logdet(factor: MatRef<'_, f64>) -> f64 {
102    2.0 * diagonal_log_sum(factor.diagonal())
103}
104
105#[inline]
106fn diagonal_log_sum(diagonal: DiagRef<'_, f64>) -> f64 {
107    diagonal
108        .column_vector()
109        .iter()
110        .map(|&x| x.ln())
111        .sum::<f64>()
112}
113
114impl FaerSymmetricFactor {
115    /// Returns the dimension of the factorized square matrix.
116    #[inline]
117    pub fn n(&self) -> usize {
118        use faer::linalg::solvers::ShapeCore;
119        match self {
120            FaerSymmetricFactor::Llt(f) => f.nrows(),
121            FaerSymmetricFactor::Ldlt(f) => f.nrows(),
122            FaerSymmetricFactor::Lblt(f) => f.nrows(),
123        }
124    }
125
126    #[inline]
127    pub fn solve(&self, rhs: MatRef<'_, f64>) -> Mat<f64> {
128        match self {
129            FaerSymmetricFactor::Llt(f) => f.solve(rhs),
130            FaerSymmetricFactor::Ldlt(f) => f.solve(rhs),
131            FaerSymmetricFactor::Lblt(f) => f.solve(rhs),
132        }
133    }
134
135    #[inline]
136    pub fn solve_in_place(&self, rhs: MatMut<'_, f64>) {
137        match self {
138            FaerSymmetricFactor::Llt(f) => f.solve_in_place(rhs),
139            FaerSymmetricFactor::Ldlt(f) => f.solve_in_place(rhs),
140            FaerSymmetricFactor::Lblt(f) => f.solve_in_place(rhs),
141        }
142    }
143}
144
145impl crate::matrix::FactorizedSystem for FaerSymmetricFactor {
146    fn solve(&self, rhs: &Array1<f64>) -> Result<Array1<f64>, String> {
147        let mut out = rhs.clone();
148        let mut out_mat = array1_to_col_matmut(&mut out);
149        self.solve_in_place(out_mat.as_mut());
150        if !out.iter().all(|v| v.is_finite()) {
151            return Err("symmetric factor solve produced non-finite values".to_string());
152        }
153        Ok(out)
154    }
155
156    fn solvemulti(&self, rhs: &Array2<f64>) -> Result<Array2<f64>, String> {
157        let mut out = Array2::<f64>::zeros(rhs.raw_dim());
158        for j in 0..rhs.ncols() {
159            for i in 0..rhs.nrows() {
160                out[[i, j]] = rhs[[i, j]];
161            }
162        }
163        let mut out_mat = array2_to_matmut(&mut out);
164        self.solve_in_place(out_mat.as_mut());
165        if !out.iter().all(|v| v.is_finite()) {
166            return Err("symmetric factor multi-solve produced non-finite values".to_string());
167        }
168        Ok(out)
169    }
170
171    fn logdet(&self) -> f64 {
172        match self {
173            FaerSymmetricFactor::Llt(f) => cholesky_factor_logdet(f.L()),
174            FaerSymmetricFactor::Ldlt(f) => diagonal_log_sum(f.D()),
175            FaerSymmetricFactor::Lblt(..) => {
176                // lblt doesn't easily expose diagonal determinant. Fallback to sparse or other representations if needed, but typically Lblt is indefinite!
177                // Actually faer doesn't easily expose lblt logdet since it has 2x2 blocks.
178                // For our ML systems, if we dropped to LBLT, the matrix was indefinite and logdet is ill-defined (or complex).
179                f64::NAN
180            }
181        }
182    }
183}
184
185/// Factorize a symmetric system with LLT -> LDLT -> LBLT fallback.
186#[inline]
187pub fn factorize_symmetricwith_fallback(
188    matrix: MatRef<'_, f64>,
189    side: Side,
190) -> Result<FaerSymmetricFactor, FaerLinalgError> {
191    if let Ok(llt) = FaerLlt::new(matrix, side) {
192        return Ok(FaerSymmetricFactor::Llt(llt));
193    }
194    let ldlt_err = match FaerLdlt::new(matrix, side) {
195        Ok(ldlt) => return Ok(FaerSymmetricFactor::Ldlt(ldlt)),
196        Err(err) => err,
197    };
198    let lblt = catch_unwind(AssertUnwindSafe(|| FaerLblt::new(matrix, side)))
199        .map_err(|_| FaerLinalgError::Ldlt(ldlt_err))?;
200    Ok(FaerSymmetricFactor::Lblt(lblt))
201}
202
203#[inline]
204const fn should_use_faer_matmul(m: usize, n: usize, k: usize) -> bool {
205    // Small, centralized dispatch policy:
206    // - stay on ndarray for tiny products to avoid setup overhead,
207    // - switch to faer GEMM/GEMV for moderate+ sizes.
208    const MIN_DIM: usize = 32;
209    const MIN_FLOP_SCALE: usize = 64 * 64;
210    (m >= MIN_DIM || n >= MIN_DIM || k >= MIN_DIM)
211        && m.saturating_mul(n).saturating_mul(k) >= MIN_FLOP_SCALE
212}
213
214#[inline]
215pub fn matmul_parallelism(m: usize, n: usize, k: usize) -> Par {
216    // Prefer a work-based policy over per-dimension thresholds.
217    // Tall/skinny products (e.g. N x p with large N, modest p) should still
218    // parallelize when total work is high.
219    const PAR_MIN_FLOP_SCALE: usize = 2_000_000;
220    const PAR_MIN_LONG_DIM: usize = 256;
221    let flop_scale = m.saturating_mul(n).saturating_mul(k);
222    let long_dim = m.max(n).max(k);
223    if flop_scale >= PAR_MIN_FLOP_SCALE && long_dim >= PAR_MIN_LONG_DIM {
224        // `effective_global_parallelism` collapses to `Par::Seq` when this GEMM
225        // is reached from inside a `NestedParallelGuard` row region, preventing
226        // the Rayon-pool × faer-pool multiplicative oversubscription.
227        effective_global_parallelism()
228    } else {
229        Par::Seq
230    }
231}
232
233#[inline]
234pub fn array2_to_matmut(array: &mut Array2<f64>) -> MatMut<'_, f64> {
235    let (rows, cols) = array.dim();
236    let strides = array.strides();
237
238    // Check if we can get a pointer.
239    // If the array is contiguous (either C or F order), or simply sliced with strides,
240    // faer can handle it as long as we pass the pointer and strides.
241    // However, as_mut_ptr() requires a mutable reference.
242    // ndarray's as_ptr/as_mut_ptr works for both layouts.
243
244    let s0 = strides[0];
245    let s1 = strides[1];
246
247    // SAFETY: array.as_mut_ptr() is ndarray's logical (0, 0) pointer, and
248    // ndarray's dimensions plus signed element strides describe every initialized
249    // element of this uniquely borrowed Array2 for the returned MatMut lifetime.
250    unsafe { MatMut::from_raw_parts_mut(array.as_mut_ptr(), rows, cols, s0, s1) }
251}
252
253#[inline]
254pub fn array1_to_col_matmut(array: &mut Array1<f64>) -> MatMut<'_, f64> {
255    let len = array.len();
256    let stride = array.strides()[0];
257    // SAFETY: array.as_mut_ptr() is ndarray's logical first-element pointer, and
258    // len plus the signed element stride describe every initialized element of
259    // this uniquely borrowed Array1 for the returned len×1 MatMut lifetime.
260    unsafe {
261        MatMut::from_raw_parts_mut(
262            array.as_mut_ptr(),
263            len,
264            1,
265            stride,
266            0, // col stride irrelevant for 1 column
267        )
268    }
269}
270
271/// Compute A^T * A using faer's SIMD-optimized GEMM.
272/// This is MUCH faster than ndarray's .t().dot() for matrices where n > ~100.
273///
274/// For a matrix A of shape (n, p), this computes the (p, p) result.
275/// Uses a zero-copy view for positive-stride layouts and copies only layouts
276/// with non-positive strides.
277#[inline]
278pub fn fast_ata<S: Data<Elem = f64>>(a: &ArrayBase<S, Ix2>) -> Array2<f64> {
279    let p = a.ncols();
280    let mut out = Array2::<f64>::zeros((p, p));
281    fast_ata_into(a, &mut out);
282    out
283}
284
285/// Compute A^T * A into a pre-allocated output buffer.
286/// `out` must be shaped (p, p) where A is (n, p).
287#[inline]
288pub fn fast_ata_into<S: Data<Elem = f64>>(a: &ArrayBase<S, Ix2>, out: &mut Array2<f64>) {
289    use faer::Accum;
290    use faer::linalg::matmul::triangular::{BlockStructure, matmul as tri_matmul};
291
292    let (n, p) = a.dim();
293    assert_eq!(out.nrows(), p, "output rows must match p");
294    assert_eq!(out.ncols(), p, "output cols must match p");
295
296    if !should_use_faer_matmul(p, p, n) {
297        out.assign(&a.t().dot(a));
298        return;
299    }
300
301    let mut outview = array2_to_matmut(out);
302
303    let aview = FaerArrayView::new(a);
304    let a_ref = aview.as_ref();
305    let a_t = a_ref.transpose();
306    let par = matmul_parallelism(p, p, n);
307    tri_matmul(
308        outview.as_mut(),
309        BlockStructure::TriangularLower,
310        Accum::Replace,
311        a_t,
312        BlockStructure::Rectangular,
313        a_ref,
314        BlockStructure::Rectangular,
315        1.0,
316        par,
317    );
318    // Mirror lower triangle to upper to populate the full symmetric output.
319    for i in 0..p {
320        for j in (i + 1)..p {
321            out[[i, j]] = out[[j, i]];
322        }
323    }
324}
325
326/// Compute A^T * B using faer's SIMD-optimized GEMM.
327/// For A of shape (n, p) and B of shape (n, q), this computes the (p, q) result.
328/// Uses zero-copy views when possible.
329#[inline]
330pub fn fast_atb<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
331    a: &ArrayBase<S1, Ix2>,
332    b: &ArrayBase<S2, Ix2>,
333) -> Array2<f64> {
334    if let Some(out) =
335        crate::gpu_hook::gpu_dispatch().and_then(|d| d.try_fast_atb(a.view(), b.view()))
336    {
337        return out;
338    }
339    let (n_a, p) = a.dim();
340    let q = b.ncols();
341    fast_atb_with_parallelism(a, b, matmul_parallelism(p, q, n_a))
342}
343
344/// Compute A^T * B with an explicit faer parallelism policy for callers that
345/// are already running independent products in an outer Rayon task.
346#[inline]
347pub fn fast_atb_with_parallelism<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
348    a: &ArrayBase<S1, Ix2>,
349    b: &ArrayBase<S2, Ix2>,
350    par: Par,
351) -> Array2<f64> {
352    use faer::linalg::matmul::matmul;
353    use faer::{Accum, Mat};
354
355    let (n_a, p) = a.dim();
356    let (n_b, q) = b.dim();
357    assert_eq!(n_a, n_b, "A and B must have same number of rows");
358
359    // For very small matrices, ndarray might be faster due to less overhead
360    if !should_use_faer_matmul(p, q, n_a) {
361        return a.t().dot(b);
362    }
363
364    let mut result = Mat::<f64>::zeros(p, q);
365
366    let aview = FaerArrayView::new(a);
367    let bview = FaerArrayView::new(b);
368    let a_ref = aview.as_ref();
369    let b_ref = bview.as_ref();
370
371    // dst = A^T * B
372    matmul(
373        result.as_mut(),
374        Accum::Replace,
375        a_ref.transpose(),
376        b_ref,
377        1.0,
378        par,
379    );
380
381    mat_to_array(result.as_ref())
382}
383
384/// Compute A * B^T using faer's SIMD-optimized GEMM.
385/// For A of shape (m, k) and B of shape (n, k), this computes the (m, n) result.
386#[inline]
387pub fn fast_abt<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
388    a: &ArrayBase<S1, Ix2>,
389    b: &ArrayBase<S2, Ix2>,
390) -> Array2<f64> {
391    use faer::linalg::matmul::matmul;
392    use faer::{Accum, Mat};
393
394    let (m, k_a) = a.dim();
395    let (n, k_b) = b.dim();
396    assert_eq!(
397        k_a, k_b,
398        "A and B must have same number of columns for A·Bᵀ"
399    );
400
401    if !should_use_faer_matmul(m, n, k_a) {
402        return a.dot(&b.t());
403    }
404
405    let mut result = Mat::<f64>::zeros(m, n);
406    let aview = FaerArrayView::new(a);
407    let bview = FaerArrayView::new(b);
408    let par = matmul_parallelism(m, n, k_a);
409    matmul(
410        result.as_mut(),
411        Accum::Replace,
412        aview.as_ref(),
413        bview.as_ref().transpose(),
414        1.0,
415        par,
416    );
417    mat_to_array(result.as_ref())
418}
419
420/// Compute A * B using faer's SIMD-optimized GEMM.
421/// For A of shape (n, p) and B of shape (p, q), this computes the (n, q) result.
422/// Uses zero-copy views when possible.
423#[inline]
424pub fn fast_ab<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
425    a: &ArrayBase<S1, Ix2>,
426    b: &ArrayBase<S2, Ix2>,
427) -> Array2<f64> {
428    if let Some(out) =
429        crate::gpu_hook::gpu_dispatch().and_then(|d| d.try_fast_ab(a.view(), b.view()))
430    {
431        return out;
432    }
433    let n = a.nrows();
434    let q = b.ncols();
435    let mut out = Array2::<f64>::zeros((n, q));
436    fast_ab_into(a, b, &mut out);
437    out
438}
439
440// ────────────────────────────────────────────────────────────────────────
441// Compensated / blocked SIMD reduction kernels for the GEMV hot paths.
442//
443// `fast_av` (η = Xβ) and `fast_atv` (Xᵀr — e.g. the penalized-likelihood
444// gradient and REML score) are reduction-bound: every output entry is a sum
445// of products over a long axis. faer's generic GEMM serves them as degenerate
446// single-RHS-column matmuls, whose blocking/setup cost is poorly amortized by
447// one column. For the dominant row-major-contiguous case we use tight hand
448// kernels that are simultaneously
449//   * faster — several independent FMA accumulators expose the
450//     instruction-level parallelism the backend lowers to packed AVX
451//     `vfmadd` lanes, and the row work fans out across the Rayon pool; and
452//   * more accurate — `f64::mul_add` fuses each product into its accumulator
453//     with a single rounding (no rounded intermediate product), the lanes
454//     reduce as a small pairwise tree, and the long Xᵀr reduction is split
455//     into fixed-size row blocks whose partials are combined pairwise,
456//     turning the naive O(n·ε) error growth into ~O((block + log(n/block))·ε).
457//
458// Non-contiguous / non-row-major operands fall back to the faer path, so the
459// numerics only change (improve) on the common standard-layout inputs.
460// ────────────────────────────────────────────────────────────────────────
461
462/// Number of independent FMA accumulator lanes. Eight lanes keep two 256-bit
463/// (`f64x4`) FMA pipelines fed and set the partial-pairwise leaf width.
464const FMA_LANES: usize = 8;
465
466/// FLOP-scale (n·p) below which the kernels stay serial; at or above it, and
467/// only when not already inside a parallel row region, the row loop fans out
468/// across the Rayon pool.
469const KERNEL_PAR_MIN_FLOP: usize = 1 << 18; // 262_144
470
471/// Rows per row-block in [`fast_av_rowmajor_into`]'s parallel fan-out; large
472/// enough to amortize Rayon task overhead over many short row dots.
473const AV_PAR_CHUNK_ROWS: usize = 1024;
474
475/// Rows per reduction block in [`fast_atv_rowmajor_into`]; each block sums its
476/// rows into a private length-p partial and the partials combine pairwise, so
477/// the long-axis rounding error grows with the block size plus the log of the
478/// block count rather than with `n`.
479const ATV_BLOCK_ROWS: usize = 512;
480
481#[inline]
482fn kernel_should_parallelize(n: usize, p: usize) -> bool {
483    !in_nested_parallel_region()
484        && n.saturating_mul(p) >= KERNEL_PAR_MIN_FLOP
485        && rayon::current_num_threads() > 1
486}
487
488/// Compensated dot product (the Ogita–Rump–Oishi *Dot2* error-free transform)
489/// of two equal-length contiguous slices, evaluated over [`FMA_LANES`]
490/// independent compensated accumulators.
491///
492/// For each term the product is split into its rounded value plus the *exact*
493/// product error via `mul_add` (`two_prod`), and added into the running sum via
494/// a branchless `two_sum`, with both rounding errors folded into a
495/// per-lane compensation. The result carries roughly twice the working
496/// precision: its error-vs-truth is bounded by `u·|result| + O(n·u²)·|x|ᵀ|y|`
497/// versus the naive recurrence's `O(n·u)·|x|ᵀ|y|`, i.e. strictly — often by
498/// many orders of magnitude — more accurate. The eight independent lanes keep
499/// the FMA pipelines saturated, and on the GEMV hot paths the extra arithmetic
500/// is hidden under the memory traffic of streaming `X`, so accuracy rises with
501/// no throughput cost.
502#[inline(always)]
503fn fma_dot(a: &[f64], b: &[f64]) -> f64 {
504    assert_eq!(a.len(), b.len());
505    let mut sum = [0.0f64; FMA_LANES];
506    let mut comp = [0.0f64; FMA_LANES];
507    let mut ca = a.chunks_exact(FMA_LANES);
508    let mut cb = b.chunks_exact(FMA_LANES);
509    for (xa, xb) in ca.by_ref().zip(cb.by_ref()) {
510        for l in 0..FMA_LANES {
511            let x = xa[l];
512            let y = xb[l];
513            // two_prod: p = round(x·y), ep = exact error x·y − p.
514            let p = x * y;
515            let ep = x.mul_add(y, -p);
516            // two_sum: s = round(sum + p), es = exact error.
517            let s = sum[l] + p;
518            let bb = s - sum[l];
519            let es = (sum[l] - (s - bb)) + (p - bb);
520            sum[l] = s;
521            comp[l] += ep + es;
522        }
523    }
524    // Compensated remainder lane (length < FMA_LANES).
525    let mut sr = 0.0f64;
526    let mut cr = 0.0f64;
527    for (&x, &y) in ca.remainder().iter().zip(cb.remainder().iter()) {
528        let p = x * y;
529        let ep = x.mul_add(y, -p);
530        let s = sr + p;
531        let bb = s - sr;
532        let es = (sr - (s - bb)) + (p - bb);
533        sr = s;
534        cr += ep + es;
535    }
536    // Fold each lane's compensation back in, then reduce the (few) lanes.
537    let mut total = sr + cr;
538    for l in 0..FMA_LANES {
539        total += sum[l] + comp[l];
540    }
541    total
542}
543
544/// `out[i] = Σ_j X[i,j]·v[j]` for row-major-contiguous `x_all` (len `n·p`) and
545/// `v` (len `p`). Each output row is an independent [`fma_dot`]; rows fan out
546/// in chunks across the Rayon pool when the work is large.
547fn fast_av_rowmajor_into(x_all: &[f64], v: &[f64], n: usize, p: usize, out: &mut [f64]) {
548    assert_eq!(x_all.len(), n * p);
549    assert_eq!(v.len(), p);
550    assert_eq!(out.len(), n);
551    if kernel_should_parallelize(n, p) {
552        use rayon::prelude::*;
553        out.par_chunks_mut(AV_PAR_CHUNK_ROWS)
554            .enumerate()
555            .for_each(|(c, chunk)| {
556                let base = c * AV_PAR_CHUNK_ROWS;
557                for (k, o) in chunk.iter_mut().enumerate() {
558                    let i = base + k;
559                    *o = fma_dot(&x_all[i * p..i * p + p], v);
560                }
561            });
562    } else {
563        for (i, o) in out.iter_mut().enumerate() {
564            *o = fma_dot(&x_all[i * p..i * p + p], v);
565        }
566    }
567}
568
569/// Pairwise (tree) sum of equal-length partial vectors into `out`.
570fn pairwise_sum_into(parts: &[Vec<f64>], out: &mut [f64]) {
571    match parts.len() {
572        0 => out.fill(0.0),
573        1 => out.copy_from_slice(&parts[0]),
574        _ => {
575            let mid = parts.len() / 2;
576            let p = out.len();
577            let mut left = vec![0.0f64; p];
578            let mut right = vec![0.0f64; p];
579            pairwise_sum_into(&parts[..mid], &mut left);
580            pairwise_sum_into(&parts[mid..], &mut right);
581            for ((o, &l), &r) in out.iter_mut().zip(left.iter()).zip(right.iter()) {
582                *o = l + r;
583            }
584        }
585    }
586}
587
588/// `out[j] = Σ_i v[i]·X[i,j]` for row-major-contiguous `x_all` (len `n·p`).
589///
590/// Rows are grouped into [`ATV_BLOCK_ROWS`] blocks; each block FMA-accumulates
591/// its rows into a private partial vector (fused `v[i]·X[i,j]`), and the block
592/// partials are combined pairwise. This blocked/pairwise reduction is both
593/// better-conditioned than a single running sum over all `n` rows and trivially
594/// parallel across blocks.
595fn fast_atv_rowmajor_into(x_all: &[f64], v: &[f64], n: usize, p: usize, out: &mut [f64]) {
596    assert_eq!(x_all.len(), n * p);
597    assert_eq!(v.len(), n);
598    assert_eq!(out.len(), p);
599    let nblocks = n.div_ceil(ATV_BLOCK_ROWS);
600
601    let block_partial = |b: usize| -> Vec<f64> {
602        let start = b * ATV_BLOCK_ROWS;
603        let end = (start + ATV_BLOCK_ROWS).min(n);
604        let mut acc = vec![0.0f64; p];
605        for i in start..end {
606            let vi = v[i];
607            let row = &x_all[i * p..i * p + p];
608            for (a, &xij) in acc.iter_mut().zip(row.iter()) {
609                *a = xij.mul_add(vi, *a);
610            }
611        }
612        acc
613    };
614
615    let partials: Vec<Vec<f64>> = if kernel_should_parallelize(n, p) {
616        use rayon::prelude::*;
617        (0..nblocks).into_par_iter().map(block_partial).collect()
618    } else {
619        (0..nblocks).map(block_partial).collect()
620    };
621
622    pairwise_sum_into(&partials, out);
623}
624
625/// Compute A * v using faer's SIMD-optimized GEMV.
626/// For A of shape (n, p) and v of shape (p,), this computes the (n,) result.
627#[inline]
628pub fn fast_av<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
629    a: &ArrayBase<S1, Ix2>,
630    v: &ArrayBase<S2, Ix1>,
631) -> Array1<f64> {
632    if let Some(out) =
633        crate::gpu_hook::gpu_dispatch().and_then(|d| d.try_fast_av(a.view(), v.view()))
634    {
635        return out;
636    }
637    fast_av_impl(a, v)
638}
639
640#[inline]
641fn fast_av_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
642    a: &ArrayBase<S1, Ix2>,
643    v: &ArrayBase<S2, Ix1>,
644) -> Array1<f64> {
645    use faer::linalg::matmul::matmul;
646    use faer::{Accum, Mat};
647
648    let (n, p) = a.dim();
649    assert_eq!(p, v.len(), "A cols must match v length");
650
651    // Row-major-contiguous fast path: tight multi-lane FMA dot per row, both
652    // faster (ILP / Rayon fan-out) and more accurate (fused products, pairwise
653    // lane reduction) than the degenerate single-column faer GEMV.
654    if let (Some(x_all), Some(vs)) = (a.as_slice(), v.as_slice())
655        && n != 0
656        && p != 0
657    {
658        let mut out = Array1::<f64>::zeros(n);
659        fast_av_rowmajor_into(
660            x_all,
661            vs,
662            n,
663            p,
664            out.as_slice_mut().expect("fresh Array1 is contiguous"),
665        );
666        return out;
667    }
668
669    if !should_use_faer_matmul(n, 1, p) {
670        return a.dot(v);
671    }
672
673    let mut result = Mat::<f64>::zeros(n, 1);
674
675    let aview = FaerArrayView::new(a);
676    let vview = FaerColView::new(v);
677    let a_ref = aview.as_ref();
678    let v_ref = vview.as_ref();
679
680    let par = matmul_parallelism(n, 1, p);
681    matmul(result.as_mut(), Accum::Replace, a_ref, v_ref, 1.0, par);
682
683    let mut out = Array1::<f64>::zeros(n);
684    for i in 0..n {
685        out[i] = result[(i, 0)];
686    }
687    out
688}
689
690/// Compute A * v into a pre-allocated output buffer.
691/// `out` must be length n where A is (n, p) and v is length p.
692#[inline]
693pub fn fast_av_into<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
694    a: &ArrayBase<S1, Ix2>,
695    v: &ArrayBase<S2, Ix1>,
696    out: &mut Array1<f64>,
697) {
698    fast_av_into_impl(a, v, out);
699}
700
701#[inline]
702fn fast_av_into_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
703    a: &ArrayBase<S1, Ix2>,
704    v: &ArrayBase<S2, Ix1>,
705    out: &mut Array1<f64>,
706) {
707    use faer::Accum;
708    use faer::linalg::matmul::matmul;
709
710    let (n, p) = a.dim();
711    assert_eq!(v.len(), p, "vector length must match A cols");
712    assert_eq!(out.len(), n, "output length must match A rows");
713
714    if let (Some(x_all), Some(vs)) = (a.as_slice(), v.as_slice())
715        && n != 0
716        && p != 0
717        && let Some(out_s) = out.as_slice_mut()
718    {
719        fast_av_rowmajor_into(x_all, vs, n, p, out_s);
720        return;
721    }
722
723    if !should_use_faer_matmul(n, 1, p) {
724        out.assign(&a.dot(v));
725        return;
726    }
727
728    let mut outview = array1_to_col_matmut(out);
729
730    let aview = FaerArrayView::new(a);
731    let vview = FaerColView::new(v);
732    let a_ref = aview.as_ref();
733    let v_ref = vview.as_ref();
734    let par = matmul_parallelism(n, 1, p);
735    matmul(outview.as_mut(), Accum::Replace, a_ref, v_ref, 1.0, par);
736}
737
738/// Compute A * v into a pre-allocated `ArrayViewMut1` slice. Like
739/// [`fast_av_into`] but accepts a writable slice rather than `&mut Array1`,
740/// so callers can write directly into a sub-range of a larger buffer
741/// without intermediate allocation.
742///
743/// `out` must have length n where A is (n, p) and v is length p.
744#[inline]
745pub fn fast_av_view_into<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
746    a: &ArrayBase<S1, Ix2>,
747    v: &ArrayBase<S2, Ix1>,
748    out: ArrayViewMut1<'_, f64>,
749) {
750    fast_av_view_into_impl(a, v, out);
751}
752
753#[inline]
754fn fast_av_view_into_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
755    a: &ArrayBase<S1, Ix2>,
756    v: &ArrayBase<S2, Ix1>,
757    mut out: ArrayViewMut1<'_, f64>,
758) {
759    use faer::Accum;
760    use faer::linalg::matmul::matmul;
761
762    let (n, p) = a.dim();
763    assert_eq!(v.len(), p, "vector length must match A cols");
764    assert_eq!(out.len(), n, "output length must match A rows");
765
766    if let (Some(x_all), Some(vs)) = (a.as_slice(), v.as_slice())
767        && n != 0
768        && p != 0
769        && let Some(out_s) = out.as_slice_mut()
770    {
771        fast_av_rowmajor_into(x_all, vs, n, p, out_s);
772        return;
773    }
774
775    if !should_use_faer_matmul(n, 1, p) {
776        let prod = a.dot(v);
777        out.assign(&prod);
778        return;
779    }
780
781    let len = out.len();
782    let stride = out.strides()[0];
783    // SAFETY: out.as_mut_ptr() is ndarray's logical first-element pointer, and
784    // len plus the signed element stride describe every initialized element of
785    // this uniquely borrowed view for the returned len×1 MatMut lifetime.
786    let outview = unsafe {
787        MatMut::from_raw_parts_mut(
788            out.as_mut_ptr(),
789            len,
790            1,
791            stride,
792            0, // col stride irrelevant for 1 column
793        )
794    };
795
796    let aview = FaerArrayView::new(a);
797    let vview = FaerColView::new(v);
798    let a_ref = aview.as_ref();
799    let v_ref = vview.as_ref();
800    let par = matmul_parallelism(n, 1, p);
801    matmul(outview, Accum::Replace, a_ref, v_ref, 1.0, par);
802}
803
804/// Compute A^T * v using faer's SIMD-optimized GEMV.
805/// For A of shape (n, p) and v of shape (n,), this computes the (p,) result.
806#[inline]
807pub fn fast_atv<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
808    a: &ArrayBase<S1, Ix2>,
809    v: &ArrayBase<S2, Ix1>,
810) -> Array1<f64> {
811    if let Some(out) =
812        crate::gpu_hook::gpu_dispatch().and_then(|d| d.try_fast_atv(a.view(), v.view()))
813    {
814        return out;
815    }
816    fast_atv_impl(a, v)
817}
818
819#[inline]
820fn fast_atv_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
821    a: &ArrayBase<S1, Ix2>,
822    v: &ArrayBase<S2, Ix1>,
823) -> Array1<f64> {
824    use faer::Accum;
825    use faer::linalg::matmul::matmul;
826
827    let (n, p) = a.dim();
828    assert_eq!(n, v.len(), "A rows must match v length");
829
830    // Row-major-contiguous fast path: blocked + pairwise FMA reduction over the
831    // long n-axis. Lower error-vs-truth than a single running sum and parallel
832    // across row blocks.
833    if let (Some(x_all), Some(vs)) = (a.as_slice(), v.as_slice())
834        && n != 0
835        && p != 0
836    {
837        let mut out = Array1::<f64>::zeros(p);
838        fast_atv_rowmajor_into(
839            x_all,
840            vs,
841            n,
842            p,
843            out.as_slice_mut().expect("fresh Array1 is contiguous"),
844        );
845        return out;
846    }
847
848    // For very small arrays, ndarray might be faster
849    if !should_use_faer_matmul(p, 1, n) {
850        return a.t().dot(v);
851    }
852
853    let mut out = Array1::<f64>::zeros(p);
854    let mut outview = array1_to_col_matmut(&mut out);
855
856    let aview = FaerArrayView::new(a);
857    let vview = FaerColView::new(v);
858    let a_ref = aview.as_ref();
859    let v_ref = vview.as_ref();
860
861    // dst = A^T * v (treating v as n×1 matrix)
862    let par = matmul_parallelism(p, 1, n);
863    matmul(
864        outview.as_mut(),
865        Accum::Replace,
866        a_ref.transpose(),
867        v_ref,
868        1.0,
869        par,
870    );
871
872    out
873}
874
875/// Compute A^T * v into a pre-allocated output buffer.
876/// `out` must be length p where A is (n, p) and v is length n.
877#[inline]
878pub fn fast_atv_into<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
879    a: &ArrayBase<S1, Ix2>,
880    v: &ArrayBase<S2, Ix1>,
881    out: &mut Array1<f64>,
882) {
883    fast_atv_into_impl(a, v, out);
884}
885
886#[inline]
887fn fast_atv_into_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
888    a: &ArrayBase<S1, Ix2>,
889    v: &ArrayBase<S2, Ix1>,
890    out: &mut Array1<f64>,
891) {
892    use faer::Accum;
893    use faer::linalg::matmul::matmul;
894
895    let (n, p) = a.dim();
896    assert_eq!(v.len(), n, "vector length must match A rows");
897    assert_eq!(out.len(), p, "output length must match A cols");
898
899    if let (Some(x_all), Some(vs)) = (a.as_slice(), v.as_slice())
900        && n != 0
901        && p != 0
902        && let Some(out_s) = out.as_slice_mut()
903    {
904        fast_atv_rowmajor_into(x_all, vs, n, p, out_s);
905        return;
906    }
907
908    if !should_use_faer_matmul(p, 1, n) {
909        out.assign(&a.t().dot(v));
910        return;
911    }
912
913    let mut outview = array1_to_col_matmut(out);
914
915    let aview = FaerArrayView::new(a);
916    let vview = FaerColView::new(v);
917    let a_ref = aview.as_ref();
918    let v_ref = vview.as_ref();
919    let par = matmul_parallelism(p, 1, n);
920    matmul(
921        outview.as_mut(),
922        Accum::Replace,
923        a_ref.transpose(),
924        v_ref,
925        1.0,
926        par,
927    );
928}
929
930/// Compute A^T * diag(W) * A using streaming chunks to avoid O(n*p) allocation.
931#[inline]
932pub fn fast_xt_diag_x<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
933    x: &ArrayBase<S1, Ix2>,
934    w: &ArrayBase<S2, Ix1>,
935) -> Array2<f64> {
936    assert_eq!(
937        x.nrows(),
938        w.len(),
939        "fast_xt_diag_x row/weight length mismatch"
940    );
941    if let Some(out) =
942        crate::gpu_hook::gpu_dispatch().and_then(|d| d.try_fast_xt_diag_x(x.view(), w.view()))
943    {
944        return out;
945    }
946    let p = x.ncols();
947    fast_xt_diag_x_with_parallelism(x, w, matmul_parallelism(p, p, x.nrows()))
948}
949
950/// Compute A^T * diag(W) * A with an explicit faer parallelism policy for
951/// callers that parallelize multiple independent Hessian blocks externally.
952#[inline]
953pub fn fast_xt_diag_x_with_parallelism<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
954    x: &ArrayBase<S1, Ix2>,
955    w: &ArrayBase<S2, Ix1>,
956    par: Par,
957) -> Array2<f64> {
958    assert_eq!(
959        x.nrows(),
960        w.len(),
961        "fast_xt_diag_x_with_parallelism row/weight length mismatch"
962    );
963    fast_xt_diag_x_with_parallelism_impl(x, w, par)
964}
965
966#[inline]
967fn fast_xt_diag_x_with_parallelism_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
968    x: &ArrayBase<S1, Ix2>,
969    w: &ArrayBase<S2, Ix1>,
970    par: Par,
971) -> Array2<f64> {
972    use ndarray::ShapeBuilder;
973
974    let p = x.ncols();
975    // F-order result so the symmetric lower-triangle accumulation writes
976    // column-contiguously; the kernel mirrors to a full symmetric matrix.
977    let mut result = Array2::<f64>::zeros((p, p).f());
978    stream_weighted_crossprod_into(
979        x,
980        w,
981        &mut result,
982        CrossprodStructure::SymmetricLower,
983        CrossprodAccum::Replace,
984        par,
985    );
986    result
987}
988
989/// Output packaging for [`stream_weighted_crossprod_into`].
990#[derive(Clone, Copy, PartialEq, Eq, Debug)]
991pub enum CrossprodStructure {
992    /// Compute every entry of the (symmetric) Gram via full GEMM.
993    Full,
994    /// Accumulate only the lower triangle via triangular matmul (~50% fewer
995    /// FLOPs), then mirror once into the upper triangle for a full symmetric
996    /// result. Mathematically identical output to [`Full`](Self::Full).
997    SymmetricLower,
998}
999
1000/// Accumulation policy for [`stream_weighted_crossprod_into`].
1001#[derive(Clone, Copy, PartialEq, Eq, Debug)]
1002pub enum CrossprodAccum {
1003    /// Overwrite `out` with `Xᵀ·diag(W)·X`, ignoring prior contents.
1004    Replace,
1005    /// Add `Xᵀ·diag(W)·X` into the existing contents of `out`.
1006    Add,
1007}
1008
1009/// Shared dense weighted-Gram kernel: accumulate `Xᵀ·diag(W)·X` into `out`.
1010///
1011/// This is the single tuned implementation of the chunked row-scaling +
1012/// matmul strategy; the matrix-returning (`fast_xt_diag_x*`) entry points and
1013/// stream-in callers share it so that performance tuning, negative-weight
1014/// handling, chunk sizing, and layout fixes land in exactly one place.
1015///
1016/// Computes the product as `Xᵀ·(W·X)` to preserve the sign of `W`: the prior
1017/// `sqrt(max(0, w))`-then-Gram form clipped negative weights to zero, which
1018/// corrupted observed-Hessian assembly when any block carried heavy residuals
1019/// (e.g. under the logb σ link).
1020///
1021/// Peak working-set allocation is `chunk_rows × p × 8` bytes (~8 MB) rather
1022/// than `n × p × 8` bytes for a materialized `W·X`.
1023///
1024/// `out` must be `p × p`. With [`CrossprodStructure::SymmetricLower`] the
1025/// lower triangle is accumulated and then mirrored, so on return `out` holds
1026/// the full symmetric matrix regardless of `structure`.
1027pub fn stream_weighted_crossprod_into<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
1028    x: &ArrayBase<S1, Ix2>,
1029    w: &ArrayBase<S2, Ix1>,
1030    out: &mut Array2<f64>,
1031    structure: CrossprodStructure,
1032    accum: CrossprodAccum,
1033    par: Par,
1034) {
1035    use faer::Accum;
1036    use faer::linalg::matmul::matmul;
1037    use faer::linalg::matmul::triangular::{BlockStructure, matmul as tri_matmul};
1038    use ndarray::s;
1039
1040    let (n, p) = x.dim();
1041    assert_eq!(n, w.len(), "X rows must match W length");
1042    assert_eq!(out.nrows(), p, "output rows must match X cols");
1043    assert_eq!(out.ncols(), p, "output cols must match X cols");
1044    if p == 0 {
1045        return;
1046    }
1047    if n == 0 {
1048        if accum == CrossprodAccum::Replace {
1049            out.fill(0.0);
1050        }
1051        return;
1052    }
1053
1054    if !should_use_faer_matmul(p, p, n) {
1055        // Tiny products: ndarray's own GEMM avoids faer setup overhead.
1056        let w_x = Array2::from_shape_fn((n, p), |(i, j)| w[i] * x[[i, j]]);
1057        let gram = x.t().dot(&w_x);
1058        match accum {
1059            CrossprodAccum::Replace => out.assign(&gram),
1060            CrossprodAccum::Add => *out += &gram,
1061        }
1062        return;
1063    }
1064
1065    // Streaming chunked: peak allocation is chunk_rows × p instead of n × p.
1066    const TARGET_BYTES: usize = 8 * 1024 * 1024;
1067    const MIN_ROWS: usize = 512;
1068    const MAX_ROWS: usize = 131_072;
1069    let chunk_rows = (TARGET_BYTES / (p.max(1) * 8))
1070        .clamp(MIN_ROWS, MAX_ROWS)
1071        .min(n);
1072
1073    // Triangular accumulation requires a zero baseline in the lower triangle
1074    // because each chunk's `Accum::Add` lands there; for a Replace request we
1075    // zero up front and add every chunk, for an Add request the caller's
1076    // contents are preserved and every chunk adds on top.
1077    if accum == CrossprodAccum::Replace {
1078        out.fill(0.0);
1079    }
1080
1081    // Row-major wx_chunk so the per-row scaling loop has stride-1 writes
1082    // alongside stride-1 reads from a row-major X. An F-order wx_chunk would
1083    // force strided writes by `chunk_rows`, breaking vectorization and cache
1084    // locality on the per-PIRLS-iter Hessian assembly. faer's matmul handles
1085    // either layout via FaerArrayView.
1086    let mut wx_chunk = Array2::<f64>::zeros((chunk_rows, p));
1087
1088    let x_is_row_major = x.is_standard_layout();
1089    let w_slice_opt = w.as_slice();
1090
1091    // Scope the faer mutable view so its borrow on `out` ends before the
1092    // symmetric mirror step.
1093    {
1094        let mut out_view = array2_to_matmut(out);
1095        for start in (0..n).step_by(chunk_rows) {
1096            let rows = (n - start).min(chunk_rows);
1097            {
1098                let chunk_slice = wx_chunk
1099                    .as_slice_mut()
1100                    .expect("row-major chunk is contiguous");
1101                if x_is_row_major && let (Some(x_all), Some(w_all)) = (x.as_slice(), w_slice_opt) {
1102                    for local in 0..rows {
1103                        let src = start + local;
1104                        let wi = w_all[src];
1105                        let src_off = src * p;
1106                        let dst_off = local * p;
1107                        let src_row = &x_all[src_off..src_off + p];
1108                        let dst_row = &mut chunk_slice[dst_off..dst_off + p];
1109                        for col in 0..p {
1110                            dst_row[col] = src_row[col] * wi;
1111                        }
1112                    }
1113                } else {
1114                    let x_slice = x.slice(s![start..start + rows, ..]);
1115                    for local in 0..rows {
1116                        let wi = w[start + local];
1117                        let xrow = x_slice.row(local);
1118                        let dst_off = local * p;
1119                        let dst_row = &mut chunk_slice[dst_off..dst_off + p];
1120                        for (col, xij) in xrow.iter().enumerate() {
1121                            dst_row[col] = xij * wi;
1122                        }
1123                    }
1124                }
1125            }
1126            let x_slice = x.slice(s![start..start + rows, ..]);
1127            let wx_slice = wx_chunk.slice(s![0..rows, ..]);
1128            let x_view = FaerArrayView::new(&x_slice);
1129            let wx_view = FaerArrayView::new(&wx_slice);
1130            match structure {
1131                CrossprodStructure::SymmetricLower => {
1132                    // X^T diag(W) X is symmetric; accumulate the lower triangle
1133                    // only, then mirror once after the chunk loop. ~50% fewer
1134                    // FLOPs vs. full GEMM.
1135                    tri_matmul(
1136                        out_view.as_mut(),
1137                        BlockStructure::TriangularLower,
1138                        Accum::Add,
1139                        x_view.as_ref().transpose(),
1140                        BlockStructure::Rectangular,
1141                        wx_view.as_ref(),
1142                        BlockStructure::Rectangular,
1143                        1.0,
1144                        par,
1145                    );
1146                }
1147                CrossprodStructure::Full => {
1148                    matmul(
1149                        out_view.as_mut(),
1150                        Accum::Add,
1151                        x_view.as_ref().transpose(),
1152                        wx_view.as_ref(),
1153                        1.0,
1154                        par,
1155                    );
1156                }
1157            }
1158        }
1159    }
1160
1161    if structure == CrossprodStructure::SymmetricLower {
1162        // Mirror lower triangle to upper for a full symmetric output.
1163        for i in 0..p {
1164            for j in (i + 1)..p {
1165                out[[i, j]] = out[[j, i]];
1166            }
1167        }
1168    }
1169}
1170
1171/// Compute A^T * diag(W) * B using streaming chunks.
1172#[inline]
1173pub fn fast_xt_diag_y<S1: Data<Elem = f64>, S2: Data<Elem = f64>, S3: Data<Elem = f64>>(
1174    x: &ArrayBase<S1, Ix2>,
1175    w: &ArrayBase<S2, Ix1>,
1176    y: &ArrayBase<S3, Ix2>,
1177) -> Array2<f64> {
1178    assert_eq!(x.nrows(), y.nrows(), "fast_xt_diag_y X/Y row mismatch");
1179    assert_eq!(
1180        y.nrows(),
1181        w.len(),
1182        "fast_xt_diag_y row/weight length mismatch"
1183    );
1184    if let Some(out) = crate::gpu_hook::gpu_dispatch()
1185        .and_then(|d| d.try_fast_xt_diag_y(x.view(), w.view(), y.view()))
1186    {
1187        return out;
1188    }
1189    fast_xt_diag_y_impl(x, w, y)
1190}
1191
1192#[inline]
1193fn fast_xt_diag_y_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>, S3: Data<Elem = f64>>(
1194    x: &ArrayBase<S1, Ix2>,
1195    w: &ArrayBase<S2, Ix1>,
1196    y: &ArrayBase<S3, Ix2>,
1197) -> Array2<f64> {
1198    use faer::Accum;
1199    use faer::linalg::matmul::matmul;
1200    use ndarray::{ShapeBuilder, s};
1201
1202    let (n, q) = y.dim();
1203    let px = x.ncols();
1204    assert_eq!(n, w.len(), "Y rows must match W length");
1205    assert_eq!(n, x.nrows(), "X rows must match Y rows");
1206    if n == 0 || px == 0 || q == 0 {
1207        return Array2::<f64>::zeros((px, q));
1208    }
1209    if !should_use_faer_matmul(px, q, n) {
1210        let w_y = Array2::from_shape_fn((n, q), |(i, j)| w[i] * y[[i, j]]);
1211        return x.t().dot(&w_y);
1212    }
1213
1214    // Streaming: only allocate chunk_rows × q for the weighted Y slice.
1215    const TARGET_BYTES: usize = 8 * 1024 * 1024;
1216    const MIN_ROWS: usize = 512;
1217    const MAX_ROWS: usize = 131_072;
1218    let total_cols = px + q;
1219    let chunk_rows = (TARGET_BYTES / (total_cols.max(1) * 8))
1220        .clamp(MIN_ROWS, MAX_ROWS)
1221        .min(n);
1222
1223    let mut result = Array2::<f64>::zeros((px, q).f());
1224    // Row-major wy_chunk — same rationale as fast_xt_diag_x: stride-1
1225    // writes alongside stride-1 reads from a row-major Y.
1226    let mut wy_chunk = Array2::<f64>::zeros((chunk_rows, q));
1227
1228    let y_is_row_major = y.is_standard_layout();
1229    let w_slice_opt = w.as_slice();
1230
1231    {
1232        let mut out_view = array2_to_matmut(&mut result);
1233
1234        for start in (0..n).step_by(chunk_rows) {
1235            let rows = (n - start).min(chunk_rows);
1236            {
1237                let chunk_slice = wy_chunk
1238                    .as_slice_mut()
1239                    .expect("row-major chunk is contiguous");
1240                if y_is_row_major && let (Some(y_all), Some(w_all)) = (y.as_slice(), w_slice_opt) {
1241                    for local in 0..rows {
1242                        let src = start + local;
1243                        let wi = w_all[src];
1244                        let src_off = src * q;
1245                        let dst_off = local * q;
1246                        let src_row = &y_all[src_off..src_off + q];
1247                        let dst_row = &mut chunk_slice[dst_off..dst_off + q];
1248                        for col in 0..q {
1249                            dst_row[col] = src_row[col] * wi;
1250                        }
1251                    }
1252                } else {
1253                    let y_slice = y.slice(s![start..start + rows, ..]);
1254                    for local in 0..rows {
1255                        let wi = w[start + local];
1256                        let yrow = y_slice.row(local);
1257                        let dst_off = local * q;
1258                        let dst_row = &mut chunk_slice[dst_off..dst_off + q];
1259                        for (col, yij) in yrow.iter().enumerate() {
1260                            dst_row[col] = yij * wi;
1261                        }
1262                    }
1263                }
1264            }
1265            let x_slice = x.slice(s![start..start + rows, ..]);
1266            let wy_slice = wy_chunk.slice(s![0..rows, ..]);
1267            let x_view = FaerArrayView::new(&x_slice);
1268            let wy_view = FaerArrayView::new(&wy_slice);
1269            let par = matmul_parallelism(px, q, rows);
1270            matmul(
1271                out_view.as_mut(),
1272                Accum::Add,
1273                x_view.as_ref().transpose(),
1274                wy_view.as_ref(),
1275                1.0,
1276                par,
1277            );
1278        }
1279    }
1280
1281    result
1282}
1283
1284/// Compute the 2×2 block joint Hessian in a single streaming pass:
1285///   [X_a^T diag(w_aa) X_a,   X_a^T diag(w_ab) X_b]
1286///   [X_b^T diag(w_ab) X_a,   X_b^T diag(w_bb) X_b]
1287///
1288/// This reads X_a and X_b once per chunk instead of twice (saving 50% bandwidth).
1289pub fn fast_joint_hessian_2x2<
1290    S1: Data<Elem = f64>,
1291    S2: Data<Elem = f64>,
1292    S3: Data<Elem = f64>,
1293    S4: Data<Elem = f64>,
1294    S5: Data<Elem = f64>,
1295>(
1296    x_a: &ArrayBase<S1, Ix2>,
1297    x_b: &ArrayBase<S2, Ix2>,
1298    w_aa: &ArrayBase<S3, Ix1>,
1299    w_ab: &ArrayBase<S4, Ix1>,
1300    w_bb: &ArrayBase<S5, Ix1>,
1301) -> Array2<f64> {
1302    if let Some(out) = crate::gpu_hook::gpu_dispatch().and_then(|d| {
1303        d.try_fast_joint_hessian_2x2(
1304            x_a.view(),
1305            x_b.view(),
1306            w_aa.view(),
1307            w_ab.view(),
1308            w_bb.view(),
1309        )
1310    }) {
1311        return out;
1312    }
1313    fast_joint_hessian_2x2_impl(x_a, x_b, w_aa, w_ab, w_bb)
1314}
1315
1316#[inline]
1317fn fast_joint_hessian_2x2_impl<
1318    S1: Data<Elem = f64>,
1319    S2: Data<Elem = f64>,
1320    S3: Data<Elem = f64>,
1321    S4: Data<Elem = f64>,
1322    S5: Data<Elem = f64>,
1323>(
1324    x_a: &ArrayBase<S1, Ix2>,
1325    x_b: &ArrayBase<S2, Ix2>,
1326    w_aa: &ArrayBase<S3, Ix1>,
1327    w_ab: &ArrayBase<S4, Ix1>,
1328    w_bb: &ArrayBase<S5, Ix1>,
1329) -> Array2<f64> {
1330    use faer::Accum;
1331    use faer::linalg::matmul::matmul;
1332    use ndarray::{ShapeBuilder, s};
1333
1334    let n = x_a.nrows();
1335    let pa = x_a.ncols();
1336    let pb = x_b.ncols();
1337    let total = pa + pb;
1338    assert_eq!(n, x_b.nrows());
1339    assert_eq!(n, w_aa.len());
1340    assert_eq!(n, w_ab.len());
1341    assert_eq!(n, w_bb.len());
1342
1343    if n == 0 || total == 0 {
1344        return Array2::<f64>::zeros((total, total));
1345    }
1346
1347    // For small problems, fall back to separate computations
1348    if !should_use_faer_matmul(pa.max(pb), pa.max(pb), n) {
1349        let waa_xa = Array2::from_shape_fn((n, pa), |(i, j)| w_aa[i] * x_a[[i, j]]);
1350        let wab_xb = Array2::from_shape_fn((n, pb), |(i, j)| w_ab[i] * x_b[[i, j]]);
1351        let wbb_xb = Array2::from_shape_fn((n, pb), |(i, j)| w_bb[i] * x_b[[i, j]]);
1352        let mut out = Array2::<f64>::zeros((total, total));
1353        out.slice_mut(s![..pa, ..pa]).assign(&x_a.t().dot(&waa_xa));
1354        out.slice_mut(s![..pa, pa..]).assign(&x_a.t().dot(&wab_xb));
1355        out.slice_mut(s![pa.., pa..]).assign(&x_b.t().dot(&wbb_xb));
1356        // Mirror upper to lower
1357        for i in 0..total {
1358            for j in 0..i {
1359                out[[i, j]] = out[[j, i]];
1360            }
1361        }
1362        return out;
1363    }
1364
1365    const TARGET_BYTES: usize = 8 * 1024 * 1024;
1366    const MIN_ROWS: usize = 512;
1367    const MAX_ROWS: usize = 131_072;
1368    // Need buffers for: waa_xa(chunk×pa) + wab_xb(chunk×pb) + wbb_xb(chunk×pb)
1369    let cols_needed = pa + 2 * pb;
1370    let chunk_rows = (TARGET_BYTES / (cols_needed.max(1) * 8))
1371        .clamp(MIN_ROWS, MAX_ROWS)
1372        .min(n);
1373
1374    let mut out = Array2::<f64>::zeros((total, total).f());
1375    // Row-major weighted buffers so the per-row scale loops have stride-1
1376    // writes (the previous F-order layout strided writes by chunk_rows
1377    // across `pa` / `pb`, gutting vectorization on the per-PIRLS-iter
1378    // joint Hessian assembly). faer's matmul handles either layout.
1379    let mut waa_xa_buf = Array2::<f64>::zeros((chunk_rows, pa));
1380    let mut wab_xb_buf = Array2::<f64>::zeros((chunk_rows, pb));
1381    let mut wbb_xb_buf = Array2::<f64>::zeros((chunk_rows, pb));
1382
1383    let xa_is_row_major = x_a.is_standard_layout();
1384    let xb_is_row_major = x_b.is_standard_layout();
1385    let waa_slice_opt = w_aa.as_slice();
1386    let wab_slice_opt = w_ab.as_slice();
1387    let wbb_slice_opt = w_bb.as_slice();
1388
1389    {
1390        let mut out_mat = array2_to_matmut(&mut out);
1391
1392        for start in (0..n).step_by(chunk_rows) {
1393            let rows = (n - start).min(chunk_rows);
1394            let xa_slice = x_a.slice(s![start..start + rows, ..]);
1395            let xb_slice = x_b.slice(s![start..start + rows, ..]);
1396
1397            // Weight X_a and X_b in a single pass through this chunk.
1398            {
1399                let waa_chunk = waa_xa_buf
1400                    .as_slice_mut()
1401                    .expect("row-major waa chunk is contiguous");
1402                let wab_chunk = wab_xb_buf
1403                    .as_slice_mut()
1404                    .expect("row-major wab chunk is contiguous");
1405                let wbb_chunk = wbb_xb_buf
1406                    .as_slice_mut()
1407                    .expect("row-major wbb chunk is contiguous");
1408
1409                if xa_is_row_major
1410                    && xb_is_row_major
1411                    && let (Some(xa_all), Some(xb_all)) = (x_a.as_slice(), x_b.as_slice())
1412                    && let (Some(waa_all), Some(wab_all), Some(wbb_all)) =
1413                        (waa_slice_opt, wab_slice_opt, wbb_slice_opt)
1414                {
1415                    for local in 0..rows {
1416                        let i = start + local;
1417                        let waa_i = waa_all[i];
1418                        let wab_i = wab_all[i];
1419                        let wbb_i = wbb_all[i];
1420                        let xa_off = i * pa;
1421                        let xa_row = &xa_all[xa_off..xa_off + pa];
1422                        let xb_off = i * pb;
1423                        let xb_row = &xb_all[xb_off..xb_off + pb];
1424                        let waa_off = local * pa;
1425                        let wab_off = local * pb;
1426                        let wbb_off = local * pb;
1427                        let waa_row = &mut waa_chunk[waa_off..waa_off + pa];
1428                        for col in 0..pa {
1429                            waa_row[col] = xa_row[col] * waa_i;
1430                        }
1431                        let wab_row = &mut wab_chunk[wab_off..wab_off + pb];
1432                        let wbb_row = &mut wbb_chunk[wbb_off..wbb_off + pb];
1433                        for col in 0..pb {
1434                            let xij = xb_row[col];
1435                            wab_row[col] = xij * wab_i;
1436                            wbb_row[col] = xij * wbb_i;
1437                        }
1438                    }
1439                } else {
1440                    for local in 0..rows {
1441                        let i = start + local;
1442                        let waa_i = w_aa[i];
1443                        let wab_i = w_ab[i];
1444                        let wbb_i = w_bb[i];
1445                        let waa_off = local * pa;
1446                        let wab_off = local * pb;
1447                        let wbb_off = local * pb;
1448                        let waa_row = &mut waa_chunk[waa_off..waa_off + pa];
1449                        let xa_row = xa_slice.row(local);
1450                        for (col, xij) in xa_row.iter().enumerate() {
1451                            waa_row[col] = xij * waa_i;
1452                        }
1453                        let wab_row = &mut wab_chunk[wab_off..wab_off + pb];
1454                        let wbb_row = &mut wbb_chunk[wbb_off..wbb_off + pb];
1455                        let xb_row = xb_slice.row(local);
1456                        for (col, xij) in xb_row.iter().enumerate() {
1457                            wab_row[col] = xij * wab_i;
1458                            wbb_row[col] = xij * wbb_i;
1459                        }
1460                    }
1461                }
1462            }
1463
1464            let xa_view = FaerArrayView::new(&xa_slice);
1465            let xb_view = FaerArrayView::new(&xb_slice);
1466            let waa_xa_slice = waa_xa_buf.slice(s![0..rows, ..]);
1467            let wab_xb_slice = wab_xb_buf.slice(s![0..rows, ..]);
1468            let wbb_xb_slice = wbb_xb_buf.slice(s![0..rows, ..]);
1469            let waa_xa_view = FaerArrayView::new(&waa_xa_slice);
1470            let wab_xb_view = FaerArrayView::new(&wab_xb_slice);
1471            let wbb_xb_view = FaerArrayView::new(&wbb_xb_slice);
1472
1473            // Block [0..pa, 0..pa]: X_a^T diag(w_aa) X_a
1474            matmul(
1475                out_mat.rb_mut().submatrix_mut(0, 0, pa, pa),
1476                Accum::Add,
1477                xa_view.as_ref().transpose(),
1478                waa_xa_view.as_ref(),
1479                1.0,
1480                matmul_parallelism(pa, pa, rows),
1481            );
1482            // Block [0..pa, pa..total]: X_a^T diag(w_ab) X_b
1483            matmul(
1484                out_mat.rb_mut().submatrix_mut(0, pa, pa, pb),
1485                Accum::Add,
1486                xa_view.as_ref().transpose(),
1487                wab_xb_view.as_ref(),
1488                1.0,
1489                matmul_parallelism(pa, pb, rows),
1490            );
1491            // Block [pa..total, pa..total]: X_b^T diag(w_bb) X_b
1492            matmul(
1493                out_mat.rb_mut().submatrix_mut(pa, pa, pb, pb),
1494                Accum::Add,
1495                xb_view.as_ref().transpose(),
1496                wbb_xb_view.as_ref(),
1497                1.0,
1498                matmul_parallelism(pb, pb, rows),
1499            );
1500        }
1501    } // out_mat dropped
1502    // Mirror upper triangle to lower
1503    for i in 0..total {
1504        for j in 0..i {
1505            out[[i, j]] = out[[j, i]];
1506        }
1507    }
1508    out
1509}
1510
1511fn mat_to_array(mat: MatRef<'_, f64>) -> Array2<f64> {
1512    let nrows = mat.nrows();
1513    let ncols = mat.ncols();
1514    let mut out = Array2::<f64>::zeros((nrows, ncols));
1515    if nrows == 0 || ncols == 0 {
1516        return out;
1517    }
1518    // ndarray is row-major by default. Write row-by-row for best cache behavior
1519    // on the output side.
1520    if let Some(out_slice) = out.as_slice_memory_order_mut() {
1521        // Row-major: out_slice[i * ncols + j] = mat[(i, j)]
1522        for i in 0..nrows {
1523            let row_start = i * ncols;
1524            for j in 0..ncols {
1525                out_slice[row_start + j] = mat[(i, j)];
1526            }
1527        }
1528    } else {
1529        for j in 0..ncols {
1530            for i in 0..nrows {
1531                out[[i, j]] = mat[(i, j)];
1532            }
1533        }
1534    }
1535    out
1536}
1537
1538/// Write faer matmul result A*B directly into a pre-allocated ndarray Array2.
1539/// Avoids the intermediate faer::Mat allocation and mat_to_array copy.
1540#[inline]
1541pub fn fast_ab_into<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
1542    a: &ArrayBase<S1, Ix2>,
1543    b: &ArrayBase<S2, Ix2>,
1544    out: &mut Array2<f64>,
1545) {
1546    fast_ab_into_impl(a, b, out);
1547}
1548
1549#[inline]
1550fn fast_ab_into_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
1551    a: &ArrayBase<S1, Ix2>,
1552    b: &ArrayBase<S2, Ix2>,
1553    out: &mut Array2<f64>,
1554) {
1555    use faer::Accum;
1556    use faer::linalg::matmul::matmul;
1557
1558    let (n, p) = a.dim();
1559    let (p_b, q) = b.dim();
1560    assert_eq!(p, p_b, "A and B must have compatible inner dimensions");
1561    assert_eq!(out.dim(), (n, q), "output dimensions must match A*B result");
1562
1563    if !should_use_faer_matmul(n, q, p) {
1564        out.assign(&a.dot(b));
1565        return;
1566    }
1567
1568    let aview = FaerArrayView::new(a);
1569    let bview = FaerArrayView::new(b);
1570    let a_ref = aview.as_ref();
1571    let b_ref = bview.as_ref();
1572
1573    let par = matmul_parallelism(n, q, p);
1574    let mut outview = array2_to_matmut(out);
1575    matmul(outview.as_mut(), Accum::Replace, a_ref, b_ref, 1.0, par);
1576}
1577
1578fn diag_to_array(diag: DiagRef<'_, f64>) -> Array1<f64> {
1579    let mat = diag.column_vector().as_mat();
1580    let mut out = Array1::<f64>::zeros(mat.nrows());
1581    for i in 0..mat.nrows() {
1582        out[i] = mat[(i, 0)];
1583    }
1584    out
1585}
1586
1587pub struct FaerArrayView<'a> {
1588    ptr: *const f64,
1589    rows: usize,
1590    cols: usize,
1591    row_stride: isize,
1592    col_stride: isize,
1593    owned: Option<Array2<f64>>,
1594    marker: PhantomData<&'a f64>,
1595}
1596
1597impl<'a> FaerArrayView<'a> {
1598    #[inline]
1599    pub fn new<S: Data<Elem = f64>>(array: &'a ArrayBase<S, Ix2>) -> Self {
1600        let (rows, cols) = array.dim();
1601        let strides = array.strides();
1602        // Guard against layouts that can alias or reverse memory traversal (e.g.
1603        // negative/zero strides). These can violate assumptions in faer kernels.
1604        // For such layouts we materialize a compact owned copy.
1605        if strides[0] <= 0 || strides[1] <= 0 {
1606            let owned = array.to_owned();
1607            let owned_strides = owned.strides();
1608            return Self {
1609                ptr: owned.as_ptr(),
1610                rows,
1611                cols,
1612                row_stride: owned_strides[0],
1613                col_stride: owned_strides[1],
1614                owned: Some(owned),
1615                marker: PhantomData,
1616            };
1617        }
1618
1619        Self {
1620            ptr: array.as_ptr(),
1621            rows,
1622            cols,
1623            row_stride: strides[0],
1624            col_stride: strides[1],
1625            owned: None,
1626            marker: PhantomData,
1627        }
1628    }
1629
1630    #[inline]
1631    pub fn as_ref(&self) -> MatRef<'_, f64> {
1632        let (ptr, rows, cols, row_stride, col_stride) = if let Some(owned) = &self.owned {
1633            let strides = owned.strides();
1634            (
1635                owned.as_ptr(),
1636                owned.nrows(),
1637                owned.ncols(),
1638                strides[0],
1639                strides[1],
1640            )
1641        } else {
1642            (
1643                self.ptr,
1644                self.rows,
1645                self.cols,
1646                self.row_stride,
1647                self.col_stride,
1648            )
1649        };
1650        // SAFETY: ptr/shape/strides come from either a live ndarray view
1651        // (positive strides, validated bounds/alignment) or the owned
1652        // compact copy held inside this wrapper — no mutable aliasing.
1653        unsafe { MatRef::from_raw_parts(ptr, rows, cols, row_stride, col_stride) }
1654    }
1655}
1656
1657pub struct FaerColView<'a> {
1658    ptr: *const f64,
1659    len: usize,
1660    stride: isize,
1661    owned: Option<Array1<f64>>,
1662    marker: PhantomData<&'a f64>,
1663}
1664
1665impl<'a> FaerColView<'a> {
1666    #[inline]
1667    pub fn new<S: Data<Elem = f64>>(array: &'a ArrayBase<S, Ix1>) -> Self {
1668        let len = array.len();
1669        let stride = array.strides()[0];
1670        if stride <= 0 {
1671            let owned = array.to_owned();
1672            return Self {
1673                ptr: owned.as_ptr(),
1674                len,
1675                stride: 1,
1676                owned: Some(owned),
1677                marker: PhantomData,
1678            };
1679        }
1680        Self {
1681            ptr: array.as_ptr(),
1682            len,
1683            stride,
1684            owned: None,
1685            marker: PhantomData,
1686        }
1687    }
1688
1689    #[inline]
1690    pub fn as_ref(&self) -> MatRef<'_, f64> {
1691        let (ptr, len, stride) = if let Some(owned) = &self.owned {
1692            (owned.as_ptr(), owned.len(), 1)
1693        } else {
1694            (self.ptr, self.len, self.stride)
1695        };
1696        // SAFETY: ptr/len/stride come from either a live ndarray column
1697        // (positive stride, validated bounds/alignment) or the owned
1698        // compact copy; ncols=1 so the 0 col-stride is unused.
1699        unsafe { MatRef::from_raw_parts(ptr, len, 1, stride, 0) }
1700    }
1701}
1702
1703pub trait FaerSvd {
1704    fn svd(
1705        &self,
1706        compute_u: bool,
1707        computevt: bool,
1708    ) -> Result<(Option<Array2<f64>>, Array1<f64>, Option<Array2<f64>>), FaerLinalgError>;
1709}
1710
1711impl<S: Data<Elem = f64>> FaerSvd for ArrayBase<S, Ix2> {
1712    fn svd(
1713        &self,
1714        compute_u: bool,
1715        computevt: bool,
1716    ) -> Result<(Option<Array2<f64>>, Array1<f64>, Option<Array2<f64>>), FaerLinalgError> {
1717        let faerview = FaerArrayView::new(self);
1718        let faer_mat = faerview.as_ref();
1719        if !compute_u && !computevt {
1720            let (rows, cols) = faer_mat.shape();
1721            let mut singular = Diag::<f64>::zeros(rows.min(cols));
1722            let par = get_global_parallelism();
1723            let mut mem = MemBuffer::new(svd::svd_scratch::<f64>(
1724                rows,
1725                cols,
1726                ComputeSvdVectors::No,
1727                ComputeSvdVectors::No,
1728                par,
1729                Default::default(),
1730            ));
1731            let stack = MemStack::new(&mut mem);
1732            svd::svd(
1733                faer_mat,
1734                singular.as_mut(),
1735                None,
1736                None,
1737                par,
1738                stack,
1739                Default::default(),
1740            )
1741            .map_err(|_| FaerLinalgError::SvdNoConvergence {
1742                context: "faer SVD singular values only",
1743            })?;
1744            let singularvalues = diag_to_array(singular.as_ref());
1745            return Ok((None, singularvalues, None));
1746        }
1747
1748        let (rows, cols) = faer_mat.shape();
1749        let rank = rows.min(cols);
1750        let compute_u_flag = if compute_u {
1751            ComputeSvdVectors::Thin
1752        } else {
1753            ComputeSvdVectors::No
1754        };
1755        let computev_flag = if computevt {
1756            ComputeSvdVectors::Thin
1757        } else {
1758            ComputeSvdVectors::No
1759        };
1760
1761        let mut singular = Diag::<f64>::zeros(rows.min(cols));
1762        let mut u_storage = compute_u.then(|| Mat::<f64>::zeros(rows, rank));
1763        let mut v_storage = computevt.then(|| Mat::<f64>::zeros(cols, rank));
1764
1765        let par = get_global_parallelism();
1766        let mut mem = MemBuffer::new(svd::svd_scratch::<f64>(
1767            rows,
1768            cols,
1769            compute_u_flag,
1770            computev_flag,
1771            par,
1772            Default::default(),
1773        ));
1774        let stack = MemStack::new(&mut mem);
1775
1776        svd::svd(
1777            faer_mat.as_ref(),
1778            singular.as_mut(),
1779            u_storage.as_mut().map(|mat| mat.as_mut()),
1780            v_storage.as_mut().map(|mat| mat.as_mut()),
1781            par,
1782            stack,
1783            Default::default(),
1784        )
1785        .map_err(|_| FaerLinalgError::SvdNoConvergence {
1786            context: "faer SVD with vectors",
1787        })?;
1788
1789        let singularvalues = diag_to_array(singular.as_ref());
1790        let u_opt = u_storage.map(|mat| mat_to_array(mat.as_ref()));
1791        let vt_opt = v_storage.map(|mat| {
1792            let mat_ref = mat.as_ref();
1793            let mut out = Array2::<f64>::zeros((mat_ref.ncols(), mat_ref.nrows()));
1794            for j in 0..mat_ref.nrows() {
1795                for i in 0..mat_ref.ncols() {
1796                    out[[i, j]] = mat_ref[(j, i)];
1797                }
1798            }
1799            out
1800        });
1801
1802        Ok((u_opt, singularvalues, vt_opt))
1803    }
1804}
1805
1806pub trait FaerEigh {
1807    fn eigh(&self, side: Side) -> Result<(Array1<f64>, Array2<f64>), FaerLinalgError>;
1808}
1809
1810impl<S: Data<Elem = f64>> FaerEigh for ArrayBase<S, Ix2> {
1811    fn eigh(&self, side: Side) -> Result<(Array1<f64>, Array2<f64>), FaerLinalgError> {
1812        fn try_eigh(
1813            matrix: &Array2<f64>,
1814            side: Side,
1815        ) -> Result<(Array1<f64>, Array2<f64>), FaerLinalgError> {
1816            let faerview = FaerArrayView::new(matrix);
1817            let eigen = catch_unwind(AssertUnwindSafe(|| {
1818                faerview.as_ref().self_adjoint_eigen(side)
1819            }))
1820            .map_err(|_| FaerLinalgError::FactorizationFailed {
1821                context: "self-adjoint eigendecomposition panic boundary",
1822            })?
1823            .map_err(FaerLinalgError::SelfAdjointEigen)?;
1824            let values = diag_to_array(eigen.S());
1825            let vectors = mat_to_array(eigen.U());
1826            Ok((values, vectors))
1827        }
1828
1829        let owned = self.to_owned();
1830        if owned.nrows() != owned.ncols() {
1831            return Err(FaerLinalgError::FactorizationFailed {
1832                context: "self-adjoint eigendecomposition non-square input",
1833            });
1834        }
1835        if owned.nrows() == 0 {
1836            return Ok((Array1::zeros(0), Array2::zeros((0, 0))));
1837        }
1838        if owned.iter().any(|value| !value.is_finite()) {
1839            return Err(FaerLinalgError::SelfAdjointEigenNonFiniteInput {
1840                context: "self-adjoint eigendecomposition input validation",
1841            });
1842        }
1843        if let Ok((evals, evecs)) = try_eigh(&owned, side)
1844            && evals.iter().all(|value| value.is_finite())
1845            && evecs.iter().all(|value| value.is_finite())
1846        {
1847            return Ok((evals, evecs));
1848        }
1849
1850        let mut repaired = owned.clone();
1851        crate::matrix::symmetrize_in_place(&mut repaired);
1852
1853        let scale = repaired
1854            .iter()
1855            .fold(0.0_f64, |acc, &value| acc.max(value.abs()))
1856            .max(1.0);
1857        let scaled = repaired.mapv(|value| value / scale);
1858        // Relative diagonal-jitter ladder for the eigendecomposition repair: the
1859        // matrix is pre-scaled to unit max-abs, so these are fractions of its
1860        // scale. We try the unperturbed matrix first, then escalate the ridge by
1861        // two decades per attempt until the factorization yields all-finite
1862        // eigenpairs, accepting the smallest jitter that succeeds.
1863        const JITTER_SCHEDULE: [f64; 6] = [0.0, 1e-12, 1e-10, 1e-8, 1e-6, 1e-4];
1864        let jitter_schedule = JITTER_SCHEDULE;
1865        let mut last_error = FaerLinalgError::FactorizationFailed {
1866            context: "self-adjoint eigendecomposition repair attempts",
1867        };
1868
1869        for &jitter in &jitter_schedule {
1870            let mut candidate = scaled.clone();
1871            if jitter > 0.0 {
1872                let n = candidate.nrows();
1873                for i in 0..n {
1874                    candidate[[i, i]] += jitter;
1875                }
1876            }
1877
1878            match try_eigh(&candidate, side) {
1879                Ok((mut evals, evecs))
1880                    if evals.iter().all(|value| value.is_finite())
1881                        && evecs.iter().all(|value| value.is_finite()) =>
1882                {
1883                    for value in &mut evals {
1884                        *value = (*value - jitter) * scale;
1885                    }
1886                    return Ok((evals, evecs));
1887                }
1888                Ok((_, _)) => {
1889                    last_error = FaerLinalgError::SelfAdjointEigenNonFiniteInput {
1890                        context: "self-adjoint eigendecomposition repaired output validation",
1891                    };
1892                }
1893                Err(err) => {
1894                    last_error = err;
1895                }
1896            }
1897        }
1898
1899        Err(last_error)
1900    }
1901}
1902
1903pub struct FaerCholeskyFactor {
1904    factor: solvers::Llt<f64>,
1905}
1906
1907impl FaerCholeskyFactor {
1908    pub fn solvevec(&self, rhs: &Array1<f64>) -> Array1<f64> {
1909        let mut rhs = rhs.to_owned();
1910        let mut rhsview = array1_to_col_matmut(&mut rhs);
1911        self.factor.solve_in_place(rhsview.as_mut());
1912        rhs
1913    }
1914
1915    pub fn solve_mat_in_place(&self, rhs: &mut Array2<f64>) {
1916        let mut rhsview = array2_to_matmut(rhs);
1917        self.factor.solve_in_place(rhsview.as_mut());
1918    }
1919
1920    pub fn solve_mat_into<S: Data<Elem = f64>>(
1921        &self,
1922        rhs: &ArrayBase<S, Ix2>,
1923        out: &mut Array2<f64>,
1924    ) {
1925        if out.dim() != rhs.dim() {
1926            *out = Array2::<f64>::zeros(rhs.dim());
1927        }
1928        out.assign(rhs);
1929        self.solve_mat_in_place(out);
1930    }
1931
1932    pub fn solve_mat(&self, rhs: &Array2<f64>) -> Array2<f64> {
1933        let mut out = Array2::<f64>::zeros(rhs.dim());
1934        self.solve_mat_into(rhs, &mut out);
1935        out
1936    }
1937
1938    pub fn diag(&self) -> Array1<f64> {
1939        diag_to_array(self.factor.L().diagonal())
1940    }
1941
1942    pub fn lower_triangular(&self) -> Array2<f64> {
1943        mat_to_array(self.factor.L())
1944    }
1945}
1946
1947pub trait FaerCholesky {
1948    fn cholesky(&self, side: Side) -> Result<FaerCholeskyFactor, FaerLinalgError>;
1949}
1950
1951impl<S: Data<Elem = f64>> FaerCholesky for ArrayBase<S, Ix2> {
1952    fn cholesky(&self, side: Side) -> Result<FaerCholeskyFactor, FaerLinalgError> {
1953        let faerview = FaerArrayView::new(self);
1954        let factor = faerview
1955            .as_ref()
1956            .llt(side)
1957            .map_err(FaerLinalgError::Cholesky)?;
1958        Ok(FaerCholeskyFactor { factor })
1959    }
1960}
1961
1962pub trait FaerQr {
1963    fn qr(&self) -> Result<(Array2<f64>, Array2<f64>), FaerLinalgError>;
1964}
1965
1966impl<S: Data<Elem = f64>> FaerQr for ArrayBase<S, Ix2> {
1967    fn qr(&self) -> Result<(Array2<f64>, Array2<f64>), FaerLinalgError> {
1968        let faerview = FaerArrayView::new(self);
1969        let qr = faerview.as_ref().qr();
1970        let q = qr.compute_thin_Q();
1971        let r = qr.thin_R();
1972        Ok((mat_to_array(q.as_ref()), mat_to_array(r)))
1973    }
1974}
1975
1976/// Compute an orthonormal basis for `null(a^T)` using column-pivoted QR on `a`.
1977///
1978/// This is intended for tall/skinny matrices where `a ∈ R^{m×n}` with `m >= n`.
1979/// If `A P^T = Q R`, then the trailing `m-rank(A)` columns of `Q` span
1980/// `null(A^T)`.
1981///
1982/// The trailing columns of `Q` are reconstructed by applying the stored
1983/// Householder reflector sequence to canonical basis vectors. When `A` is
1984/// numerically rank zero (e.g. an entirely unpenalized block penalty in a
1985/// parametric-only GLM), *every* reflector is degenerate — the Householder
1986/// vector of a zero column has zero norm, so faer's coefficients become
1987/// non-finite and the reconstructed basis is filled with `NaN`. Mathematically
1988/// a rank-zero `m×n` matrix has `null(A^T) = R^m`, whose canonical orthonormal
1989/// basis is the identity, so we return `I_m` directly instead of routing through
1990/// the (undefined) reflectors. This keeps every downstream consumer — REML
1991/// null-space log-determinants, identifiability audits — finite and exact for
1992/// the fully-unpenalized case. For `rank >= 1` at least one well-defined
1993/// reflector seeds the block, and the reconstruction stays finite.
1994pub fn rrqr_nullspace_basis<S: Data<Elem = f64>>(
1995    a: &ArrayBase<S, Ix2>,
1996    rank_alpha: f64,
1997) -> Result<(Array2<f64>, usize), FaerLinalgError> {
1998    let faerview = FaerArrayView::new(a);
1999    let qr = faerview.as_ref().col_piv_qr();
2000    let r = qr.thin_R();
2001    let diag_len = r.nrows().min(r.ncols());
2002    let leading_diag = if diag_len > 0 { r[(0, 0)].abs() } else { 0.0 };
2003    let tol = rank_alpha
2004        * f64::EPSILON
2005        * (a.nrows().max(a.ncols()).max(1) as f64)
2006        * leading_diag.max(1.0);
2007    let rank = (0..diag_len).filter(|&i| r[(i, i)].abs() > tol).count();
2008    let z = if rank >= a.nrows() {
2009        Array2::<f64>::zeros((a.nrows(), 0))
2010    } else if rank == 0 {
2011        // Numerically rank-zero input: the whole space is the null space.
2012        // Return the canonical orthonormal basis directly; the Householder
2013        // reflectors of a zero matrix are degenerate and would yield NaN.
2014        Array2::<f64>::eye(a.nrows())
2015    } else {
2016        let nullity = a.nrows() - rank;
2017        let mut selector = Mat::<f64>::zeros(a.nrows(), nullity);
2018        for j in 0..nullity {
2019            selector[(rank + j, j)] = 1.0;
2020        }
2021        let par = get_global_parallelism();
2022        faer::linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
2023            qr.Q_basis(),
2024            qr.Q_coeff(),
2025            Conj::No,
2026            selector.as_mut(),
2027            par,
2028            MemStack::new(&mut MemBuffer::new(
2029                faer::linalg::householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<f64>(
2030                    a.nrows(),
2031                    qr.Q_coeff().nrows(),
2032                    nullity,
2033                ),
2034            )),
2035        );
2036        mat_to_array(selector.as_ref())
2037    };
2038    Ok((z, rank))
2039}
2040
2041#[inline]
2042pub const fn default_rrqr_rank_alpha() -> f64 {
2043    RRQR_RANK_ALPHA
2044}
2045
2046/// Result of a column-pivoted QR with rank detection and column permutation.
2047///
2048/// `A · P = Q · R` where the permutation `P` is exposed as the forward index
2049/// array: column `j` of `A · P` corresponds to original column
2050/// `column_permutation[j]` of `A`. With rank `r < min(m, n)`, the trailing
2051/// `min(m, n) - r` entries of `column_permutation` name the columns that the
2052/// pivoted QR demoted past the rank threshold — i.e., the columns identified
2053/// as redundant. Identifiability auditors (`identifiability::audit`)
2054/// use that suffix to attribute `DroppedColumn` entries to specific original
2055/// columns.
2056pub struct RrqrWithPermutation {
2057    pub rank: usize,
2058    pub column_permutation: Vec<usize>,
2059    pub leading_diag_abs: f64,
2060    pub rank_tol: f64,
2061}
2062
2063/// Column-pivoted rank-revealing QR returning the rank, the column permutation,
2064/// and the rank-detection tolerance. Use this when callers need to name which
2065/// columns the pivoted QR demoted past the rank threshold.
2066///
2067/// The rank cutoff matches [`rrqr_nullspace_basis`]: a column-pivoted QR is
2068/// computed on `a`; columns with `|R[i, i]| > tol` count toward the rank,
2069/// where `tol = rank_alpha · eps · max(m, n, 1) · max(|R[0, 0]|, 1)`. Returns
2070/// `Err` when `a` has zero rows.
2071pub fn rrqr_with_permutation<S: Data<Elem = f64>>(
2072    a: &ArrayBase<S, Ix2>,
2073    rank_alpha: f64,
2074) -> Result<RrqrWithPermutation, FaerLinalgError> {
2075    if a.nrows() == 0 {
2076        return Err(FaerLinalgError::FactorizationFailed {
2077            context: "rrqr_with_permutation: input has zero rows",
2078        });
2079    }
2080    let faerview = FaerArrayView::new(a);
2081    let qr = faerview.as_ref().col_piv_qr();
2082    let r = qr.thin_R();
2083    let diag_len = r.nrows().min(r.ncols());
2084    let leading_diag = if diag_len > 0 { r[(0, 0)].abs() } else { 0.0 };
2085    let tol = rank_alpha
2086        * f64::EPSILON
2087        * (a.nrows().max(a.ncols()).max(1) as f64)
2088        * leading_diag.max(1.0);
2089    let rank = (0..diag_len).filter(|&i| r[(i, i)].abs() > tol).count();
2090    let (forward, _inverse) = qr.P().arrays();
2091    let column_permutation: Vec<usize> = forward.iter().copied().map(|idx| idx.unbound()).collect();
2092    Ok(RrqrWithPermutation {
2093        rank,
2094        column_permutation,
2095        leading_diag_abs: leading_diag,
2096        rank_tol: tol,
2097    })
2098}
2099
2100/// Result of a Gram-driven column-pivoted RRQR (see
2101/// [`rrqr_from_gram_with_permutation`]). Carries the same rank / permutation /
2102/// tolerance as [`RrqrWithPermutation`], plus a `verdict_margin` that measures
2103/// how unambiguous the rank cut is — the ratio between the smallest *kept*
2104/// pivot and the rank tolerance. A large margin means squaring the design into
2105/// a Gram could not have flipped any rank decision; a small margin means the
2106/// verdict sits near the cliff and the caller should re-confirm on the full
2107/// (un-squared) design to stay bit-exact.
2108pub struct RrqrFromGram {
2109    pub rank: usize,
2110    pub column_permutation: Vec<usize>,
2111    pub rank_tol: f64,
2112    /// Leading pivot magnitude `|R[0,0]|` of the square-root factor — equal to
2113    /// the largest column norm of the original tall design (col-piv QR pivots the
2114    /// largest-norm column first), so it matches the tall path's
2115    /// `RrqrWithPermutation::leading_diag_abs`.
2116    pub leading_diag_abs: f64,
2117    /// `min_kept_pivot / rank_tol` (∞ when full rank with no kept pivot below
2118    /// tol, i.e. every pivot is comfortably above; `0` when rank is 0).
2119    pub verdict_margin: f64,
2120}
2121
2122/// Column-pivoted rank-revealing QR computed from the design's `p × p` Gram
2123/// `G = AᵀA` (or penalty-augmented `AᵀA + SᵀS`) instead of from the tall
2124/// `m × p` design itself.
2125///
2126/// # Why this is exact (in exact arithmetic)
2127///
2128/// Column-pivoted QR selects, at each step, the not-yet-pivoted column with the
2129/// largest residual norm, where the residual is the part orthogonal to the
2130/// already-chosen columns. Those residual norms — and the resulting pivot
2131/// sequence, the diagonal magnitudes `|R[i,i]|`, and hence the rank cut — are a
2132/// function of the column *inner products* only, i.e. of the Gram `G`. Running
2133/// col-piv QR on the Cholesky factor `R₀` of `G` (`R₀ᵀR₀ = G`, `R₀` is `p × p`)
2134/// reproduces the identical pivot order and identical `|R[i,i]|` as col-piv QR
2135/// on the original `m × p` matrix, because both see the same column geometry.
2136/// This is the standard "pivoted QR depends only on the Gram" identity and lets
2137/// the joint identifiability rank verdict run in `O(p³)` instead of streaming
2138/// all `m ≈ 2·10⁵` rows again.
2139///
2140/// # Tolerance
2141///
2142/// The rank cutoff must match what the tall-matrix [`rrqr_with_permutation`]
2143/// would have used, so the caller passes `m_rows` (the row count of the
2144/// original tall design, including any appended penalty rows). The tolerance is
2145/// `rank_alpha · eps · max(m_rows, p) · max(|R[0,0]|, 1)` — bit-identical to the
2146/// tall path, since `|R[0,0]|` (the leading pivot magnitude = largest column
2147/// norm) is the same in both factorizations.
2148///
2149/// # Finite-precision guard
2150///
2151/// Forming `G = AᵀA` squares the condition number, so a rank decision that sits
2152/// right at the tolerance cliff could in principle flip. The returned
2153/// `verdict_margin` lets the caller detect that case and fall back to the exact
2154/// tall RRQR; in the overwhelmingly common well-separated case (full column
2155/// rank, smallest pivot orders of magnitude above tol) the margin is huge and
2156/// no fallback is needed.
2157pub fn rrqr_from_gram_with_permutation<S: Data<Elem = f64>>(
2158    gram: &ArrayBase<S, Ix2>,
2159    m_rows: usize,
2160    rank_alpha: f64,
2161) -> Result<RrqrFromGram, FaerLinalgError> {
2162    let p = gram.ncols();
2163    if p == 0 {
2164        return Ok(RrqrFromGram {
2165            rank: 0,
2166            column_permutation: Vec::new(),
2167            rank_tol: 0.0,
2168            leading_diag_abs: 0.0,
2169            verdict_margin: 0.0,
2170        });
2171    }
2172    if gram.nrows() != p {
2173        return Err(FaerLinalgError::FactorizationFailed {
2174            context: "rrqr_from_gram_with_permutation: Gram is not square",
2175        });
2176    }
2177    // Symmetric square-root factor F (p×p) with FᵀF = G. The Gram is PSD by
2178    // construction (AᵀA), so its eigendecomposition G = V·diag(λ)·Vᵀ gives the
2179    // factor F = diag(√λ₊)·Vᵀ (rows indexed by eigenpair, columns by original
2180    // design column). Any factor with FᵀF = G reproduces the same column
2181    // geometry, which is all col-piv QR consumes — we use the eigen square root
2182    // rather than a bare Cholesky because Cholesky fails on the numerically
2183    // semidefinite Gram that is exactly the rank-deficient case we must classify.
2184    // Tiny-negative eigenvalues from finite precision are clamped to zero.
2185    let (evals, evecs) = gram.eigh(Side::Lower)?;
2186    let mut f = Array2::<f64>::zeros((p, p));
2187    for k in 0..p {
2188        let scale = evals[k].max(0.0).sqrt();
2189        if scale == 0.0 {
2190            continue;
2191        }
2192        for i in 0..p {
2193            f[[k, i]] = scale * evecs[[i, k]];
2194        }
2195    }
2196    // Single col-piv QR on F. Its pivot order, per-pivot |R[i,i]| magnitudes,
2197    // and leading pivot equal those of col-piv QR on the original tall design
2198    // (FᵀF = G), so this reproduces the exact tall-path geometry.
2199    let faer_f = FaerArrayView::new(&f);
2200    let qr = faer_f.as_ref().col_piv_qr();
2201    let r = qr.thin_R();
2202    let diag_len = r.nrows().min(r.ncols());
2203    let pivots: Vec<f64> = (0..diag_len).map(|i| r[(i, i)].abs()).collect();
2204    let leading_diag = pivots.first().copied().unwrap_or(0.0);
2205    let (forward, _inverse) = qr.P().arrays();
2206    let column_permutation: Vec<usize> = forward.iter().copied().map(|idx| idx.unbound()).collect();
2207    // Re-scale the tolerance from F's `max(p, p)=p` row dimension to the
2208    // original tall design's `max(m_rows, p)`, keeping the rank cut bit-
2209    // identical to what the tall [`rrqr_with_permutation`] would have produced.
2210    let tol = rank_alpha * f64::EPSILON * (m_rows.max(p).max(1) as f64) * leading_diag.max(1.0);
2211    let rank = pivots.iter().filter(|&&v| v > tol).count();
2212    let min_kept = pivots[..rank].iter().copied().fold(f64::INFINITY, f64::min);
2213    let max_dropped = pivots[rank..].iter().copied().fold(0.0f64, f64::max);
2214    // Margin: how far the verdict is from the cliff. Use the smaller of
2215    // (min_kept / tol) and (tol / max_dropped) so a near-tol dropped pivot also
2216    // shrinks the margin. A margin ≫ 1 means no rank decision could flip.
2217    let kept_margin = if rank == 0 {
2218        f64::INFINITY
2219    } else {
2220        min_kept / tol
2221    };
2222    let dropped_margin = if rank == diag_len {
2223        f64::INFINITY
2224    } else {
2225        tol / max_dropped.max(f64::MIN_POSITIVE)
2226    };
2227    // Gram-squaring precision floor. Forming `G = XᵀX` collapses the bottom half
2228    // of the spectrum: a true singular value below `√ε · σ_max` is lost in the
2229    // rounding of `G` (its squared value `σ² < ε·σ_max²` underflows the Gram's
2230    // representable range), and the eigen-square-root then RESURRECTS it as a
2231    // SPURIOUS pivot of magnitude `≈ √(ε·σ_max²) = √ε · σ_max` — orders of
2232    // magnitude ABOVE the true σ and above `tol`. That artefact makes col-piv QR
2233    // on `F` KEEP a column the tall (un-squared) QR would demote: an EXACTLY
2234    // collinear alias (true σ = 0, so `σ² = 0` floored at `≈ ε·σ_max²`) shows up
2235    // as a kept pivot near `√ε · leading`, over-ranking the design and dropping
2236    // nothing (gam#933: a callback-owned column aliased with a higher-priority
2237    // anchor was never demoted, so the reduction never ran and the MAP-uniqueness
2238    // check then fired on the raw collinear joint design). `min_kept / tol` does
2239    // NOT catch this — the spurious pivot sits comfortably above `tol`, so the
2240    // existing margin reports a falsely-confident verdict. The honest test is
2241    // whether the smallest KEPT pivot is itself near the Gram precision floor
2242    // `√ε · leading`: if so, the Gram path cannot distinguish it from a true zero
2243    // and the verdict MUST be re-confirmed on the full-precision tall design.
2244    // Encode that as a third margin term `min_kept / (√ε · leading)` so a kept
2245    // pivot in the floor regime shrinks `verdict_margin` below the caller's
2246    // fallback threshold; for a genuinely full-rank design every kept pivot is
2247    // `≫ √ε · leading` and this term is large, leaving the fast path intact.
2248    let gram_precision_floor = f64::EPSILON.sqrt() * leading_diag.max(1.0);
2249    let kept_floor_margin = if rank == 0 {
2250        f64::INFINITY
2251    } else {
2252        min_kept / gram_precision_floor.max(f64::MIN_POSITIVE)
2253    };
2254    let verdict_margin = kept_margin.min(dropped_margin).min(kept_floor_margin);
2255    Ok(RrqrFromGram {
2256        rank,
2257        column_permutation,
2258        rank_tol: tol,
2259        leading_diag_abs: leading_diag,
2260        verdict_margin,
2261    })
2262}
2263
2264#[cfg(test)]
2265mod tests {
2266    use super::*;
2267    use ndarray::{array, s};
2268
2269    /// Local mirror of the audit's `JOINT_GRAM_RRQR_MIN_VERDICT_MARGIN` fallback
2270    /// threshold, used only by the regression tests below to assert the verdict
2271    /// margin lands on the correct side of the cliff. Kept in sync by value (1e3).
2272    const JOINT_GRAM_RRQR_TRUST_MARGIN_FOR_TEST: f64 = 1.0e3;
2273
2274    #[test]
2275    fn rrqr_nullspace_basis_is_orthonormal_and_annihilates_transpose() {
2276        let a = array![[1.0, 0.0], [1.0, 0.0], [0.0, 2.0], [0.0, 0.0],];
2277        let (z, rank) =
2278            rrqr_nullspace_basis(&a, default_rrqr_rank_alpha()).expect("RRQR should succeed");
2279        assert_eq!(rank, 2);
2280        assert_eq!(z.nrows(), 4);
2281        assert_eq!(z.ncols(), 2);
2282
2283        let gram = z.t().dot(&z);
2284        let ident = Array2::<f64>::eye(z.ncols());
2285        let gram_err = (&gram - &ident)
2286            .iter()
2287            .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2288        assert!(gram_err < 1e-10, "Z is not orthonormal: {gram_err:e}");
2289
2290        let residual = a.t().dot(&z);
2291        let resid_max = residual.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2292        assert!(resid_max < 1e-10, "A^T Z residual too large: {resid_max:e}");
2293    }
2294
2295    #[test]
2296    fn rrqr_with_permutation_attributes_redundant_column() {
2297        // 3 columns, column 2 is a duplicate of column 0 → rank 2, column 2
2298        // is the redundant one that the pivoted QR should demote past the
2299        // rank threshold. (Column 1 contributes a different direction.)
2300        let a = array![
2301            [1.0, 0.0, 1.0],
2302            [1.0, 0.0, 1.0],
2303            [0.0, 2.0, 0.0],
2304            [0.0, 0.0, 0.0],
2305        ];
2306        let result =
2307            rrqr_with_permutation(&a, default_rrqr_rank_alpha()).expect("RRQR should succeed");
2308        assert_eq!(result.rank, 2);
2309        assert_eq!(result.column_permutation.len(), 3);
2310        let demoted = result.column_permutation[result.rank..].to_vec();
2311        assert!(
2312            demoted.contains(&2) || demoted.contains(&0),
2313            "demoted suffix should include one of the aliased columns (0 or 2), got {demoted:?}"
2314        );
2315        let mut sorted = result.column_permutation.clone();
2316        sorted.sort();
2317        assert_eq!(
2318            sorted,
2319            vec![0, 1, 2],
2320            "permutation must be a valid bijection on 0..n"
2321        );
2322    }
2323
2324    #[test]
2325    fn rrqr_with_permutation_full_rank_returns_identity_like_order() {
2326        let a = array![[1.0, 0.0], [0.0, 2.0], [0.0, 0.0]];
2327        let result =
2328            rrqr_with_permutation(&a, default_rrqr_rank_alpha()).expect("RRQR should succeed");
2329        assert_eq!(result.rank, 2);
2330        let mut sorted = result.column_permutation.clone();
2331        sorted.sort();
2332        assert_eq!(sorted, vec![0, 1]);
2333    }
2334
2335    #[test]
2336    fn rrqr_with_permutation_rejects_zero_rows() {
2337        let a = Array2::<f64>::zeros((0, 3));
2338        assert!(rrqr_with_permutation(&a, default_rrqr_rank_alpha()).is_err());
2339    }
2340
2341    #[test]
2342    fn rrqr_nullspace_basis_square_zero_matrix_is_finite_identity() {
2343        // Square zero matrix (the parametric-only penalty case): null(A^T) is
2344        // the whole space, so the basis must be a finite orthonormal 3x3 set.
2345        let a = Array2::<f64>::zeros((3, 3));
2346        let (z, rank) =
2347            rrqr_nullspace_basis(&a, default_rrqr_rank_alpha()).expect("RRQR should succeed");
2348        assert_eq!(rank, 0);
2349        assert_eq!(z.dim(), (3, 3));
2350        assert!(
2351            z.iter().all(|v| v.is_finite()),
2352            "square zero matrix produced a non-finite null basis: {z:?}"
2353        );
2354        let gram = z.t().dot(&z);
2355        let ident = Array2::<f64>::eye(3);
2356        let gram_err = (&gram - &ident)
2357            .iter()
2358            .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2359        assert!(gram_err < 1e-10, "Z is not orthonormal: {gram_err:e}");
2360    }
2361
2362    #[test]
2363    fn rrqr_nullspace_basis_detectszero_rank_matrix() {
2364        let a = Array2::<f64>::zeros((5, 2));
2365        let (z, rank) =
2366            rrqr_nullspace_basis(&a, default_rrqr_rank_alpha()).expect("RRQR should succeed");
2367        assert_eq!(rank, 0);
2368        assert_eq!(z.dim(), (5, 5));
2369        let ident = Array2::<f64>::eye(5);
2370        let max_err = (&z.slice(s![.., ..5]).to_owned() - &ident)
2371            .iter()
2372            .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2373        assert!(max_err < 1e-10, "zero matrix should yield identity basis");
2374    }
2375
2376    //
2377    // Eigendecomposition NoConvergence on pathological matrices
2378    //
2379    // These tests lock down the hardened contract for FaerEigh::eigh:
2380    // non-finite input must be rejected explicitly, while finite symmetric
2381    // matrices still produce finite spectra.
2382    //
2383
2384    #[test]
2385    fn eigh_on_nan_matrix_rejects_non_finite_input() {
2386        let mat = array![
2387            [1.0, 0.0, 0.0, 0.0],
2388            [0.0, 2.0, 0.0, 0.0],
2389            [0.0, 0.0, 3.0, f64::NAN],
2390            [0.0, 0.0, f64::NAN, 4.0]
2391        ];
2392        let err = mat
2393            .eigh(Side::Lower)
2394            .expect_err("non-finite symmetric input must be rejected");
2395        assert!(matches!(
2396            err,
2397            FaerLinalgError::SelfAdjointEigenNonFiniteInput { .. }
2398        ));
2399    }
2400
2401    #[test]
2402    fn fast_ata_matches_full_gemm_above_threshold() {
2403        // Pick (n, p) large enough to trigger the faer triangular path
2404        // (should_use_faer_matmul threshold is MIN_DIM=32, MIN_FLOP_SCALE=64*64).
2405        let n = 200;
2406        let p = 40;
2407        let a: Array2<f64> = Array2::from_shape_fn((n, p), |(i, j)| {
2408            ((i * 7 + j * 3) as f64).sin() + 0.1 * j as f64
2409        });
2410        let expected = a.t().dot(&a);
2411        let got = fast_ata(&a);
2412        let max_err = (&got - &expected)
2413            .iter()
2414            .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2415        assert!(max_err < 1e-10, "fast_ata mismatch: {max_err:e}");
2416        // Output must be fully populated and symmetric.
2417        for i in 0..p {
2418            for j in 0..p {
2419                assert!((got[[i, j]] - got[[j, i]]).abs() < 1e-12);
2420            }
2421        }
2422    }
2423
2424    #[test]
2425    fn fast_xt_diag_x_matches_naive_above_threshold() {
2426        let n = 400;
2427        let p = 36;
2428        let x: Array2<f64> =
2429            Array2::from_shape_fn((n, p), |(i, j)| (i as f64 * 0.1).cos() + j as f64 * 0.05);
2430        let w: Array1<f64> = Array1::from_shape_fn(n, |i| (i as f64 * 0.03).sin());
2431        // Naive reference: X^T diag(w) X.
2432        let wx = Array2::from_shape_fn((n, p), |(i, j)| w[i] * x[[i, j]]);
2433        let expected = x.t().dot(&wx);
2434        let got = fast_xt_diag_x(&x, &w);
2435        let max_err = (&got - &expected)
2436            .iter()
2437            .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2438        assert!(max_err < 1e-9, "fast_xt_diag_x mismatch: {max_err:e}");
2439        for i in 0..p {
2440            for j in 0..p {
2441                assert!((got[[i, j]] - got[[j, i]]).abs() < 1e-12);
2442            }
2443        }
2444    }
2445
2446    #[test]
2447    fn stream_weighted_crossprod_full_and_triangular_parity_with_negative_weights() {
2448        // The stream-in and matrix-returning `fast_xt_diag_x*` packaging modes
2449        // share one kernel. Both packaging modes — and both accumulation
2450        // modes — must reproduce the naive `Xᵀ·diag(w)·X` reference, including signed
2451        // (negative) weights, which the pre-unification sqrt-clip form
2452        // silently corrupted.
2453        //
2454        // Exercise both the streaming faer path (n large enough to clear
2455        // `should_use_faer_matmul`) and the tiny ndarray fallback (small n,p).
2456        for &(n, p) in &[(900usize, 40usize), (8usize, 3usize)] {
2457            let x: Array2<f64> =
2458                Array2::from_shape_fn((n, p), |(i, j)| (i as f64 * 0.07).cos() + j as f64 * 0.013);
2459            // Weights span both signs and zero so negative-weight handling and
2460            // sign preservation are genuinely tested.
2461            let w: Array1<f64> =
2462                Array1::from_shape_fn(n, |i| (i as f64 * 0.11).sin() - 0.25 * (i % 3) as f64);
2463            assert!(
2464                w.iter().any(|&v| v < 0.0),
2465                "weight vector must contain negatives to test sign preservation"
2466            );
2467
2468            // Naive reference: Xᵀ diag(w) X with signed weights.
2469            let wx = Array2::from_shape_fn((n, p), |(i, j)| w[i] * x[[i, j]]);
2470            let expected = x.t().dot(&wx);
2471
2472            let par = matmul_parallelism(p, p, n);
2473
2474            // Full output, Replace.
2475            let mut full = Array2::<f64>::ones((p, p));
2476            stream_weighted_crossprod_into(
2477                &x,
2478                &w,
2479                &mut full,
2480                CrossprodStructure::Full,
2481                CrossprodAccum::Replace,
2482                par,
2483            );
2484
2485            // Triangular+mirror output, Replace. Seed with garbage to prove
2486            // Replace clears prior contents (incl. the upper triangle, which
2487            // the triangular path only reaches via the mirror).
2488            let mut tri = Array2::<f64>::from_elem((p, p), -7.0);
2489            stream_weighted_crossprod_into(
2490                &x,
2491                &w,
2492                &mut tri,
2493                CrossprodStructure::SymmetricLower,
2494                CrossprodAccum::Replace,
2495                par,
2496            );
2497
2498            let full_err = (&full - &expected)
2499                .iter()
2500                .fold(0.0_f64, |a, &v| a.max(v.abs()));
2501            let tri_err = (&tri - &expected)
2502                .iter()
2503                .fold(0.0_f64, |a, &v| a.max(v.abs()));
2504            assert!(
2505                full_err < 1e-9,
2506                "full kernel mismatch (n={n}, p={p}): {full_err:e}"
2507            );
2508            assert!(
2509                tri_err < 1e-9,
2510                "triangular kernel mismatch (n={n}, p={p}): {tri_err:e}"
2511            );
2512
2513            // Full and triangular packaging must agree elementwise, and both
2514            // must be exactly symmetric.
2515            for i in 0..p {
2516                for j in 0..p {
2517                    assert!(
2518                        (full[[i, j]] - tri[[i, j]]).abs() < 1e-12,
2519                        "full vs triangular disagree at ({i},{j})"
2520                    );
2521                    assert!(
2522                        (tri[[i, j]] - tri[[j, i]]).abs() < 1e-12,
2523                        "triangular output not symmetric at ({i},{j})"
2524                    );
2525                }
2526            }
2527
2528            // Accumulation parity: Add into a pre-filled buffer must equal the
2529            // prior contents plus the Gram, for both structures.
2530            let base = Array2::<f64>::from_elem((p, p), 1.5);
2531            let mut add_full = base.clone();
2532            stream_weighted_crossprod_into(
2533                &x,
2534                &w,
2535                &mut add_full,
2536                CrossprodStructure::Full,
2537                CrossprodAccum::Add,
2538                par,
2539            );
2540            let mut add_tri = base.clone();
2541            stream_weighted_crossprod_into(
2542                &x,
2543                &w,
2544                &mut add_tri,
2545                CrossprodStructure::SymmetricLower,
2546                CrossprodAccum::Add,
2547                par,
2548            );
2549            let expected_add = &base + &expected;
2550            let add_full_err = (&add_full - &expected_add)
2551                .iter()
2552                .fold(0.0_f64, |a, &v| a.max(v.abs()));
2553            let add_tri_err = (&add_tri - &expected_add)
2554                .iter()
2555                .fold(0.0_f64, |a, &v| a.max(v.abs()));
2556            assert!(
2557                add_full_err < 1e-9,
2558                "full Add mismatch (n={n}, p={p}): {add_full_err:e}"
2559            );
2560            assert!(
2561                add_tri_err < 1e-9,
2562                "triangular Add mismatch (n={n}, p={p}): {add_tri_err:e}"
2563            );
2564
2565            // The matrix.rs adapter (Full + Replace into a zeroed buffer) must
2566            // match the faer_ndarray return-style adapter bit-for-functionally.
2567            let returned = fast_xt_diag_x(&x, &w);
2568            let returned_err = (&returned - &full)
2569                .iter()
2570                .fold(0.0_f64, |a, &v| a.max(v.abs()));
2571            assert!(
2572                returned_err < 1e-12,
2573                "return adapter vs stream-into adapter disagree (n={n}, p={p}): {returned_err:e}"
2574            );
2575        }
2576    }
2577
2578    #[test]
2579    fn eigh_succeeds_on_same_structure_without_nan() {
2580        // Control: the same matrix with finite values produces finite eigenvalues.
2581        let mat = array![[1.0, 0.5, 0.1], [0.5, 2.0, 0.3], [0.1, 0.3, 1.5]];
2582        let (evals, _) = mat
2583            .eigh(Side::Lower)
2584            .expect("eigh should succeed on a well-conditioned finite matrix");
2585        assert!(
2586            evals.iter().all(|&v| v.is_finite()),
2587            "all eigenvalues should be finite"
2588        );
2589    }
2590
2591    /// gam#933 regression: the Gram-squared RRQR must NOT silently over-rank an
2592    /// EXACTLY collinear design. The invariant is: either the Gram path finds the
2593    /// correct rank (3) by itself — because the precision-floor logic demotes the
2594    /// spurious near-zero pivot before it reaches the kept set — OR, if it
2595    /// over-ranks (reports 4), the `verdict_margin` must collapse below the
2596    /// caller's fallback threshold so the full-precision tall path is used
2597    /// instead. Both outcomes prevent the original gam#933 bug (silent rank=4
2598    /// with high-confidence margin that the caller trusts without verification).
2599    #[test]
2600    fn gram_rrqr_flags_low_margin_on_exact_collinearity_so_caller_falls_back() {
2601        // Joint design [1, x | x, x²] with x ∈ [-1, 1]: columns 1 and 2 are an
2602        // EXACT duplicate (the #933 callback-owned alias), so the true rank is 3.
2603        let n = 48usize;
2604        let x: Vec<f64> = (0..n)
2605            .map(|i| -1.0 + 2.0 * (i as f64) / (n as f64 - 1.0))
2606            .collect();
2607        let mut a = Array2::<f64>::zeros((n, 4));
2608        for i in 0..n {
2609            a[[i, 0]] = 1.0;
2610            a[[i, 1]] = x[i];
2611            a[[i, 2]] = x[i];
2612            a[[i, 3]] = x[i] * x[i];
2613        }
2614        let alpha = default_rrqr_rank_alpha();
2615
2616        // The tall (un-squared) RRQR is the full-precision reference: it must see
2617        // rank 3 and demote one of the duplicate x columns.
2618        let tall = rrqr_with_permutation(&a, alpha).expect("tall RRQR should succeed");
2619        assert_eq!(tall.rank, 3, "tall RRQR must demote the exact alias");
2620
2621        // The Gram-squared RRQR must satisfy the gam#933 invariant:
2622        //   rank == 3 (correct result)  OR  verdict_margin < threshold (force fallback)
2623        //
2624        // The precision-floor margin term was designed to catch the case where
2625        // squaring the spectrum resurrects a spurious kept pivot near √ε·σ_max.
2626        // When the eigen-square-root approach correctly demotes that pivot
2627        // (yielding rank=3 without spurious kept columns), the margin is
2628        // legitimately high — trusting the Gram result is then safe and correct.
2629        // When it over-ranks (rank=4), the floor margin must be low so the
2630        // caller falls back to the tall RRQR and gets the right answer.
2631        let unit = Array1::<f64>::ones(n);
2632        let gram = fast_xt_diag_x_with_parallelism(&a, &unit, faer::get_global_parallelism());
2633        let gram_rrqr =
2634            rrqr_from_gram_with_permutation(&gram, n, alpha).expect("Gram RRQR should succeed");
2635        let ok = gram_rrqr.rank == 3
2636            || gram_rrqr.verdict_margin < JOINT_GRAM_RRQR_TRUST_MARGIN_FOR_TEST;
2637        assert!(
2638            ok,
2639            "gam#933: Gram RRQR must either find correct rank=3 OR signal low margin \
2640             (< {:.0e}) to force the tall fallback; got rank={} margin={:.3e}",
2641            JOINT_GRAM_RRQR_TRUST_MARGIN_FOR_TEST,
2642            gram_rrqr.rank,
2643            gram_rrqr.verdict_margin,
2644        );
2645    }
2646
2647    /// Companion to the regression above: a genuinely full-rank, moderately
2648    /// conditioned design must keep a LARGE Gram verdict margin so the fast Gram
2649    /// path is retained (the precision-floor term must not trip on real, small-
2650    /// but-nonzero singular values).
2651    #[test]
2652    fn gram_rrqr_keeps_high_margin_on_full_rank_design() {
2653        let n = 200usize;
2654        let p = 5usize;
2655        let mut a = Array2::<f64>::zeros((n, p));
2656        // Deterministic, well-separated columns (distinct low-order polynomials).
2657        for i in 0..n {
2658            let t = (i as f64) / (n as f64 - 1.0);
2659            a[[i, 0]] = 1.0;
2660            a[[i, 1]] = t;
2661            a[[i, 2]] = t * t;
2662            a[[i, 3]] = t * t * t;
2663            a[[i, 4]] = (t * 6.0).sin();
2664        }
2665        let alpha = default_rrqr_rank_alpha();
2666        let unit = Array1::<f64>::ones(n);
2667        let gram = fast_xt_diag_x_with_parallelism(&a, &unit, faer::get_global_parallelism());
2668        let gram_rrqr =
2669            rrqr_from_gram_with_permutation(&gram, n, alpha).expect("Gram RRQR should succeed");
2670        assert_eq!(gram_rrqr.rank, p, "full-rank design must keep all columns");
2671        assert!(
2672            gram_rrqr.verdict_margin >= JOINT_GRAM_RRQR_TRUST_MARGIN_FOR_TEST,
2673            "full-rank design must keep a high margin (fast Gram path); got {:.3e}",
2674            gram_rrqr.verdict_margin,
2675        );
2676    }
2677
2678    // ── fast_ab / fast_atb / fast_abt / fast_av / fast_atv / fast_xt_diag_y ──
2679
2680    fn max_abs_diff(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
2681        assert_eq!(a.dim(), b.dim(), "shape mismatch in max_abs_diff");
2682        a.iter().zip(b.iter()).fold(0.0_f64, |acc, (&x, &y)| acc.max((x - y).abs()))
2683    }
2684
2685    fn max_abs_diff_1d(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
2686        assert_eq!(a.len(), b.len(), "len mismatch in max_abs_diff_1d");
2687        a.iter().zip(b.iter()).fold(0.0_f64, |acc, (&x, &y)| acc.max((x - y).abs()))
2688    }
2689
2690    /// `fast_ab(A, B)` matches `A.dot(&B)` for small (ndarray-path) matrices.
2691    #[test]
2692    fn fast_ab_small_matches_ndarray_dot() {
2693        let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
2694        let b = array![[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]];
2695        let got = fast_ab(&a, &b);
2696        let want = a.dot(&b);
2697        assert!(max_abs_diff(&got, &want) < 1e-12, "fast_ab small mismatch");
2698        assert_eq!(got.dim(), (2, 2));
2699    }
2700
2701    /// `fast_ab` on larger matrices (faer path) agrees with ndarray dot.
2702    #[test]
2703    fn fast_ab_large_matches_ndarray_dot() {
2704        let n = 50usize;
2705        let p = 40usize;
2706        let q = 35usize;
2707        let mut a = Array2::<f64>::zeros((n, p));
2708        let mut b = Array2::<f64>::zeros((p, q));
2709        let mut state = 0xDEAD_BEEF_1234_5678u64;
2710        let next = |s: &mut u64| -> f64 {
2711            *s ^= *s << 13;
2712            *s ^= *s >> 7;
2713            *s ^= *s << 17;
2714            ((*s >> 11) as f64 / ((1u64 << 53) as f64)) - 0.5
2715        };
2716        for v in a.iter_mut() { *v = next(&mut state); }
2717        for v in b.iter_mut() { *v = next(&mut state); }
2718        let got = fast_ab(&a, &b);
2719        let want = a.dot(&b);
2720        assert!(max_abs_diff(&got, &want) < 1e-9, "fast_ab large mismatch");
2721    }
2722
2723    /// `fast_atb(A, B)` = A^T * B for small matrices (ndarray path).
2724    #[test]
2725    fn fast_atb_small_matches_ndarray_dot() {
2726        let a = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2727        let b = array![[7.0, 8.0, 9.0], [10.0, 11.0, 12.0], [13.0, 14.0, 15.0]];
2728        let got = fast_atb(&a, &b);
2729        let want = a.t().dot(&b);
2730        assert!(max_abs_diff(&got, &want) < 1e-12, "fast_atb small mismatch");
2731        assert_eq!(got.dim(), (2, 3));
2732    }
2733
2734    /// `fast_atb` on larger matrices (faer path) agrees with ndarray.
2735    #[test]
2736    fn fast_atb_large_matches_ndarray_dot() {
2737        let n = 50usize;
2738        let p = 40usize;
2739        let q = 35usize;
2740        let mut a = Array2::<f64>::zeros((n, p));
2741        let mut b = Array2::<f64>::zeros((n, q));
2742        let mut state = 0xCAFE_BABE_9876_5432u64;
2743        let next = |s: &mut u64| -> f64 {
2744            *s ^= *s << 13;
2745            *s ^= *s >> 7;
2746            *s ^= *s << 17;
2747            ((*s >> 11) as f64 / ((1u64 << 53) as f64)) - 0.5
2748        };
2749        for v in a.iter_mut() { *v = next(&mut state); }
2750        for v in b.iter_mut() { *v = next(&mut state); }
2751        let got = fast_atb(&a, &b);
2752        let want = a.t().dot(&b);
2753        assert!(max_abs_diff(&got, &want) < 1e-9, "fast_atb large mismatch");
2754    }
2755
2756    /// `fast_abt(A, B)` = A * B^T for small matrices (ndarray path).
2757    #[test]
2758    fn fast_abt_small_matches_ndarray_dot() {
2759        let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
2760        let b = array![[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]];
2761        let got = fast_abt(&a, &b);
2762        let want = a.dot(&b.t());
2763        assert!(max_abs_diff(&got, &want) < 1e-12, "fast_abt small mismatch");
2764        assert_eq!(got.dim(), (2, 2));
2765    }
2766
2767    /// `fast_av(A, v)` = A * v for small (ndarray path) and larger (faer path).
2768    #[test]
2769    fn fast_av_small_matches_ndarray_dot() {
2770        let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
2771        let v = array![1.0, -1.0, 2.0];
2772        let got = fast_av(&a, &v);
2773        let want = a.dot(&v);
2774        assert!(max_abs_diff_1d(&got, &want) < 1e-12, "fast_av small mismatch");
2775        // 1*1 + 2*(-1) + 3*2 = 1-2+6 = 5
2776        assert!((got[0] - 5.0).abs() < 1e-12, "fast_av[0] should be 5");
2777        // 4*1 + 5*(-1) + 6*2 = 4-5+12 = 11
2778        assert!((got[1] - 11.0).abs() < 1e-12, "fast_av[1] should be 11");
2779    }
2780
2781    /// `fast_av` on larger matrices (faer path) agrees with ndarray.
2782    #[test]
2783    fn fast_av_large_matches_ndarray_dot() {
2784        let n = 50usize;
2785        let p = 40usize;
2786        let mut a = Array2::<f64>::zeros((n, p));
2787        let mut v = Array1::<f64>::zeros(p);
2788        let mut state = 0xFEED_FACE_ABCD_EF01u64;
2789        let next = |s: &mut u64| -> f64 {
2790            *s ^= *s << 13;
2791            *s ^= *s >> 7;
2792            *s ^= *s << 17;
2793            ((*s >> 11) as f64 / ((1u64 << 53) as f64)) - 0.5
2794        };
2795        for v in a.iter_mut() { *v = next(&mut state); }
2796        for x in v.iter_mut() { *x = next(&mut state); }
2797        let got = fast_av(&a, &v);
2798        let want = a.dot(&v);
2799        assert!(max_abs_diff_1d(&got, &want) < 1e-9, "fast_av large mismatch");
2800    }
2801
2802    /// `fast_atv(A, v)` = A^T * v for small matrices (ndarray path).
2803    #[test]
2804    fn fast_atv_small_matches_ndarray_dot() {
2805        let a = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2806        let v = array![1.0, 0.0, -1.0];
2807        let got = fast_atv(&a, &v);
2808        let want = a.t().dot(&v);
2809        // A^T * v = [1*1+3*0+5*(-1), 2*1+4*0+6*(-1)] = [-4, -4]
2810        assert!(max_abs_diff_1d(&got, &want) < 1e-12, "fast_atv small mismatch");
2811        assert!((got[0] - (-4.0)).abs() < 1e-12, "fast_atv[0]");
2812        assert!((got[1] - (-4.0)).abs() < 1e-12, "fast_atv[1]");
2813    }
2814
2815    /// `fast_atv` on larger matrices (faer path) agrees with ndarray.
2816    #[test]
2817    fn fast_atv_large_matches_ndarray_dot() {
2818        let n = 50usize;
2819        let p = 40usize;
2820        let mut a = Array2::<f64>::zeros((n, p));
2821        let mut v = Array1::<f64>::zeros(n);
2822        let mut state = 0x1234_ABCD_5678_EF90u64;
2823        let next = |s: &mut u64| -> f64 {
2824            *s ^= *s << 13;
2825            *s ^= *s >> 7;
2826            *s ^= *s << 17;
2827            ((*s >> 11) as f64 / ((1u64 << 53) as f64)) - 0.5
2828        };
2829        for x in a.iter_mut() { *x = next(&mut state); }
2830        for x in v.iter_mut() { *x = next(&mut state); }
2831        let got = fast_atv(&a, &v);
2832        let want = a.t().dot(&v);
2833        assert!(max_abs_diff_1d(&got, &want) < 1e-9, "fast_atv large mismatch");
2834    }
2835
2836    /// `fast_xt_diag_y(X, d, Y)` = X^T * diag(d) * Y, verified against
2837    /// a manual triple-product for small inputs.
2838    #[test]
2839    fn fast_xt_diag_y_small_matches_manual() {
2840        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2841        let d = array![2.0, 0.5, 1.0];
2842        let y = array![[7.0, 8.0, 9.0], [10.0, 11.0, 12.0], [13.0, 14.0, 15.0]];
2843        let got = fast_xt_diag_y(&x, &d, &y);
2844        // Manual: X^T * diag(d) * Y
2845        let diag_y = {
2846            let mut dy = Array2::<f64>::zeros(y.dim());
2847            for i in 0..3 {
2848                for j in 0..3 {
2849                    dy[[i, j]] = d[i] * y[[i, j]];
2850                }
2851            }
2852            dy
2853        };
2854        let want = x.t().dot(&diag_y);
2855        assert!(max_abs_diff(&got, &want) < 1e-12, "fast_xt_diag_y small mismatch");
2856        assert_eq!(got.dim(), (2, 3));
2857    }
2858
2859    // ── Compensated-reduction accuracy oracle ────────────────────────────
2860    //
2861    // Truth is an error-free (exact-expansion / double-double) reference. We
2862    // assert the production GEMV kernels are pointwise no less accurate than —
2863    // and in aggregate strictly better than — a naive sequential sum.
2864
2865    #[inline]
2866    fn two_prod(a: f64, b: f64) -> (f64, f64) {
2867        let p = a * b;
2868        let e = a.mul_add(b, -p);
2869        (p, e)
2870    }
2871
2872    #[inline]
2873    fn two_sum(a: f64, b: f64) -> (f64, f64) {
2874        let s = a + b;
2875        let bb = s - a;
2876        let e = (a - (s - bb)) + (b - bb);
2877        (s, e)
2878    }
2879
2880    /// Shewchuk grow-expansion: add `q` to the non-overlapping expansion `e`.
2881    fn grow_expansion(e: &mut Vec<f64>, mut q: f64) {
2882        for h in e.iter_mut() {
2883            let (s, err) = two_sum(*h, q);
2884            *h = err;
2885            q = s;
2886        }
2887        if q != 0.0 {
2888            e.push(q);
2889        }
2890    }
2891
2892    /// Exact dot product (correctly rounded to `f64`) via an error-free
2893    /// expansion of every `two_prod` component. O(n²) — for short reference
2894    /// vectors only — but a true gold standard, strictly more precise than any
2895    /// double-precision accumulator under test.
2896    fn exact_dot(a: &[f64], b: &[f64]) -> f64 {
2897        let mut e: Vec<f64> = Vec::new();
2898        for (&x, &y) in a.iter().zip(b.iter()) {
2899            let (p, ep) = two_prod(x, y);
2900            grow_expansion(&mut e, p);
2901            grow_expansion(&mut e, ep);
2902        }
2903        // Components are non-overlapping and ascending in magnitude; summing
2904        // smallest-first yields the correctly rounded total.
2905        e.iter().fold(0.0f64, |acc, &c| acc + c)
2906    }
2907
2908    /// High-precision reference dot via compensated (double-double) summation.
2909    /// Cheap (O(n)) — used where naive's error is enormous so ~2u precision is
2910    /// already far more accurate than the baseline under test.
2911    fn dd_dot(a: &[f64], b: &[f64]) -> f64 {
2912        let (mut s, mut c) = (0.0f64, 0.0f64);
2913        for (&x, &y) in a.iter().zip(b.iter()) {
2914            let (p, ep) = two_prod(x, y);
2915            let (s2, es) = two_sum(s, p);
2916            s = s2;
2917            c += ep + es;
2918        }
2919        s + c
2920    }
2921
2922    fn naive_dot(a: &[f64], b: &[f64]) -> f64 {
2923        let mut acc = 0.0f64;
2924        for (&x, &y) in a.iter().zip(b.iter()) {
2925            acc += x * y;
2926        }
2927        acc
2928    }
2929
2930    /// Catastrophic-cancellation generator: large opposing terms plus small
2931    /// ones, so the naive running sum loses many bits to cancellation.
2932    fn ill_conditioned_pair(len: usize, seed: u64) -> (Vec<f64>, Vec<f64>) {
2933        let mut s = seed | 1;
2934        let mut next = || {
2935            s ^= s << 13;
2936            s ^= s >> 7;
2937            s ^= s << 17;
2938            (s >> 11) as f64 / ((1u64 << 53) as f64) - 0.5
2939        };
2940        let mut a = Vec::with_capacity(len);
2941        let mut b = Vec::with_capacity(len);
2942        for i in 0..len {
2943            // Span ~16 orders of magnitude with alternating signs.
2944            let scale = 10f64.powi((i % 17) as i32 - 8);
2945            let sign = if i % 2 == 0 { 1.0 } else { -1.0 };
2946            a.push(sign * next() * scale);
2947            b.push(next() * scale);
2948        }
2949        (a, b)
2950    }
2951
2952    /// `fma_dot` (compensated Dot2) error-vs-truth never exceeds the naive
2953    /// sum's and is strictly lower on the ill-conditioned ensemble in aggregate.
2954    #[test]
2955    fn fma_dot_beats_naive_accuracy() {
2956        let mut fma_total = 0.0f64;
2957        let mut naive_total = 0.0f64;
2958        let mut strict_wins = 0;
2959        for seed in 0..64u64 {
2960            let len = 200 + (seed as usize % 57);
2961            let (a, b) = ill_conditioned_pair(len, 0x9E37_79B9 ^ seed.wrapping_mul(2654435761));
2962            let truth = exact_dot(&a, &b);
2963            let fe = (super::fma_dot(&a, &b) - truth).abs();
2964            let ne = (naive_dot(&a, &b) - truth).abs();
2965            // Compensated (Dot2) summation is pointwise no less accurate than
2966            // the naive recurrence. The floor term tolerates a few-ulp tie when
2967            // both already sit at the round-to-nearest limit (well-conditioned).
2968            let floor = 8.0 * f64::EPSILON * truth.abs();
2969            assert!(
2970                fe <= ne * (1.0 + 1e-6) + floor,
2971                "fma_dot worse than naive: seed={seed} fma_err={fe:.3e} naive_err={ne:.3e}",
2972            );
2973            if fe < ne {
2974                strict_wins += 1;
2975            }
2976            fma_total += fe;
2977            naive_total += ne;
2978        }
2979        assert!(
2980            fma_total < naive_total,
2981            "fma_dot aggregate error {fma_total:.3e} not below naive {naive_total:.3e}",
2982        );
2983        assert!(
2984            strict_wins >= 40,
2985            "expected fma_dot to strictly win the majority; only {strict_wins}/64",
2986        );
2987    }
2988
2989    /// `fast_atv` blocked+pairwise reduction is strictly more accurate than a
2990    /// naive running column-sum on a long, ill-conditioned `n`-axis.
2991    #[test]
2992    fn fast_atv_blocked_beats_naive_accuracy() {
2993        let n = 200_003usize;
2994        let p = 3usize;
2995        let mut s = 0xD1B5_4A32u64;
2996        let mut next = || {
2997            s ^= s << 13;
2998            s ^= s >> 7;
2999            s ^= s << 17;
3000            (s >> 11) as f64 / ((1u64 << 53) as f64) - 0.5
3001        };
3002        let mut x = Array2::<f64>::zeros((n, p));
3003        let mut v = Array1::<f64>::zeros(n);
3004        for i in 0..n {
3005            let scale = 10f64.powi((i % 17) as i32 - 8);
3006            v[i] = if i % 2 == 0 { scale } else { -scale } * next();
3007            for j in 0..p {
3008                x[[i, j]] = next() * scale;
3009            }
3010        }
3011        let got = fast_atv(&x, &v);
3012        // Per-column truth and naive baseline.
3013        for j in 0..p {
3014            let col: Vec<f64> = (0..n).map(|i| x[[i, j]]).collect();
3015            let vv: Vec<f64> = v.to_vec();
3016            let truth = dd_dot(&col, &vv);
3017            let naive = naive_dot(&col, &vv);
3018            let ge = (got[j] - truth).abs();
3019            let ne = (naive - truth).abs();
3020            assert!(
3021                ge <= ne + f64::MIN_POSITIVE,
3022                "col {j}: blocked err {ge:.3e} exceeds naive {ne:.3e}",
3023            );
3024        }
3025    }
3026
3027    /// Non-contiguous (transposed-view) operands take the faer fallback and
3028    /// still match ndarray, proving the kernel gate is layout-safe.
3029    #[test]
3030    fn fast_av_strided_input_matches_ndarray() {
3031        let mut base = Array2::<f64>::zeros((40, 60));
3032        let mut s = 0x0BAD_F00Du64;
3033        let mut next = || {
3034            s ^= s << 13;
3035            s ^= s >> 7;
3036            s ^= s << 17;
3037            (s >> 11) as f64 / ((1u64 << 53) as f64) - 0.5
3038        };
3039        for x in base.iter_mut() {
3040            *x = next();
3041        }
3042        // A transposed view of `base` is (60, 40), non-row-major-contiguous.
3043        let a = base.t();
3044        let mut v = Array1::<f64>::zeros(40);
3045        for x in v.iter_mut() {
3046            *x = next();
3047        }
3048        let got = fast_av(&a, &v);
3049        let want = a.dot(&v);
3050        assert!(
3051            max_abs_diff_1d(&got, &want) < 1e-11,
3052            "strided fast_av mismatch (fallback path)",
3053        );
3054    }
3055}