Skip to main content

gam_solve/arrow_schur/
reduced_solve.rs

1//! The reduced `K x K` shared-system solve: dense Schur assembly (direct and
2//! square-root BA), the Schur matvec, the Jacobi/cluster/Schwarz
3//! preconditioners, Steihaug-PCG, and the [`ArrowSchurError`] type.
4
5use super::*;
6
7/// Host budget for a dense reduced Schur `k × k` f64 matrix (#1017). Above this
8/// the dense assembly is refused with a loud `SchurFactorFailed` rather than
9/// OOM-killing the host. 8 GiB ⇒ `k ≈ 32768`; every currently-feasible SAE border
10/// (k ≤ 5120 ⇒ 0.2 GiB) is well under it, while the qwen LLM border (k = 98304 ⇒
11/// 77 GiB) is correctly rejected as matrix-free-only.
12pub(crate) const DENSE_SCHUR_BYTES_BUDGET: u128 = 8 * 1024 * 1024 * 1024;
13
14/// Reduce one contiguous device tile's rows into a private `-Σ leftᵀ·right`
15/// partial (`k×k`).
16///
17/// The tile stacks its per-row `left_i` / `right_i` factors (each `d×k`) into
18/// two `(Σ_i d_i × k)` matrices and tries a single per-ordinal `AᵀB` device
19/// GEMM (`gam_gpu::try_fast_atb_on_ordinal`), which runs on the device this
20/// worker thread already bound — one big GPU GEMM per tile rather than `n` small
21/// CPU ones. When the device primitive declines (no GPU, shape below policy,
22/// transient failure) the tile reduces with the exact CPU `block_gemm_subtract`
23/// loop, so the result is unchanged. The partial is negated so the caller's
24/// `schur += partial` reproduces the serial `schur -= Σ contribution`.
25pub(crate) fn tile_schur_partial<B: BatchedBlockSolver>(
26    sys: &ArrowSchurSystem,
27    htt_factors: &ArrowFactorSlab,
28    backend: &B,
29    kind: SchurReductionKind,
30    ordinal: usize,
31    range: Range<usize>,
32) -> Result<Array2<f64>, ArrowSchurError> {
33    let k = sys.k;
34
35    // Build the per-row contribution factors once; both the GPU stacked-GEMM
36    // and the CPU fallback consume them.
37    let mut factors: Vec<(Array2<f64>, Array2<f64>)> = Vec::with_capacity(range.len());
38    let mut total_d = 0usize;
39    for i in range.clone() {
40        let (left, right) = row_schur_contribution_factors(
41            sys,
42            i,
43            &sys.rows[i],
44            htt_factors.factor(i),
45            backend,
46            kind,
47        )?;
48        total_d += left.nrows();
49        factors.push((left, right));
50    }
51
52    // Stack into (total_d × k) left/right matrices for one device AᵀB GEMM on
53    // this tile's bound ordinal. `try_fast_atb_on_ordinal` returns leftᵀ·right
54    // (k×k); negate into the partial. At an SAE-shaped whole-fit tile with
55    // n=2000 rows, k=2048 shared columns, M=12 local rows per observation, and
56    // K=8 candidate/atom batches, the stacked GEMM is
57    // 2*(n*M)*k^2 = 201_326_592_000 flops per batch, or
58    // 1_610_612_736_000 flops across K=8, so the policy work gate is cleared
59    // even though the observation count is far below the old row floor.
60    if total_d > 0 && k > 0 {
61        let mut left_stack = Array2::<f64>::zeros((total_d, k));
62        let mut right_stack = Array2::<f64>::zeros((total_d, k));
63        let mut base = 0usize;
64        for (left, right) in &factors {
65            let di = left.nrows();
66            left_stack
67                .slice_mut(ndarray::s![base..base + di, ..])
68                .assign(left);
69            right_stack
70                .slice_mut(ndarray::s![base..base + di, ..])
71                .assign(right);
72            base += di;
73        }
74        if let Some(product) =
75            gam_gpu::try_fast_atb_on_ordinal(ordinal, left_stack.view(), right_stack.view())
76        {
77            return Ok(product.mapv(|v| -v));
78        }
79    }
80
81    // CPU fallback: exact per-row block_gemm_subtract into a zero-seeded partial.
82    let mut partial = Array2::<f64>::zeros((k, k));
83    for (left, right) in &factors {
84        backend.block_gemm_subtract(&mut partial, left, right);
85    }
86    Ok(partial)
87}
88
89/// Reduce the per-row Schur contributions `Σ_i H_tβ^(i)ᵀ (H_tt^(i))⁻¹ H_tβ^(i)`
90/// out of `schur` (seeded with `H_ββ + ρ_β·I`).
91///
92/// The per-row contributions are independent — exactly the "sum over independent
93/// arrow-tip blocks" axis the device pool partitions. When more than one GPU is
94/// usable, [`gam_gpu::pool::balanced_partition`] splits the `0..n` rows into
95/// per-device contiguous tiles; each tile is reduced on its own scoped thread
96/// (binding that ordinal's context so the per-row GEMM-subtract offloads to its
97/// device) into a private `k×k` partial, and the partials are summed back into
98/// `schur` in tile order. The tiles are contiguous, ordered to cover `0..n`, and
99/// folded back in that same order, so within each tile the per-row accumulation
100/// order is preserved and the only departure from the serial loop is the
101/// inter-tile reassociation of the reduction sum — the established
102/// reduction-order equivalence the device pool already operates under, well
103/// inside the Newton solve's tolerance.
104///
105/// With a single device (or no GPU) the row loop runs serially in place, which
106/// is bit-for-bit the original behaviour.
107pub(crate) fn reduce_row_schur_contributions<B: BatchedBlockSolver + Sync>(
108    sys: &ArrowSchurSystem,
109    htt_factors: &ArrowFactorSlab,
110    backend: &B,
111    kind: SchurReductionKind,
112    schur: &mut Array2<f64>,
113) -> Result<(), ArrowSchurError> {
114    let n = sys.rows.len();
115    let k = sys.k;
116
117    let tiles = gam_gpu::device_runtime::GpuRuntime::global()
118        .map(|rt| gam_gpu::pool::balanced_partition(rt, n))
119        .filter(|tiles| tiles.len() > 1);
120
121    let Some(tiles) = tiles else {
122        // Single-device / CPU. The per-row contributions `-Σ_i leftᵀ·right` fold
123        // into the `k×k` `schur` independently — the same dense-assembly axis the
124        // multi-GPU tile path partitions, and the dense-Direct analog of the
125        // per-row matvec / streaming `accumulate_chunk` loops already parallelized
126        // for #1017. At the SAE Direct-solve shape (`n` in the thousands, wide
127        // border `k`) this O(n·d·k²) reduction is the dense assembly's whole cost
128        // and was the last serial CPU step on the dense-Schur build.
129        //
130        // Fan it across rayon over fixed row chunks: each chunk reduces its rows
131        // (in row order) into a private zero-seeded `k×k` partial, then the
132        // partials are folded into `schur` in CHUNK order. The per-chunk row order
133        // and the inter-chunk fold order are both fixed independent of thread
134        // scheduling, so the f64 reduction is **bit-identical run-to-run** (the
135        // #1017 determinism gate). NOTE: bit-identical run-to-run does NOT make
136        // it bit-identical to the in-place serial loop — the chunk-boundary
137        // reassociation of the reduction sum is a genuine f64 departure (the
138        // established equivalence `accumulate_chunk` / the per-row matvec operate
139        // under, well inside the Newton solve's tolerance). It bounds candidate-
140        // to-candidate drift to that reassociation margin, so the criterion
141        // ranking is stable EXCEPT for candidates tying within the margin, where
142        // the winner can flip; it is not an exact no-move guarantee (#1211). For
143        // an exact-order guarantee, take the serial path. Stay in-place serial
144        // below the row floor and when already inside a rayon worker (the topology
145        // race fans candidates with `run_topology_race_parallel`) to avoid
146        // nested-rayon oversubscription — the same guard the matvec uses.
147        let n_rows = sys.rows.len();
148        let parallel =
149            n_rows >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
150        if parallel {
151            use rayon::prelude::*;
152            const CHUNK: usize = 64;
153            let partials: Result<Vec<Array2<f64>>, ArrowSchurError> = (0..n_rows)
154                .into_par_iter()
155                .chunks(CHUNK)
156                .map(|idxs| {
157                    let mut partial = Array2::<f64>::zeros((k, k));
158                    for i in idxs {
159                        subtract_row_schur_contribution(
160                            sys,
161                            i,
162                            &sys.rows[i],
163                            htt_factors.factor(i),
164                            backend,
165                            kind,
166                            &mut partial,
167                        )?;
168                    }
169                    Ok(partial)
170                })
171                .collect();
172            // Deterministic ordered fold: chunk partials hold `-Σ contribution`
173            // over their rows, so `schur += partial` reproduces the serial
174            // `schur -= Σ contribution` in fixed (chunk, a, b) order.
175            for partial in &partials? {
176                for a in 0..k {
177                    for b in 0..k {
178                        schur[[a, b]] += partial[[a, b]];
179                    }
180                }
181            }
182            return Ok(());
183        }
184        // Serial in-place reduction (original order) — bit-for-bit reference.
185        for (i, row) in sys.rows.iter().enumerate() {
186            subtract_row_schur_contribution(
187                sys,
188                i,
189                row,
190                htt_factors.factor(i),
191                backend,
192                kind,
193                schur,
194            )?;
195        }
196        return Ok(());
197    };
198
199    // Multi-GPU: one private `-Σ leftᵀ·right` partial per contiguous device
200    // tile. Each tile runs on its own scoped worker thread that binds its
201    // ordinal's context and issues a single stacked AᵀB GEMM on that device, so
202    // the tiles' GEMMs overlap across the pool. Folding the partials back into
203    // the H_ββ-seeded `schur` reproduces the serial reduction (up to inter-tile
204    // reassociation).
205    let partials: Result<Vec<Array2<f64>>, ArrowSchurError> = std::thread::scope(|scope| {
206        let handles: Vec<_> = tiles
207            .iter()
208            .map(|(ordinal, range)| {
209                let ordinal = *ordinal;
210                let range = range.clone();
211                scope.spawn(move || {
212                    // Bind this ordinal's CUDA context on this worker thread so
213                    // the per-row GPU GEMM shims issued from `tile_schur_partial`
214                    // offload to that device. A missing context or bind failure
215                    // is intentionally consumed without escalation — the shims
216                    // no-op back to CPU and the math is unchanged. Off Linux
217                    // `GpuRuntime::global()` is always `None`, so this branch
218                    // is unreachable and the bind is omitted entirely.
219                    #[cfg(target_os = "linux")]
220                    {
221                        if let Some(ctx) = gam_gpu::device_runtime::cuda_context_for(ordinal) {
222                            if ctx.bind_to_thread().is_err() {
223                                // Fall through: this tile reduces on the CPU.
224                            }
225                        }
226                    }
227                    tile_schur_partial(sys, htt_factors, backend, kind, ordinal, range)
228                })
229            })
230            .collect();
231        handles
232            .into_iter()
233            .map(|handle| {
234                handle
235                    .join()
236                    .map_err(|_| ArrowSchurError::SchurFactorFailed {
237                        reason: "schur-reduction tile thread panicked".to_string(),
238                    })?
239            })
240            .collect()
241    });
242    let partials = partials?;
243
244    // Fold partials into `schur` in tile order (contiguous, covering 0..n) so
245    // the per-tile and inter-tile accumulation order is the row order; each
246    // partial holds `-Σ contribution` over its rows, so `schur += partial`
247    // reproduces `schur -= Σ contribution`.
248    for partial in &partials {
249        for a in 0..k {
250            for b in 0..k {
251                schur[[a, b]] += partial[[a, b]];
252            }
253        }
254    }
255    Ok(())
256}
257
258pub(crate) fn build_dense_schur_direct<B: BatchedBlockSolver + Sync>(
259    sys: &ArrowSchurSystem,
260    htt_factors: &ArrowFactorSlab,
261    ridge_beta: f64,
262    backend: &B,
263) -> Result<Array2<f64>, ArrowSchurError> {
264    let k = sys.k;
265    // Materialise H_ββ via the BetaPenaltyOp trait (#296): DensePenaltyOp
266    // for the legacy dense path, structured ops for SAE / Kronecker smooths.
267    let op = sys.effective_penalty_op();
268    if op.dim() != k {
269        return Err(ArrowSchurError::SchurFactorFailed {
270            reason: "Direct BA requires a K×K shared H_ββ penalty operator".to_string(),
271        });
272    }
273    // Fail LOUD, never OOM-kill (#1017): the dense reduced Schur is `k × k` f64.
274    // At SAE LLM borders (qwen `k = 98304` ⇒ 77 GiB) materialising it would crash
275    // the host. The matrix-free device PCG already solves the *step* without it
276    // (`try_device_arrow_direct_sae_pcg`); only the joint-Hessian log-det still
277    // routes here. A matrix-free determinant-lemma log-det (the proper follow-up)
278    // is not yet wired, so refuse the allocation with an actionable error rather
279    // than degrading silently into an OOM. The budget is generous so every
280    // currently-feasible border (k ≤ 5120 ⇒ 0.2 GiB) is unaffected.
281    let dense_bytes = (k as u128).saturating_mul(k as u128).saturating_mul(8);
282    if dense_bytes > DENSE_SCHUR_BYTES_BUDGET {
283        return Err(ArrowSchurError::SchurFactorFailed {
284            reason: format!(
285                "dense reduced Schur is {k}×{k} f64 = {} MiB, exceeding the {} MiB host budget; \
286                 this border is matrix-free-only (the device PCG solves the step without the dense \
287                 Schur) and a matrix-free determinant-lemma log-det is the required follow-up",
288                dense_bytes / (1024 * 1024),
289                DENSE_SCHUR_BYTES_BUDGET / (1024 * 1024),
290            ),
291        });
292    }
293    let mut schur = op.to_dense();
294    for j in 0..k {
295        schur[[j, j]] += ridge_beta;
296    }
297    reduce_row_schur_contributions(
298        sys,
299        htt_factors,
300        backend,
301        SchurReductionKind::Direct,
302        &mut schur,
303    )?;
304    symmetrize_upper_from_lower(&mut schur);
305    Ok(schur)
306}
307
308pub(crate) fn build_dense_schur_sqrt_ba<B: BatchedBlockSolver + Sync>(
309    sys: &ArrowSchurSystem,
310    htt_factors: &ArrowFactorSlab,
311    ridge_beta: f64,
312    backend: &B,
313) -> Result<Array2<f64>, ArrowSchurError> {
314    let k = sys.k;
315    // Materialise H_ββ via the BetaPenaltyOp trait (#296).
316    let op = sys.effective_penalty_op();
317    if op.dim() != k {
318        return Err(ArrowSchurError::SchurFactorFailed {
319            reason: "Square-Root BA direct solve requires a K×K shared H_ββ penalty operator"
320                .to_string(),
321        });
322    }
323    let mut schur = op.to_dense();
324    for j in 0..k {
325        schur[[j, j]] += ridge_beta;
326    }
327    reduce_row_schur_contributions(
328        sys,
329        htt_factors,
330        backend,
331        SchurReductionKind::SqrtBa,
332        &mut schur,
333    )?;
334    symmetrize_upper_from_lower(&mut schur);
335    Ok(schur)
336}
337
338/// Certified Carson–Higham mixed-precision solve of the reduced dense Schur
339/// system `S Δβ = rhs` (#1014), specialized to the streaming/residency path.
340///
341/// Returns `Some(Δβ)` when certified mixed precision is enabled AND the κ gate
342/// admits the f32 factorization AND the f64 backward-error certificate closes;
343/// `None` in every other case so the caller falls back to the exact f64
344/// triangular solve. The f64 `factor` (whose diagonal carries the exact
345/// `log|S|`) is supplied by the caller and never re-derived here — the logdet
346/// the evidence path reads stays f64 by construction.
347///
348/// Method: store the f64 Cholesky factor as f32, solve in f32, then refine with
349/// residuals `r = rhs − S·x` computed in f64 against the f64 `S`. With
350/// `κ(S)·u_f32 < margin` the refinement contracts at rate `κ·u`, and the
351/// terminating certificate is the normwise backward error
352/// `‖r‖∞ / (‖S‖∞‖x‖∞ + ‖rhs‖∞) ≤ tol`. A non-decreasing residual or an
353/// unmet certificate after `max_refinement_steps` returns `None`.
354pub(crate) fn mixed_precision_reduced_beta(
355    schur: &Array2<f64>,
356    factor: &Array2<f64>,
357    rhs: &Array1<f64>,
358    options: &ArrowSolveOptions,
359) -> Option<Array1<f64>> {
360    let ArrowSolvePrecisionPolicy::CertifiedMixed {
361        max_refinement_steps,
362        residual_relative_tolerance,
363        kappa_unit_roundoff_margin,
364    } = options.solve_precision
365    else {
366        return None;
367    };
368    // The reduced-system mixed-precision path is the dense reduced solve only;
369    // a trust-region-truncated step takes the Steihaug branch below in f64.
370    if options.trust_region.radius.is_finite() {
371        return None;
372    }
373    let n = schur.nrows();
374    if n == 0 {
375        return None;
376    }
377
378    // κ gate: the f32 factorization is only admissible when κ(S)·u_f32 leaves
379    // the refinement contraction headroom the certificate needs.
380    let kappa = cholesky_factor_kappa_estimate(factor);
381    if !kappa.is_finite() || kappa * F32_UNIT_ROUNDOFF >= kappa_unit_roundoff_margin {
382        return None;
383    }
384
385    let factor_f32 = factor.mapv(|v| v as f32);
386    let s_inf = matrix_inf_norm(schur);
387    let rhs_inf = rhs.iter().fold(0.0_f64, |a, &b| a.max(b.abs()));
388    let certificate_tol = residual_relative_tolerance
389        .max(MIXED_PRECISION_CERTIFICATE_EPSILON_MULTIPLIER * f64::EPSILON);
390
391    // f32 solve of the seed system, then f64-residual refinement steps.
392    let mut x = cholesky_solve_lower_f32(&factor_f32, &rhs.mapv(|v| v as f32)).mapv(|v| v as f64);
393    let mut last_residual = f64::INFINITY;
394    for _ in 0..=max_refinement_steps {
395        // Residual r = rhs − S·x in f64 against the f64 model.
396        let sx = schur.dot(&x);
397        let mut r = rhs.clone();
398        r -= &sx;
399        let r_inf = r.iter().fold(0.0_f64, |a, &b| a.max(b.abs()));
400        let x_inf = x.iter().fold(0.0_f64, |a, &b| a.max(b.abs()));
401        let denom = s_inf * x_inf + rhs_inf;
402        let backward_error = if denom > 0.0 { r_inf / denom } else { 0.0 };
403        if backward_error <= certificate_tol {
404            return Some(x);
405        }
406        // Refinement must make monotone progress, else hand back to f64.
407        if !(r_inf < last_residual) {
408            return None;
409        }
410        last_residual = r_inf;
411        // Correction solve in f32 against the f32 factor: S·δ = r.
412        let delta = cholesky_solve_lower_f32(&factor_f32, &r.mapv(|v| v as f32)).mapv(|v| v as f64);
413        x += &delta;
414    }
415    None
416}
417
418/// Infinity norm (max absolute row sum) of a dense matrix.
419pub(crate) fn matrix_inf_norm(a: &Array2<f64>) -> f64 {
420    let mut max_row = 0.0_f64;
421    for row in a.rows() {
422        let s: f64 = row.iter().map(|v| v.abs()).sum();
423        if s > max_row {
424            max_row = s;
425        }
426    }
427    max_row
428}
429
430/// Spectral positive-definiteness floor for the reduced Schur complement
431/// `S` (#1026 SAE co-collapse SOLVE-path cure).
432///
433/// Reached only after the genuine Cholesky of `S` has REFUSED it (an indefinite
434/// reduced Schur: collapsed atoms drive a per-row `H_tt` near-singular, so the
435/// accumulated `Σ_i H_tβᵀ (H_tt)⁻¹ H_tβ` over-subtracts `H_ββ + ridge_β·I` into a
436/// matrix with a non-positive eigenvalue). Rather than reject and let the LM
437/// loop inflate `ridge_β` over EVERY β direction (the #1026 "crawl"), we
438/// symmetric-eigendecompose `S` and clamp every eigenvalue UP to
439/// `floor·max(λ)`. This is Levenberg–Marquardt restricted to exactly the
440/// indefinite/collapsed subspace: a well-separated positive direction
441/// (`λ ≫ floor·max λ`) keeps its EXACT eigenvalue (`λ.max(floor·max λ) = λ`), so
442/// the Newton step in the healthy β subspace is unchanged, while only the
443/// collapsed directions get the minimal positive stiffness needed for a PD
444/// solve. Returns the floored, symmetric, strictly-PD matrix, or `None` if `S`
445/// has no usable scale (non-finite / all-zero spectrum), in which case the
446/// caller keeps the strict refusal.
447///
448/// Mirrors the per-row evidence floor
449/// [`super::factorization::factor_spectral_deflated_evidence_row`]; the only
450/// difference is the floored VALUE — a small positive `floor·max λ` (Tikhonov,
451/// for an accurate solve) here, vs unit stiffness `+1` (`log 1 = 0`) there (for
452/// the quotient log-det).
453pub(crate) fn spectral_pd_floored_schur(
454    schur: &Array2<f64>,
455    relative_floor: f64,
456) -> Option<Array2<f64>> {
457    let n = schur.nrows();
458    if n == 0 || schur.ncols() != n || !(relative_floor.is_finite() && relative_floor > 0.0) {
459        return None;
460    }
461    // Symmetrise defensively (the assembled Schur is symmetric up to reduction
462    // order; the eig routine assumes exact symmetry).
463    let mut sym = Array2::<f64>::zeros((n, n));
464    for i in 0..n {
465        for j in 0..n {
466            let v = 0.5 * (schur[[i, j]] + schur[[j, i]]);
467            if !v.is_finite() {
468                return None;
469            }
470            sym[[i, j]] = v;
471        }
472    }
473    let (evals, evecs) = sym.eigh(Side::Lower).ok()?;
474    let max_abs = evals.iter().fold(
475        0.0_f64,
476        |acc, &v| if v.is_finite() { acc.max(v.abs()) } else { acc },
477    );
478    if !(max_abs.is_finite() && max_abs > 0.0) {
479        return None;
480    }
481    let floor = relative_floor * max_abs;
482    // Reconstruct `Σ_i max(λ_i, floor) v_i v_iᵀ`: clamp every eigenvalue UP to a
483    // strictly positive `floor`. Healthy positive directions (`λ ≫ floor`) are
484    // untouched; non-positive / tiny collapsed directions are lifted to exactly
485    // `floor`. The result is symmetric PD by construction.
486    let mut conditioned = Array2::<f64>::zeros((n, n));
487    for eig_idx in 0..evals.len() {
488        let lambda = evals[eig_idx];
489        let lambda_floored = if lambda.is_finite() {
490            lambda.max(floor)
491        } else {
492            floor
493        };
494        for i in 0..n {
495            let vi = evecs[[i, eig_idx]];
496            if vi == 0.0 {
497                continue;
498            }
499            for j in 0..n {
500                conditioned[[i, j]] += lambda_floored * vi * evecs[[j, eig_idx]];
501            }
502        }
503    }
504    Some(conditioned)
505}
506
507pub(crate) fn solve_dense_reduced_system(
508    schur: &Array2<f64>,
509    rhs_beta: &Array1<f64>,
510    options: &ArrowSolveOptions,
511    metric_weights: Option<&MetricWeights>,
512) -> Result<(Array1<f64>, Option<Array2<f64>>, PcgDiagnostics), ArrowSchurError> {
513    let factor = match cholesky_lower(schur) {
514        Ok(factor) => factor,
515        Err(e) => {
516            // #1026 — opt-in spectral PD-floor on the indefinite reduced Schur.
517            // When enabled (SAE solve path), condition ONLY the collapsed
518            // directions and re-factor, instead of erroring out and letting the
519            // outer LM loop inflate `ridge_β` over every β direction (the
520            // co-collapse "crawl"). Disabled (default `None`) keeps the strict
521            // refusal so BA / non-SAE callers are bit-for-bit unchanged.
522            match options.schur_pd_floor {
523                Some(relative_floor) => match spectral_pd_floored_schur(schur, relative_floor) {
524                    Some(floored) => match cholesky_lower(&floored) {
525                        Ok(factor) => {
526                            // Solve against the floored (PD) Schur. The healthy β
527                            // subspace keeps its exact eigenvalues, so its Δβ is
528                            // the exact Newton component; only the collapsed
529                            // subspace is minimally damped.
530                            let direct =
531                                mixed_precision_reduced_beta(&floored, &factor, rhs_beta, options)
532                                    .unwrap_or_else(|| cholesky_solve_vector(&factor, rhs_beta));
533                            if step_inside_trust_region(
534                                direct.view(),
535                                options.trust_region.radius,
536                                metric_weights,
537                            ) {
538                                return Ok((direct, Some(factor), PcgDiagnostics::default()));
539                            }
540                            let identity = IdentityPreconditioner;
541                            let (delta, diag) = steihaug_dense_system(
542                                &floored,
543                                rhs_beta,
544                                &identity,
545                                &ArrowPcgOptions {
546                                    max_iterations: options.trust_region.max_iterations,
547                                    relative_tolerance: options
548                                        .trust_region
549                                        .steihaug_relative_tolerance,
550                                },
551                                &options.trust_region,
552                                metric_weights,
553                            )?;
554                            return Ok((delta, Some(factor), diag));
555                        }
556                        Err(floored_err) => {
557                            return Err(ArrowSchurError::SchurFactorFailed {
558                                reason: format!(
559                                    "reduced Schur non-PD ({e}); spectral PD-floor \
560                                     reconstruction still non-PD: {floored_err}"
561                                ),
562                            });
563                        }
564                    },
565                    None => {
566                        return Err(ArrowSchurError::SchurFactorFailed {
567                            reason: format!(
568                                "reduced Schur non-PD ({e}); spectral PD-floor declined \
569                                 (no usable spectrum)"
570                            ),
571                        });
572                    }
573                },
574                None => return Err(ArrowSchurError::SchurFactorFailed { reason: e }),
575            }
576        }
577    };
578    // Ill-conditioned-but-PD Schur guard. The per-row factor checks reject
579    // any single barely-PD H_tt^(i) block, but the reduced Schur complement
580    //     S = H_ββ + ridge_β·I − Σ_i H_tβ^(i)ᵀ (H_tt^(i))⁻¹ H_tβ^(i)
581    // accumulates the (H_tt^(i))⁻¹ contributions of every row in finite
582    // precision. With many weak-but-admissible rows those terms can sum to a
583    // Schur matrix whose Cholesky succeeds yet whose condition number is far
584    // past the safe inversion regime, so `cholesky_solve_vector` yields an
585    // inaccurate Δβ that is silently propagated to the Newton step. Apply the
586    // same diagonal-ratio κ proxy used per-row to the reduced factor and treat
587    // an over-threshold estimate as a Schur-stability failure: `SchurFactorFailed`
588    // is already recoverable in `solve_with_lm_escalation_inner`, so this lifts
589    // `ridge_beta` and re-forms a better-conditioned Schur. This guard is
590    // exclusive to the dense Direct / SqrtBA path (the only caller of this
591    // function); the inexact-PCG path tolerates higher κ(S) and is unaffected.
592    // Evidence/log-det-only callers (`tolerate_ill_conditioning`) skip this
593    // rejection: the factor is genuinely PD (Cholesky above succeeded), so its
594    // diagonal still yields an exact `log|S|`, and an inaccurate Δβ is harmless
595    // because the step is discarded.
596    if !options.tolerate_ill_conditioning {
597        let schur_kappa = cholesky_factor_kappa_estimate(&factor);
598        if !schur_kappa.is_finite() || schur_kappa > safe_spd_kappa_max(schur.nrows()) {
599            // #1026 — over-complete SAE dictionaries park surplus atoms dead
600            // (β_k → 0), so the reduced Schur is PD (the Cholesky above succeeded)
601            // but ILL-CONDITIONED: the dead decoder subspace carries near-zero
602            // eigenvalues while the live subspace is healthy. The kappa gate's
603            // concern is an inaccurate Δβ from accumulated (H_tt)⁻¹ contamination —
604            // but on the dead subspace the correct Δβ IS ≈0 (those atoms have no
605            // signal), so the only "inaccuracy" is in directions whose true step is
606            // zero. When the spectral PD-floor is enabled (the SAE solve path),
607            // clamp exactly those collapsed directions up to `floor·max(λ)` and
608            // solve against the floored Schur: the live subspace keeps its EXACT
609            // Newton component, the dead subspace is damped to ≈0, and κ is bounded
610            // so Δβ is accurate where it matters. This is the same conditioning the
611            // non-PD branch above applies; here it also covers the PD-but-ill-
612            // conditioned case so the LM loop does not exhaust `ridge_β` trying to
613            // (futilely) lift a fundamentally rank-deficient dead-atom subspace.
614            // Without the floor (BA / non-SAE callers) the strict refusal stands.
615            if let Some(relative_floor) = options.schur_pd_floor
616                && let Some(floored) = spectral_pd_floored_schur(schur, relative_floor)
617                && let Ok(floored_factor) = cholesky_lower(&floored)
618            {
619                let direct =
620                    mixed_precision_reduced_beta(&floored, &floored_factor, rhs_beta, options)
621                        .unwrap_or_else(|| cholesky_solve_vector(&floored_factor, rhs_beta));
622                if step_inside_trust_region(
623                    direct.view(),
624                    options.trust_region.radius,
625                    metric_weights,
626                ) {
627                    return Ok((direct, Some(floored_factor), PcgDiagnostics::default()));
628                }
629                let identity = IdentityPreconditioner;
630                let (delta, diag) = steihaug_dense_system(
631                    &floored,
632                    rhs_beta,
633                    &identity,
634                    &ArrowPcgOptions {
635                        max_iterations: options.trust_region.max_iterations,
636                        relative_tolerance: options.trust_region.steihaug_relative_tolerance,
637                    },
638                    &options.trust_region,
639                    metric_weights,
640                )?;
641                return Ok((delta, Some(floored_factor), diag));
642            }
643            return Err(ArrowSchurError::SchurFactorFailed {
644                reason: format!(
645                    "reduced Schur complement Cholesky succeeded but is ill-conditioned \
646                     (kappa_estimate={schur_kappa:e}); accumulated per-row \
647                     (H_tt)⁻¹ contamination would yield an inaccurate Δβ"
648                ),
649            });
650        }
651    }
652    // Reduced-system solve. The f64 `factor` is always retained and returned —
653    // its diagonal is the EXACT `log|S|` the evidence path reads, so the logdet
654    // stays f64 regardless of how Δβ is computed (#1014 invariant). When the
655    // streaming/residency path enabled certified mixed precision, the Δβ solve
656    // itself runs f32-then-f64-refined (κ-gated, with the f64 triangular solve
657    // as the automatic fallback); the certificate is the f64 backward error.
658    let direct = mixed_precision_reduced_beta(schur, &factor, rhs_beta, options)
659        .unwrap_or_else(|| cholesky_solve_vector(&factor, rhs_beta));
660    if step_inside_trust_region(direct.view(), options.trust_region.radius, metric_weights) {
661        return Ok((direct, Some(factor), PcgDiagnostics::default()));
662    }
663
664    // Ceres-style trust-region correction: once the dense BA solve proposes a
665    // step outside the trust ball, Steihaug-CG returns the boundary point
666    // without requiring a second dense factorization.
667    let identity = IdentityPreconditioner;
668    let (delta, diag) = steihaug_dense_system(
669        schur,
670        rhs_beta,
671        &identity,
672        &ArrowPcgOptions {
673            max_iterations: options.trust_region.max_iterations,
674            relative_tolerance: options.trust_region.steihaug_relative_tolerance,
675        },
676        &options.trust_region,
677        metric_weights,
678    )?;
679    Ok((delta, Some(factor), diag))
680}
681
682/// Solve an externally accumulated dense reduced β system
683/// `S Δβ = rhs_β` with the same LM-style ridge escalation the full-batch
684/// driver applies: on a `SchurFactorFailed` (non-PD or ill-conditioned `S`),
685/// geometrically grow a proximal ridge on `S`'s diagonal and retry.
686///
687/// Used by the SAE streaming joint fit, which accumulates `S` and `rhs_β` over
688/// re-materialized row chunks (via [`StreamingArrowSchur::take_accumulators`])
689/// and must solve the single global reduced system without a per-row
690/// `ArrowSchurSystem`. `S` is symmetrized from its lower triangle before each
691/// factorization. `base_ridge_beta` is folded into the caller's `S` already;
692/// this routine only adds the *escalation* ridge on top.
693pub fn solve_streaming_reduced_beta(
694    s_acc: &Array2<f64>,
695    rhs_beta: &Array1<f64>,
696    options: &ArrowSolveOptions,
697) -> Result<Array1<f64>, ArrowSchurError> {
698    let mut proximal_ridge = 0.0_f64;
699    let mut last_err: Option<ArrowSchurError> = None;
700    for attempt in 0..=DEFAULT_PROXIMAL_MAX_ATTEMPTS {
701        let mut schur = s_acc.clone();
702        symmetrize_upper_from_lower(&mut schur);
703        if proximal_ridge > 0.0 {
704            for j in 0..schur.nrows() {
705                schur[[j, j]] += proximal_ridge;
706            }
707        }
708        // Reduced K-system on device: Jacobi-preconditioned CG over the dense
709        // symmetric `S`. The `O(K²)` `S·p` matvec runs device-side; only the
710        // K-vectors cross the boundary per CG iteration. This is the dominant
711        // cost of the streaming SAE joint fit at `K = 100K`. Any device-side
712        // failure (`Unavailable`, non-PD Jacobi diagonal) falls through to the
713        // CPU `solve_dense_reduced_system`, which then drives the same proximal
714        // ridge escalation. A genuine device PD failure is non-recoverable for
715        // this attempt's `schur`, so we let the CPU path re-confirm and escalate.
716        if gam_gpu::device_runtime::GpuRuntime::is_available() {
717            match crate::gpu_kernels::arrow_schur::solve_reduced_beta_pcg(
718                &schur,
719                rhs_beta,
720                options.trust_region.max_iterations,
721                options.trust_region.steihaug_relative_tolerance,
722            ) {
723                Ok(delta_beta) => return Ok(delta_beta),
724                Err(crate::gpu_kernels::arrow_schur::ArrowSchurGpuFailure::Unavailable) => {}
725                Err(_) => {
726                    // Device declined this `schur` (e.g. non-PD Jacobi diag);
727                    // let the CPU path confirm and escalate the proximal ridge.
728                }
729            }
730        }
731        match solve_dense_reduced_system(&schur, rhs_beta, options, None) {
732            Ok((delta_beta, _factor, _diag)) => return Ok(delta_beta),
733            Err(err) => {
734                let recoverable = matches!(
735                    err,
736                    ArrowSchurError::SchurFactorFailed { .. }
737                        | ArrowSchurError::PcgFailed { .. }
738                        | ArrowSchurError::UnboundedNegativeCurvature { .. }
739                );
740                last_err = Some(err);
741                if !recoverable || attempt == DEFAULT_PROXIMAL_MAX_ATTEMPTS {
742                    break;
743                }
744                proximal_ridge = if proximal_ridge == 0.0 {
745                    DEFAULT_PROXIMAL_INITIAL_RIDGE
746                } else {
747                    proximal_ridge * DEFAULT_PROXIMAL_RIDGE_GROWTH
748                };
749            }
750        }
751    }
752    Err(last_err.expect("escalation loop set last_err on failure"))
753}
754
755pub(crate) fn step_inside_trust_region(
756    step: ArrayView1<'_, f64>,
757    radius: f64,
758    metric_weights: Option<&MetricWeights>,
759) -> bool {
760    !radius.is_finite() || metric_norm(step, metric_weights) <= radius
761}
762
763/// Below this row count the per-row Schur loop stays sequential: the rayon
764/// fan-out (chunk dispatch + the deterministic per-chunk length-`K` reduction)
765/// costs more than it saves for the handful-of-rows arrow systems that dominate
766/// the non-SAE callers. Above it — the SAE LLM shape (`n` in the thousands,
767/// wide border `k`) that issue #1017 names — the per-row `H_βt (H_tt)⁻¹ H_tβ x`
768/// contributions are the matvec's whole cost and parallelize cleanly.
769pub(crate) const SCHUR_MATVEC_PARALLEL_ROW_MIN: usize = 256;
770
771/// Below this border width `k` the dense `H_ββ` penalty-prologue GEMV stays
772/// sequential: parallelizing a `k×k` matvec only pays once `k²` is large enough
773/// to dwarf the rayon fan-out, which for the arrow callers with narrow borders
774/// it never is. At the SAE LLM border (`k` in the low thousands) the `O(k²)`
775/// prologue is ≈4M flops/CG-iteration and was the serial Amdahl ceiling on the
776/// otherwise per-row-parallel matvec (#1017), so it crosses this threshold and
777/// fans out. 512 keeps the prologue serial for every non-SAE arrow system while
778/// engaging it for the wide SAE/Qwen borders the issue targets.
779pub(crate) const SCHUR_PROLOGUE_PARALLEL_K_MIN: usize = 512;
780
781/// Device-residency CPU analogue for the SAE reduced-Schur matvec (#1017).
782///
783/// In the production SAE joint fit the per-row cross-block factors as
784/// `H_tβ^(i) = L_i P_i`, where `L_i` (`q_i × p`) is the row's local
785/// assignment/coordinate Jacobian and `P_i` (`p × K`, sparse) gathers the
786/// active atoms' decoder blocks (`P_i x = Σ_s φ_s · x[base_s .. base_s+p]`).
787/// The reduced-Schur point-elimination contribution of one row is therefore
788///
789/// ```text
790/// S_i x = H_βt^(i) (H_tt^(i)+ρ_t I)⁻¹ H_tβ^(i) x
791///       = P_iᵀ · [ L_iᵀ (H_tt^(i)+ρ_t I)⁻¹ L_i ] · P_i x
792///       = P_iᵀ G_i (P_i x),      G_i := L_iᵀ (H_tt^(i)+ρ_t I)⁻¹ L_i   (p×p).
793/// ```
794///
795/// The block `G_i = L_iᵀ Y_i` depends only on the assembled per-row blocks and
796/// the (already-computed, solve-stable) `H_tt` factor — NOT on the CG iterate
797/// `x`. The generic [`schur_matvec`] re-walks `apply_jbeta → apply_l →
798/// solve(d×d) → apply_l_t → scatter` on every CG iteration; this object **stages
799/// the factors `(L_i, Y_i)` once per CG solve** (the "upload X once" residency
800/// mechanism, applied on CPU to the matvec rather than a dense factorization),
801/// turning each subsequent matvec into a sparse gather → two `di×p` GEMVs →
802/// sparse scatter, with no per-iteration triangular solve and no operator-closure
803/// re-walk. It never materialises the dense `p×p` product: `di ≪ p` for SAE
804/// rows, so the factored apply is `2·support_i·p + 2·di·p` flops/row — the two
805/// `di·p` GEMVs PLUS the `support_i·p` sparse gather (`P_i x`) and `support_i·p`
806/// sparse scatter (`P_iᵀ prod`) — versus the dense `p²` block apply, and
807/// `O(n·di·p)` memory (vs `O(n·p²)` ≈ 67 GB at the Qwen shape — the dense form
808/// is OOM). For dense/full active support `support_i` can scale with the active
809/// β-columns, so the gather/scatter term is NOT negligible and is counted here.
810///
811/// Numerically identical to the generic path up to floating-point reassociation
812/// (it differentiates and accumulates the SAME quotient). It is deterministic
813/// run-to-run and within the reassociation margin of the serial path, so the
814/// criterion ranking across topology candidates is stable except for candidates
815/// separated by less than that f64 margin, where reassociation can flip the
816/// near-tie winner — it is NOT an exact no-move guarantee (#1211).
817pub(crate) struct SaeResidentReducedSchur {
818    /// Decoder output dimension `p` (the side length of every `G_i = L_iᵀ Y_i`).
819    pub(crate) p: usize,
820    /// Per-row **factored** residency: `(L_i, Y_i)`, each stored row-major as a
821    /// `di × p` slab (`L_i` = local Jacobian, `Y_i = (H_tt^(i)+ρ_t I)⁻¹ L_i`).
822    /// The reduced block is `G_i = L_iᵀ Y_i` (`p×p`, symmetric PSD), but it has
823    /// rank ≤ `di` and `di ≪ p` for SAE rows (the per-row latent dim is 1–2
824    /// while `p` is the decoder block width, ~2048). Materialising the dense
825    /// `p×p` block would cost `O(n·p²)` memory (≈67 GB at the Qwen shape) and
826    /// `p²` flops per matvec/row; the factored form costs `O(n·di·p)` memory and
827    /// `2·support_i·p + 2·di·p` flops/row, applying `G_i v = L_iᵀ (Y_i v)`
828    /// (sparse gather over `support_i` atoms → `di`-length GEMV → `p`-length
829    /// GEMV → sparse scatter over `support_i` atoms). The `2·support_i·p`
830    /// gather/scatter term is part of the per-row cost — for dense/full support
831    /// `support_i` scales with active β-columns — and is not dropped. A row with
832    /// empty active support / degenerate dims gets `di = 0` and is skipped.
833    /// `(di, L_i, Y_i)` per row; `L_i`/`Y_i` are `di·p`-length row-major buffers.
834    pub(crate) rows: Vec<ResidentRowFactor>,
835    /// Per-row active atom support `(β-block base index, φ weight)`, shared with
836    /// the assembler's [`DeviceSaePcgData`] (no re-clone of the index lists).
837    pub(crate) a_phi: Arc<[Vec<(usize, f64)>]>,
838    /// #1033: per-row local Jacobian `L_i` (row-major `di × p`), SHARED via `Arc`
839    /// with the assembler's [`DeviceSaePcgData`] rather than copied into each
840    /// `ResidentRowFactor`. The staged factor previously held its own verbatim
841    /// row-major copy of `data.local_jac[row]` — a second full `O(n·di·p)` slab
842    /// for zero benefit (the bytes and the `di × p` layout are identical). The
843    /// matvec now reads `L_i = &self.local_jac[row]` directly; only the SOLVED
844    /// factor `Y_i = (H_tt+ρI)⁻¹ L_i` (genuinely new data) stays per-row. Reads
845    /// are byte-for-byte the former `rf.l` (same slab, same `r·p + c` indexing),
846    /// so the matvec/preconditioner output is bit-identical.
847    pub(crate) local_jac: Arc<[Vec<f64>]>,
848}
849
850/// Factored per-row residency block: `G_i = L_iᵀ Y_i` kept as its `di×p` factors
851/// so the matvec never materialises the dense `p×p` product. The local Jacobian
852/// factor `L_i` is NOT stored here — it is shared via
853/// [`SaeResidentReducedSchur::local_jac`] (`&local_jac[row]`); only the solved
854/// `Y_i` is per-row. See [`SaeResidentReducedSchur`].
855pub(crate) struct ResidentRowFactor {
856    /// Row latent dimension `di` (the inner contraction width). `0` ⇒ skipped.
857    pub(crate) di: usize,
858    /// `Y_i = (H_tt^(i)+ρ_t I)⁻¹ L_i` row-major `di × p`. Empty when `di == 0`.
859    pub(crate) y: Vec<f64>,
860}
861
862impl SaeResidentReducedSchur {
863    /// Stage the per-row `G_i = L_iᵀ (H_tt^(i)+ρ_t I)⁻¹ L_i` blocks once, from
864    /// the SAE structure (`DeviceSaePcgData`: `p`, per-row `a_phi`, per-row
865    /// row-major `local_jac` = `L_i`) and the already-factored `H_tt` slab.
866    ///
867    /// Returns `None` when the structure does not match (degenerate `p`, row
868    /// count mismatch) so the caller falls back to the generic matvec. Row
869    /// builds are independent and run under the same deterministic rayon
870    /// discipline as the matvec (each `G_i` is self-contained — no cross-row
871    /// reduction — so there is no ordering subtlety).
872    /// `ridge_t` is NOT a parameter: it is already folded into the factored
873    /// blocks `htt_factors` carry (they factor `H_tt^(i) + ridge_t·I` — see
874    /// `factor_blocks`), so solving against the factor yields `(H_tt^(i)+ρ_t I)⁻¹`
875    /// exactly. The residency block is a pure function of the factor and `L_i`.
876    pub(crate) fn build<B: BatchedBlockSolver + Sync>(
877        sys: &ArrowSchurSystem,
878        htt_factors: &ArrowFactorSlab,
879        backend: &B,
880    ) -> Option<Self> {
881        let data = sys.device_sae_pcg.as_ref()?;
882        let p = data.p;
883        let n = sys.rows.len();
884        if p == 0
885            || sys.htbeta_dense_supplement
886            || data.a_phi.len() != n
887            || data.local_jac.len() != n
888        {
889            return None;
890        }
891        let empty = || ResidentRowFactor {
892            di: 0,
893            y: Vec::new(),
894        };
895        let build_row = |row: usize| -> ResidentRowFactor {
896            let di = sys.row_dims[row];
897            let jac = &data.local_jac[row];
898            // q_i = len/p; must match the row's latent dimension di.
899            if p == 0 || jac.len() != di * p || di == 0 {
900                return empty();
901            }
902            // L_i as a (di × p) matrix (row-major in `local_jac`).
903            let l_i = match ArrayView2::from_shape((di, p), jac.as_slice()) {
904                Ok(v) => v.to_owned(),
905                Err(_) => return empty(),
906            };
907            // Solve (H_tt+ρ_t I) Y = L_i for Y (di × p): one batched back-solve
908            // over the p columns against the cached factor. Stage `(L_i, Y_i)`
909            // — NOT the dense `p×p` product `G_i = L_iᵀ Y_i` — so storage and the
910            // matvec stay `O(di·p)` instead of `O(p²)` (`di ≪ p` for SAE rows).
911            let y = backend.solve_block_matrix(htt_factors.factor(row), l_i.view());
912            // Flatten the SOLVED factor to a `di × p` row-major buffer (iteration
913            // over a standard-layout view is row-major regardless of the source
914            // strides, so the hot loop can index `r*p + c` directly). `L_i` is NOT
915            // copied — the matvec reads it from the shared `local_jac` slab (it is
916            // byte-for-byte `data.local_jac[row]`).
917            let y_flat: Vec<f64> = y.iter().copied().collect();
918            ResidentRowFactor { di, y: y_flat }
919        };
920        let rows: Vec<ResidentRowFactor> =
921            if n >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
922                use rayon::prelude::*;
923                (0..n).into_par_iter().map(build_row).collect()
924            } else {
925                (0..n).map(build_row).collect()
926            };
927        Some(Self {
928            p,
929            rows,
930            a_phi: data.a_phi_shared(),
931            local_jac: data.local_jac_shared(),
932        })
933    }
934
935    /// Accumulate one row's `S_i x = P_iᵀ G_i (P_i x) = P_iᵀ L_iᵀ Y_i (P_i x)`
936    /// into `acc` (length `K`). `gather`/`prod` are caller-owned length-`p`
937    /// buffers and `w` a caller-owned `≥ max_i di`-length buffer, all reused
938    /// across rows to keep the hot loop allocation-free. The matvec applies the
939    /// factored block in four steps: sparse gather `P_i x = Σ_s φ_s·x[base_s..]`
940    /// (`support_i·p` flops), `w = Y_i·(P_i x)` (`di`-length, `di·p` flops),
941    /// `prod = L_iᵀ·w` (`p`-length, `di·p` flops), and sparse scatter
942    /// `acc += P_iᵀ prod` (`support_i·p` flops) — `2·support_i·p + 2·di·p`
943    /// total, never the dense `p²` product. The gather/scatter `2·support_i·p`
944    /// term is counted: it is not dominated by the GEMVs when the active support
945    /// is wide.
946    #[inline]
947    pub(crate) fn row_into(
948        &self,
949        row: usize,
950        x: &Array1<f64>,
951        acc: &mut Array1<f64>,
952        gather: &mut [f64],
953        prod: &mut [f64],
954        w: &mut [f64],
955    ) {
956        let rf = &self.rows[row];
957        let di = rf.di;
958        if di == 0 {
959            return;
960        }
961        let p = self.p;
962        let support = &self.a_phi[row];
963        if support.is_empty() {
964            return;
965        }
966        // P_i x = Σ_s φ_s · x[base_s .. base_s+p]   (length p).
967        for v in gather.iter_mut() {
968            *v = 0.0;
969        }
970        for &(base, phi) in support {
971            if phi == 0.0 {
972                continue;
973            }
974            for j in 0..p {
975                gather[j] += phi * x[base + j];
976            }
977        }
978        // w = Y_i · (P_i x)   (di × p GEMV → length di).  Y_i row-major di×p.
979        for r in 0..di {
980            let yrow = &rf.y[r * p..r * p + p];
981            let mut s = 0.0_f64;
982            for c in 0..p {
983                s += yrow[c] * gather[c];
984            }
985            w[r] = s;
986        }
987        // prod = L_iᵀ · w   (p × di GEMV → length p).  L_i row-major di×p, so
988        // L_iᵀ[j,r] = L_i[r,j]; accumulate column-by-column over the di rows.
989        // `L_i` is the shared `local_jac[row]` slab (#1033) — byte-for-byte the
990        // former per-row `rf.l` copy.
991        let l_i = &self.local_jac[row];
992        for v in prod.iter_mut().take(p) {
993            *v = 0.0;
994        }
995        for r in 0..di {
996            let lrow = &l_i[r * p..r * p + p];
997            let wr = w[r];
998            for j in 0..p {
999                prod[j] += lrow[j] * wr;
1000            }
1001        }
1002        // acc += P_iᵀ prod = scatter φ_s · prod into base_s blocks.
1003        for &(base, phi) in support {
1004            if phi == 0.0 {
1005                continue;
1006            }
1007            for j in 0..p {
1008                acc[base + j] += phi * prod[j];
1009            }
1010        }
1011    }
1012
1013    /// Max row latent dim `di` across resident rows — the size of the `w`
1014    /// scratch the matvec needs for the inner `Y_i·(P_i x)` GEMV.
1015    pub(crate) fn max_di(&self) -> usize {
1016        self.rows.iter().map(|r| r.di).max().unwrap_or(0)
1017    }
1018}
1019
1020/// Reduced-Schur matvec `out = S·x` with an optional pre-staged SAE residency
1021/// operator. When `resident` is `Some`, the per-row point-elimination term is
1022/// applied through the resident `p×p` blocks (#1017 CPU residency); otherwise it
1023/// falls back to the generic per-row `apply → solve → transpose` path. Both
1024/// routes accumulate the SAME reduced operator
1025/// `S = H_ββ + ρ_β I − Σ_i H_βt^(i)(H_tt^(i))⁻¹H_tβ^(i)`.
1026pub(crate) fn schur_matvec<B: BatchedBlockSolver + Sync>(
1027    sys: &ArrowSchurSystem,
1028    htt_factors: &ArrowFactorSlab,
1029    ridge_beta: f64,
1030    x: &Array1<f64>,
1031    out: &mut Array1<f64>,
1032    backend: &B,
1033    resident: Option<&SaeResidentReducedSchur>,
1034) {
1035    // `steihaug_cg` reuses one output buffer across iterations and requires
1036    // `matvec` to ASSIGN every entry of `out` (the contract `dense_matvec`
1037    // upholds). This routine builds `S·x` purely by accumulation
1038    // (`penalty_matvec_add`, `out[a] += ridge·x`, `out[a] -= neg_contrib`), so it
1039    // MUST clear `out` first. Without this, iteration n>0 returns `S·x` plus the
1040    // previous call's `S·p`, the PCG solves a corrupted reduced system, and the
1041    // resulting Newton step is inconsistent with the assembled gradient
1042    // (g·δ ≈ 0 — a non-descent direction that defeats the line search).
1043    out.fill(0.0);
1044    let k = sys.k;
1045    // Top-level (not nested in a rayon worker) and big enough to amortize the
1046    // fan-out: the single gate that authorizes BOTH the dense penalty-prologue
1047    // GEMV and the per-row point-elimination loop to go parallel. The topology
1048    // race fans candidates with `run_topology_race_parallel`, so inside a worker
1049    // both stay sequential (no nested-rayon oversubscription).
1050    let parallel =
1051        sys.rows.len() >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
1052    // Route the penalty-side (H_ββ + ridge·I) x product through the prologue:
1053    // no Arc-clone hot-path cost when penalty_op is None (falls back to hbb
1054    // inline); the dense fallback fans across cores at the wide SAE border (#1017).
1055    {
1056        let x_slice = x.as_slice().expect("x must be contiguous");
1057        let out_slice = out.as_slice_mut().expect("out must be contiguous");
1058        sys.penalty_ridge_prologue_into(x_slice, ridge_beta, out_slice, parallel);
1059    }
1060    // The reduced-Schur point-elimination term: `out -= Σ_i H_βt^(i) (H_tt^(i))⁻¹
1061    // H_tβ^(i) x`. Each row contributes an independent length-`K` vector, so for
1062    // the SAE LLM shape (#1017) this is the matvec's whole cost and is
1063    // embarrassingly parallel. Run it under rayon over fixed row chunks, summing
1064    // the per-chunk partials in chunk order so the f64 reduction is bit-identical
1065    // run-to-run regardless of thread scheduling (the #1017 verification gate).
1066    // This is deterministic and within the chunk-reassociation margin of serial,
1067    // so the criterion ranking is stable except for candidates that tie inside
1068    // that f64 margin — not an exact no-move guarantee (#1211). Stay
1069    // sequential when already inside a rayon worker (the topology race fans
1070    // candidates with `run_topology_race_parallel`) to avoid nested-rayon
1071    // oversubscription — the same guard `HyperOperator::mul_mat` uses. The
1072    // `parallel` gate above authorizes this loop too.
1073    let p = resident.map(|r| r.p).unwrap_or(0);
1074    if parallel {
1075        use rayon::prelude::*;
1076        const CHUNK: usize = 64;
1077        let n = sys.rows.len();
1078        let partials: Vec<Array1<f64>> = (0..n)
1079            .into_par_iter()
1080            .chunks(CHUNK)
1081            .map(|idxs| {
1082                let mut acc = Array1::<f64>::zeros(k);
1083                if let Some(res) = resident {
1084                    // Resident path: each matvec is gather → factored di×p GEMVs
1085                    // → scatter, reading only the pre-staged `(L_i, Y_i)` (no
1086                    // per-iteration solve, no dense p×p block).
1087                    let mut gather = vec![0.0_f64; p];
1088                    let mut prod = vec![0.0_f64; p];
1089                    let mut w = vec![0.0_f64; res.max_di()];
1090                    for i in idxs {
1091                        res.row_into(i, x, &mut acc, &mut gather, &mut prod, &mut w);
1092                    }
1093                } else {
1094                    let mut local = Array1::<f64>::zeros(sys.d);
1095                    for i in idxs {
1096                        schur_matvec_row_into(
1097                            sys,
1098                            htt_factors,
1099                            x,
1100                            backend,
1101                            i,
1102                            &mut local,
1103                            &mut acc,
1104                        );
1105                    }
1106                }
1107                acc
1108            })
1109            .collect();
1110        // Deterministic ordered reduction: fold chunk partials left-to-right.
1111        for acc in &partials {
1112            for a in 0..k {
1113                out[a] -= acc[a];
1114            }
1115        }
1116    } else if let Some(res) = resident {
1117        let mut acc = Array1::<f64>::zeros(k);
1118        let mut gather = vec![0.0_f64; p];
1119        let mut prod = vec![0.0_f64; p];
1120        let mut w = vec![0.0_f64; res.max_di()];
1121        for i in 0..sys.rows.len() {
1122            res.row_into(i, x, &mut acc, &mut gather, &mut prod, &mut w);
1123        }
1124        for a in 0..k {
1125            out[a] -= acc[a];
1126        }
1127    } else {
1128        // Allocate scratch at max_d; per-row slice is `..di`.
1129        let mut local = Array1::<f64>::zeros(sys.d);
1130        let mut neg_contrib = Array1::<f64>::zeros(k);
1131        for i in 0..sys.rows.len() {
1132            neg_contrib.fill(0.0);
1133            schur_matvec_row_into(
1134                sys,
1135                htt_factors,
1136                x,
1137                backend,
1138                i,
1139                &mut local,
1140                &mut neg_contrib,
1141            );
1142            for a in 0..k {
1143                out[a] -= neg_contrib[a];
1144            }
1145        }
1146    }
1147}
1148
1149/// Accumulate one row's reduced-Schur point-elimination contribution
1150/// `H_βt^(i) (H_tt^(i))⁻¹ H_tβ^(i) x` (length `K`) into `acc`.
1151///
1152/// `local` is caller-owned `≥ sys.d`-length scratch (reused across rows to keep
1153/// the hot loop allocation-free); only `..di` is touched. `acc` is **added to**,
1154/// never cleared, so the caller controls whether contributions sum into a chunk
1155/// partial (parallel path) or a per-row buffer (sequential path).
1156#[inline]
1157pub(crate) fn schur_matvec_row_into<B: BatchedBlockSolver>(
1158    sys: &ArrowSchurSystem,
1159    htt_factors: &ArrowFactorSlab,
1160    x: &Array1<f64>,
1161    backend: &B,
1162    i: usize,
1163    local: &mut Array1<f64>,
1164    acc: &mut Array1<f64>,
1165) {
1166    let row = &sys.rows[i];
1167    let di = sys.row_dims[i];
1168    // H_tβ^(i) · x → local[..di], routed through sys.htbeta_matvec
1169    // when the dense block is absent.
1170    let mut local_i = local.slice_mut(ndarray::s![..di]).to_owned();
1171    local_i.fill(0.0);
1172    sys_htbeta_apply_row(sys, i, row, x.view(), &mut local_i);
1173    let solved = backend.solve_block_vector(htt_factors.factor(i), local_i.view());
1174    // H_βt^(i) · solved accumulates into acc (length k).  Routed through
1175    // sys.htbeta_matvec when needed.
1176    sys_htbeta_accumulate_transpose(sys, i, row, solved.view(), acc);
1177}
1178
1179/// One per-term block factor for the block-Jacobi Schur preconditioner.
1180///
1181/// Carries either a dense Cholesky factor (for PD blocks ≤ 256 columns) or
1182/// the scalar inverses for that block's diagonal as a fallback.
1183#[derive(Clone)]
1184pub(crate) enum BlockFactor {
1185    /// Cholesky L stored column-major via faer. `range` identifies the
1186    /// columns in the full K-vector this block covers.
1187    Chol {
1188        factor: FaerLlt<f64>,
1189        range: Range<usize>,
1190    },
1191    /// Scalar fallback: per-element `1/s_aa` for each column in `range`.
1192    Scalar {
1193        inv: Array1<f64>,
1194        range: Range<usize>,
1195    },
1196}
1197
1198impl std::fmt::Debug for BlockFactor {
1199    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1200        match self {
1201            BlockFactor::Chol { range, .. } => {
1202                write!(f, "BlockFactor::Chol {{ range: {:?} }}", range)
1203            }
1204            BlockFactor::Scalar { inv, range } => {
1205                write!(
1206                    f,
1207                    "BlockFactor::Scalar {{ inv.len: {}, range: {:?} }}",
1208                    inv.len(),
1209                    range
1210                )
1211            }
1212        }
1213    }
1214}
1215
1216/// Block-Jacobi Schur preconditioner for BA's inexact reduced-system PCG.
1217///
1218/// When [`ArrowSchurSystem::block_offsets`] is populated (via
1219/// [`ArrowSchurSystem::set_block_offsets`]) and the largest block has ≤ 256
1220/// columns, builds one small dense Schur block per term, factors it with
1221/// Cholesky (faer LLT), and applies the preconditioner as per-block
1222/// triangular solves.  Non-PD blocks fall back to scalar diagonal inversion
1223/// for that block only.  When `block_offsets` is empty or the largest block
1224/// exceeds 256 columns the preconditioner reduces to pure scalar-diagonal
1225/// Jacobi (pre-#283 behaviour), so callers that have not called
1226/// `set_block_offsets` are unaffected.
1227///
1228/// The `block_offsets` plumbing is compatible with issue #287 (custom
1229/// `ParameterBlockSpec` families): those callers supply ranges derived from
1230/// their own block layout.
1231#[derive(Debug, Clone)]
1232pub struct JacobiPreconditioner {
1233    pub(crate) blocks: Vec<BlockFactor>,
1234}
1235
1236/// Maximum block size for which we attempt dense block-Jacobi factorization.
1237pub(crate) const BLOCK_JACOBI_MAX_BLOCK: usize = 256;
1238
1239/// Positive-definiteness floor on a Schur-complement Jacobi diagonal entry.
1240/// A diagonal at or below this value (or non-finite) signals a non-PD reduced
1241/// system: the preconditioner cannot invert it, so the PCG solve fails loudly
1242/// and demands operator regularization rather than returning a garbage scale.
1243pub(crate) const JACOBI_DIAGONAL_PD_FLOOR: f64 = 1e-18;
1244
1245impl JacobiPreconditioner {
1246    /// Build the block-Jacobi (or scalar fallback) preconditioner from the
1247    /// Arrow-Schur system without materializing the full dense Schur
1248    /// complement.
1249    ///
1250    /// When `sys.block_offsets` is non-empty and `max(block_size) ≤ 256`,
1251    /// each block gets a dense `b×b` Schur sub-matrix formed, factored, and
1252    /// stored.  Otherwise every column gets its own scalar entry.
1253    pub(crate) fn from_arrow_schur<B: BatchedBlockSolver + Sync>(
1254        sys: &ArrowSchurSystem,
1255        htt_factors: &ArrowFactorSlab,
1256        ridge_beta: f64,
1257        backend: &B,
1258        resident: Option<&SaeResidentReducedSchur>,
1259    ) -> Result<Self, ArrowSchurError> {
1260        let use_block = !sys.block_offsets.is_empty()
1261            && sys
1262                .block_offsets
1263                .iter()
1264                .map(|r| r.end.saturating_sub(r.start))
1265                .max()
1266                .unwrap_or(0)
1267                <= BLOCK_JACOBI_MAX_BLOCK;
1268        if use_block {
1269            if let Some(res) = resident {
1270                Self::build_block_jacobi_resident(sys, ridge_beta, res)
1271            } else {
1272                Self::build_block_jacobi(sys, htt_factors, ridge_beta, backend)
1273            }
1274        } else if let Some(res) = resident {
1275            // #1017 — SAE residency scalar Jacobi. The generic scalar build
1276            // probes `H_tβ^(i) e_a` and re-solves `(H_tt^(i))⁻¹` once for EVERY
1277            // (row, β-column) pair: `O(n·K)` triangular solves and `O(n·K·p)`
1278            // operator-probe work per Newton step, with `K = K_atoms·p` in the
1279            // tens of thousands at LLM shapes. The reduced-Schur diagonal is the
1280            // same quotient the resident `(L_i, Y_i)` factors already carry, so
1281            // read the diagonal straight off them in one support-sparse pass —
1282            // no probe, no per-column solve.
1283            Self::build_scalar_jacobi_resident(sys, ridge_beta, res)
1284        } else {
1285            Self::build_scalar_jacobi(sys, htt_factors, ridge_beta, backend)
1286        }
1287    }
1288
1289    /// Build scalar-diagonal Jacobi: one `BlockFactor::Scalar` of length 1
1290    /// per column.  Matches pre-#283 semantics.
1291    ///
1292    /// When `sys.htbeta_matvec` is set and per-row `htbeta` slabs are absent,
1293    /// each column is probed via the matvec (one call per column per row).
1294    pub(crate) fn build_scalar_jacobi<B: BatchedBlockSolver + Sync>(
1295        sys: &ArrowSchurSystem,
1296        htt_factors: &ArrowFactorSlab,
1297        ridge_beta: f64,
1298        backend: &B,
1299    ) -> Result<Self, ArrowSchurError> {
1300        let k = sys.k;
1301        // Extract diagonal of H_ββ via penalty_diagonal_add (#296):
1302        // no Arc-clone; falls back to hbb_diag or hbb[[a,a]] inline.
1303        let mut diag = Array1::<f64>::zeros(k);
1304        {
1305            let diag_slice = diag.as_slice_mut().expect("diag must be contiguous");
1306            sys.penalty_diagonal_add(diag_slice);
1307        }
1308        for a in 0..k {
1309            diag[a] += ridge_beta;
1310        }
1311        // Per-row body: subtract this row's `Σ_a (H_tβ^(i)e_a)ᵀ(H_tt^(i))⁻¹
1312        // (H_tβ^(i)e_a)` contribution into a caller-provided length-`K` diagonal
1313        // accumulator (`-=`). For each column `a`, probe the cross-block (or read
1314        // the dense slab) and compute the scalar point-elimination quotient. The
1315        // `O(K)` solves per row are the build's whole cost; the row contributions
1316        // are independent length-`K` vectors, so a worker sums a chunk into a
1317        // private `diag_part` and the caller folds the partials back in chunk
1318        // order — bit-identical run-to-run (the #1017 preconditioner gate).
1319        let row_into = |i: usize, row: &ArrowRowBlock, diag_part: &mut Array1<f64>| {
1320            let di = sys.row_dims[i];
1321            // Dense-slab fast path (#1017): when the per-row cross-block is a
1322            // materialized `di × k` slab (no matrix-free operator), the entire
1323            // reduced-Schur diagonal contribution for this row is
1324            // `Σ_c H_tβ[c,a] · ((H_tt)⁻¹ H_tβ)[c,a]`. The generic loop below
1325            // re-solved `(H_tt)⁻¹` once PER COLUMN — `O(k)` block solves + `O(k)`
1326            // allocations per row, i.e. `O(n·k)` tiny solves per Newton step
1327            // (the dominant fixed per-solve cost at the SAE wide-border shape,
1328            // k in the tens of thousands). Solve all `k` columns in ONE batched
1329            // block solve instead, then take the column dots. Reassociates the
1330            // diagonal within the documented #1211 preconditioner margin (same as
1331            // the resident no-probe path), and the preconditioner only steers the
1332            // PCG iterate, which still terminates at the PCG tolerance.
1333            if sys.htbeta_matvec.is_none() && row.htbeta.dim() == (di, k) {
1334                let solved = backend.solve_block_matrix(htt_factors.factor(i), row.htbeta.view());
1335                for a in 0..k {
1336                    let mut acc = 0.0;
1337                    for c in 0..di {
1338                        acc += row.htbeta[[c, a]] * solved[[c, a]];
1339                    }
1340                    diag_part[a] -= acc;
1341                }
1342                return;
1343            }
1344            // Matrix-free path: probe column a. `e_a` stays all-zero between
1345            // columns — set the single active entry and reset it after the probe,
1346            // so we never pay the `O(k)` `e_a.fill(0.0)` per column (that fill was
1347            // `O(n·k²)`). `sys_htbeta_apply_row` zeroes `col_i` internally.
1348            let mut col_i = Array1::<f64>::zeros(di);
1349            let mut e_a = Array1::<f64>::zeros(k);
1350            for a in 0..k {
1351                e_a[a] = 1.0;
1352                sys_htbeta_apply_row(sys, i, row, e_a.view(), &mut col_i);
1353                e_a[a] = 0.0;
1354                let solved = backend.solve_block_vector(htt_factors.factor(i), col_i.view());
1355                let mut acc = 0.0;
1356                for c in 0..di {
1357                    acc += col_i[c] * solved[c];
1358                }
1359                diag_part[a] -= acc;
1360            }
1361        };
1362        let n = sys.rows.len();
1363        let parallel =
1364            n >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
1365        if parallel {
1366            use rayon::prelude::*;
1367            const CHUNK: usize = 64;
1368            let partials: Vec<Array1<f64>> = (0..n)
1369                .into_par_iter()
1370                .chunks(CHUNK)
1371                .map(|idxs| {
1372                    let mut diag_part = Array1::<f64>::zeros(k);
1373                    for i in idxs {
1374                        row_into(i, &sys.rows[i], &mut diag_part);
1375                    }
1376                    diag_part
1377                })
1378                .collect();
1379            // Deterministic ordered reduction: fold chunk partials left-to-right.
1380            for part in &partials {
1381                for a in 0..k {
1382                    diag[a] += part[a];
1383                }
1384            }
1385        } else {
1386            for (i, row) in sys.rows.iter().enumerate() {
1387                row_into(i, row, &mut diag);
1388            }
1389        }
1390        let mut blocks = Vec::with_capacity(k);
1391        for a in 0..k {
1392            let v = diag[a];
1393            if !v.is_finite() || v <= JACOBI_DIAGONAL_PD_FLOOR {
1394                return Err(ArrowSchurError::PcgFailed {
1395                    reason: format!(
1396                        "invalid Schur Jacobi diagonal at index {a}: {v}; \
1397                         operator regularization is required"
1398                    ),
1399                });
1400            }
1401            blocks.push(BlockFactor::Scalar {
1402                inv: Array1::from_elem(1, 1.0 / v),
1403                range: a..a + 1,
1404            });
1405        }
1406        Ok(Self { blocks })
1407    }
1408
1409    /// Build scalar-diagonal Jacobi from the pre-staged SAE residency factors
1410    /// `(L_i, Y_i)` (#1017).
1411    ///
1412    /// The generic [`Self::build_scalar_jacobi`] forms each reduced-Schur
1413    /// diagonal entry `S_aa = H_ββ,aa + ρ − Σ_i (H_tβ^(i) e_a)ᵀ(H_tt^(i))⁻¹(H_tβ^(i) e_a)`
1414    /// by probing the cross-block operator with the unit vector `e_a` and
1415    /// re-solving `(H_tt^(i))⁻¹` for every `(row, column)` pair — `O(n·K)`
1416    /// triangular solves per Newton step. For the SAE Kronecker cross-block the
1417    /// `a`-th column lives on exactly one active support entry: `a = beta_base + j`
1418    /// for some `(beta_base, φ) ∈ a_phi[i]` and output channel `j ∈ 0..p`, with
1419    /// `H_tβ^(i) e_a = φ · L_i[:, j]`. The point-elimination quotient is then
1420    ///
1421    /// ```text
1422    /// (H_tβ^(i) e_a)ᵀ (H_tt^(i))⁻¹ (H_tβ^(i) e_a)
1423    ///     = φ² · L_i[:, j]ᵀ (H_tt^(i))⁻¹ L_i[:, j]
1424    ///     = φ² · (L_i[:, j] · Y_i[:, j]),          Y_i := (H_tt^(i))⁻¹ L_i.
1425    /// ```
1426    ///
1427    /// so the whole diagonal is accumulated in ONE support-sparse pass over the
1428    /// resident factors — no probe, no per-column solve, the staged `Y_i` reused
1429    /// from the matvec residency. The result is the SAME quotient the generic
1430    /// path computes (up to float reassociation of the row sum), so the PCG
1431    /// preconditioner is unchanged up to that f64 margin. Since the preconditioner
1432    /// only steers the iterate (which still terminates at the PCG tolerance), the
1433    /// criterion ranking is stable except for candidates within that margin,
1434    /// where the near-tie winner can flip — not an exact no-move guarantee (#1211).
1435    pub(crate) fn build_scalar_jacobi_resident(
1436        sys: &ArrowSchurSystem,
1437        ridge_beta: f64,
1438        resident: &SaeResidentReducedSchur,
1439    ) -> Result<Self, ArrowSchurError> {
1440        let k = sys.k;
1441        let p = resident.p;
1442        let n = resident.rows.len();
1443        // Seed with diag(H_ββ) + ridge — same penalty source the generic path
1444        // reads, so the only difference is how the point-elimination term is
1445        // gathered.
1446        let mut diag = Array1::<f64>::zeros(k);
1447        {
1448            let diag_slice = diag.as_slice_mut().expect("diag must be contiguous");
1449            sys.penalty_diagonal_add(diag_slice);
1450        }
1451        for a in 0..k {
1452            diag[a] += ridge_beta;
1453        }
1454        // Per-row point-elimination diagonal: for each active support entry
1455        // `(beta_base, φ)` and channel `j`, subtract `φ² · L_i[:, j]·Y_i[:, j]`
1456        // into `diag[beta_base + j]`. `L_i`/`Y_i` are row-major `di × p`, so the
1457        // `j`-th column dot is `Σ_r L_i[r·p + j]·Y_i[r·p + j]`.
1458        //
1459        // The accumulation is into a SHARED `diag` (rows scatter into overlapping
1460        // `beta_base + j` columns), so — like the generic `build_scalar_jacobi`
1461        // and the `schur_matvec` row loop (#1017) — parallelism uses worker-private
1462        // length-`K` partials folded back in chunk order: each chunk is a
1463        // contiguous ascending row range and rows within it stay ascending, so the
1464        // chunk-ordered fold reproduces the serial `row = 0..n` subtraction order
1465        // bit-for-bit run-to-run (the #1017 determinism gate). Run-to-run
1466        // bit-identity does not extend to bit-identity with the in-place serial
1467        // accumulation, so the preconditioner — and any criterion ranking it
1468        // steers — is stable only up to the chunk-reassociation margin; a near-tie
1469        // winner inside that margin can flip (#1211).
1470        // This build runs once per inexact-PCG solve = O(inner-Newton-iters)
1471        // per fit; at the SAE LLM shape (thousands of rows, wide border `k`) the
1472        // per-row support sweep is the build's whole cost and was on one core.
1473        // The per-channel column dot `col_dot[j] = Σ_r L_i[r·p+j]·Y_i[r·p+j]`
1474        // (the diagonal of `G_i = L_iᵀ(H_tt)⁻¹L_i`) depends ONLY on the row `i`,
1475        // not on the support entry `(beta_base, φ)`. The previous loop recomputed
1476        // it once per support entry — a row with `m` active atoms paid `m·p`
1477        // column dots over `di`. Hoist it: compute the `p` column dots once per
1478        // row into reusable `col_dot` scratch, then each support entry is a pure
1479        // scatter `diag[beta_base+j] -= φ²·col_dot[j]`. Bit-for-bit identical:
1480        // each `col_dot[j]` is the same `r`-ascending sum, and `φ²·col_dot[j]`
1481        // yields identical bits whether `col_dot[j]` was just computed or cached.
1482        let row_into = |row: usize, diag_part: &mut [f64], col_dot: &mut [f64]| {
1483            let rf = &resident.rows[row];
1484            let di = rf.di;
1485            if di == 0 {
1486                return;
1487            }
1488            let support = &resident.a_phi[row];
1489            if support.is_empty() {
1490                return;
1491            }
1492            // `L_i` is the shared `local_jac[row]` slab (#1033) — byte-for-byte
1493            // the former per-row `rf.l` copy.
1494            let l_i = &resident.local_jac[row];
1495            for (j, slot) in col_dot.iter_mut().enumerate().take(p) {
1496                let mut acc = 0.0_f64;
1497                for r in 0..di {
1498                    let idx = r * p + j;
1499                    acc += l_i[idx] * rf.y[idx];
1500                }
1501                *slot = acc;
1502            }
1503            for &(beta_base, phi) in support {
1504                if phi == 0.0 {
1505                    continue;
1506                }
1507                let phi2 = phi * phi;
1508                for j in 0..p {
1509                    diag_part[beta_base + j] -= phi2 * col_dot[j];
1510                }
1511            }
1512        };
1513        let parallel =
1514            n >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
1515        if parallel {
1516            use rayon::prelude::*;
1517            const CHUNK: usize = 64;
1518            let partials: Vec<Array1<f64>> = (0..n)
1519                .into_par_iter()
1520                .chunks(CHUNK)
1521                .map(|idxs| {
1522                    let mut diag_part = Array1::<f64>::zeros(k);
1523                    let mut col_dot = vec![0.0_f64; p];
1524                    let slice = diag_part
1525                        .as_slice_mut()
1526                        .expect("diag_part must be contiguous");
1527                    for i in idxs {
1528                        row_into(i, slice, &mut col_dot);
1529                    }
1530                    diag_part
1531                })
1532                .collect();
1533            // Deterministic ordered reduction: fold chunk partials left-to-right
1534            // (each partial already holds the per-row terms subtracted, so add
1535            // them into `diag` in chunk order to mirror the serial subtraction).
1536            for part in &partials {
1537                for a in 0..k {
1538                    diag[a] += part[a];
1539                }
1540            }
1541        } else {
1542            let diag_slice = diag.as_slice_mut().expect("diag must be contiguous");
1543            let mut col_dot = vec![0.0_f64; p];
1544            for row in 0..n {
1545                row_into(row, diag_slice, &mut col_dot);
1546            }
1547        }
1548        let mut blocks = Vec::with_capacity(k);
1549        for a in 0..k {
1550            let v = diag[a];
1551            if !v.is_finite() || v <= JACOBI_DIAGONAL_PD_FLOOR {
1552                return Err(ArrowSchurError::PcgFailed {
1553                    reason: format!(
1554                        "invalid SAE-resident Schur Jacobi diagonal at index {a}: {v}; \
1555                         operator regularization is required"
1556                    ),
1557                });
1558            }
1559            blocks.push(BlockFactor::Scalar {
1560                inv: Array1::from_elem(1, 1.0 / v),
1561                range: a..a + 1,
1562            });
1563        }
1564        Ok(Self { blocks })
1565    }
1566
1567    /// Build block-Jacobi from the pre-staged SAE residency factors `(L_i, Y_i)`.
1568    ///
1569    /// This is the block analogue of [`Self::build_scalar_jacobi_resident`].
1570    /// When SAE block offsets are small enough to select BetaBlockJacobi (for
1571    /// example per-atom decoder blocks with `basis_size·p <= 256`), the generic
1572    /// block builder materializes every row's dense `(d_i × K)` `H_tβ` by probing
1573    /// the matrix-free operator, then re-solves `(H_tt)⁻¹` for each block column.
1574    /// The resident factors already carry `G_i = L_iᵀ(H_tt)⁻¹L_i`, so each block
1575    /// is assembled by scattering only the active support pairs inside that block:
1576    ///
1577    /// ```text
1578    /// S_block -= Σ_i Σ_(s,t in block support) φ_s φ_t · G_i[channel_s, channel_t]
1579    /// ```
1580    ///
1581    /// It computes the same block-diagonal restriction as the generic path, but
1582    /// avoids the full-row `H_tβ` materialization and per-column triangular solves.
1583    pub(crate) fn build_block_jacobi_resident(
1584        sys: &ArrowSchurSystem,
1585        ridge_beta: f64,
1586        resident: &SaeResidentReducedSchur,
1587    ) -> Result<Self, ArrowSchurError> {
1588        let block_offsets = &sys.block_offsets;
1589        let p = resident.p;
1590        let mut schur_blocks: Vec<Array2<f64>> = Vec::with_capacity(block_offsets.len());
1591        for (block_idx, range) in block_offsets.iter().enumerate() {
1592            let b = range.end - range.start;
1593            let mut schur_block = Array2::<f64>::zeros((b, b));
1594            sys.penalty_block_add(
1595                BetaBlockId(block_idx),
1596                block_offsets.as_ref(),
1597                &mut schur_block,
1598            );
1599            for bi in 0..b {
1600                schur_block[[bi, bi]] += ridge_beta;
1601            }
1602            schur_blocks.push(schur_block);
1603        }
1604
1605        let row_into = |row: usize, blocks: &mut [Array2<f64>]| {
1606            let rf = &resident.rows[row];
1607            let di = rf.di;
1608            if di == 0 {
1609                return;
1610            }
1611            let support = &resident.a_phi[row];
1612            if support.is_empty() {
1613                return;
1614            }
1615            // `L_i` is the shared `local_jac[row]` slab (#1033) — byte-for-byte
1616            // the former per-row `rf.l` copy.
1617            let l_i = &resident.local_jac[row];
1618            for (block_idx, range) in block_offsets.iter().enumerate() {
1619                let block = &mut blocks[block_idx];
1620                for &(base_left, phi_left) in support {
1621                    if phi_left == 0.0 {
1622                        continue;
1623                    }
1624                    let left_start = base_left.max(range.start);
1625                    let left_end = (base_left + p).min(range.end);
1626                    if left_start >= left_end {
1627                        continue;
1628                    }
1629                    for &(base_right, phi_right) in support {
1630                        if phi_right == 0.0 {
1631                            continue;
1632                        }
1633                        let right_start = base_right.max(range.start);
1634                        let right_end = (base_right + p).min(range.end);
1635                        if right_start >= right_end {
1636                            continue;
1637                        }
1638                        let phi = phi_left * phi_right;
1639                        for gi in left_start..left_end {
1640                            let li = gi - range.start;
1641                            let ch_i = gi - base_left;
1642                            for gj in right_start..right_end {
1643                                let lj = gj - range.start;
1644                                let ch_j = gj - base_right;
1645                                let mut gij = 0.0_f64;
1646                                for r in 0..di {
1647                                    gij += l_i[r * p + ch_i] * rf.y[r * p + ch_j];
1648                                }
1649                                block[[li, lj]] -= phi * gij;
1650                            }
1651                        }
1652                    }
1653                }
1654            }
1655        };
1656
1657        let n = resident.rows.len();
1658        let parallel =
1659            n >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
1660        if parallel {
1661            use rayon::prelude::*;
1662            const CHUNK: usize = 64;
1663            let n_blocks = block_offsets.len();
1664            let block_dims: Vec<usize> = block_offsets.iter().map(|r| r.end - r.start).collect();
1665            let partials: Vec<Vec<Array2<f64>>> = (0..n)
1666                .into_par_iter()
1667                .chunks(CHUNK)
1668                .map(|idxs| {
1669                    let mut local: Vec<Array2<f64>> = block_dims
1670                        .iter()
1671                        .map(|&b| Array2::<f64>::zeros((b, b)))
1672                        .collect();
1673                    for i in idxs {
1674                        row_into(i, &mut local);
1675                    }
1676                    local
1677                })
1678                .collect();
1679            for local in &partials {
1680                for bidx in 0..n_blocks {
1681                    schur_blocks[bidx] += &local[bidx];
1682                }
1683            }
1684        } else {
1685            for row in 0..n {
1686                row_into(row, &mut schur_blocks);
1687            }
1688        }
1689
1690        let mut blocks = Vec::with_capacity(block_offsets.len());
1691        for (block_idx, range) in block_offsets.iter().enumerate() {
1692            let b = range.end - range.start;
1693            let schur_block = &schur_blocks[block_idx];
1694            let factor_opt = {
1695                use faer::Side;
1696                let view = FaerArrayView::new(schur_block);
1697                FaerLlt::new(view.as_ref(), Side::Lower).ok()
1698            };
1699            if let Some(llt) = factor_opt {
1700                blocks.push(BlockFactor::Chol {
1701                    factor: llt,
1702                    range: range.clone(),
1703                });
1704            } else {
1705                let mut inv = Array1::<f64>::zeros(b);
1706                for bi in 0..b {
1707                    let v = schur_block[[bi, bi]];
1708                    if !v.is_finite() || v <= JACOBI_DIAGONAL_PD_FLOOR {
1709                        return Err(ArrowSchurError::PcgFailed {
1710                            reason: format!(
1711                                "SAE-resident block Jacobi scalar fallback: non-PD diagonal at \
1712                                 global index {}: {v}; regularization required",
1713                                range.start + bi
1714                            ),
1715                        });
1716                    }
1717                    inv[bi] = 1.0 / v;
1718                }
1719                blocks.push(BlockFactor::Scalar {
1720                    inv,
1721                    range: range.clone(),
1722                });
1723            }
1724        }
1725        Ok(Self { blocks })
1726    }
1727
1728    /// Build term-block Jacobi: one dense `b×b` Schur block per term in
1729    /// `sys.block_offsets`.
1730    pub(crate) fn build_block_jacobi<B: BatchedBlockSolver + Sync>(
1731        sys: &ArrowSchurSystem,
1732        htt_factors: &ArrowFactorSlab,
1733        ridge_beta: f64,
1734        backend: &B,
1735    ) -> Result<Self, ArrowSchurError> {
1736        let block_offsets = &sys.block_offsets;
1737
1738        // Initialise every b×b Schur sub-block from H_ββ + ridge·I via
1739        // penalty_block_add (#296): routes to penalty_op or falls back to
1740        // hbb / hbb_diag inline without Arc-clone per loop iteration. These are
1741        // the block-diagonal restrictions of the reduced Schur complement; the
1742        // per-row cross-block contributions are accumulated in the row sweep
1743        // below.
1744        let mut schur_blocks: Vec<Array2<f64>> = Vec::with_capacity(block_offsets.len());
1745        for (block_idx, range) in block_offsets.iter().enumerate() {
1746            let b = range.end - range.start;
1747            let mut schur_block = Array2::<f64>::zeros((b, b));
1748            sys.penalty_block_add(
1749                BetaBlockId(block_idx),
1750                block_offsets.as_ref(),
1751                &mut schur_block,
1752            );
1753            for bi in 0..b {
1754                schur_block[[bi, bi]] += ridge_beta;
1755            }
1756            schur_blocks.push(schur_block);
1757        }
1758
1759        // Subtract Schur contributions:
1760        // S_kk -= H_βt_k^(i) (H_tt^(i))^{-1} H_tβ_k^(i)
1761        //
1762        // Materialize each row's (d_i × K) cross-block ONCE and scatter its
1763        // contribution into every block-diagonal sub-block — mirroring the
1764        // row-outer structure of `build_dense_schur_direct`. The previous
1765        // block-outer form re-materialized every row for each β-block
1766        // (O(n_blocks · n · K) probes); for the matrix-free softmax cross-block
1767        // each materialize is itself O(K²), so that nesting made the
1768        // preconditioner build quadratically more expensive than the direct
1769        // dense Schur it preconditions. sys_htbeta_materialize_row handles the
1770        // Kronecker / htbeta_matvec path transparently.
1771        // Per-row body: materialize the row's `(d_i × K)` cross-block once and
1772        // subtract its `H_βt_k^(i)(H_tt^(i))⁻¹H_tβ_k^(i)` contribution into EACH
1773        // block-diagonal sub-block. Writes INTO a caller-provided `blocks`
1774        // accumulator (`-=`) so a rayon worker can subtract a chunk's rows into
1775        // a worker-private zero-seeded `Vec<Array2>` and the caller folds the
1776        // chunk partials back in chunk order — bit-identical run-to-run
1777        // regardless of thread scheduling (the #1017 verification gate). This
1778        // is deterministic and within the chunk-reassociation margin of serial,
1779        // so the preconditioner, hence the criterion ranking, is stable except
1780        // for near-tie candidates inside that f64 margin — not an exact no-move
1781        // guarantee (#1211).
1782        let row_into = |i: usize,
1783                        row: &ArrowRowBlock,
1784                        blocks: &mut [Array2<f64>]|
1785         -> Result<(), ArrowSchurError> {
1786            let di = sys.row_dims[i];
1787            let htbeta_full = sys_htbeta_materialize_row(sys, i, row)?;
1788            for (block_idx, range) in block_offsets.iter().enumerate() {
1789                let b = range.end - range.start;
1790                let mut solved_cols = Array2::<f64>::zeros((di, b));
1791                for bj in 0..b {
1792                    let gj = range.start + bj;
1793                    let rhs = htbeta_full.column(gj).to_owned();
1794                    let solved = backend.solve_block_vector(htt_factors.factor(i), rhs.view());
1795                    for c in 0..di {
1796                        solved_cols[[c, bj]] = solved[c];
1797                    }
1798                }
1799                let schur_block = &mut blocks[block_idx];
1800                for bi in 0..b {
1801                    let gi = range.start + bi;
1802                    for bj in 0..b {
1803                        let mut acc = 0.0;
1804                        for c in 0..di {
1805                            acc += htbeta_full[[c, gi]] * solved_cols[[c, bj]];
1806                        }
1807                        schur_block[[bi, bj]] -= acc;
1808                    }
1809                }
1810            }
1811            Ok(())
1812        };
1813        // Each row materializes an `O(K²)` cross-block (Kronecker) plus `Σ_k b_k`
1814        // triangular solves — the preconditioner build's whole per-row cost at
1815        // the SAE LLM shape (#1017), and the rows are independent. Fan over fixed
1816        // row chunks above the threshold, staying serial for the handful-of-rows
1817        // non-SAE callers and inside a rayon worker (topology-race nesting guard)
1818        // — the same gate `schur_matvec` uses.
1819        let n = sys.rows.len();
1820        let parallel =
1821            n >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
1822        if parallel {
1823            use rayon::prelude::*;
1824            const CHUNK: usize = 64;
1825            let n_blocks = block_offsets.len();
1826            let block_dims: Vec<usize> = block_offsets.iter().map(|r| r.end - r.start).collect();
1827            let partials: Vec<Vec<Array2<f64>>> = (0..n)
1828                .into_par_iter()
1829                .chunks(CHUNK)
1830                .map(|idxs| {
1831                    let mut local: Vec<Array2<f64>> = block_dims
1832                        .iter()
1833                        .map(|&b| Array2::<f64>::zeros((b, b)))
1834                        .collect();
1835                    for i in idxs {
1836                        row_into(i, &sys.rows[i], &mut local)?;
1837                    }
1838                    Ok::<_, ArrowSchurError>(local)
1839                })
1840                .collect::<Result<Vec<_>, _>>()?;
1841            // Deterministic ordered reduction: fold chunk partials left-to-right.
1842            for local in &partials {
1843                for bidx in 0..n_blocks {
1844                    schur_blocks[bidx] += &local[bidx];
1845                }
1846            }
1847        } else {
1848            for (i, row) in sys.rows.iter().enumerate() {
1849                row_into(i, row, &mut schur_blocks)?;
1850            }
1851        }
1852
1853        // Factor each accumulated block: LLT, with scalar-diagonal fallback for
1854        // a block that comes out non-PD at this ridge.
1855        let mut blocks = Vec::with_capacity(block_offsets.len());
1856        for (block_idx, range) in block_offsets.iter().enumerate() {
1857            let b = range.end - range.start;
1858            let schur_block = &schur_blocks[block_idx];
1859            let factor_opt = {
1860                use faer::Side;
1861                let view = FaerArrayView::new(schur_block);
1862                FaerLlt::new(view.as_ref(), Side::Lower).ok()
1863            };
1864            if let Some(llt) = factor_opt {
1865                blocks.push(BlockFactor::Chol {
1866                    factor: llt,
1867                    range: range.clone(),
1868                });
1869            } else {
1870                // Non-PD block: fall back to scalar diagonal for this block.
1871                let mut inv = Array1::<f64>::zeros(b);
1872                for bi in 0..b {
1873                    let v = schur_block[[bi, bi]];
1874                    if !v.is_finite() || v <= JACOBI_DIAGONAL_PD_FLOOR {
1875                        return Err(ArrowSchurError::PcgFailed {
1876                            reason: format!(
1877                                "block Jacobi scalar fallback: non-PD diagonal at \
1878                                 global index {}: {v}; regularization required",
1879                                range.start + bi
1880                            ),
1881                        });
1882                    }
1883                    inv[bi] = 1.0 / v;
1884                }
1885                blocks.push(BlockFactor::Scalar {
1886                    inv,
1887                    range: range.clone(),
1888                });
1889            }
1890        }
1891        Ok(Self { blocks })
1892    }
1893
1894    pub(crate) fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
1895        let mut out = Array1::<f64>::zeros(r.len());
1896        for block in &self.blocks {
1897            match block {
1898                BlockFactor::Scalar { inv, range } => {
1899                    for (local, gi) in range.clone().enumerate() {
1900                        out[gi] = inv[local] * r[gi];
1901                    }
1902                }
1903                BlockFactor::Chol { factor, range } => {
1904                    let b = range.end - range.start;
1905                    let mut rhs = Array1::<f64>::zeros(b);
1906                    for (local, gi) in range.clone().enumerate() {
1907                        rhs[local] = r[gi];
1908                    }
1909                    use faer::linalg::solvers::Solve;
1910                    let stride = rhs.strides()[0];
1911                    let len = rhs.len();
1912                    // SAFETY: rhs is a uniquely-borrowed contiguous Array1
1913                    // with positive stride (standard layout).
1914                    let rhs_mat =
1915                        unsafe { faer::MatRef::from_raw_parts(rhs.as_ptr(), len, 1, stride, 0) };
1916                    let solved = factor.solve(rhs_mat);
1917                    for (local, gi) in range.clone().enumerate() {
1918                        out[gi] = solved[(local, 0)];
1919                    }
1920                }
1921            }
1922        }
1923        out
1924    }
1925}
1926
1927// ---------------------------------------------------------------------------
1928// Preconditioner ladder: SchurPreconditionerKind, ClusterJacobi,
1929// AdditiveSchwarz  (issue #299)
1930// ---------------------------------------------------------------------------
1931
1932/// Which Schur preconditioner to use in the inexact-PCG path.
1933///
1934/// Ladder ordered by cost / effectiveness:
1935/// - `Diagonal`: scalar Jacobi (pre-#283 behaviour).
1936/// - `BetaBlockJacobi`: block-Jacobi per `block_offsets` term (#287).
1937/// - `ClusterJacobi`: one dense block per beta-graph connected component.
1938/// - `AdditiveSchwarz { overlap }`: component + `overlap`-hop expansion,
1939///   overlapping columns averaged by partition-of-unity weights (full dense
1940///   local-inverse apply per subdomain).
1941/// - `DiagAssembledSchwarz { overlap }`: the cheap Schwarz variant (#299) —
1942///   same overlapping decomposition, but each subdomain contributes only the
1943///   diagonal of its local inverse `(A_k⁻¹)_ii`, assembled additively with
1944///   partition-of-unity weights into a single `O(K)`-apply diagonal.
1945/// - `BlockIncompleteCholesky`: level-0 incomplete Cholesky (#299). Within each
1946///   connected component of the β-coupling graph the dense reduced-Schur block
1947///   `S[C,C]` is assembled once, its structural-nonzero pattern is taken as the
1948///   level-0 fill pattern, and a no-fill incomplete Cholesky `S ≈ L̃ L̃ᵀ` is
1949///   formed keeping ONLY that pattern (Saad, *Iterative Methods*, IC(0)). Apply
1950///   is a sparse triangular forward/back solve over `nnz(S[C,C])`, so for a
1951///   large component with internal sparsity it is far cheaper to build and apply
1952///   than `ClusterJacobi`'s full dense Cholesky (which fills the whole `b×b`
1953///   factor) while retaining the inter-block coupling that ClusterJacobi keeps
1954///   but the diagonal/Schwarz tiers discard. A non-PD incomplete pivot degrades
1955///   that component to the scalar reciprocal diagonal.
1956#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1957pub enum SchurPreconditionerKind {
1958    Diagonal,
1959    BetaBlockJacobi,
1960    ClusterJacobi,
1961    AdditiveSchwarz { overlap: usize },
1962    DiagAssembledSchwarz { overlap: usize },
1963    BlockIncompleteCholesky,
1964}
1965
1966/// Escalate beyond BetaBlockJacobi only when K exceeds this value and PCG
1967/// exhausted `max_iterations`.
1968pub(crate) const PRECOND_ESCALATE_K_THRESHOLD: usize = 100;
1969
1970/// #1026 matrix-free Schur curvature-floor (the unbounded-PCG analogue of the
1971/// dense `spectral_pd_floored_schur`). On `pᵀSp ≤ 0` in the unbounded SAE inner
1972/// PCG, the operator ridge is lifted by the minimal amount that restores
1973/// positive curvature along the offending direction, plus this fractional
1974/// margin (so the next CG iterate sits strictly inside the positive cone, not on
1975/// the `0` knife-edge).
1976pub(crate) const SCHUR_CURVATURE_FLOOR_MARGIN: f64 = 1.0e-2;
1977/// Lower bound on the curvature-floor ridge bump, relative to the rhs scale, so
1978/// a `pᵀSp` that rounds to exactly `0` still gets a strictly positive bump.
1979pub(crate) const SCHUR_CURVATURE_FLOOR_REL_FLOOR: f64 = 1.0e-12;
1980/// Ceiling on the accumulated curvature-floor ridge, relative to the rhs scale.
1981/// Beyond this the operator is treated as un-conditionable by a minimal floor
1982/// and the recoverable failure is handed to the outer LM loop (which re-forms
1983/// the whole system at a heavier ridge). Generous so that a large collapsed
1984/// over-subtraction `(H_tβ)²/H_tt` is still reachable.
1985pub(crate) const SCHUR_CURVATURE_FLOOR_REL_CEILING: f64 = 1.0e12;
1986/// Multiplicative growth for the DIAGONAL-refusal ridge escalation (no
1987/// `(curvature, ‖p‖²)` deficit is available there), matching the per-row
1988/// `factor_one_row_result` `RIDGE_GROWTH_FACTOR`.
1989pub(crate) const SCHUR_CURVATURE_FLOOR_DIAG_GROWTH: f64 = 10.0;
1990/// Max curvature-floor ridge-lift attempts before deferring to the outer LM
1991/// loop. The diagonal-refusal path grows ×10 per attempt, so this bounds the
1992/// reachable ridge at `rhs_scale · 10^(attempts)` — ample for any realistic
1993/// over-subtraction while still bounded.
1994pub(crate) const SCHUR_CURVATURE_FLOOR_MAX_ATTEMPTS: usize = 24;
1995
1996/// Cholesky or scalar factor for one cluster of the beta-coefficient graph.
1997#[derive(Clone)]
1998pub(crate) enum ClusterFactor {
1999    Chol {
2000        cols: Vec<usize>,
2001        factor: FaerLlt<f64>,
2002    },
2003    Scalar {
2004        cols: Vec<usize>,
2005        inv: Vec<f64>,
2006    },
2007}
2008
2009impl std::fmt::Debug for ClusterFactor {
2010    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2011        match self {
2012            ClusterFactor::Chol { cols, .. } => {
2013                write!(f, "ClusterFactor::Chol {{ cols.len: {} }}", cols.len())
2014            }
2015            ClusterFactor::Scalar { cols, inv } => write!(
2016                f,
2017                "ClusterFactor::Scalar {{ cols.len: {}, inv.len: {} }}",
2018                cols.len(),
2019                inv.len()
2020            ),
2021        }
2022    }
2023}
2024
2025/// Maximum columns per cluster before scalar fallback.
2026pub(crate) const CLUSTER_JACOBI_MAX_CLUSTER: usize = 512;
2027
2028/// Maximum columns in a single connected component for which the IC(0)
2029/// preconditioner assembles the dense `S[C,C]` to derive its sparsity pattern.
2030/// IC(0) is cheap to APPLY at any size, but the pattern is read from the dense
2031/// assembly, which is `O(b²)` memory; beyond this the component falls back to
2032/// the scalar reciprocal diagonal (the same ceiling concern as
2033/// `CLUSTER_JACOBI_MAX_CLUSTER`, lifted because the IC(0) FACTOR is sparse).
2034pub(crate) const IC0_MAX_COMPONENT: usize = 4096;
2035
2036/// Relative threshold below which an assembled `S[i,j]` is treated as a
2037/// structural zero when deriving the IC(0) level-0 pattern. Scaled by
2038/// `sqrt(|S_ii|·|S_jj|)` so it is invariant to column scaling; this prunes
2039/// entries that are pure FMA round-off (a genuinely decoupled `(i,j)` pair
2040/// assembles to ~0) so they do not enter the kept fill pattern.
2041pub(crate) const IC0_PATTERN_REL_DROP: f64 = 1.0e-13;
2042
2043/// Assemble the dense `b×b` reduced-Schur block for the column set `cols`:
2044/// `S[cols, cols] = H_ββ[cols, cols] + ridge·I − Σ_i H_tβ[cols]ᵀ (H_tt^i)⁻¹ H_tβ[cols]`.
2045///
2046/// Shared by `ClusterJacobiPreconditioner::build_from_column_groups` (which
2047/// Cholesky-factors the returned block) and `DiagAssembledSchwarzPreconditioner`
2048/// (which inverts each subdomain block and keeps only its diagonal). The result
2049/// is the LOWER triangle filled by the row reduction; callers that need the full
2050/// symmetric block must `symmetrize_upper_from_lower`.
2051///
2052/// The per-row Schur contribution is fanned over fixed 64-row chunks above
2053/// `SCHUR_MATVEC_PARALLEL_ROW_MIN` and folded left-to-right so the assembly is
2054/// bit-identical to the serial path (and run-to-run deterministic), exactly as
2055/// in `build_block_jacobi` (#1017).
2056pub(crate) fn assemble_local_schur_block<B: BatchedBlockSolver + Sync>(
2057    sys: &ArrowSchurSystem,
2058    htt_factors: &ArrowFactorSlab,
2059    ridge_beta: f64,
2060    backend: &B,
2061    cols: &[usize],
2062) -> Array2<f64> {
2063    let d = sys.d;
2064    let b = cols.len();
2065    let mut s_block = Array2::<f64>::zeros((b, b));
2066    // Initialise from H_ββ via penalty_subblock_add (#296): routes through
2067    // penalty_op or falls back to hbb / hbb_diag inline.
2068    sys.penalty_subblock_add(cols, &mut s_block);
2069    for bi in 0..b {
2070        s_block[[bi, bi]] += ridge_beta;
2071    }
2072    let cluster_row_into = |row_idx: usize, row: &ArrowRowBlock, acc: &mut Array2<f64>| {
2073        let mut col_vec = Array1::<f64>::zeros(d);
2074        let mut solved_cols = Array2::<f64>::zeros((d, b));
2075        for bj in 0..b {
2076            let gj = cols[bj];
2077            for c in 0..d {
2078                col_vec[c] = row.htbeta[[c, gj]];
2079            }
2080            let solved = backend.solve_block_vector(htt_factors.factor(row_idx), col_vec.view());
2081            for c in 0..d {
2082                solved_cols[[c, bj]] = solved[c];
2083            }
2084        }
2085        for bi in 0..b {
2086            let gi = cols[bi];
2087            for bj in 0..b {
2088                let mut dot = 0.0;
2089                for c in 0..d {
2090                    dot += row.htbeta[[c, gi]] * solved_cols[[c, bj]];
2091                }
2092                acc[[bi, bj]] -= dot;
2093            }
2094        }
2095    };
2096    let n = sys.rows.len();
2097    let parallel = n >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
2098    if parallel {
2099        use rayon::prelude::*;
2100        const CHUNK: usize = 64;
2101        let partials: Vec<Array2<f64>> = (0..n)
2102            .into_par_iter()
2103            .chunks(CHUNK)
2104            .map(|idxs| {
2105                let mut local = Array2::<f64>::zeros((b, b));
2106                for i in idxs {
2107                    cluster_row_into(i, &sys.rows[i], &mut local);
2108                }
2109                local
2110            })
2111            .collect();
2112        for local in &partials {
2113            s_block += local;
2114        }
2115    } else {
2116        for (row_idx, row) in sys.rows.iter().enumerate() {
2117            cluster_row_into(row_idx, row, &mut s_block);
2118        }
2119    }
2120    s_block
2121}
2122
2123/// Dense Schur block per connected component of the beta-coupling graph.
2124///
2125/// Nodes = beta blocks (`block_offsets`); edges = rows where two blocks
2126/// co-occur with nonzero `H_t_beta` entries. One Cholesky factor per
2127/// connected component; applied as a triangular solve.
2128#[derive(Debug, Clone)]
2129pub struct ClusterJacobiPreconditioner {
2130    pub(crate) clusters: Vec<ClusterFactor>,
2131}
2132
2133impl ClusterJacobiPreconditioner {
2134    pub fn from_arrow_schur<B: BatchedBlockSolver + Sync>(
2135        sys: &ArrowSchurSystem,
2136        htt_factors: &ArrowFactorSlab,
2137        ridge_beta: f64,
2138        backend: &B,
2139    ) -> Result<Self, ArrowSchurError> {
2140        if sys.block_offsets.is_empty() {
2141            let cols: Vec<usize> = (0..sys.k).collect();
2142            return Self::build_from_column_groups(sys, htt_factors, ridge_beta, backend, &[cols]);
2143        }
2144        let graph = BetaCouplingGraph::build(
2145            &sys.block_offsets,
2146            &sys.rows
2147                .iter()
2148                .map(|r| r.htbeta.clone())
2149                .collect::<Vec<_>>(),
2150        );
2151        let col_groups: Vec<Vec<usize>> = graph
2152            .component_partition()
2153            .iter()
2154            .map(|comp_blocks| {
2155                let mut cols: Vec<usize> = comp_blocks
2156                    .iter()
2157                    .flat_map(|&b| sys.block_offsets[b].clone())
2158                    .collect();
2159                cols.sort_unstable();
2160                cols
2161            })
2162            .collect();
2163        Self::build_from_column_groups(sys, htt_factors, ridge_beta, backend, &col_groups)
2164    }
2165
2166    pub(crate) fn build_from_column_groups<B: BatchedBlockSolver + Sync>(
2167        sys: &ArrowSchurSystem,
2168        htt_factors: &ArrowFactorSlab,
2169        ridge_beta: f64,
2170        backend: &B,
2171        col_groups: &[Vec<usize>],
2172    ) -> Result<Self, ArrowSchurError> {
2173        let mut clusters = Vec::with_capacity(col_groups.len());
2174        for cols in col_groups {
2175            let b = cols.len();
2176            if b == 0 {
2177                continue;
2178            }
2179            if b > CLUSTER_JACOBI_MAX_CLUSTER {
2180                let inv = build_schur_scalar_inv(sys, htt_factors, ridge_beta, backend, cols)?;
2181                clusters.push(ClusterFactor::Scalar {
2182                    cols: cols.clone(),
2183                    inv,
2184                });
2185                continue;
2186            }
2187            let mut s_block =
2188                assemble_local_schur_block(sys, htt_factors, ridge_beta, backend, cols);
2189            symmetrize_upper_from_lower(&mut s_block);
2190            let factor_opt = {
2191                use faer::Side;
2192                let view = FaerArrayView::new(&s_block);
2193                FaerLlt::new(view.as_ref(), Side::Lower).ok()
2194            };
2195            if let Some(llt) = factor_opt {
2196                clusters.push(ClusterFactor::Chol {
2197                    cols: cols.clone(),
2198                    factor: llt,
2199                });
2200            } else {
2201                let inv = build_schur_scalar_inv(sys, htt_factors, ridge_beta, backend, cols)?;
2202                clusters.push(ClusterFactor::Scalar {
2203                    cols: cols.clone(),
2204                    inv,
2205                });
2206            }
2207        }
2208        Ok(Self { clusters })
2209    }
2210
2211    pub(crate) fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
2212        let mut out = Array1::<f64>::zeros(r.len());
2213        for cluster in &self.clusters {
2214            apply_cluster(cluster, r, &mut out, &ClusterApplyMode::Overwrite);
2215        }
2216        out
2217    }
2218}
2219
2220/// Additive Schwarz: base components expanded by `overlap` graph-hops;
2221/// overlapping columns averaged by partition-of-unity weights.
2222#[derive(Debug, Clone)]
2223pub struct AdditiveSchwarzPreconditioner {
2224    pub(crate) clusters: Vec<ClusterFactor>,
2225    pub(crate) weights: Vec<f64>,
2226}
2227
2228impl AdditiveSchwarzPreconditioner {
2229    pub fn from_arrow_schur<B: BatchedBlockSolver + Sync>(
2230        sys: &ArrowSchurSystem,
2231        htt_factors: &ArrowFactorSlab,
2232        ridge_beta: f64,
2233        backend: &B,
2234        overlap: usize,
2235    ) -> Result<Self, ArrowSchurError> {
2236        if sys.block_offsets.is_empty() {
2237            let cols: Vec<usize> = (0..sys.k).collect();
2238            let inner = ClusterJacobiPreconditioner::build_from_column_groups(
2239                sys,
2240                htt_factors,
2241                ridge_beta,
2242                backend,
2243                &[cols],
2244            )?;
2245            return Ok(Self {
2246                clusters: inner.clusters,
2247                weights: vec![1.0f64; sys.k],
2248            });
2249        }
2250        let graph = BetaCouplingGraph::build(
2251            &sys.block_offsets,
2252            &sys.rows
2253                .iter()
2254                .map(|r| r.htbeta.clone())
2255                .collect::<Vec<_>>(),
2256        );
2257        let col_groups: Vec<Vec<usize>> = graph
2258            .component_partition()
2259            .iter()
2260            .map(|seed| {
2261                let mut current = seed.clone();
2262                for _ in 0..overlap {
2263                    current = graph.expand_one_hop(&current);
2264                }
2265                let mut cols: Vec<usize> = current
2266                    .iter()
2267                    .flat_map(|&b| sys.block_offsets[b].clone())
2268                    .collect();
2269                cols.sort_unstable();
2270                cols.dedup();
2271                cols
2272            })
2273            .collect();
2274        let mut counts = vec![0u32; sys.k];
2275        for cols in &col_groups {
2276            for &gi in cols {
2277                counts[gi] += 1;
2278            }
2279        }
2280        let weights: Vec<f64> = counts
2281            .iter()
2282            .map(|&c| if c == 0 { 1.0 } else { 1.0 / c as f64 })
2283            .collect();
2284        let inner = ClusterJacobiPreconditioner::build_from_column_groups(
2285            sys,
2286            htt_factors,
2287            ridge_beta,
2288            backend,
2289            &col_groups,
2290        )?;
2291        Ok(Self {
2292            clusters: inner.clusters,
2293            weights,
2294        })
2295    }
2296
2297    pub(crate) fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
2298        let mut out = Array1::<f64>::zeros(r.len());
2299        for cluster in &self.clusters {
2300            apply_cluster(
2301                cluster,
2302                r,
2303                &mut out,
2304                &ClusterApplyMode::Accumulate {
2305                    weights: &self.weights,
2306                },
2307            );
2308        }
2309        out
2310    }
2311}
2312
2313/// Diagonal-assembled additive Schwarz (#299).
2314///
2315/// The cheap Schwarz variant the domain-decomposition literature recommends as
2316/// the default for sparse-coupling β-graphs: instead of storing and applying a
2317/// dense Cholesky factor per overlapping subdomain (as
2318/// [`AdditiveSchwarzPreconditioner`] does), it inverts each overlapping
2319/// subdomain Schur block ONCE at build time and keeps only the **diagonal of the
2320/// local inverse** `(A_k⁻¹)_ii`. Those per-subdomain diagonal contributions are
2321/// then assembled additively across overlapping subdomains with partition-of-
2322/// unity weights into a single global diagonal `m`, applied as `out[i] = m[i]·r[i]`.
2323///
2324/// This is strictly richer than scalar Jacobi (`1/S_ii`): the local inverse
2325/// diagonal `(A_k⁻¹)_ii` folds in the off-diagonal coupling WITHIN the subdomain,
2326/// so a strongly-coupled column gets a smaller (better-damped) effective scale
2327/// than its bare reciprocal diagonal would give — while the apply stays `O(K)`
2328/// (one multiply per column), unlike the `O(Σ b_k²)` triangular solves of dense
2329/// Schwarz. For `overlap = 0` and one column per subdomain it reduces exactly to
2330/// scalar Jacobi.
2331#[derive(Debug, Clone)]
2332pub struct DiagAssembledSchwarzPreconditioner {
2333    /// Global per-column multiplier `m[i]`; `out[i] = m[i] · r[i]`.
2334    pub(crate) inv_diag: Vec<f64>,
2335}
2336
2337impl DiagAssembledSchwarzPreconditioner {
2338    pub fn from_arrow_schur<B: BatchedBlockSolver + Sync>(
2339        sys: &ArrowSchurSystem,
2340        htt_factors: &ArrowFactorSlab,
2341        ridge_beta: f64,
2342        backend: &B,
2343        overlap: usize,
2344    ) -> Result<Self, ArrowSchurError> {
2345        // Build the overlapping subdomain column groups exactly like
2346        // AdditiveSchwarz (component partition + `overlap` graph-hop expansion),
2347        // so the two Schwarz variants decompose the β space identically and
2348        // differ only in how each subdomain's local inverse is applied.
2349        let col_groups: Vec<Vec<usize>> = if sys.block_offsets.is_empty() {
2350            vec![(0..sys.k).collect()]
2351        } else {
2352            let graph = BetaCouplingGraph::build(
2353                &sys.block_offsets,
2354                &sys.rows
2355                    .iter()
2356                    .map(|r| r.htbeta.clone())
2357                    .collect::<Vec<_>>(),
2358            );
2359            graph
2360                .component_partition()
2361                .iter()
2362                .map(|seed| {
2363                    let mut current = seed.clone();
2364                    for _ in 0..overlap {
2365                        current = graph.expand_one_hop(&current);
2366                    }
2367                    let mut cols: Vec<usize> = current
2368                        .iter()
2369                        .flat_map(|&b| sys.block_offsets[b].clone())
2370                        .collect();
2371                    cols.sort_unstable();
2372                    cols.dedup();
2373                    cols
2374                })
2375                .collect()
2376        };
2377        Self::build_from_column_groups(sys, htt_factors, ridge_beta, backend, &col_groups)
2378    }
2379
2380    pub(crate) fn build_from_column_groups<B: BatchedBlockSolver + Sync>(
2381        sys: &ArrowSchurSystem,
2382        htt_factors: &ArrowFactorSlab,
2383        ridge_beta: f64,
2384        backend: &B,
2385        col_groups: &[Vec<usize>],
2386    ) -> Result<Self, ArrowSchurError> {
2387        // Partition-of-unity weights: a column shared by `c` subdomains gets each
2388        // of its `c` diagonal contributions scaled by `1/c`, so the assembled
2389        // diagonal is a convex combination (and reduces to a single contribution
2390        // for non-overlapping columns).
2391        let mut counts = vec![0u32; sys.k];
2392        for cols in col_groups {
2393            for &gi in cols {
2394                counts[gi] += 1;
2395            }
2396        }
2397        let mut accum = vec![0.0f64; sys.k];
2398        for cols in col_groups {
2399            let b = cols.len();
2400            if b == 0 {
2401                continue;
2402            }
2403            // For large subdomains, the dense inverse is too costly; fall back to
2404            // the global scalar Schur diagonal inverse `1/S_ii` for those columns
2405            // (the diag-assembled variant then coincides with scalar Jacobi over
2406            // that subdomain, which is exactly the intended cheap degradation).
2407            if b > CLUSTER_JACOBI_MAX_CLUSTER {
2408                let inv = build_schur_scalar_inv(sys, htt_factors, ridge_beta, backend, cols)?;
2409                for (local, &gi) in cols.iter().enumerate() {
2410                    let w = if counts[gi] == 0 {
2411                        1.0
2412                    } else {
2413                        1.0 / counts[gi] as f64
2414                    };
2415                    accum[gi] += w * inv[local];
2416                }
2417                continue;
2418            }
2419            let mut s_block =
2420                assemble_local_schur_block(sys, htt_factors, ridge_beta, backend, cols);
2421            symmetrize_upper_from_lower(&mut s_block);
2422            // Diagonal of the local inverse `(A_k⁻¹)_ii`, obtained by solving
2423            // `A_k X = I` through the same faer Cholesky used elsewhere; on a
2424            // non-PD local block, degrade to the scalar reciprocal diagonal.
2425            let local_inv_diag = match local_inverse_diagonal(&s_block) {
2426                Some(diag) => diag,
2427                None => {
2428                    let inv = build_schur_scalar_inv(sys, htt_factors, ridge_beta, backend, cols)?;
2429                    inv
2430                }
2431            };
2432            for (local, &gi) in cols.iter().enumerate() {
2433                let w = if counts[gi] == 0 {
2434                    1.0
2435                } else {
2436                    1.0 / counts[gi] as f64
2437                };
2438                accum[gi] += w * local_inv_diag[local];
2439            }
2440        }
2441        // A column never covered by any subdomain (only possible for `k` columns
2442        // with no block_offsets coverage) keeps a neutral unit scale.
2443        for (gi, &c) in counts.iter().enumerate() {
2444            if c == 0 {
2445                accum[gi] = 1.0;
2446            }
2447        }
2448        for (gi, m) in accum.iter().enumerate() {
2449            if !m.is_finite() || *m <= 0.0 {
2450                return Err(ArrowSchurError::PcgFailed {
2451                    reason: format!(
2452                        "diag-assembled Schwarz: non-positive assembled diagonal at index {gi}: {m}"
2453                    ),
2454                });
2455            }
2456        }
2457        Ok(Self { inv_diag: accum })
2458    }
2459
2460    pub(crate) fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
2461        let mut out = Array1::<f64>::zeros(r.len());
2462        for (gi, &m) in self.inv_diag.iter().enumerate() {
2463            out[gi] = m * r[gi];
2464        }
2465        out
2466    }
2467}
2468
2469/// Diagonal of `A⁻¹` for a small dense SPD block `A`, via the same faer
2470/// Cholesky used by the cluster/Schwarz factors. Returns `None` if `A` is not
2471/// positive-definite (caller degrades to the scalar reciprocal diagonal).
2472pub(crate) fn local_inverse_diagonal(a: &Array2<f64>) -> Option<Vec<f64>> {
2473    let b = a.nrows();
2474    let llt = {
2475        use faer::Side;
2476        let view = FaerArrayView::new(a);
2477        FaerLlt::new(view.as_ref(), Side::Lower).ok()?
2478    };
2479    use faer::linalg::solvers::Solve;
2480    let mut diag = Vec::with_capacity(b);
2481    for col in 0..b {
2482        // Solve `A x = e_col`; the `col`-th entry of `x` is `(A⁻¹)_{col,col}`.
2483        let mut rhs = Array1::<f64>::zeros(b);
2484        rhs[col] = 1.0;
2485        let stride = rhs.strides()[0];
2486        let len = rhs.len();
2487        // SAFETY: `rhs` is a uniquely-borrowed contiguous `Array1<f64>` of `len`
2488        // elements with positive row stride; a single column never dereferences
2489        // the column stride, so `0` is sound.
2490        let rhs_mat = unsafe { faer::MatRef::from_raw_parts(rhs.as_ptr(), len, 1, stride, 0) };
2491        let solved = llt.solve(rhs_mat);
2492        diag.push(solved[(col, 0)]);
2493    }
2494    Some(diag)
2495}
2496
2497/// How a cluster factor's contribution is written into the output vector.
2498///
2499/// `Overwrite` assigns `out[gi] = value` (non-overlapping clusters, each global
2500/// column touched by exactly one cluster). `Accumulate` adds the partition-of-unity
2501/// weighted contribution `out[gi] += weights[gi] * value` (overlapping Schwarz
2502/// clusters, where a column may belong to several clusters).
2503pub(crate) enum ClusterApplyMode<'w> {
2504    Overwrite,
2505    Accumulate { weights: &'w [f64] },
2506}
2507
2508impl ClusterApplyMode<'_> {
2509    #[inline]
2510    pub(crate) fn write(&self, out: &mut Array1<f64>, gi: usize, value: f64) {
2511        match self {
2512            ClusterApplyMode::Overwrite => out[gi] = value,
2513            ClusterApplyMode::Accumulate { weights } => out[gi] += weights[gi] * value,
2514        }
2515    }
2516}
2517
2518/// Apply a single cluster factor to the residual `r`, writing into `out`
2519/// according to `mode` (overwrite for non-overlapping clusters, weighted
2520/// accumulate for overlapping Schwarz clusters).
2521pub(crate) fn apply_cluster(
2522    cluster: &ClusterFactor,
2523    r: &Array1<f64>,
2524    out: &mut Array1<f64>,
2525    mode: &ClusterApplyMode<'_>,
2526) {
2527    match cluster {
2528        ClusterFactor::Scalar { cols, inv } => {
2529            for (local, &gi) in cols.iter().enumerate() {
2530                mode.write(out, gi, inv[local] * r[gi]);
2531            }
2532        }
2533        ClusterFactor::Chol { cols, factor } => {
2534            let b = cols.len();
2535            let mut rhs = Array1::<f64>::zeros(b);
2536            for (local, &gi) in cols.iter().enumerate() {
2537                rhs[local] = r[gi];
2538            }
2539            use faer::linalg::solvers::Solve;
2540            let stride = rhs.strides()[0];
2541            let len = rhs.len();
2542            // SAFETY: rhs is uniquely-borrowed contiguous Array1 with positive stride.
2543            let rhs_mat = unsafe { faer::MatRef::from_raw_parts(rhs.as_ptr(), len, 1, stride, 0) };
2544            let solved = factor.solve(rhs_mat);
2545            for (local, &gi) in cols.iter().enumerate() {
2546                mode.write(out, gi, solved[(local, 0)]);
2547            }
2548        }
2549    }
2550}
2551
2552/// One connected-component factor of the block IC(0) preconditioner.
2553///
2554/// `IncompleteChol` holds a sparse lower-triangular `L̃` in column-compressed
2555/// form over the component's local indices: `col_ptr[j]..col_ptr[j+1]` indexes
2556/// into `(row_idx, val)` for column `j` (rows `>= j`, diagonal first). `cols`
2557/// maps a local index back to its global β column. `Scalar` is the non-PD /
2558/// oversized degradation, identical in meaning to [`ClusterFactor::Scalar`].
2559#[derive(Clone)]
2560pub(crate) enum Ic0Factor {
2561    IncompleteChol {
2562        cols: Vec<usize>,
2563        col_ptr: Vec<usize>,
2564        row_idx: Vec<usize>,
2565        val: Vec<f64>,
2566    },
2567    Scalar {
2568        cols: Vec<usize>,
2569        inv: Vec<f64>,
2570    },
2571}
2572
2573impl std::fmt::Debug for Ic0Factor {
2574    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2575        match self {
2576            Ic0Factor::IncompleteChol { cols, val, .. } => write!(
2577                f,
2578                "Ic0Factor::IncompleteChol {{ cols.len: {}, nnz: {} }}",
2579                cols.len(),
2580                val.len()
2581            ),
2582            Ic0Factor::Scalar { cols, .. } => {
2583                write!(f, "Ic0Factor::Scalar {{ cols.len: {} }}", cols.len())
2584            }
2585        }
2586    }
2587}
2588
2589/// Level-0 incomplete-Cholesky Schur preconditioner (#299).
2590///
2591/// One sparse incomplete-Cholesky factor per connected component of the
2592/// β-coupling graph. Within a component the dense `S[C,C]` is assembled, its
2593/// structural-nonzero pattern `P = { (i,j) : |S_ij| > drop·sqrt(S_ii S_jj) }`
2594/// is taken as the level-0 fill set, and the no-fill incomplete Cholesky
2595/// `S ≈ L̃ L̃ᵀ` is formed keeping only `P` (drop any update landing outside it).
2596/// See [`SchurPreconditionerKind::BlockIncompleteCholesky`].
2597#[derive(Debug, Clone)]
2598pub struct BlockIncompleteCholeskyPreconditioner {
2599    pub(crate) components: Vec<Ic0Factor>,
2600}
2601
2602impl BlockIncompleteCholeskyPreconditioner {
2603    pub fn from_arrow_schur<B: BatchedBlockSolver + Sync>(
2604        sys: &ArrowSchurSystem,
2605        htt_factors: &ArrowFactorSlab,
2606        ridge_beta: f64,
2607        backend: &B,
2608    ) -> Result<Self, ArrowSchurError> {
2609        // Column grouping mirrors ClusterJacobi: one group per connected
2610        // component of the β-coupling graph (whole-K single group when no
2611        // block_offsets are registered), so IC(0) preconditions exactly the
2612        // coupling ClusterJacobi keeps, but with a sparse (no-fill) factor.
2613        let col_groups: Vec<Vec<usize>> = if sys.block_offsets.is_empty() {
2614            vec![(0..sys.k).collect()]
2615        } else {
2616            let graph = BetaCouplingGraph::build(
2617                &sys.block_offsets,
2618                &sys.rows
2619                    .iter()
2620                    .map(|r| r.htbeta.clone())
2621                    .collect::<Vec<_>>(),
2622            );
2623            graph
2624                .component_partition()
2625                .iter()
2626                .map(|comp| {
2627                    let mut cols: Vec<usize> = comp
2628                        .iter()
2629                        .flat_map(|&blk| sys.block_offsets[blk].clone())
2630                        .collect();
2631                    cols.sort_unstable();
2632                    cols.dedup();
2633                    cols
2634                })
2635                .collect()
2636        };
2637
2638        let mut components = Vec::with_capacity(col_groups.len());
2639        for cols in &col_groups {
2640            let b = cols.len();
2641            if b == 0 {
2642                continue;
2643            }
2644            if b > IC0_MAX_COMPONENT {
2645                let inv = build_schur_scalar_inv(sys, htt_factors, ridge_beta, backend, cols)?;
2646                components.push(Ic0Factor::Scalar {
2647                    cols: cols.clone(),
2648                    inv,
2649                });
2650                continue;
2651            }
2652            let mut s_block =
2653                assemble_local_schur_block(sys, htt_factors, ridge_beta, backend, cols);
2654            symmetrize_upper_from_lower(&mut s_block);
2655            match incomplete_cholesky_level0(&s_block) {
2656                Some((col_ptr, row_idx, val)) => components.push(Ic0Factor::IncompleteChol {
2657                    cols: cols.clone(),
2658                    col_ptr,
2659                    row_idx,
2660                    val,
2661                }),
2662                None => {
2663                    // Non-PD incomplete pivot: degrade this component to the
2664                    // scalar reciprocal diagonal (mirrors the ClusterJacobi
2665                    // non-PD fallback), which is always applicable for a
2666                    // PD-floored Schur diagonal.
2667                    let inv = build_schur_scalar_inv(sys, htt_factors, ridge_beta, backend, cols)?;
2668                    components.push(Ic0Factor::Scalar {
2669                        cols: cols.clone(),
2670                        inv,
2671                    });
2672                }
2673            }
2674        }
2675        Ok(Self { components })
2676    }
2677
2678    pub(crate) fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
2679        let mut out = Array1::<f64>::zeros(r.len());
2680        for comp in &self.components {
2681            match comp {
2682                Ic0Factor::Scalar { cols, inv } => {
2683                    for (local, &gi) in cols.iter().enumerate() {
2684                        out[gi] = inv[local] * r[gi];
2685                    }
2686                }
2687                Ic0Factor::IncompleteChol {
2688                    cols,
2689                    col_ptr,
2690                    row_idx,
2691                    val,
2692                } => {
2693                    let b = cols.len();
2694                    // Gather the local residual, solve `L̃ L̃ᵀ z = r_local` by a
2695                    // sparse forward solve (`L̃ y = r`) then a sparse back solve
2696                    // (`L̃ᵀ z = y`), then scatter `z` back to global columns.
2697                    let mut z = vec![0.0f64; b];
2698                    for (local, &gi) in cols.iter().enumerate() {
2699                        z[local] = r[gi];
2700                    }
2701                    // Forward solve `L̃ y = r` (overwrite z with y). Column-major
2702                    // CSC: row_idx[col_ptr[j]] == j (diagonal stored first).
2703                    for j in 0..b {
2704                        let dstart = col_ptr[j];
2705                        let diag = val[dstart];
2706                        z[j] /= diag;
2707                        let yj = z[j];
2708                        for k in (dstart + 1)..col_ptr[j + 1] {
2709                            z[row_idx[k]] -= val[k] * yj;
2710                        }
2711                    }
2712                    // Back solve `L̃ᵀ z = y` (overwrite z). Walk columns in
2713                    // reverse; the below-diagonal entries of column j are the
2714                    // off-diagonal entries of row j of L̃ᵀ.
2715                    for j in (0..b).rev() {
2716                        let dstart = col_ptr[j];
2717                        let mut acc = z[j];
2718                        for k in (dstart + 1)..col_ptr[j + 1] {
2719                            acc -= val[k] * z[row_idx[k]];
2720                        }
2721                        z[j] = acc / val[dstart];
2722                    }
2723                    for (local, &gi) in cols.iter().enumerate() {
2724                        out[gi] = z[local];
2725                    }
2726                }
2727            }
2728        }
2729        out
2730    }
2731}
2732
2733/// Level-0 incomplete Cholesky of a dense SPD-ish block `a` (`b×b`, symmetric).
2734///
2735/// Returns the lower factor `L̃` in column-compressed (CSC) form
2736/// `(col_ptr, row_idx, val)` where each column lists its diagonal entry FIRST
2737/// followed by the strictly-below-diagonal entries, in increasing row order.
2738/// The kept pattern is the level-0 set `P` = structural nonzeros of `a` (a
2739/// relative drop threshold prunes round-off). IC(0) computes the standard
2740/// Cholesky recurrence but DROPS any value at a position outside `P`, so the
2741/// factor has exactly `nnz(tril(P))` entries — no fill. Returns `None` on a
2742/// non-positive pivot (caller degrades to scalar diagonal).
2743///
2744/// Reference: Y. Saad, *Iterative Methods for Sparse Linear Systems*, 2nd ed.,
2745/// §10.3.2 (IC(0)). This is the left-looking, pattern-restricted variant.
2746pub(crate) fn incomplete_cholesky_level0(
2747    a: &Array2<f64>,
2748) -> Option<(Vec<usize>, Vec<usize>, Vec<f64>)> {
2749    let b = a.nrows();
2750    assert_eq!(a.ncols(), b, "incomplete Cholesky needs a square block");
2751
2752    // ---- derive the level-0 lower-triangular pattern from `a` --------------
2753    // Per column j, the kept below-or-on-diagonal rows i>=j with a structurally
2754    // nonzero a[i,j]. The diagonal is always kept.
2755    let mut col_ptr = vec![0usize; b + 1];
2756    let mut row_idx: Vec<usize> = Vec::new();
2757    // value buffer, parallel to row_idx, initialised from tril(a) on the pattern
2758    let mut val: Vec<f64> = Vec::new();
2759    // For O(1) "is (i,j) in pattern + where" lookups during the recurrence, keep
2760    // a per-column map from global row -> position in that column's value slice.
2761    let mut col_pos: Vec<std::collections::HashMap<usize, usize>> = Vec::with_capacity(b);
2762    for j in 0..b {
2763        let ajj = a[[j, j]];
2764        let scale_j = ajj.abs().max(0.0).sqrt();
2765        let mut map = std::collections::HashMap::new();
2766        // diagonal first
2767        map.insert(j, val.len());
2768        row_idx.push(j);
2769        val.push(ajj);
2770        for i in (j + 1)..b {
2771            let aij = a[[i, j]];
2772            let scale_i = a[[i, i]].abs().sqrt();
2773            let thresh = IC0_PATTERN_REL_DROP * scale_i * scale_j;
2774            if aij.abs() > thresh {
2775                map.insert(i, val.len());
2776                row_idx.push(i);
2777                val.push(aij);
2778            }
2779        }
2780        col_pos.push(map);
2781        col_ptr[j + 1] = val.len();
2782    }
2783
2784    // ---- IC(0) recurrence, left-looking over columns -----------------------
2785    // For column j: subtract the contributions of all prior columns k<j that
2786    // have BOTH a nonzero at row j (so they touch the diagonal/the column) — the
2787    // multiplier L[j,k] — and a nonzero at the rows i of column j's pattern.
2788    // Any update whose target (i,j) is OUTSIDE the kept pattern is dropped.
2789    for j in 0..b {
2790        // Diagonal: a[j,j] - Σ_{k<j} L[j,k]². Each prior column k<j contributes
2791        // its row-j entry L[j,k] (looked up by row, so the column index is not
2792        // needed); columns without a row-j entry contribute nothing.
2793        let dpos = col_ptr[j];
2794        let mut diag = val[dpos];
2795        for mapk in &col_pos[..j] {
2796            if let Some(&pjk) = mapk.get(&j) {
2797                let ljk = val[pjk];
2798                diag -= ljk * ljk;
2799            }
2800        }
2801        if !diag.is_finite() || diag <= JACOBI_DIAGONAL_PD_FLOOR {
2802            return None;
2803        }
2804        let ljj = diag.sqrt();
2805        val[dpos] = ljj;
2806        // Below-diagonal of column j: L[i,j] = (a[i,j] - Σ_{k<j} L[i,k] L[j,k]) / L[j,j]
2807        for p in (dpos + 1)..col_ptr[j + 1] {
2808            let i = row_idx[p];
2809            let mut s = val[p];
2810            for mapk in &col_pos[..j] {
2811                if let (Some(&pik), Some(&pjk)) = (mapk.get(&i), mapk.get(&j)) {
2812                    s -= val[pik] * val[pjk];
2813                }
2814            }
2815            val[p] = s / ljj;
2816        }
2817    }
2818    Some((col_ptr, row_idx, val))
2819}
2820
2821/// One row of the #299 preconditioner-ladder iteration study: the converged
2822/// PCG iteration count and stop reason for a single preconditioner tier.
2823#[derive(Debug, Clone, Copy)]
2824pub struct PrecondLadderRow {
2825    /// PCG iterations to convergence (or to the `MaxIter` cutoff).
2826    pub iterations: usize,
2827    /// Whether the PCG converged (vs hit `MaxIter` / negative curvature).
2828    pub converged: bool,
2829    /// Final relative residual reported by the PCG.
2830    pub final_relative_residual: f64,
2831}
2832
2833/// Full #299 ladder iteration study on one reduced-Schur system: run the SAME
2834/// preconditioned CG (same `rhs`, tolerances, trust radius) once per ladder tier
2835/// and report the iteration count of each. This is the public seam the
2836/// `tests/owed_299.rs` iteration-reduction gate drives — it keeps the internal
2837/// `run_pcg_with_preconditioner` / preconditioner constructors `pub(crate)`
2838/// while exposing exactly the per-tier measurement the issue asks for.
2839///
2840/// Tiers (in escalation order): scalar `Diagonal`, `BetaBlockJacobi`,
2841/// `ClusterJacobi`, `AdditiveSchwarz{overlap:1}`, `DiagAssembledSchwarz{1}`, and
2842/// `BlockIncompleteCholesky`. A tier whose build fails (e.g. non-PD reduced
2843/// Schur with no curvature floor) reports `None` for that entry; every healthy
2844/// SPD reduced system populates all six.
2845pub fn arrow_precond_ladder_iteration_study(
2846    sys: &ArrowSchurSystem,
2847    ridge_beta: f64,
2848    rhs: &Array1<f64>,
2849    pcg: &ArrowPcgOptions,
2850    trust: &ArrowTrustRegionOptions,
2851) -> Result<Vec<(SchurPreconditionerKind, Option<PrecondLadderRow>)>, ArrowSchurError> {
2852    let backend = CpuBatchedBlockSolver;
2853    let htt_factors = backend.factor_blocks(&sys.rows, 0.0, sys.d, false)?;
2854
2855    let run = |apply: &dyn Fn(&Array1<f64>) -> Array1<f64>| -> Option<PrecondLadderRow> {
2856        let (_sol, diag) = run_pcg_with_preconditioner(
2857            sys,
2858            &htt_factors,
2859            ridge_beta,
2860            rhs,
2861            |r| apply(r),
2862            pcg,
2863            trust,
2864            &backend,
2865            None,
2866            None,
2867            None,
2868        )
2869        .ok()?;
2870        Some(PrecondLadderRow {
2871            iterations: diag.iterations,
2872            converged: matches!(diag.stopping_reason, PcgStopReason::Converged),
2873            final_relative_residual: diag.final_relative_residual,
2874        })
2875    };
2876
2877    let mut out: Vec<(SchurPreconditionerKind, Option<PrecondLadderRow>)> = Vec::with_capacity(6);
2878
2879    // Scalar Diagonal Jacobi: force the scalar path by clearing block_offsets on
2880    // a clone so the build does not pick up the per-block dense Schur blocks.
2881    let diag_row = {
2882        let mut bare = sys.clone();
2883        bare.set_block_offsets(std::sync::Arc::from([] as [Range<usize>; 0]));
2884        let bare_factors = backend.factor_blocks(&bare.rows, 0.0, bare.d, false)?;
2885        JacobiPreconditioner::from_arrow_schur(&bare, &bare_factors, ridge_beta, &backend, None)
2886            .ok()
2887            .and_then(|p| {
2888                run_pcg_with_preconditioner(
2889                    &bare,
2890                    &bare_factors,
2891                    ridge_beta,
2892                    rhs,
2893                    |r| p.apply(r),
2894                    pcg,
2895                    trust,
2896                    &backend,
2897                    None,
2898                    None,
2899                    None,
2900                )
2901                .ok()
2902                .map(|(_s, diag)| PrecondLadderRow {
2903                    iterations: diag.iterations,
2904                    converged: matches!(diag.stopping_reason, PcgStopReason::Converged),
2905                    final_relative_residual: diag.final_relative_residual,
2906                })
2907            })
2908    };
2909    out.push((SchurPreconditionerKind::Diagonal, diag_row));
2910
2911    let block_row =
2912        JacobiPreconditioner::from_arrow_schur(sys, &htt_factors, ridge_beta, &backend, None)
2913            .ok()
2914            .and_then(|p| run(&|r| p.apply(r)));
2915    out.push((SchurPreconditionerKind::BetaBlockJacobi, block_row));
2916
2917    let cluster_row =
2918        ClusterJacobiPreconditioner::from_arrow_schur(sys, &htt_factors, ridge_beta, &backend)
2919            .ok()
2920            .and_then(|p| run(&|r| p.apply(r)));
2921    out.push((SchurPreconditionerKind::ClusterJacobi, cluster_row));
2922
2923    let schwarz_row =
2924        AdditiveSchwarzPreconditioner::from_arrow_schur(sys, &htt_factors, ridge_beta, &backend, 1)
2925            .ok()
2926            .and_then(|p| run(&|r| p.apply(r)));
2927    out.push((
2928        SchurPreconditionerKind::AdditiveSchwarz { overlap: 1 },
2929        schwarz_row,
2930    ));
2931
2932    let diag_schwarz_row = DiagAssembledSchwarzPreconditioner::from_arrow_schur(
2933        sys,
2934        &htt_factors,
2935        ridge_beta,
2936        &backend,
2937        1,
2938    )
2939    .ok()
2940    .and_then(|p| run(&|r| p.apply(r)));
2941    out.push((
2942        SchurPreconditionerKind::DiagAssembledSchwarz { overlap: 1 },
2943        diag_schwarz_row,
2944    ));
2945
2946    let ic0_row = BlockIncompleteCholeskyPreconditioner::from_arrow_schur(
2947        sys,
2948        &htt_factors,
2949        ridge_beta,
2950        &backend,
2951    )
2952    .ok()
2953    .and_then(|p| run(&|r| p.apply(r)));
2954    out.push((SchurPreconditionerKind::BlockIncompleteCholesky, ic0_row));
2955
2956    Ok(out)
2957}
2958
2959/// Build scalar diagonal inverses for a set of global column indices.
2960///
2961/// Used when a cluster is non-PD or exceeds `CLUSTER_JACOBI_MAX_CLUSTER`.
2962pub(crate) fn build_schur_scalar_inv<B: BatchedBlockSolver>(
2963    sys: &ArrowSchurSystem,
2964    htt_factors: &ArrowFactorSlab,
2965    ridge_beta: f64,
2966    backend: &B,
2967    cols: &[usize],
2968) -> Result<Vec<f64>, ArrowSchurError> {
2969    let d = sys.d;
2970    let mut result = Vec::with_capacity(cols.len());
2971    let mut col_vec = Array1::<f64>::zeros(d);
2972    // Extract the penalty diagonal for all K columns once, then index per-column.
2973    let mut full_diag = Array1::<f64>::zeros(sys.k);
2974    {
2975        let diag_slice = full_diag.as_slice_mut().expect("full_diag contiguous");
2976        sys.penalty_diagonal_add(diag_slice);
2977    }
2978    for &gi in cols {
2979        let mut s = full_diag[gi] + ridge_beta;
2980        for (row_idx, row) in sys.rows.iter().enumerate() {
2981            for c in 0..d {
2982                col_vec[c] = row.htbeta[[c, gi]];
2983            }
2984            let solved = backend.solve_block_vector(htt_factors.factor(row_idx), col_vec.view());
2985            let mut acc = 0.0;
2986            for c in 0..d {
2987                acc += col_vec[c] * solved[c];
2988            }
2989            s -= acc;
2990        }
2991        if !s.is_finite() || s <= JACOBI_DIAGONAL_PD_FLOOR {
2992            return Err(ArrowSchurError::PcgFailed {
2993                reason: format!(
2994                    "cluster Schur scalar fallback: non-PD diagonal at index {gi}: {s}"
2995                ),
2996            });
2997        }
2998        result.push(1.0 / s);
2999    }
3000    Ok(result)
3001}
3002
3003/// Inexact PCG with automatic preconditioner-ladder escalation.
3004///
3005/// Starts with `JacobiPreconditioner` (Diagonal or BetaBlockJacobi).
3006/// If PCG hits `MaxIter` and `k > PRECOND_ESCALATE_K_THRESHOLD`,
3007/// escalates to `ClusterJacobi`; if still `MaxIter`, escalates to
3008/// `AdditiveSchwarz { overlap: 1 }`.
3009pub(crate) fn steihaug_pcg_auto<B: BatchedBlockSolver + Sync>(
3010    sys: &ArrowSchurSystem,
3011    htt_factors: &ArrowFactorSlab,
3012    ridge_beta: f64,
3013    rhs: &Array1<f64>,
3014    pcg: &ArrowPcgOptions,
3015    trust: &ArrowTrustRegionOptions,
3016    backend: &B,
3017    gpu_matvec: Option<&GpuSchurMatvec>,
3018    metric_weights: Option<&MetricWeights>,
3019    curvature_floor: Option<f64>,
3020) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurError> {
3021    // #1017 CPU residency: stage the per-row reduced-Schur factors `(L_i, Y_i)`
3022    // (NOT the dense `p×p` block — `di ≪ p`, so the factored form is `O(n·di·p)`
3023    // memory and `2·support_i·p + 2·di·p` flops/row including the sparse
3024    // gather/scatter over the active support) once, up
3025    // front, when the SAE structure is installed and the matvec runs on host
3026    // (CPU). The GPU matvec carries its own residency, so skip when it is engaged.
3027    // The same staged operator is reused across the whole preconditioner ladder
3028    // (Jacobi → ClusterJacobi → AdditiveSchwarz) — built once, not per tier.
3029    let resident = if gpu_matvec.is_none() {
3030        SaeResidentReducedSchur::build(sys, htt_factors, backend)
3031    } else {
3032        None
3033    };
3034    // #1026 — curvature-floor retry on the Jacobi tier. The unbounded SAE inner
3035    // PCG (trust radius = ∞) fails on `pᵀSp ≤ 0` when the reduced Schur is
3036    // indefinite (K≥4 co-collapse: a near-singular per-row `H_tt` over-subtracts
3037    // `S`). Instead of letting that failure propagate to the outer LM loop —
3038    // which inflates `ridge_β` over EVERY β direction and makes the inner Newton
3039    // crawl — floor the OPERATOR by the minimal ridge `δ = |pᵀSp|/‖p‖² · (1+ε)`
3040    // that restores positive curvature along the offending direction, rebuild the
3041    // Jacobi preconditioner at the lifted ridge, and retry. This is the
3042    // matrix-free analogue of the dense `spectral_pd_floored_schur`: the healthy
3043    // β subspace (where curvature is already positive) is essentially untouched
3044    // by a tiny `δ`, while the collapsed direction gets exactly the stiffness it
3045    // needs to make a real descent step. A PD reduced Schur never hits `pᵀSp ≤ 0`,
3046    // so this loop is a strict no-op there (bit-for-bit unchanged). Bounded by a
3047    // small attempt cap and a relative ridge ceiling; on exhaustion the original
3048    // recoverable failure still reaches the outer LM loop.
3049    let mut effective_ridge = ridge_beta;
3050    let mut x0_diag0: Option<(Array1<f64>, PcgDiagnostics)> = None;
3051    let mut last_curvature_err: Option<ArrowSchurError> = None;
3052    let rhs_scale = metric_norm(rhs.view(), metric_weights).max(1.0);
3053    let ridge_ceiling = ridge_beta.max(SCHUR_CURVATURE_FLOOR_REL_CEILING * rhs_scale);
3054    for _attempt in 0..=SCHUR_CURVATURE_FLOOR_MAX_ATTEMPTS {
3055        // The Jacobi preconditioner build itself refuses a non-PD Schur diagonal
3056        // (`PcgFailed: invalid Schur Jacobi diagonal`) — the SAME co-collapse
3057        // signature reached BEFORE the CG loop, since `S_ii = H_ββ,ii − Σ …` goes
3058        // negative. Treat that build failure as a curvature deficit too: when the
3059        // floor is enabled, lift the ridge and retry; otherwise propagate.
3060        let jacobi = match JacobiPreconditioner::from_arrow_schur(
3061            sys,
3062            htt_factors,
3063            effective_ridge,
3064            backend,
3065            resident.as_ref(),
3066        ) {
3067            Ok(jacobi) => jacobi,
3068            Err(err @ ArrowSchurError::PcgFailed { .. }) => {
3069                if curvature_floor.is_none() {
3070                    return Err(err);
3071                }
3072                // A diagonal refusal carries no `(curvature, ‖p‖²)` deficit, and
3073                // the over-subtraction magnitude `Σ H_tβᵀ(H_tt)⁻¹H_tβ` is
3074                // unbounded relative to `rhs_scale`, so a small additive bump
3075                // would crawl. Escalate the ridge MULTIPLICATIVELY (×10, matching
3076                // the per-row `factor_one_row_result` RIDGE_GROWTH_FACTOR), seeded
3077                // at `rhs_scale`, so even a large deficit (the collapsed
3078                // `(H_tβ)²/H_tt` over-subtraction) is reached in a handful of
3079                // attempts. The ceiling + attempt cap still bound it; on
3080                // exhaustion the recoverable failure reaches the outer LM loop.
3081                let next = if effective_ridge > 0.0 {
3082                    effective_ridge * SCHUR_CURVATURE_FLOOR_DIAG_GROWTH
3083                } else {
3084                    rhs_scale
3085                };
3086                last_curvature_err = Some(err);
3087                if !next.is_finite() || next > ridge_ceiling {
3088                    break;
3089                }
3090                effective_ridge = next;
3091                continue;
3092            }
3093            Err(other) => return Err(other),
3094        };
3095        match run_pcg_with_preconditioner(
3096            sys,
3097            htt_factors,
3098            effective_ridge,
3099            rhs,
3100            |r| jacobi.apply(r),
3101            pcg,
3102            trust,
3103            backend,
3104            gpu_matvec,
3105            metric_weights,
3106            resident.as_ref(),
3107        ) {
3108            Ok(result) => {
3109                x0_diag0 = Some(result);
3110                break;
3111            }
3112            Err(ArrowSchurError::UnboundedNegativeCurvature {
3113                curvature,
3114                direction_norm_sq,
3115            }) => {
3116                // Only floor when the caller opted in (SAE solve path); otherwise
3117                // propagate the raw negative-curvature signal so BA / non-SAE
3118                // unbounded solves keep their existing failure contract.
3119                let Some(relative_floor) = curvature_floor else {
3120                    return Err(ArrowSchurError::UnboundedNegativeCurvature {
3121                        curvature,
3122                        direction_norm_sq,
3123                    });
3124                };
3125                // Minimal ridge to make `pᵀ(S+δI)p = |curvature| + δ·‖p‖² > 0`,
3126                // with a margin so the next CG iterate has strictly positive
3127                // curvature rather than sitting on the `0` knife-edge.
3128                let deficit = if direction_norm_sq > 0.0 {
3129                    curvature.abs() / direction_norm_sq
3130                } else {
3131                    0.0
3132                };
3133                let bump = (deficit * (1.0 + SCHUR_CURVATURE_FLOOR_MARGIN))
3134                    .max(relative_floor.max(SCHUR_CURVATURE_FLOOR_REL_FLOOR) * rhs_scale);
3135                let next = (effective_ridge + bump).max(effective_ridge * 2.0);
3136                last_curvature_err = Some(ArrowSchurError::UnboundedNegativeCurvature {
3137                    curvature,
3138                    direction_norm_sq,
3139                });
3140                if !next.is_finite() || next > ridge_ceiling {
3141                    break;
3142                }
3143                effective_ridge = next;
3144            }
3145            Err(other) => return Err(other),
3146        }
3147    }
3148    let (x0, diag0) = match x0_diag0 {
3149        Some(result) => result,
3150        None => {
3151            // The curvature floor could not condition the operator within the
3152            // ceiling; hand the recoverable failure to the outer LM loop, which
3153            // re-forms the system at a heavier ridge.
3154            return Err(last_curvature_err.unwrap_or(ArrowSchurError::PcgFailed {
3155                reason: "unbounded Schur PCG negative curvature unresolved by curvature floor"
3156                    .to_string(),
3157            }));
3158        }
3159    };
3160    if sys.k <= PRECOND_ESCALATE_K_THRESHOLD || diag0.stopping_reason != PcgStopReason::MaxIter {
3161        return Ok((x0, diag0));
3162    }
3163    // Escalation tiers reuse the curvature-floored `effective_ridge` so the
3164    // operator they precondition is the SAME (PD-floored) one the Jacobi tier
3165    // settled on; a still-negative-curvature signal here is handed to the outer
3166    // LM loop (it only arises if the floored Jacobi tier merely ran out of
3167    // iterations yet a coarser preconditioner still finds an indefinite
3168    // direction — rare; the LM loop re-forms at a heavier ridge).
3169    let cluster =
3170        ClusterJacobiPreconditioner::from_arrow_schur(sys, htt_factors, effective_ridge, backend)?;
3171    let (x1, diag1) = run_pcg_with_preconditioner(
3172        sys,
3173        htt_factors,
3174        effective_ridge,
3175        rhs,
3176        |r| cluster.apply(r),
3177        pcg,
3178        trust,
3179        backend,
3180        gpu_matvec,
3181        metric_weights,
3182        resident.as_ref(),
3183    )?;
3184    if diag1.stopping_reason != PcgStopReason::MaxIter {
3185        return Ok((x1, diag1));
3186    }
3187    let schwarz = AdditiveSchwarzPreconditioner::from_arrow_schur(
3188        sys,
3189        htt_factors,
3190        effective_ridge,
3191        backend,
3192        1,
3193    )?;
3194    let (x2, diag2) = run_pcg_with_preconditioner(
3195        sys,
3196        htt_factors,
3197        effective_ridge,
3198        rhs,
3199        |r| schwarz.apply(r),
3200        pcg,
3201        trust,
3202        backend,
3203        gpu_matvec,
3204        metric_weights,
3205        resident.as_ref(),
3206    )?;
3207    if diag2.stopping_reason != PcgStopReason::MaxIter {
3208        return Ok((x2, diag2));
3209    }
3210    // Final tier — diagonal-assembled additive Schwarz (#299), the cheap-apply
3211    // Schwarz variant. When the dense-block AdditiveSchwarz still ran out of
3212    // iterations its O(Σ b_k²) apply may have throttled the iteration budget on
3213    // a wide subdomain; the diag-assembled variant keeps Schwarz's overlapping
3214    // local-inverse conditioning but applies in O(K), so it can take more CG
3215    // iterations within the same wall budget. Same overlap (1) and same
3216    // curvature-floored ridge as the dense-block tier.
3217    let diag_schwarz = DiagAssembledSchwarzPreconditioner::from_arrow_schur(
3218        sys,
3219        htt_factors,
3220        effective_ridge,
3221        backend,
3222        1,
3223    )?;
3224    let (x3, diag3) = run_pcg_with_preconditioner(
3225        sys,
3226        htt_factors,
3227        effective_ridge,
3228        rhs,
3229        |r| diag_schwarz.apply(r),
3230        pcg,
3231        trust,
3232        backend,
3233        gpu_matvec,
3234        metric_weights,
3235        resident.as_ref(),
3236    )?;
3237    if diag3.stopping_reason != PcgStopReason::MaxIter {
3238        return Ok((x3, diag3));
3239    }
3240    // Richest tier — level-0 incomplete Cholesky (#299). ClusterJacobi keeps the
3241    // full DENSE Cholesky of each component (so on a single large connected
3242    // component it fills the whole `b×b` factor and its `O(b²)` apply throttles
3243    // the CG iteration budget), while the diagonal/Schwarz tiers drop most
3244    // inter-block coupling. IC(0) keeps the component's full structural coupling
3245    // but only the level-0 (no-fill) pattern, so its sparse triangular apply is
3246    // `O(nnz(S[C,C]))` — it can take more CG iterations within the same wall
3247    // budget AND conditions the off-diagonal coupling the cheap tiers discard.
3248    // Last in the ladder so it is only paid when every cheaper tier stalled.
3249    let ic0 = BlockIncompleteCholeskyPreconditioner::from_arrow_schur(
3250        sys,
3251        htt_factors,
3252        effective_ridge,
3253        backend,
3254    )?;
3255    let (x4, diag4) = run_pcg_with_preconditioner(
3256        sys,
3257        htt_factors,
3258        effective_ridge,
3259        rhs,
3260        |r| ic0.apply(r),
3261        pcg,
3262        trust,
3263        backend,
3264        gpu_matvec,
3265        metric_weights,
3266        resident.as_ref(),
3267    )?;
3268    // All five preconditioner tiers (Jacobi -> ClusterJacobi -> AdditiveSchwarz
3269    // -> DiagAssembledSchwarz -> BlockIncompleteCholesky) exhausted their
3270    // iteration budget without driving the residual below tolerance. Returning a
3271    // truncated iterate as `Ok` would feed an arbitrarily-large-residual step
3272    // into the Newton driver, where the PCG diagnostics are discarded. Surface a
3273    // recoverable failure instead so `solve_with_lm_escalation_inner` escalates
3274    // the proximal ridge: better conditioning is precisely what a stalled PCG on
3275    // an ill-conditioned reduced system needs.
3276    if diag4.stopping_reason == PcgStopReason::MaxIter {
3277        return Err(ArrowSchurError::PcgFailed {
3278            reason: format!(
3279                "Schur PCG exhausted all preconditioner tiers (Jacobi, ClusterJacobi, \
3280                 AdditiveSchwarz, DiagAssembledSchwarz, BlockIncompleteCholesky) at MaxIter; \
3281                 final relative residual = {:e}",
3282                diag4.final_relative_residual
3283            ),
3284        });
3285    }
3286    Ok((x4, diag4))
3287}
3288
3289/// Run Steihaug-CG with a generic preconditioner closure.
3290/// Routes matvec through GPU when `gpu_matvec` is set.
3291pub(crate) fn run_pcg_with_preconditioner<ApplyPrec, B: BatchedBlockSolver + Sync>(
3292    sys: &ArrowSchurSystem,
3293    htt_factors: &ArrowFactorSlab,
3294    ridge_beta: f64,
3295    rhs: &Array1<f64>,
3296    apply_prec: ApplyPrec,
3297    pcg: &ArrowPcgOptions,
3298    trust: &ArrowTrustRegionOptions,
3299    backend: &B,
3300    gpu_matvec: Option<&GpuSchurMatvec>,
3301    metric_weights: Option<&MetricWeights>,
3302    resident: Option<&SaeResidentReducedSchur>,
3303) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurError>
3304where
3305    ApplyPrec: FnMut(&Array1<f64>) -> Array1<f64>,
3306{
3307    let max_iters = pcg.max_iterations.min(trust.max_iterations);
3308    let tol = pcg
3309        .relative_tolerance
3310        .max(trust.steihaug_relative_tolerance);
3311    if let Some(gpu_mv) = gpu_matvec {
3312        let gpu_mv = Arc::clone(gpu_mv);
3313        steihaug_cg(
3314            rhs,
3315            move |p, out| gpu_mv(p, out),
3316            apply_prec,
3317            max_iters,
3318            tol,
3319            trust.radius,
3320            metric_weights,
3321        )
3322    } else {
3323        steihaug_cg(
3324            rhs,
3325            |p, out| schur_matvec(sys, htt_factors, ridge_beta, p, out, backend, resident),
3326            apply_prec,
3327            max_iters,
3328            tol,
3329            trust.radius,
3330            metric_weights,
3331        )
3332    }
3333}
3334
3335#[derive(Debug, Clone, Copy)]
3336pub(crate) struct IdentityPreconditioner;
3337
3338impl IdentityPreconditioner {
3339    pub(crate) fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
3340        r.clone()
3341    }
3342}
3343
3344pub(crate) fn steihaug_dense_system(
3345    schur: &Array2<f64>,
3346    rhs: &Array1<f64>,
3347    preconditioner: &IdentityPreconditioner,
3348    pcg: &ArrowPcgOptions,
3349    trust: &ArrowTrustRegionOptions,
3350    metric_weights: Option<&MetricWeights>,
3351) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurError> {
3352    steihaug_cg(
3353        rhs,
3354        |p, out| dense_matvec(schur, p, out),
3355        |r| preconditioner.apply(r),
3356        pcg.max_iterations,
3357        pcg.relative_tolerance,
3358        trust.radius,
3359        metric_weights,
3360    )
3361}
3362
3363pub(crate) fn steihaug_cg<MatVec, ApplyPrec>(
3364    rhs: &Array1<f64>,
3365    mut matvec: MatVec,
3366    mut apply_preconditioner: ApplyPrec,
3367    max_iterations: usize,
3368    relative_tolerance: f64,
3369    trust_radius: f64,
3370    metric_weights: Option<&MetricWeights>,
3371) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurError>
3372where
3373    MatVec: FnMut(&Array1<f64>, &mut Array1<f64>),
3374    ApplyPrec: FnMut(&Array1<f64>) -> Array1<f64>,
3375{
3376    let n = rhs.len();
3377    if let Some(weights) = metric_weights {
3378        assert_eq!(
3379            weights.len(),
3380            n,
3381            "Steihaug-CG metric weight length must match solve dimension"
3382        );
3383    }
3384    let radius = if trust_radius.is_finite() && trust_radius > 0.0 {
3385        trust_radius
3386    } else {
3387        f64::INFINITY
3388    };
3389    let rhs_norm = metric_norm(rhs.view(), metric_weights);
3390    if rhs_norm == 0.0 {
3391        return Ok((Array1::<f64>::zeros(n), PcgDiagnostics::default()));
3392    }
3393    let tol = (relative_tolerance.max(0.0) * rhs_norm).max(PCG_ABSOLUTE_TOLERANCE_FLOOR);
3394    let mut x = Array1::<f64>::zeros(n);
3395    let mut r = rhs.clone();
3396    let mut z = apply_preconditioner(&r);
3397    let mut diag = PcgDiagnostics {
3398        precond_apply_calls: 1,
3399        ..PcgDiagnostics::default()
3400    };
3401    let mut p = z.clone();
3402    let mut rz = metric_dot(&r, &z, metric_weights);
3403    if rz <= 0.0 || !rz.is_finite() {
3404        if radius.is_finite() {
3405            diag.final_relative_residual = metric_norm(r.view(), metric_weights) / rhs_norm;
3406            diag.stopping_reason = PcgStopReason::TrustRegion;
3407            return Ok((step_to_trust_boundary(&x, &r, radius, metric_weights), diag));
3408        }
3409        // Unbounded (radius = ∞) non-positive preconditioned residual: the
3410        // reduced Schur is indefinite at the very first direction. Surface the
3411        // typed curvature-floor signal so `steihaug_pcg_auto` floors the
3412        // operator minimally and retries, instead of failing into a global
3413        // `ridge_β` ramp. `rz = rᵀM⁻¹r` is a preconditioner-metric curvature;
3414        // report it with the residual norm² as the direction scale.
3415        return Err(ArrowSchurError::UnboundedNegativeCurvature {
3416            curvature: rz,
3417            direction_norm_sq: metric_dot(&r, &r, metric_weights),
3418        });
3419    }
3420    if metric_norm(r.view(), metric_weights) <= tol {
3421        diag.final_relative_residual = 0.0;
3422        diag.stopping_reason = PcgStopReason::Converged;
3423        return Ok((x, diag));
3424    }
3425    let mut ap = Array1::<f64>::zeros(n);
3426    // Reused candidate scratch — avoid per-iteration clone of x.
3427    let mut candidate = Array1::<f64>::zeros(n);
3428    for _ in 0..max_iterations {
3429        matvec(&p, &mut ap);
3430        diag.matvec_calls += 1;
3431        diag.iterations += 1;
3432        let pap = metric_dot(&p, &ap, metric_weights);
3433        if pap <= 0.0 || !pap.is_finite() {
3434            if radius.is_finite() {
3435                diag.final_relative_residual = metric_norm(r.view(), metric_weights) / rhs_norm;
3436                diag.stopping_reason = PcgStopReason::TrustRegion;
3437                return Ok((step_to_trust_boundary(&x, &p, radius, metric_weights), diag));
3438            }
3439            // Unbounded negative curvature `pᵀSp ≤ 0`: the reduced Schur is
3440            // indefinite along `p` (the #1026 co-collapse direction). Surface
3441            // the typed signal carrying `pᵀSp` and `‖p‖²` so the caller floors
3442            // the operator by the minimal ridge `δ = |pᵀSp|/‖p‖²` (which makes
3443            // `pᵀ(S+δI)p = 0⁺`) plus a margin, and retries.
3444            return Err(ArrowSchurError::UnboundedNegativeCurvature {
3445                curvature: pap,
3446                direction_norm_sq: metric_dot(&p, &p, metric_weights),
3447            });
3448        }
3449        let alpha = rz / pap;
3450        for i in 0..n {
3451            candidate[i] = x[i] + alpha * p[i];
3452        }
3453        if radius.is_finite() && metric_norm(candidate.view(), metric_weights) >= radius {
3454            diag.final_relative_residual = metric_norm(r.view(), metric_weights) / rhs_norm;
3455            diag.stopping_reason = PcgStopReason::TrustRegion;
3456            return Ok((step_to_trust_boundary(&x, &p, radius, metric_weights), diag));
3457        }
3458        x.assign(&candidate);
3459        for i in 0..n {
3460            r[i] -= alpha * ap[i];
3461        }
3462        if metric_norm(r.view(), metric_weights) <= tol {
3463            diag.final_relative_residual = metric_norm(r.view(), metric_weights) / rhs_norm;
3464            diag.stopping_reason = PcgStopReason::Converged;
3465            return Ok((x, diag));
3466        }
3467        z = apply_preconditioner(&r);
3468        diag.precond_apply_calls += 1;
3469        let rz_next = metric_dot(&r, &z, metric_weights);
3470        if rz_next <= 0.0 || !rz_next.is_finite() {
3471            return Err(ArrowSchurError::PcgFailed {
3472                reason: "non-positive or non-finite PCG residual".to_string(),
3473            });
3474        }
3475        let beta = rz_next / rz;
3476        for i in 0..n {
3477            p[i] = z[i] + beta * p[i];
3478        }
3479        rz = rz_next;
3480    }
3481    diag.final_relative_residual = metric_norm(r.view(), metric_weights) / rhs_norm;
3482    diag.stopping_reason = PcgStopReason::MaxIter;
3483    Ok((x, diag))
3484}
3485
3486pub(crate) fn step_to_trust_boundary(
3487    x: &Array1<f64>,
3488    p: &Array1<f64>,
3489    radius: f64,
3490    metric_weights: Option<&MetricWeights>,
3491) -> Array1<f64> {
3492    let pp = metric_dot(p, p, metric_weights);
3493    if pp == 0.0 {
3494        return x.clone();
3495    }
3496    let xp = metric_dot(x, p, metric_weights);
3497    let xx = metric_dot(x, x, metric_weights);
3498    let disc = (xp * xp + pp * (radius * radius - xx)).max(0.0);
3499    let tau = (-xp + disc.sqrt()) / pp;
3500    let mut out = x.clone();
3501    for i in 0..out.len() {
3502        out[i] += tau * p[i];
3503    }
3504    out
3505}
3506
3507pub(crate) fn dense_matvec(a: &Array2<f64>, x: &Array1<f64>, out: &mut Array1<f64>) {
3508    let n = a.nrows();
3509    for i in 0..n {
3510        let mut acc = 0.0;
3511        for j in 0..n {
3512            acc += a[[i, j]] * x[j];
3513        }
3514        out[i] = acc;
3515    }
3516}
3517
3518pub(crate) fn dot(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
3519    let mut acc = 0.0;
3520    for i in 0..a.len() {
3521        acc += a[i] * b[i];
3522    }
3523    acc
3524}
3525
3526pub(crate) fn metric_dot(
3527    a: &Array1<f64>,
3528    b: &Array1<f64>,
3529    metric_weights: Option<&MetricWeights>,
3530) -> f64 {
3531    assert_eq!(a.len(), b.len());
3532    match metric_weights {
3533        Some(weights) => {
3534            assert_eq!(weights.len(), a.len());
3535            let mut acc = 0.0;
3536            for i in 0..a.len() {
3537                acc += weights[i] * a[i] * b[i];
3538            }
3539            acc
3540        }
3541        None => dot(a, b),
3542    }
3543}
3544
3545pub(crate) fn metric_norm(v: ArrayView1<'_, f64>, metric_weights: Option<&MetricWeights>) -> f64 {
3546    let mut acc = 0.0;
3547    match metric_weights {
3548        Some(weights) => {
3549            assert_eq!(weights.len(), v.len());
3550            for i in 0..v.len() {
3551                acc += weights[i] * v[i] * v[i];
3552            }
3553        }
3554        None => {
3555            for x in v.iter() {
3556                acc += x * x;
3557            }
3558        }
3559    }
3560    acc.sqrt()
3561}
3562
3563pub(crate) fn symmetrize_upper_from_lower(a: &mut Array2<f64>) {
3564    let n = a.nrows().min(a.ncols());
3565    for i in 0..n {
3566        for j in 0..i {
3567            let v = 0.5 * (a[[i, j]] + a[[j, i]]);
3568            a[[i, j]] = v;
3569            a[[j, i]] = v;
3570        }
3571    }
3572}
3573
3574/// Errors raised by [`ArrowSchurSystem::solve`].
3575#[derive(Debug, Clone)]
3576pub enum ArrowSchurError {
3577    /// A per-row `H_tt^(i)` block was not positive-definite at the
3578    /// supplied ridge. Indicates an under-regularized latent block —
3579    /// typically a gauge-free fit without an identifiability penalty.
3580    PerRowFactorFailed { row: usize, reason: String },
3581    /// A per-row `H_tt^(i)` block factored, but the Cholesky factor failed
3582    /// the safe-inversion guard for the Schur reduction. This can be either
3583    /// an excessive diagonal-ratio condition-number estimate or a numerically
3584    /// tiny pivot relative to the row block scale. Cholesky technically
3585    /// succeeded, but the inverse used in
3586    /// `S = H_ββ − Σ_i H_tβ^(i)ᵀ (H_tt^(i))⁻¹ H_tβ^(i)` is contaminated
3587    /// by spectral terms on the order of `κ_i`; functionally
3588    /// equivalent to a PSD-fail for Schur stability. The LM outer
3589    /// wrapper escalates `ridge_t` identically to `PerRowFactorFailed`.
3590    PerRowFactorIllConditioned { row: usize, kappa_estimate: f64 },
3591    /// The Schur complement was not positive-definite. Indicates a
3592    /// near-collinear decoder or a degenerate weighting; the LM outer
3593    /// wrapper should escalate `ridge_beta` and retry.
3594    SchurFactorFailed { reason: String },
3595    /// The BA inexact-step PCG solve failed before producing a usable
3596    /// Steihaug trust-region step.
3597    PcgFailed { reason: String },
3598    /// The UNBOUNDED (trust-radius = ∞) Schur PCG encountered negative
3599    /// curvature `pᵀSp ≤ 0` (or a non-positive preconditioned residual): the
3600    /// reduced Schur is indefinite, the #1026 K≥4 co-collapse signature where
3601    /// a near-singular per-row `H_tt` over-subtracts `S`. With no trust radius
3602    /// there is no boundary to step to, so CG cannot proceed. `curvature` is
3603    /// the offending `pᵀSp` and `direction_norm_sq` the `‖p‖²` of the
3604    /// negative-curvature direction; the caller floors the operator with the
3605    /// minimal ridge `δ = (|curvature|/‖p‖² )·(1+ε)` that restores positive
3606    /// curvature along `p` and retries (matrix-free analogue of the dense
3607    /// `spectral_pd_floored_schur`), rather than blindly inflating `ridge_β`.
3608    UnboundedNegativeCurvature {
3609        curvature: f64,
3610        direction_norm_sq: f64,
3611    },
3612    /// Adaptive proximal damping could not produce an Armijo-accepted
3613    /// nonlinear step.
3614    AdaptiveCorrectionFailed { reason: String },
3615}
3616
3617impl std::fmt::Display for ArrowSchurError {
3618    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
3619        match self {
3620            ArrowSchurError::PerRowFactorFailed { row, reason } => write!(
3621                f,
3622                "arrow-Schur: per-row H_tt^({row}) Cholesky failed: {reason}"
3623            ),
3624            ArrowSchurError::PerRowFactorIllConditioned {
3625                row,
3626                kappa_estimate,
3627            } => write!(
3628                f,
3629                "arrow-Schur: per-row H_tt^({row}) Cholesky succeeded but failed \
3630                 the safe-inversion guard (kappa_estimate={kappa_estimate:e}); \
3631                 Schur reduction would be numerically contaminated"
3632            ),
3633            ArrowSchurError::SchurFactorFailed { reason } => {
3634                write!(f, "arrow-Schur: Schur complement Cholesky failed: {reason}")
3635            }
3636            ArrowSchurError::PcgFailed { reason } => {
3637                write!(f, "arrow-Schur: Schur PCG failed: {reason}")
3638            }
3639            ArrowSchurError::UnboundedNegativeCurvature {
3640                curvature,
3641                direction_norm_sq,
3642            } => write!(
3643                f,
3644                "arrow-Schur: unbounded Schur PCG hit negative curvature pᵀSp={curvature:e} \
3645                 (‖p‖²={direction_norm_sq:e}); reduced Schur is indefinite (co-collapse), \
3646                 retry with a curvature-floor ridge"
3647            ),
3648            ArrowSchurError::AdaptiveCorrectionFailed { reason } => {
3649                write!(
3650                    f,
3651                    "arrow-Schur: adaptive proximal correction failed: {reason}"
3652                )
3653            }
3654        }
3655    }
3656}
3657
3658impl std::error::Error for ArrowSchurError {}
3659
3660// ---------------------------------------------------------------------------
3661// Cholesky helpers (kept local to avoid a new public-API dependency on the
3662// linalg crate. The systems here are tiny per-row (d × d, d ∈ {1..16}) and
3663// modest at the Schur level (K × K, K ∈ {basis size}). For production SAE
3664// scales the Schur factor should switch to faer; this module's `cholesky_lower`
3665// is the obvious replacement site.)
3666// ---------------------------------------------------------------------------
3667
3668pub(crate) fn cholesky_lower(a: &Array2<f64>) -> Result<Array2<f64>, String> {
3669    let n = a.nrows();
3670    if a.ncols() != n {
3671        return Err(format!("cholesky_lower: non-square {}×{}", n, a.ncols()));
3672    }
3673    if let Some((idx, _)) = a.iter().enumerate().find(|(_, v)| !v.is_finite()) {
3674        return Err(format!(
3675            "cholesky_lower: non-finite entry at linear index {idx}"
3676        ));
3677    }
3678
3679    let mut maybe_device = a.clone();
3680    if gam_gpu::try_cholesky_lower_inplace(&mut maybe_device).is_some() {
3681        return Ok(maybe_device);
3682    }
3683
3684    let mut l = Array2::<f64>::zeros((n, n));
3685    for i in 0..n {
3686        for j in 0..=i {
3687            let mut sum = a[[i, j]];
3688            for kk in 0..j {
3689                sum -= l[[i, kk]] * l[[j, kk]];
3690            }
3691            if i == j {
3692                if !sum.is_finite() || sum <= 0.0 {
3693                    return Err(format!(
3694                        "non-PD pivot {sum} at index {i} (matrix is not positive definite)"
3695                    ));
3696                }
3697                l[[i, j]] = sum.sqrt();
3698            } else {
3699                l[[i, j]] = sum / l[[j, j]];
3700            }
3701        }
3702    }
3703    Ok(l)
3704}