Skip to main content

gam_solve/arrow_schur/
reduced_solve.rs

1//! The reduced `K x K` shared-system solve: dense Schur assembly (direct and
2//! square-root BA), the Schur matvec, the Jacobi/cluster/Schwarz
3//! preconditioners, Steihaug-PCG, and the [`ArrowSchurError`] type.
4
5use super::*;
6
7/// Host budget for a dense reduced Schur `k × k` f64 matrix (#1017). Above this
8/// the dense assembly is refused with a loud `SchurFactorFailed` rather than
9/// OOM-killing the host. 8 GiB ⇒ `k ≈ 32768`; every currently-feasible SAE border
10/// (k ≤ 5120 ⇒ 0.2 GiB) is well under it, while the qwen LLM border (k = 98304 ⇒
11/// 77 GiB) is correctly rejected as matrix-free-only.
12pub(crate) const DENSE_SCHUR_BYTES_BUDGET: u128 = 8 * 1024 * 1024 * 1024;
13
14/// Reduce one contiguous device tile's rows into a private `-Σ leftᵀ·right`
15/// partial (`k×k`).
16///
17/// The tile stacks its per-row `left_i` / `right_i` factors (each `d×k`) into
18/// two `(Σ_i d_i × k)` matrices and tries a single per-ordinal `AᵀB` device
19/// GEMM (`gam_gpu::try_fast_atb_on_ordinal`), which runs on the device this
20/// worker thread already bound — one big GPU GEMM per tile rather than `n` small
21/// CPU ones. When the device primitive declines (no GPU, shape below policy,
22/// transient failure) the tile reduces with the exact CPU `block_gemm_subtract`
23/// loop, so the result is unchanged. The partial is negated so the caller's
24/// `schur += partial` reproduces the serial `schur -= Σ contribution`.
25pub(crate) fn tile_schur_partial<B: BatchedBlockSolver>(
26    sys: &ArrowSchurSystem,
27    htt_factors: &ArrowFactorSlab,
28    backend: &B,
29    kind: SchurReductionKind,
30    ordinal: usize,
31    range: Range<usize>,
32) -> Result<Array2<f64>, ArrowSchurError> {
33    let k = sys.k;
34
35    // Build the per-row contribution factors once; both the GPU stacked-GEMM
36    // and the CPU fallback consume them.
37    let mut factors: Vec<(Array2<f64>, Array2<f64>)> = Vec::with_capacity(range.len());
38    let mut total_d = 0usize;
39    for i in range.clone() {
40        let (left, right) = row_schur_contribution_factors(
41            sys,
42            i,
43            &sys.rows[i],
44            htt_factors.factor(i),
45            backend,
46            kind,
47        )?;
48        total_d += left.nrows();
49        factors.push((left, right));
50    }
51
52    // Stack into (total_d × k) left/right matrices for one device AᵀB GEMM on
53    // this tile's bound ordinal. `try_fast_atb_on_ordinal` returns leftᵀ·right
54    // (k×k); negate into the partial. At an SAE-shaped whole-fit tile with
55    // n=2000 rows, k=2048 shared columns, M=12 local rows per observation, and
56    // K=8 candidate/atom batches, the stacked GEMM is
57    // 2*(n*M)*k^2 = 201_326_592_000 flops per batch, or
58    // 1_610_612_736_000 flops across K=8, so the policy work gate is cleared
59    // even though the observation count is far below the old row floor.
60    if total_d > 0 && k > 0 {
61        let mut left_stack = Array2::<f64>::zeros((total_d, k));
62        let mut right_stack = Array2::<f64>::zeros((total_d, k));
63        let mut base = 0usize;
64        for (left, right) in &factors {
65            let di = left.nrows();
66            left_stack
67                .slice_mut(ndarray::s![base..base + di, ..])
68                .assign(left);
69            right_stack
70                .slice_mut(ndarray::s![base..base + di, ..])
71                .assign(right);
72            base += di;
73        }
74        if let Some(product) =
75            gam_gpu::try_fast_atb_on_ordinal(ordinal, left_stack.view(), right_stack.view())
76        {
77            return Ok(product.mapv(|v| -v));
78        }
79    }
80
81    // CPU fallback: exact per-row block_gemm_subtract into a zero-seeded partial.
82    let mut partial = Array2::<f64>::zeros((k, k));
83    for (left, right) in &factors {
84        backend.block_gemm_subtract(&mut partial, left, right);
85    }
86    Ok(partial)
87}
88
89/// Reduce the per-row Schur contributions `Σ_i H_tβ^(i)ᵀ (H_tt^(i))⁻¹ H_tβ^(i)`
90/// out of `schur` (seeded with `H_ββ + ρ_β·I`).
91///
92/// The per-row contributions are independent — exactly the "sum over independent
93/// arrow-tip blocks" axis the device pool partitions. When more than one GPU is
94/// usable, [`gam_gpu::pool::balanced_partition`] splits the `0..n` rows into
95/// per-device contiguous tiles; each tile is reduced on its own scoped thread
96/// (binding that ordinal's context so the per-row GEMM-subtract offloads to its
97/// device) into a private `k×k` partial, and the partials are summed back into
98/// `schur` in tile order. The tiles are contiguous, ordered to cover `0..n`, and
99/// folded back in that same order, so within each tile the per-row accumulation
100/// order is preserved and the only departure from the serial loop is the
101/// inter-tile reassociation of the reduction sum — the established
102/// reduction-order equivalence the device pool already operates under, well
103/// inside the Newton solve's tolerance.
104///
105/// With a single device (or no GPU) the row loop runs serially in place, which
106/// is bit-for-bit the original behaviour.
107pub(crate) fn reduce_row_schur_contributions<B: BatchedBlockSolver + Sync>(
108    sys: &ArrowSchurSystem,
109    htt_factors: &ArrowFactorSlab,
110    backend: &B,
111    kind: SchurReductionKind,
112    schur: &mut Array2<f64>,
113) -> Result<(), ArrowSchurError> {
114    let n = sys.rows.len();
115    let k = sys.k;
116
117    let tiles = gam_gpu::device_runtime::GpuRuntime::global()
118        .map(|rt| gam_gpu::pool::balanced_partition(rt, n))
119        .filter(|tiles| tiles.len() > 1);
120
121    let Some(tiles) = tiles else {
122        // Single-device / CPU. The per-row contributions `-Σ_i leftᵀ·right` fold
123        // into the `k×k` `schur` independently — the same dense-assembly axis the
124        // multi-GPU tile path partitions, and the dense-Direct analog of the
125        // per-row matvec / streaming `accumulate_chunk` loops already parallelized
126        // for #1017. At the SAE Direct-solve shape (`n` in the thousands, wide
127        // border `k`) this O(n·d·k²) reduction is the dense assembly's whole cost
128        // and was the last serial CPU step on the dense-Schur build.
129        //
130        // Fan it across rayon over fixed row chunks: each chunk reduces its rows
131        // (in row order) into a private zero-seeded `k×k` partial, then the
132        // partials are folded into `schur` in CHUNK order. The per-chunk row order
133        // and the inter-chunk fold order are both fixed independent of thread
134        // scheduling, so the f64 reduction is **bit-identical run-to-run** (the
135        // #1017 determinism gate). NOTE: bit-identical run-to-run does NOT make
136        // it bit-identical to the in-place serial loop — the chunk-boundary
137        // reassociation of the reduction sum is a genuine f64 departure (the
138        // established equivalence `accumulate_chunk` / the per-row matvec operate
139        // under, well inside the Newton solve's tolerance). It bounds candidate-
140        // to-candidate drift to that reassociation margin, so the criterion
141        // ranking is stable EXCEPT for candidates tying within the margin, where
142        // the winner can flip; it is not an exact no-move guarantee (#1211). For
143        // an exact-order guarantee, take the serial path. Stay in-place serial
144        // below the row floor and when already inside a rayon worker (the topology
145        // race fans candidates with `run_topology_race_parallel`) to avoid
146        // nested-rayon oversubscription — the same guard the matvec uses.
147        let n_rows = sys.rows.len();
148        let parallel =
149            n_rows >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
150        if parallel {
151            use rayon::prelude::*;
152            const CHUNK: usize = 64;
153            let partials: Result<Vec<Array2<f64>>, ArrowSchurError> = (0..n_rows)
154                .into_par_iter()
155                .chunks(CHUNK)
156                .map(|idxs| {
157                    let mut partial = Array2::<f64>::zeros((k, k));
158                    for i in idxs {
159                        subtract_row_schur_contribution(
160                            sys,
161                            i,
162                            &sys.rows[i],
163                            htt_factors.factor(i),
164                            backend,
165                            kind,
166                            &mut partial,
167                        )?;
168                    }
169                    Ok(partial)
170                })
171                .collect();
172            // Deterministic ordered fold: chunk partials hold `-Σ contribution`
173            // over their rows, so `schur += partial` reproduces the serial
174            // `schur -= Σ contribution` in fixed (chunk, a, b) order.
175            for partial in &partials? {
176                for a in 0..k {
177                    for b in 0..k {
178                        schur[[a, b]] += partial[[a, b]];
179                    }
180                }
181            }
182            return Ok(());
183        }
184        // Serial in-place reduction (original order) — bit-for-bit reference.
185        for (i, row) in sys.rows.iter().enumerate() {
186            subtract_row_schur_contribution(
187                sys,
188                i,
189                row,
190                htt_factors.factor(i),
191                backend,
192                kind,
193                schur,
194            )?;
195        }
196        return Ok(());
197    };
198
199    // Multi-GPU: one private `-Σ leftᵀ·right` partial per contiguous device
200    // tile. Each tile runs on its own scoped worker thread that binds its
201    // ordinal's context and issues a single stacked AᵀB GEMM on that device, so
202    // the tiles' GEMMs overlap across the pool. Folding the partials back into
203    // the H_ββ-seeded `schur` reproduces the serial reduction (up to inter-tile
204    // reassociation).
205    let partials: Result<Vec<Array2<f64>>, ArrowSchurError> = std::thread::scope(|scope| {
206        let handles: Vec<_> = tiles
207            .iter()
208            .map(|(ordinal, range)| {
209                let ordinal = *ordinal;
210                let range = range.clone();
211                scope.spawn(move || {
212                    // Bind this ordinal's CUDA context on this worker thread so
213                    // the per-row GPU GEMM shims issued from `tile_schur_partial`
214                    // offload to that device. A missing context or bind failure
215                    // is intentionally consumed without escalation — the shims
216                    // no-op back to CPU and the math is unchanged. Off Linux
217                    // `GpuRuntime::global()` is always `None`, so this branch
218                    // is unreachable and the bind is omitted entirely.
219                    #[cfg(target_os = "linux")]
220                    {
221                        if let Some(ctx) = gam_gpu::device_runtime::cuda_context_for(ordinal) {
222                            if ctx.bind_to_thread().is_err() {
223                                // Fall through: this tile reduces on the CPU.
224                            }
225                        }
226                    }
227                    tile_schur_partial(sys, htt_factors, backend, kind, ordinal, range)
228                })
229            })
230            .collect();
231        handles
232            .into_iter()
233            .map(|handle| {
234                handle
235                    .join()
236                    .map_err(|_| ArrowSchurError::SchurFactorFailed {
237                        reason: "schur-reduction tile thread panicked".to_string(),
238                    })?
239            })
240            .collect()
241    });
242    let partials = partials?;
243
244    // Fold partials into `schur` in tile order (contiguous, covering 0..n) so
245    // the per-tile and inter-tile accumulation order is the row order; each
246    // partial holds `-Σ contribution` over its rows, so `schur += partial`
247    // reproduces `schur -= Σ contribution`.
248    for partial in &partials {
249        for a in 0..k {
250            for b in 0..k {
251                schur[[a, b]] += partial[[a, b]];
252            }
253        }
254    }
255    Ok(())
256}
257
258pub(crate) fn build_dense_schur_direct<B: BatchedBlockSolver + Sync>(
259    sys: &ArrowSchurSystem,
260    htt_factors: &ArrowFactorSlab,
261    ridge_beta: f64,
262    backend: &B,
263) -> Result<Array2<f64>, ArrowSchurError> {
264    let k = sys.k;
265    // Materialise H_ββ via the BetaPenaltyOp trait (#296): DensePenaltyOp
266    // for the legacy dense path, structured ops for SAE / Kronecker smooths.
267    let op = sys.effective_penalty_op();
268    if op.dim() != k {
269        return Err(ArrowSchurError::SchurFactorFailed {
270            reason: "Direct BA requires a K×K shared H_ββ penalty operator".to_string(),
271        });
272    }
273    // Fail LOUD, never OOM-kill (#1017): the dense reduced Schur is `k × k` f64.
274    // At SAE LLM borders (qwen `k = 98304` ⇒ 77 GiB) materialising it would crash
275    // the host. The matrix-free device PCG already solves the *step* without it
276    // (`try_device_arrow_direct_sae_pcg`); only the joint-Hessian log-det still
277    // routes here. A matrix-free determinant-lemma log-det (the proper follow-up)
278    // is not yet wired, so refuse the allocation with an actionable error rather
279    // than degrading silently into an OOM. The budget is generous so every
280    // currently-feasible border (k ≤ 5120 ⇒ 0.2 GiB) is unaffected.
281    let dense_bytes = (k as u128).saturating_mul(k as u128).saturating_mul(8);
282    if dense_bytes > DENSE_SCHUR_BYTES_BUDGET {
283        return Err(ArrowSchurError::SchurFactorFailed {
284            reason: format!(
285                "dense reduced Schur is {k}×{k} f64 = {} MiB, exceeding the {} MiB host budget; \
286                 this border is matrix-free-only (the device PCG solves the step without the dense \
287                 Schur) and a matrix-free determinant-lemma log-det is the required follow-up",
288                dense_bytes / (1024 * 1024),
289                DENSE_SCHUR_BYTES_BUDGET / (1024 * 1024),
290            ),
291        });
292    }
293    let mut schur = op.to_dense();
294    for j in 0..k {
295        schur[[j, j]] += ridge_beta;
296    }
297    reduce_row_schur_contributions(
298        sys,
299        htt_factors,
300        backend,
301        SchurReductionKind::Direct,
302        &mut schur,
303    )?;
304    symmetrize_upper_from_lower(&mut schur);
305    Ok(schur)
306}
307
308pub(crate) fn build_dense_schur_sqrt_ba<B: BatchedBlockSolver + Sync>(
309    sys: &ArrowSchurSystem,
310    htt_factors: &ArrowFactorSlab,
311    ridge_beta: f64,
312    backend: &B,
313) -> Result<Array2<f64>, ArrowSchurError> {
314    let k = sys.k;
315    // Materialise H_ββ via the BetaPenaltyOp trait (#296).
316    let op = sys.effective_penalty_op();
317    if op.dim() != k {
318        return Err(ArrowSchurError::SchurFactorFailed {
319            reason: "Square-Root BA direct solve requires a K×K shared H_ββ penalty operator"
320                .to_string(),
321        });
322    }
323    // Same fail-loud host-memory contract as the Direct reduction (#1017).  The
324    // square-root BA route still materialises the same dense `k×k` reduced
325    // Schur; letting this path bypass the budget would preserve an OOM-class
326    // fallback even after Direct learned to refuse matrix-free-only borders.
327    let dense_bytes = (k as u128).saturating_mul(k as u128).saturating_mul(8);
328    if dense_bytes > DENSE_SCHUR_BYTES_BUDGET {
329        return Err(ArrowSchurError::SchurFactorFailed {
330            reason: format!(
331                "square-root BA dense reduced Schur is {k}×{k} f64 = {} MiB, exceeding the \
332                 {} MiB host budget; this border is matrix-free-only",
333                dense_bytes / (1024 * 1024),
334                DENSE_SCHUR_BYTES_BUDGET / (1024 * 1024),
335            ),
336        });
337    }
338    let mut schur = op.to_dense();
339    for j in 0..k {
340        schur[[j, j]] += ridge_beta;
341    }
342    reduce_row_schur_contributions(
343        sys,
344        htt_factors,
345        backend,
346        SchurReductionKind::SqrtBa,
347        &mut schur,
348    )?;
349    symmetrize_upper_from_lower(&mut schur);
350    Ok(schur)
351}
352
353/// Certified Carson–Higham mixed-precision solve of the reduced dense Schur
354/// system `S Δβ = rhs` (#1014), specialized to the streaming/residency path.
355///
356/// Returns `Some(Δβ)` when certified mixed precision is enabled AND the κ gate
357/// admits the f32 factorization AND the f64 backward-error certificate closes;
358/// `None` in every other case so the caller falls back to the exact f64
359/// triangular solve. The f64 `factor` (whose diagonal carries the exact
360/// `log|S|`) is supplied by the caller and never re-derived here — the logdet
361/// the evidence path reads stays f64 by construction.
362///
363/// Method: store the f64 Cholesky factor as f32, solve in f32, then refine with
364/// residuals `r = rhs − S·x` computed in f64 against the f64 `S`. With
365/// `κ(S)·u_f32 < margin` the refinement contracts at rate `κ·u`, and the
366/// terminating certificate is the normwise backward error
367/// `‖r‖∞ / (‖S‖∞‖x‖∞ + ‖rhs‖∞) ≤ tol`. A non-decreasing residual or an
368/// unmet certificate after `max_refinement_steps` returns `None`.
369pub(crate) fn mixed_precision_reduced_beta(
370    schur: &Array2<f64>,
371    factor: &Array2<f64>,
372    rhs: &Array1<f64>,
373    options: &ArrowSolveOptions,
374) -> Option<Array1<f64>> {
375    let ArrowSolvePrecisionPolicy::CertifiedMixed {
376        max_refinement_steps,
377        residual_relative_tolerance,
378        kappa_unit_roundoff_margin,
379    } = options.solve_precision
380    else {
381        return None;
382    };
383    // The reduced-system mixed-precision path is the dense reduced solve only;
384    // a trust-region-truncated step takes the Steihaug branch below in f64.
385    if options.trust_region.radius.is_finite() {
386        return None;
387    }
388    let n = schur.nrows();
389    if n == 0 {
390        return None;
391    }
392
393    // κ gate: the f32 factorization is only admissible when κ(S)·u_f32 leaves
394    // the refinement contraction headroom the certificate needs.
395    let kappa = cholesky_factor_kappa_estimate(factor);
396    if !kappa.is_finite() || kappa * F32_UNIT_ROUNDOFF >= kappa_unit_roundoff_margin {
397        return None;
398    }
399
400    let factor_f32 = factor.mapv(|v| v as f32);
401    let s_inf = matrix_inf_norm(schur);
402    let rhs_inf = rhs.iter().fold(0.0_f64, |a, &b| a.max(b.abs()));
403    let certificate_tol = residual_relative_tolerance
404        .max(MIXED_PRECISION_CERTIFICATE_EPSILON_MULTIPLIER * f64::EPSILON);
405
406    // f32 solve of the seed system, then f64-residual refinement steps.
407    let mut x = cholesky_solve_lower_f32(&factor_f32, &rhs.mapv(|v| v as f32)).mapv(|v| v as f64);
408    let mut last_residual = f64::INFINITY;
409    for _ in 0..=max_refinement_steps {
410        // Residual r = rhs − S·x in f64 against the f64 model.
411        let sx = schur.dot(&x);
412        let mut r = rhs.clone();
413        r -= &sx;
414        let r_inf = r.iter().fold(0.0_f64, |a, &b| a.max(b.abs()));
415        let x_inf = x.iter().fold(0.0_f64, |a, &b| a.max(b.abs()));
416        let denom = s_inf * x_inf + rhs_inf;
417        let backward_error = if denom > 0.0 { r_inf / denom } else { 0.0 };
418        if backward_error <= certificate_tol {
419            return Some(x);
420        }
421        // Refinement must make monotone progress, else hand back to f64.
422        if !(r_inf < last_residual) {
423            return None;
424        }
425        last_residual = r_inf;
426        // Correction solve in f32 against the f32 factor: S·δ = r.
427        let delta = cholesky_solve_lower_f32(&factor_f32, &r.mapv(|v| v as f32)).mapv(|v| v as f64);
428        x += &delta;
429    }
430    None
431}
432
433/// Infinity norm (max absolute row sum) of a dense matrix.
434pub(crate) fn matrix_inf_norm(a: &Array2<f64>) -> f64 {
435    let mut max_row = 0.0_f64;
436    for row in a.rows() {
437        let s: f64 = row.iter().map(|v| v.abs()).sum();
438        if s > max_row {
439            max_row = s;
440        }
441    }
442    max_row
443}
444
445/// Spectral positive-definiteness floor for the reduced Schur complement
446/// `S` (#1026 SAE co-collapse SOLVE-path cure).
447///
448/// Reached only after the genuine Cholesky of `S` has REFUSED it (an indefinite
449/// reduced Schur: collapsed atoms drive a per-row `H_tt` near-singular, so the
450/// accumulated `Σ_i H_tβᵀ (H_tt)⁻¹ H_tβ` over-subtracts `H_ββ + ridge_β·I` into a
451/// matrix with a non-positive eigenvalue). Rather than reject and let the LM
452/// loop inflate `ridge_β` over EVERY β direction (the #1026 "crawl"), we
453/// symmetric-eigendecompose `S` and clamp every eigenvalue UP to
454/// `floor·max(λ)`. This is Levenberg–Marquardt restricted to exactly the
455/// indefinite/collapsed subspace: a well-separated positive direction
456/// (`λ ≫ floor·max λ`) keeps its EXACT eigenvalue (`λ.max(floor·max λ) = λ`), so
457/// the Newton step in the healthy β subspace is unchanged, while only the
458/// collapsed directions get the minimal positive stiffness needed for a PD
459/// solve. Returns the floored, symmetric, strictly-PD matrix, or `None` if `S`
460/// has no usable scale (non-finite / all-zero spectrum), in which case the
461/// caller keeps the strict refusal.
462///
463/// Mirrors the per-row evidence floor
464/// [`super::factorization::factor_spectral_deflated_evidence_row`]; the only
465/// difference is the floored VALUE — a small positive `floor·max λ` (Tikhonov,
466/// for an accurate solve) here, vs unit stiffness `+1` (`log 1 = 0`) there (for
467/// the quotient log-det).
468pub(crate) fn spectral_pd_floored_schur(
469    schur: &Array2<f64>,
470    relative_floor: f64,
471) -> Option<Array2<f64>> {
472    let n = schur.nrows();
473    if n == 0 || schur.ncols() != n || !(relative_floor.is_finite() && relative_floor > 0.0) {
474        return None;
475    }
476    // Symmetrise defensively (the assembled Schur is symmetric up to reduction
477    // order; the eig routine assumes exact symmetry).
478    let mut sym = Array2::<f64>::zeros((n, n));
479    for i in 0..n {
480        for j in 0..n {
481            let v = 0.5 * (schur[[i, j]] + schur[[j, i]]);
482            if !v.is_finite() {
483                return None;
484            }
485            sym[[i, j]] = v;
486        }
487    }
488    let (evals, evecs) = sym.eigh(Side::Lower).ok()?;
489    let max_abs = evals.iter().fold(
490        0.0_f64,
491        |acc, &v| if v.is_finite() { acc.max(v.abs()) } else { acc },
492    );
493    if !(max_abs.is_finite() && max_abs > 0.0) {
494        return None;
495    }
496    let floor = relative_floor * max_abs;
497    // Reconstruct `Σ_i max(λ_i, floor) v_i v_iᵀ`: clamp every eigenvalue UP to a
498    // strictly positive `floor`. Healthy positive directions (`λ ≫ floor`) are
499    // untouched; non-positive / tiny collapsed directions are lifted to exactly
500    // `floor`. The result is symmetric PD by construction.
501    let mut conditioned = Array2::<f64>::zeros((n, n));
502    for eig_idx in 0..evals.len() {
503        let lambda = evals[eig_idx];
504        let lambda_floored = if lambda.is_finite() {
505            lambda.max(floor)
506        } else {
507            floor
508        };
509        for i in 0..n {
510            let vi = evecs[[i, eig_idx]];
511            if vi == 0.0 {
512                continue;
513            }
514            for j in 0..n {
515                conditioned[[i, j]] += lambda_floored * vi * evecs[[j, eig_idx]];
516            }
517        }
518    }
519    Some(conditioned)
520}
521
522/// Unit-stiffness quotient conditioning for the *reduced* evidence Schur block.
523///
524/// `spectral_pd_floored_schur` is the right object for Newton steps: it is a
525/// Levenberg-Marquardt floor that damps collapsed decoder directions just enough
526/// to compute a stable `Δβ`.  The Laplace evidence path is different.  Once the
527/// reduced Schur is being used only for a log determinant, a non-positive (or
528/// numerically null) reduced direction is a quotient/null direction, just like
529/// the per-row `H_tt` spectral-deflation case.  It must contribute the
530/// ρ-independent constant `log 1 = 0`, not `log(floor·max λ)`: the latter is a
531/// ρ-dependent Occam reward for collapsed/redundant decoders and can make the
532/// outer REML sweep prefer a worse planted-manifold optimum.
533pub(crate) fn spectral_unit_deflated_schur(
534    schur: &Array2<f64>,
535    relative_floor: f64,
536) -> Option<Array2<f64>> {
537    let n = schur.nrows();
538    if n == 0 || schur.ncols() != n || !(relative_floor.is_finite() && relative_floor > 0.0) {
539        return None;
540    }
541    let mut sym = Array2::<f64>::zeros((n, n));
542    for i in 0..n {
543        for j in 0..n {
544            let v = 0.5 * (schur[[i, j]] + schur[[j, i]]);
545            if !v.is_finite() {
546                return None;
547            }
548            sym[[i, j]] = v;
549        }
550    }
551    let (evals, evecs) = sym.eigh(Side::Lower).ok()?;
552    let max_abs = evals.iter().fold(
553        0.0_f64,
554        |acc, &v| if v.is_finite() { acc.max(v.abs()) } else { acc },
555    );
556    if !(max_abs.is_finite() && max_abs > 0.0) {
557        return None;
558    }
559    let floor = relative_floor * max_abs;
560    let deflate_floor = floor * (1.0 - SPECTRAL_DEFLATION_HYSTERESIS_FRACTION);
561    let mut conditioned = Array2::<f64>::zeros((n, n));
562    for eig_idx in 0..evals.len() {
563        let lambda = evals[eig_idx];
564        let lambda_conditioned = if !lambda.is_finite() || lambda <= 0.0 || lambda < deflate_floor {
565            1.0
566        } else {
567            lambda.max(floor)
568        };
569        for i in 0..n {
570            let vi = evecs[[i, eig_idx]];
571            if vi == 0.0 {
572                continue;
573            }
574            for j in 0..n {
575                conditioned[[i, j]] += lambda_conditioned * vi * evecs[[j, eig_idx]];
576            }
577        }
578    }
579    Some(conditioned)
580}
581
582pub(crate) fn factor_dense_reduced_schur(
583    schur: &Array2<f64>,
584    schur_pd_floor: Option<f64>,
585    unit_deflate_null_directions: bool,
586) -> Result<(Array2<f64>, Option<Array2<f64>>), ArrowSchurError> {
587    let (factor, floored_schur) = match cholesky_lower(schur) {
588        Ok(factor) => (factor, None),
589        Err(e) => {
590            // #1026/#1038 — every dense reduced-Schur factorization in the SAE
591            // path must honor the same opt-in spectral floor. Otherwise
592            // auxiliary entry points (mixed precision and cross-row IBP
593            // preconditioning) can reject the collapsed dead-atom subspace even
594            // though the main direct solve would floor it and continue.
595            //
596            // #1803 — Newton-step callers use the Levenberg-Marquardt PD floor
597            // (`spectral_pd_floored_schur`) so `Δβ` is stable. Evidence/log-det
598            // callers (`unit_deflate_null_directions`) instead deflate
599            // quotient/null directions to unit stiffness so they contribute the
600            // ρ-independent `log 1 = 0` to the Laplace normaliser rather than a
601            // ρ-dependent Occam reward for collapsed decoders.
602            match schur_pd_floor {
603                Some(relative_floor) => match if unit_deflate_null_directions {
604                    spectral_unit_deflated_schur(schur, relative_floor)
605                } else {
606                    spectral_pd_floored_schur(schur, relative_floor)
607                } {
608                    Some(floored) => (
609                        cholesky_lower(&floored).map_err(|floored_err| {
610                            ArrowSchurError::SchurFactorFailed {
611                                reason: format!(
612                                    "reduced Schur non-PD ({e}); spectral PD-floor \
613                                     reconstruction still non-PD: {floored_err}"
614                                ),
615                            }
616                        })?,
617                        Some(floored),
618                    ),
619                    None => {
620                        return Err(ArrowSchurError::SchurFactorFailed {
621                            reason: format!(
622                                "reduced Schur non-PD ({e}); spectral PD-floor declined \
623                                 (no usable spectrum)"
624                            ),
625                        });
626                    }
627                },
628                None => return Err(ArrowSchurError::SchurFactorFailed { reason: e }),
629            }
630        }
631    };
632    Ok((factor, floored_schur))
633}
634
635pub(crate) fn solve_dense_reduced_system(
636    schur: &Array2<f64>,
637    rhs_beta: &Array1<f64>,
638    options: &ArrowSolveOptions,
639    metric_weights: Option<&MetricWeights>,
640) -> Result<(Array1<f64>, Option<Array2<f64>>, PcgDiagnostics), ArrowSchurError> {
641    let (factor, floored_schur) =
642        factor_dense_reduced_schur(schur, options.schur_pd_floor, options.tolerate_ill_conditioning)?;
643    if let Some(floored) = floored_schur {
644        let direct = mixed_precision_reduced_beta(&floored, &factor, rhs_beta, options)
645            .unwrap_or_else(|| cholesky_solve_vector(&factor, rhs_beta));
646        if step_inside_trust_region(direct.view(), options.trust_region.radius, metric_weights) {
647            return Ok((direct, Some(factor), PcgDiagnostics::default()));
648        }
649        let identity = IdentityPreconditioner;
650        let (delta, diag) = steihaug_dense_system(
651            &floored,
652            rhs_beta,
653            &identity,
654            &ArrowPcgOptions {
655                max_iterations: options.trust_region.max_iterations,
656                relative_tolerance: options.trust_region.steihaug_relative_tolerance,
657            },
658            &options.trust_region,
659            metric_weights,
660        )?;
661        return Ok((delta, Some(factor), diag));
662    }
663    // Ill-conditioned-but-PD Schur guard. The per-row factor checks reject
664    // any single barely-PD H_tt^(i) block, but the reduced Schur complement
665    //     S = H_ββ + ridge_β·I − Σ_i H_tβ^(i)ᵀ (H_tt^(i))⁻¹ H_tβ^(i)
666    // accumulates the (H_tt^(i))⁻¹ contributions of every row in finite
667    // precision. With many weak-but-admissible rows those terms can sum to a
668    // Schur matrix whose Cholesky succeeds yet whose condition number is far
669    // past the safe inversion regime, so `cholesky_solve_vector` yields an
670    // inaccurate Δβ that is silently propagated to the Newton step. Apply the
671    // same diagonal-ratio κ proxy used per-row to the reduced factor and treat
672    // an over-threshold estimate as a Schur-stability failure: `SchurFactorFailed`
673    // is already recoverable in `solve_with_lm_escalation_inner`, so this lifts
674    // `ridge_beta` and re-forms a better-conditioned Schur. This guard is
675    // exclusive to the dense Direct / SqrtBA path (the only caller of this
676    // function); the inexact-PCG path tolerates higher κ(S) and is unaffected.
677    // Evidence/log-det-only callers (`tolerate_ill_conditioning`) skip this
678    // rejection: the factor is genuinely PD (Cholesky above succeeded), so its
679    // diagonal still yields an exact `log|S|`, and an inaccurate Δβ is harmless
680    // because the step is discarded.
681    if !options.tolerate_ill_conditioning {
682        let schur_kappa = cholesky_factor_kappa_estimate(&factor);
683        if !schur_kappa.is_finite() || schur_kappa > safe_spd_kappa_max(schur.nrows()) {
684            // #1026 — over-complete SAE dictionaries park surplus atoms dead
685            // (β_k → 0), so the reduced Schur is PD (the Cholesky above succeeded)
686            // but ILL-CONDITIONED: the dead decoder subspace carries near-zero
687            // eigenvalues while the live subspace is healthy. The kappa gate's
688            // concern is an inaccurate Δβ from accumulated (H_tt)⁻¹ contamination —
689            // but on the dead subspace the correct Δβ IS ≈0 (those atoms have no
690            // signal), so the only "inaccuracy" is in directions whose true step is
691            // zero. When the spectral PD-floor is enabled (the SAE solve path),
692            // clamp exactly those collapsed directions up to `floor·max(λ)` and
693            // solve against the floored Schur: the live subspace keeps its EXACT
694            // Newton component, the dead subspace is damped to ≈0, and κ is bounded
695            // so Δβ is accurate where it matters. This is the same conditioning the
696            // non-PD branch above applies; here it also covers the PD-but-ill-
697            // conditioned case so the LM loop does not exhaust `ridge_β` trying to
698            // (futilely) lift a fundamentally rank-deficient dead-atom subspace.
699            // Without the floor (BA / non-SAE callers) the strict refusal stands.
700            if let Some(relative_floor) = options.schur_pd_floor
701                && let Some(floored) = spectral_pd_floored_schur(schur, relative_floor)
702                && let Ok(floored_factor) = cholesky_lower(&floored)
703            {
704                let direct =
705                    mixed_precision_reduced_beta(&floored, &floored_factor, rhs_beta, options)
706                        .unwrap_or_else(|| cholesky_solve_vector(&floored_factor, rhs_beta));
707                if step_inside_trust_region(
708                    direct.view(),
709                    options.trust_region.radius,
710                    metric_weights,
711                ) {
712                    return Ok((direct, Some(floored_factor), PcgDiagnostics::default()));
713                }
714                let identity = IdentityPreconditioner;
715                let (delta, diag) = steihaug_dense_system(
716                    &floored,
717                    rhs_beta,
718                    &identity,
719                    &ArrowPcgOptions {
720                        max_iterations: options.trust_region.max_iterations,
721                        relative_tolerance: options.trust_region.steihaug_relative_tolerance,
722                    },
723                    &options.trust_region,
724                    metric_weights,
725                )?;
726                return Ok((delta, Some(floored_factor), diag));
727            }
728            return Err(ArrowSchurError::SchurFactorFailed {
729                reason: format!(
730                    "reduced Schur complement Cholesky succeeded but is ill-conditioned \
731                     (kappa_estimate={schur_kappa:e}); accumulated per-row \
732                     (H_tt)⁻¹ contamination would yield an inaccurate Δβ"
733                ),
734            });
735        }
736    }
737    // Reduced-system solve. The f64 `factor` is always retained and returned —
738    // its diagonal is the EXACT `log|S|` the evidence path reads, so the logdet
739    // stays f64 regardless of how Δβ is computed (#1014 invariant). When the
740    // streaming/residency path enabled certified mixed precision, the Δβ solve
741    // itself runs f32-then-f64-refined (κ-gated, with the f64 triangular solve
742    // as the automatic fallback); the certificate is the f64 backward error.
743    let direct = mixed_precision_reduced_beta(schur, &factor, rhs_beta, options)
744        .unwrap_or_else(|| cholesky_solve_vector(&factor, rhs_beta));
745    if step_inside_trust_region(direct.view(), options.trust_region.radius, metric_weights) {
746        return Ok((direct, Some(factor), PcgDiagnostics::default()));
747    }
748
749    // Ceres-style trust-region correction: once the dense BA solve proposes a
750    // step outside the trust ball, Steihaug-CG returns the boundary point
751    // without requiring a second dense factorization.
752    let identity = IdentityPreconditioner;
753    let (delta, diag) = steihaug_dense_system(
754        schur,
755        rhs_beta,
756        &identity,
757        &ArrowPcgOptions {
758            max_iterations: options.trust_region.max_iterations,
759            relative_tolerance: options.trust_region.steihaug_relative_tolerance,
760        },
761        &options.trust_region,
762        metric_weights,
763    )?;
764    Ok((delta, Some(factor), diag))
765}
766
767/// Solve an externally accumulated dense reduced β system
768/// `S Δβ = rhs_β` with the same LM-style ridge escalation the full-batch
769/// driver applies: on a `SchurFactorFailed` (non-PD or ill-conditioned `S`),
770/// geometrically grow a proximal ridge on `S`'s diagonal and retry.
771///
772/// Used by the SAE streaming joint fit, which accumulates `S` and `rhs_β` over
773/// re-materialized row chunks (via [`StreamingArrowSchur::take_accumulators`])
774/// and must solve the single global reduced system without a per-row
775/// `ArrowSchurSystem`. `S` is symmetrized from its lower triangle before each
776/// factorization. `base_ridge_beta` is folded into the caller's `S` already;
777/// this routine only adds the *escalation* ridge on top.
778pub fn solve_streaming_reduced_beta(
779    s_acc: &Array2<f64>,
780    rhs_beta: &Array1<f64>,
781    options: &ArrowSolveOptions,
782) -> Result<Array1<f64>, ArrowSchurError> {
783    let mut proximal_ridge = 0.0_f64;
784    let mut last_err: Option<ArrowSchurError> = None;
785    for attempt in 0..=DEFAULT_PROXIMAL_MAX_ATTEMPTS {
786        let mut schur = s_acc.clone();
787        symmetrize_upper_from_lower(&mut schur);
788        if proximal_ridge > 0.0 {
789            for j in 0..schur.nrows() {
790                schur[[j, j]] += proximal_ridge;
791            }
792        }
793        // Reduced K-system on device: Jacobi-preconditioned CG over the dense
794        // symmetric `S`. The `O(K²)` `S·p` matvec runs device-side; only the
795        // K-vectors cross the boundary per CG iteration. This is the dominant
796        // cost of the streaming SAE joint fit at `K = 100K`. Any device-side
797        // failure (`Unavailable`, non-PD Jacobi diagonal) falls through to the
798        // CPU `solve_dense_reduced_system`, which then drives the same proximal
799        // ridge escalation. A genuine device PD failure is non-recoverable for
800        // this attempt's `schur`, so we let the CPU path re-confirm and escalate.
801        if gam_gpu::device_runtime::GpuRuntime::is_available() {
802            match crate::gpu_kernels::arrow_schur::solve_reduced_beta_pcg(
803                &schur,
804                rhs_beta,
805                options.trust_region.max_iterations,
806                options.trust_region.steihaug_relative_tolerance,
807            ) {
808                Ok(delta_beta) => return Ok(delta_beta),
809                Err(crate::gpu_kernels::arrow_schur::ArrowSchurGpuFailure::Unavailable) => {}
810                Err(_) => {
811                    // Device declined this `schur` (e.g. non-PD Jacobi diag);
812                    // let the CPU path confirm and escalate the proximal ridge.
813                }
814            }
815        }
816        match solve_dense_reduced_system(&schur, rhs_beta, options, None) {
817            Ok((delta_beta, _factor, _diag)) => return Ok(delta_beta),
818            Err(err) => {
819                let recoverable = matches!(
820                    err,
821                    ArrowSchurError::SchurFactorFailed { .. }
822                        | ArrowSchurError::PcgFailed { .. }
823                        | ArrowSchurError::UnboundedNegativeCurvature { .. }
824                );
825                last_err = Some(err);
826                if !recoverable || attempt == DEFAULT_PROXIMAL_MAX_ATTEMPTS {
827                    break;
828                }
829                proximal_ridge = if proximal_ridge == 0.0 {
830                    DEFAULT_PROXIMAL_INITIAL_RIDGE
831                } else {
832                    proximal_ridge * DEFAULT_PROXIMAL_RIDGE_GROWTH
833                };
834            }
835        }
836    }
837    Err(last_err.expect("escalation loop set last_err on failure"))
838}
839
840pub(crate) fn step_inside_trust_region(
841    step: ArrayView1<'_, f64>,
842    radius: f64,
843    metric_weights: Option<&MetricWeights>,
844) -> bool {
845    !radius.is_finite() || metric_norm(step, metric_weights) <= radius
846}
847
848/// Below this row count the per-row Schur loop stays sequential: the rayon
849/// fan-out (chunk dispatch + the deterministic per-chunk length-`K` reduction)
850/// costs more than it saves for the handful-of-rows arrow systems that dominate
851/// the non-SAE callers. Above it — the SAE LLM shape (`n` in the thousands,
852/// wide border `k`) that issue #1017 names — the per-row `H_βt (H_tt)⁻¹ H_tβ x`
853/// contributions are the matvec's whole cost and parallelize cleanly.
854pub(crate) const SCHUR_MATVEC_PARALLEL_ROW_MIN: usize = 256;
855
856/// Below this border width `k` the dense `H_ββ` penalty-prologue GEMV stays
857/// sequential: parallelizing a `k×k` matvec only pays once `k²` is large enough
858/// to dwarf the rayon fan-out, which for the arrow callers with narrow borders
859/// it never is. At the SAE LLM border (`k` in the low thousands) the `O(k²)`
860/// prologue is ≈4M flops/CG-iteration and was the serial Amdahl ceiling on the
861/// otherwise per-row-parallel matvec (#1017), so it crosses this threshold and
862/// fans out. 512 keeps the prologue serial for every non-SAE arrow system while
863/// engaging it for the wide SAE/Qwen borders the issue targets.
864pub(crate) const SCHUR_PROLOGUE_PARALLEL_K_MIN: usize = 512;
865
866/// Device-residency CPU analogue for the SAE reduced-Schur matvec (#1017).
867///
868/// In the production SAE joint fit the per-row cross-block factors as
869/// `H_tβ^(i) = L_i P_i`, where `L_i` (`q_i × p`) is the row's local
870/// assignment/coordinate Jacobian and `P_i` (`p × K`, sparse) gathers the
871/// active atoms' decoder blocks (`P_i x = Σ_s φ_s · x[base_s .. base_s+p]`).
872/// The reduced-Schur point-elimination contribution of one row is therefore
873///
874/// ```text
875/// S_i x = H_βt^(i) (H_tt^(i)+ρ_t I)⁻¹ H_tβ^(i) x
876///       = P_iᵀ · [ L_iᵀ (H_tt^(i)+ρ_t I)⁻¹ L_i ] · P_i x
877///       = P_iᵀ G_i (P_i x),      G_i := L_iᵀ (H_tt^(i)+ρ_t I)⁻¹ L_i   (p×p).
878/// ```
879///
880/// The block `G_i = L_iᵀ Y_i` depends only on the assembled per-row blocks and
881/// the (already-computed, solve-stable) `H_tt` factor — NOT on the CG iterate
882/// `x`. The generic [`schur_matvec`] re-walks `apply_jbeta → apply_l →
883/// solve(d×d) → apply_l_t → scatter` on every CG iteration; this object **stages
884/// the factors `(L_i, Y_i)` once per CG solve** (the "upload X once" residency
885/// mechanism, applied on CPU to the matvec rather than a dense factorization),
886/// turning each subsequent matvec into a sparse gather → two `di×p` GEMVs →
887/// sparse scatter, with no per-iteration triangular solve and no operator-closure
888/// re-walk. It never materialises the dense `p×p` product: `di ≪ p` for SAE
889/// rows, so the factored apply is `2·support_i·p + 2·di·p` flops/row — the two
890/// `di·p` GEMVs PLUS the `support_i·p` sparse gather (`P_i x`) and `support_i·p`
891/// sparse scatter (`P_iᵀ prod`) — versus the dense `p²` block apply, and
892/// `O(n·di·p)` memory (vs `O(n·p²)` ≈ 67 GB at the Qwen shape — the dense form
893/// is OOM). For dense/full active support `support_i` can scale with the active
894/// β-columns, so the gather/scatter term is NOT negligible and is counted here.
895///
896/// Numerically identical to the generic path up to floating-point reassociation
897/// (it differentiates and accumulates the SAME quotient). It is deterministic
898/// run-to-run and within the reassociation margin of the serial path, so the
899/// criterion ranking across topology candidates is stable except for candidates
900/// separated by less than that f64 margin, where reassociation can flip the
901/// near-tie winner — it is NOT an exact no-move guarantee (#1211).
902pub(crate) struct SaeResidentReducedSchur {
903    /// Decoder output dimension `p` (the side length of every `G_i = L_iᵀ Y_i`).
904    pub(crate) p: usize,
905    /// Per-row **factored** residency: `(L_i, Y_i)`, each stored row-major as a
906    /// `di × p` slab (`L_i` = local Jacobian, `Y_i = (H_tt^(i)+ρ_t I)⁻¹ L_i`).
907    /// The reduced block is `G_i = L_iᵀ Y_i` (`p×p`, symmetric PSD), but it has
908    /// rank ≤ `di` and `di ≪ p` for SAE rows (the per-row latent dim is 1–2
909    /// while `p` is the decoder block width, ~2048). Materialising the dense
910    /// `p×p` block would cost `O(n·p²)` memory (≈67 GB at the Qwen shape) and
911    /// `p²` flops per matvec/row; the factored form costs `O(n·di·p)` memory and
912    /// `2·support_i·p + 2·di·p` flops/row, applying `G_i v = L_iᵀ (Y_i v)`
913    /// (sparse gather over `support_i` atoms → `di`-length GEMV → `p`-length
914    /// GEMV → sparse scatter over `support_i` atoms). The `2·support_i·p`
915    /// gather/scatter term is part of the per-row cost — for dense/full support
916    /// `support_i` scales with active β-columns — and is not dropped. A row with
917    /// empty active support / degenerate dims gets `di = 0` and is skipped.
918    /// `(di, L_i, Y_i)` per row; `L_i`/`Y_i` are `di·p`-length row-major buffers.
919    pub(crate) rows: Vec<ResidentRowFactor>,
920    /// Per-row active atom support `(β-block base index, φ weight)`, shared with
921    /// the assembler's [`DeviceSaePcgData`] (no re-clone of the index lists).
922    pub(crate) a_phi: Arc<[Vec<(usize, f64)>]>,
923    /// #1033: per-row local Jacobian `L_i` (row-major `di × p`), SHARED via `Arc`
924    /// with the assembler's [`DeviceSaePcgData`] rather than copied into each
925    /// `ResidentRowFactor`. The staged factor previously held its own verbatim
926    /// row-major copy of `data.local_jac[row]` — a second full `O(n·di·p)` slab
927    /// for zero benefit (the bytes and the `di × p` layout are identical). The
928    /// matvec now reads `L_i = &self.local_jac[row]` directly; only the SOLVED
929    /// factor `Y_i = (H_tt+ρI)⁻¹ L_i` (genuinely new data) stays per-row. Reads
930    /// are byte-for-byte the former `rf.l` (same slab, same `r·p + c` indexing),
931    /// so the matvec/preconditioner output is bit-identical.
932    pub(crate) local_jac: Arc<[Vec<f64>]>,
933}
934
935/// Factored per-row residency block: `G_i = L_iᵀ Y_i` kept as its `di×p` factors
936/// so the matvec never materialises the dense `p×p` product. The local Jacobian
937/// factor `L_i` is NOT stored here — it is shared via
938/// [`SaeResidentReducedSchur::local_jac`] (`&local_jac[row]`); only the solved
939/// `Y_i` is per-row. See [`SaeResidentReducedSchur`].
940pub(crate) struct ResidentRowFactor {
941    /// Row latent dimension `di` (the inner contraction width). `0` ⇒ skipped.
942    pub(crate) di: usize,
943    /// `Y_i = (H_tt^(i)+ρ_t I)⁻¹ L_i` row-major `di × p`. Empty when `di == 0`.
944    pub(crate) y: Vec<f64>,
945}
946
947impl SaeResidentReducedSchur {
948    /// Stage the per-row `G_i = L_iᵀ (H_tt^(i)+ρ_t I)⁻¹ L_i` blocks once, from
949    /// the SAE structure (`DeviceSaePcgData`: `p`, per-row `a_phi`, per-row
950    /// row-major `local_jac` = `L_i`) and the already-factored `H_tt` slab.
951    ///
952    /// Returns `None` when the structure does not match (degenerate `p`, row
953    /// count mismatch) so the caller falls back to the generic matvec. Row
954    /// builds are independent and run under the same deterministic rayon
955    /// discipline as the matvec (each `G_i` is self-contained — no cross-row
956    /// reduction — so there is no ordering subtlety).
957    /// `ridge_t` is NOT a parameter: it is already folded into the factored
958    /// blocks `htt_factors` carry (they factor `H_tt^(i) + ridge_t·I` — see
959    /// `factor_blocks`), so solving against the factor yields `(H_tt^(i)+ρ_t I)⁻¹`
960    /// exactly. The residency block is a pure function of the factor and `L_i`.
961    pub(crate) fn build<B: BatchedBlockSolver + Sync>(
962        sys: &ArrowSchurSystem,
963        htt_factors: &ArrowFactorSlab,
964        backend: &B,
965    ) -> Option<Self> {
966        let data = sys.device_sae_pcg.as_ref()?;
967        let p = data.p;
968        let n = sys.rows.len();
969        if p == 0
970            || sys.htbeta_dense_supplement
971            || data.a_phi.len() != n
972            || data.local_jac.len() != n
973        {
974            return None;
975        }
976        let empty = || ResidentRowFactor {
977            di: 0,
978            y: Vec::new(),
979        };
980        let build_row = |row: usize| -> ResidentRowFactor {
981            let di = sys.row_dims[row];
982            let jac = &data.local_jac[row];
983            // q_i = len/p; must match the row's latent dimension di.
984            if p == 0 || jac.len() != di * p || di == 0 {
985                return empty();
986            }
987            // L_i as a (di × p) matrix (row-major in `local_jac`).
988            let l_i = match ArrayView2::from_shape((di, p), jac.as_slice()) {
989                Ok(v) => v.to_owned(),
990                Err(_) => return empty(),
991            };
992            // Solve (H_tt+ρ_t I) Y = L_i for Y (di × p): one batched back-solve
993            // over the p columns against the cached factor. Stage `(L_i, Y_i)`
994            // — NOT the dense `p×p` product `G_i = L_iᵀ Y_i` — so storage and the
995            // matvec stay `O(di·p)` instead of `O(p²)` (`di ≪ p` for SAE rows).
996            let y = backend.solve_block_matrix(htt_factors.factor(row), l_i.view());
997            // Flatten the SOLVED factor to a `di × p` row-major buffer (iteration
998            // over a standard-layout view is row-major regardless of the source
999            // strides, so the hot loop can index `r*p + c` directly). `L_i` is NOT
1000            // copied — the matvec reads it from the shared `local_jac` slab (it is
1001            // byte-for-byte `data.local_jac[row]`).
1002            let y_flat: Vec<f64> = y.iter().copied().collect();
1003            ResidentRowFactor { di, y: y_flat }
1004        };
1005        let rows: Vec<ResidentRowFactor> =
1006            if n >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
1007                use rayon::prelude::*;
1008                (0..n).into_par_iter().map(build_row).collect()
1009            } else {
1010                (0..n).map(build_row).collect()
1011            };
1012        Some(Self {
1013            p,
1014            rows,
1015            a_phi: data.a_phi_shared(),
1016            local_jac: data.local_jac_shared(),
1017        })
1018    }
1019
1020    /// Accumulate one row's `S_i x = P_iᵀ G_i (P_i x) = P_iᵀ L_iᵀ Y_i (P_i x)`
1021    /// into `acc` (length `K`). `gather`/`prod` are caller-owned length-`p`
1022    /// buffers and `w` a caller-owned `≥ max_i di`-length buffer, all reused
1023    /// across rows to keep the hot loop allocation-free. The matvec applies the
1024    /// factored block in four steps: sparse gather `P_i x = Σ_s φ_s·x[base_s..]`
1025    /// (`support_i·p` flops), `w = Y_i·(P_i x)` (`di`-length, `di·p` flops),
1026    /// `prod = L_iᵀ·w` (`p`-length, `di·p` flops), and sparse scatter
1027    /// `acc += P_iᵀ prod` (`support_i·p` flops) — `2·support_i·p + 2·di·p`
1028    /// total, never the dense `p²` product. The gather/scatter `2·support_i·p`
1029    /// term is counted: it is not dominated by the GEMVs when the active support
1030    /// is wide.
1031    #[inline]
1032    pub(crate) fn row_into(
1033        &self,
1034        row: usize,
1035        x: &Array1<f64>,
1036        acc: &mut Array1<f64>,
1037        gather: &mut [f64],
1038        prod: &mut [f64],
1039        w: &mut [f64],
1040    ) {
1041        let rf = &self.rows[row];
1042        let di = rf.di;
1043        if di == 0 {
1044            return;
1045        }
1046        let p = self.p;
1047        let support = &self.a_phi[row];
1048        if support.is_empty() {
1049            return;
1050        }
1051        // Slice `x`/`acc` ONCE so the per-support gather/scatter (the dominant
1052        // `support·p` terms for wide active support) run over contiguous `f64`
1053        // slices — the compiler can prove unit stride and emit vectorized FMA,
1054        // where the former `x[base+j]`/`acc[base+j]` ndarray element indexing
1055        // forced a per-element strided lookup + bounds check that blocked
1056        // autovectorization. Every accumulation order is unchanged, so the
1057        // result is bit-identical to the ndarray-indexed form.
1058        let x_slice = x.as_slice().expect("resident matvec x must be contiguous");
1059        // P_i x = Σ_s φ_s · x[base_s .. base_s+p]   (length p).
1060        let gather = &mut gather[..p];
1061        for v in gather.iter_mut() {
1062            *v = 0.0;
1063        }
1064        for &(base, phi) in support {
1065            if phi == 0.0 {
1066                continue;
1067            }
1068            let xrow = &x_slice[base..base + p];
1069            for (g, &xv) in gather.iter_mut().zip(xrow) {
1070                *g += phi * xv;
1071            }
1072        }
1073        // w = Y_i · (P_i x)   (di × p GEMV → length di).  Y_i row-major di×p.
1074        for r in 0..di {
1075            let yrow = &rf.y[r * p..r * p + p];
1076            let mut s = 0.0_f64;
1077            for (&yv, &gv) in yrow.iter().zip(gather.iter()) {
1078                s += yv * gv;
1079            }
1080            w[r] = s;
1081        }
1082        // prod = L_iᵀ · w   (p × di GEMV → length p).  L_i row-major di×p, so
1083        // L_iᵀ[j,r] = L_i[r,j]; accumulate column-by-column over the di rows.
1084        // `L_i` is the shared `local_jac[row]` slab (#1033) — byte-for-byte the
1085        // former per-row `rf.l` copy.
1086        let l_i = &self.local_jac[row];
1087        let prod = &mut prod[..p];
1088        for v in prod.iter_mut() {
1089            *v = 0.0;
1090        }
1091        for r in 0..di {
1092            let lrow = &l_i[r * p..r * p + p];
1093            let wr = w[r];
1094            for (pj, &lj) in prod.iter_mut().zip(lrow) {
1095                *pj += lj * wr;
1096            }
1097        }
1098        // acc += P_iᵀ prod = scatter φ_s · prod into base_s blocks.
1099        let acc_slice = acc
1100            .as_slice_mut()
1101            .expect("resident matvec acc must be contiguous");
1102        for &(base, phi) in support {
1103            if phi == 0.0 {
1104                continue;
1105            }
1106            let arow = &mut acc_slice[base..base + p];
1107            for (a, &pv) in arow.iter_mut().zip(prod.iter()) {
1108                *a += phi * pv;
1109            }
1110        }
1111    }
1112
1113    /// Max row latent dim `di` across resident rows — the size of the `w`
1114    /// scratch the matvec needs for the inner `Y_i·(P_i x)` GEMV.
1115    pub(crate) fn max_di(&self) -> usize {
1116        self.rows.iter().map(|r| r.di).max().unwrap_or(0)
1117    }
1118}
1119
1120/// Reduced-Schur matvec `out = S·x` with an optional pre-staged SAE residency
1121/// operator. When `resident` is `Some`, the per-row point-elimination term is
1122/// applied through the resident `p×p` blocks (#1017 CPU residency); otherwise it
1123/// falls back to the generic per-row `apply → solve → transpose` path. Both
1124/// routes accumulate the SAME reduced operator
1125/// `S = H_ββ + ρ_β I − Σ_i H_βt^(i)(H_tt^(i))⁻¹H_tβ^(i)`.
1126pub(crate) fn schur_matvec<B: BatchedBlockSolver + Sync>(
1127    sys: &ArrowSchurSystem,
1128    htt_factors: &ArrowFactorSlab,
1129    ridge_beta: f64,
1130    x: &Array1<f64>,
1131    out: &mut Array1<f64>,
1132    backend: &B,
1133    resident: Option<&SaeResidentReducedSchur>,
1134) {
1135    // `steihaug_cg` reuses one output buffer across iterations and requires
1136    // `matvec` to ASSIGN every entry of `out` (the contract `dense_matvec`
1137    // upholds). This routine builds `S·x` purely by accumulation
1138    // (`penalty_matvec_add`, `out[a] += ridge·x`, `out[a] -= neg_contrib`), so it
1139    // MUST clear `out` first. Without this, iteration n>0 returns `S·x` plus the
1140    // previous call's `S·p`, the PCG solves a corrupted reduced system, and the
1141    // resulting Newton step is inconsistent with the assembled gradient
1142    // (g·δ ≈ 0 — a non-descent direction that defeats the line search).
1143    out.fill(0.0);
1144    let k = sys.k;
1145    // Top-level (not nested in a rayon worker) and big enough to amortize the
1146    // fan-out: the single gate that authorizes BOTH the dense penalty-prologue
1147    // GEMV and the per-row point-elimination loop to go parallel. The topology
1148    // race fans candidates with `run_topology_race_parallel`, so inside a worker
1149    // both stay sequential (no nested-rayon oversubscription).
1150    let parallel =
1151        sys.rows.len() >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
1152    // Route the penalty-side (H_ββ + ridge·I) x product through the prologue:
1153    // no Arc-clone hot-path cost when penalty_op is None (falls back to hbb
1154    // inline); the dense fallback fans across cores at the wide SAE border (#1017).
1155    {
1156        let x_slice = x.as_slice().expect("x must be contiguous");
1157        let out_slice = out.as_slice_mut().expect("out must be contiguous");
1158        sys.penalty_ridge_prologue_into(x_slice, ridge_beta, out_slice, parallel);
1159    }
1160    // The reduced-Schur point-elimination term: `out -= Σ_i H_βt^(i) (H_tt^(i))⁻¹
1161    // H_tβ^(i) x`. Each row contributes an independent length-`K` vector, so for
1162    // the SAE LLM shape (#1017) this is the matvec's whole cost and is
1163    // embarrassingly parallel. Run it under rayon over fixed row chunks, summing
1164    // the per-chunk partials in chunk order so the f64 reduction is bit-identical
1165    // run-to-run regardless of thread scheduling (the #1017 verification gate).
1166    // This is deterministic and within the chunk-reassociation margin of serial,
1167    // so the criterion ranking is stable except for candidates that tie inside
1168    // that f64 margin — not an exact no-move guarantee (#1211). Stay
1169    // sequential when already inside a rayon worker (the topology race fans
1170    // candidates with `run_topology_race_parallel`) to avoid nested-rayon
1171    // oversubscription — the same guard `HyperOperator::mul_mat` uses. The
1172    // `parallel` gate above authorizes this loop too.
1173    let p = resident.map(|r| r.p).unwrap_or(0);
1174    if parallel {
1175        use rayon::prelude::*;
1176        const CHUNK: usize = 64;
1177        let n = sys.rows.len();
1178        let partials: Vec<Array1<f64>> = (0..n)
1179            .into_par_iter()
1180            .chunks(CHUNK)
1181            .map(|idxs| {
1182                let mut acc = Array1::<f64>::zeros(k);
1183                if let Some(res) = resident {
1184                    // Resident path: each matvec is gather → factored di×p GEMVs
1185                    // → scatter, reading only the pre-staged `(L_i, Y_i)` (no
1186                    // per-iteration solve, no dense p×p block).
1187                    let mut gather = vec![0.0_f64; p];
1188                    let mut prod = vec![0.0_f64; p];
1189                    let mut w = vec![0.0_f64; res.max_di()];
1190                    for i in idxs {
1191                        res.row_into(i, x, &mut acc, &mut gather, &mut prod, &mut w);
1192                    }
1193                } else {
1194                    let mut local = Array1::<f64>::zeros(sys.d);
1195                    for i in idxs {
1196                        schur_matvec_row_into(
1197                            sys,
1198                            htt_factors,
1199                            x,
1200                            backend,
1201                            i,
1202                            &mut local,
1203                            &mut acc,
1204                        );
1205                    }
1206                }
1207                acc
1208            })
1209            .collect();
1210        // Deterministic ordered reduction: fold chunk partials left-to-right.
1211        for acc in &partials {
1212            for a in 0..k {
1213                out[a] -= acc[a];
1214            }
1215        }
1216    } else if let Some(res) = resident {
1217        let mut acc = Array1::<f64>::zeros(k);
1218        let mut gather = vec![0.0_f64; p];
1219        let mut prod = vec![0.0_f64; p];
1220        let mut w = vec![0.0_f64; res.max_di()];
1221        for i in 0..sys.rows.len() {
1222            res.row_into(i, x, &mut acc, &mut gather, &mut prod, &mut w);
1223        }
1224        for a in 0..k {
1225            out[a] -= acc[a];
1226        }
1227    } else {
1228        // Allocate scratch at max_d; per-row slice is `..di`.
1229        let mut local = Array1::<f64>::zeros(sys.d);
1230        let mut neg_contrib = Array1::<f64>::zeros(k);
1231        for i in 0..sys.rows.len() {
1232            neg_contrib.fill(0.0);
1233            schur_matvec_row_into(
1234                sys,
1235                htt_factors,
1236                x,
1237                backend,
1238                i,
1239                &mut local,
1240                &mut neg_contrib,
1241            );
1242            for a in 0..k {
1243                out[a] -= neg_contrib[a];
1244            }
1245        }
1246    }
1247}
1248
1249/// Matrix-free reduced-Schur log-determinant `log|S|` via Stochastic Lanczos
1250/// Quadrature on the exact `schur_matvec` apply `v ↦ S·v`, where
1251/// `S = (H_ββ + ρ_β I) − Σ_i H_βt^(i)(H_tt^(i)+ρ_t I)⁻¹H_tβ^(i)` is the SPD
1252/// reduced Schur. **The dense `k×k` `S` is NEVER formed.**
1253///
1254/// This is the memory-matrix-free evidence path for the massive-K manifold SAE.
1255/// The dense evidence routes assemble `S` explicitly (`O(k²)` ≈ 8 GB at the
1256/// K=32k border) and Cholesky-factor it (`O(k³/3)`) purely to read `Σ 2·log Lᵢᵢ`;
1257/// that dense assembly + factor is the massive-K wall (both dense evidence
1258/// routes REFUSE above the in-core budget). Here peak memory is `O(k)` — the SLQ
1259/// Rademacher probe and Lanczos basis vectors — and the cost is
1260/// `O(num_probes·lanczos_steps · matvec)`, each matvec the same `O(n·d·k)`
1261/// reduced-Schur apply the PCG hot loop already runs. Deterministic for a fixed
1262/// `(sys, htt_factors, ρ_β, resident, num_probes, lanczos_steps, seed)` so the
1263/// REML evidence outer loop stays reproducible.
1264///
1265/// `htt_factors` are the per-row `(H_tt^(i)+ρ_t I)` Cholesky factors; `resident`
1266/// is the optional pre-staged SAE residency operator (`None` for the framed /
1267/// closure `H_tβ` path). SLQ is an ESTIMATE — the same accuracy contract the
1268/// device seam already accepts for `k ≥ SCHUR_SLQ_LOGDET_MIN_DIM`; callers that
1269/// need the exact dense log-det at small `k` must stay on the dense route.
1270///
1271/// Crate-internal because the `resident` parameter carries the `pub(crate)`
1272/// [`SaeResidentReducedSchur`] operator; cross-crate callers use the
1273/// [`matrix_free_arrow_evidence_log_det`] convenience, which stages residency
1274/// internally and exposes no crate-private type.
1275pub(crate) fn slq_reduced_schur_log_det<B: BatchedBlockSolver + Sync>(
1276    sys: &ArrowSchurSystem,
1277    htt_factors: &ArrowFactorSlab,
1278    ridge_beta: f64,
1279    backend: &B,
1280    resident: Option<&SaeResidentReducedSchur>,
1281    num_probes: usize,
1282    lanczos_steps: usize,
1283    seed: u64,
1284) -> SlqLogDet {
1285    let k = sys.k;
1286    slq_logdet(
1287        k,
1288        |v| {
1289            // `schur_matvec` clears and fully assigns `out`, so a fresh zeroed
1290            // buffer per apply is correct; the probes fan across rayon workers
1291            // (in `slq_logdet`), and `schur_matvec`'s own row parallelism is
1292            // guarded off inside a worker, so there is no nested oversubscription.
1293            let x = v.to_owned();
1294            let mut out = Array1::<f64>::zeros(k);
1295            schur_matvec(
1296                sys,
1297                htt_factors,
1298                ridge_beta,
1299                &x,
1300                &mut out,
1301                backend,
1302                resident,
1303            );
1304            out
1305        },
1306        num_probes,
1307        lanczos_steps,
1308        seed,
1309    )
1310}
1311
1312/// One-call matrix-free arrow evidence log-determinant for an assembled system.
1313///
1314/// Factors the per-row `H_tt^(i)+ρ_t I` blocks (accumulating
1315/// `log_det_tt = Σ_i Σ_axis 2·log Lᵢᵢ` from the Cholesky diagonals — the cheap
1316/// `O(n·d³)` t-tier term), stages the SAE residency operator when the system
1317/// carries `device_sae_pcg` full-`B` data, and estimates `log|S|` via
1318/// [`slq_reduced_schur_log_det`] with NO dense `k×k` Schur formed at any point.
1319///
1320/// Returns `(log_det_tt, log|S| SLQ estimate)`; the undamped joint evidence
1321/// log-det the Laplace normaliser needs is their sum. Uses the identical
1322/// [`factor_blocks_for_system`] the dense Direct evidence path uses (same gauge
1323/// deflation), so `log_det_tt` matches the dense convention exactly and only the
1324/// `k×k` Schur term is replaced by its matrix-free SLQ estimate.
1325pub fn matrix_free_arrow_evidence_log_det(
1326    sys: &ArrowSchurSystem,
1327    ridge_t: f64,
1328    ridge_beta: f64,
1329    options: &ArrowSolveOptions,
1330    num_probes: usize,
1331    lanczos_steps: usize,
1332    seed: u64,
1333) -> Result<(f64, SlqLogDet), ArrowSchurError> {
1334    let backend = CpuBatchedBlockSolver;
1335    let factorization = factor_blocks_for_system(sys, ridge_t, options, &backend)?;
1336    let htt_factors = factorization.factors;
1337    let mut log_det_tt = 0.0_f64;
1338    for row in 0..htt_factors.len() {
1339        let factor = htt_factors.factor(row);
1340        for axis in 0..factor.nrows() {
1341            log_det_tt += 2.0 * factor[[axis, axis]].ln();
1342        }
1343    }
1344    let resident = SaeResidentReducedSchur::build(sys, &htt_factors, &backend);
1345    let slq = slq_reduced_schur_log_det(
1346        sys,
1347        &htt_factors,
1348        ridge_beta,
1349        &backend,
1350        resident.as_ref(),
1351        num_probes,
1352        lanczos_steps,
1353        seed,
1354    );
1355    Ok((log_det_tt, slq))
1356}
1357
1358/// Accumulate one row's reduced-Schur point-elimination contribution
1359/// `H_βt^(i) (H_tt^(i))⁻¹ H_tβ^(i) x` (length `K`) into `acc`.
1360///
1361/// `local` is caller-owned `≥ sys.d`-length scratch (reused across rows to keep
1362/// the hot loop allocation-free); only `..di` is touched. `acc` is **added to**,
1363/// never cleared, so the caller controls whether contributions sum into a chunk
1364/// partial (parallel path) or a per-row buffer (sequential path).
1365#[inline]
1366pub(crate) fn schur_matvec_row_into<B: BatchedBlockSolver>(
1367    sys: &ArrowSchurSystem,
1368    htt_factors: &ArrowFactorSlab,
1369    x: &Array1<f64>,
1370    backend: &B,
1371    i: usize,
1372    local: &mut Array1<f64>,
1373    acc: &mut Array1<f64>,
1374) {
1375    let row = &sys.rows[i];
1376    let di = sys.row_dims[i];
1377    // H_tβ^(i) · x → local[..di], routed through sys.htbeta_matvec
1378    // when the dense block is absent.
1379    let mut local_i = local.slice_mut(ndarray::s![..di]).to_owned();
1380    local_i.fill(0.0);
1381    sys_htbeta_apply_row(sys, i, row, x.view(), &mut local_i);
1382    let solved = backend.solve_block_vector(htt_factors.factor(i), local_i.view());
1383    // H_βt^(i) · solved accumulates into acc (length k).  Routed through
1384    // sys.htbeta_matvec when needed.
1385    sys_htbeta_accumulate_transpose(sys, i, row, solved.view(), acc);
1386}
1387
1388/// One per-term block factor for the block-Jacobi Schur preconditioner.
1389///
1390/// Carries either a dense Cholesky factor (for PD blocks ≤ 256 columns) or
1391/// the scalar inverses for that block's diagonal as a fallback.
1392#[derive(Clone)]
1393pub(crate) enum BlockFactor {
1394    /// Cholesky L stored column-major via faer. `range` identifies the
1395    /// columns in the full K-vector this block covers.
1396    Chol {
1397        factor: FaerLlt<f64>,
1398        range: Range<usize>,
1399    },
1400    /// Scalar fallback: per-element `1/s_aa` for each column in `range`.
1401    Scalar {
1402        inv: Array1<f64>,
1403        range: Range<usize>,
1404    },
1405}
1406
1407impl std::fmt::Debug for BlockFactor {
1408    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1409        match self {
1410            BlockFactor::Chol { range, .. } => {
1411                write!(f, "BlockFactor::Chol {{ range: {:?} }}", range)
1412            }
1413            BlockFactor::Scalar { inv, range } => {
1414                write!(
1415                    f,
1416                    "BlockFactor::Scalar {{ inv.len: {}, range: {:?} }}",
1417                    inv.len(),
1418                    range
1419                )
1420            }
1421        }
1422    }
1423}
1424
1425/// Block-Jacobi Schur preconditioner for BA's inexact reduced-system PCG.
1426///
1427/// When [`ArrowSchurSystem::block_offsets`] is populated (via
1428/// [`ArrowSchurSystem::set_block_offsets`]) and the largest block has ≤ 256
1429/// columns, builds one small dense Schur block per term, factors it with
1430/// Cholesky (faer LLT), and applies the preconditioner as per-block
1431/// triangular solves.  Non-PD blocks fall back to scalar diagonal inversion
1432/// for that block only.  When `block_offsets` is empty or the largest block
1433/// exceeds 256 columns the preconditioner reduces to pure scalar-diagonal
1434/// Jacobi (pre-#283 behaviour), so callers that have not called
1435/// `set_block_offsets` are unaffected.
1436///
1437/// The `block_offsets` plumbing is compatible with issue #287 (custom
1438/// `ParameterBlockSpec` families): those callers supply ranges derived from
1439/// their own block layout.
1440#[derive(Debug, Clone)]
1441pub struct JacobiPreconditioner {
1442    pub(crate) blocks: Vec<BlockFactor>,
1443}
1444
1445/// Maximum block size for which we attempt dense block-Jacobi factorization.
1446pub(crate) const BLOCK_JACOBI_MAX_BLOCK: usize = 256;
1447
1448/// Positive-definiteness floor on a Schur-complement Jacobi diagonal entry.
1449/// A diagonal at or below this value (or non-finite) signals a non-PD reduced
1450/// system: the preconditioner cannot invert it, so the PCG solve fails loudly
1451/// and demands operator regularization rather than returning a garbage scale.
1452pub(crate) const JACOBI_DIAGONAL_PD_FLOOR: f64 = 1e-18;
1453
1454impl JacobiPreconditioner {
1455    /// Build the block-Jacobi (or scalar fallback) preconditioner from the
1456    /// Arrow-Schur system without materializing the full dense Schur
1457    /// complement.
1458    ///
1459    /// When `sys.block_offsets` is non-empty and `max(block_size) ≤ 256`,
1460    /// each block gets a dense `b×b` Schur sub-matrix formed, factored, and
1461    /// stored.  Otherwise every column gets its own scalar entry.
1462    pub(crate) fn from_arrow_schur<B: BatchedBlockSolver + Sync>(
1463        sys: &ArrowSchurSystem,
1464        htt_factors: &ArrowFactorSlab,
1465        ridge_beta: f64,
1466        backend: &B,
1467        resident: Option<&SaeResidentReducedSchur>,
1468    ) -> Result<Self, ArrowSchurError> {
1469        let use_block = !sys.block_offsets.is_empty()
1470            && sys
1471                .block_offsets
1472                .iter()
1473                .map(|r| r.end.saturating_sub(r.start))
1474                .max()
1475                .unwrap_or(0)
1476                <= BLOCK_JACOBI_MAX_BLOCK;
1477        if use_block {
1478            if let Some(res) = resident {
1479                Self::build_block_jacobi_resident(sys, ridge_beta, res)
1480            } else {
1481                Self::build_block_jacobi(sys, htt_factors, ridge_beta, backend)
1482            }
1483        } else if let Some(res) = resident {
1484            // #1017 — SAE residency scalar Jacobi. The generic scalar build
1485            // probes `H_tβ^(i) e_a` and re-solves `(H_tt^(i))⁻¹` once for EVERY
1486            // (row, β-column) pair: `O(n·K)` triangular solves and `O(n·K·p)`
1487            // operator-probe work per Newton step, with `K = K_atoms·p` in the
1488            // tens of thousands at LLM shapes. The reduced-Schur diagonal is the
1489            // same quotient the resident `(L_i, Y_i)` factors already carry, so
1490            // read the diagonal straight off them in one support-sparse pass —
1491            // no probe, no per-column solve.
1492            Self::build_scalar_jacobi_resident(sys, ridge_beta, res)
1493        } else {
1494            Self::build_scalar_jacobi(sys, htt_factors, ridge_beta, backend)
1495        }
1496    }
1497
1498    /// Build scalar-diagonal Jacobi: one `BlockFactor::Scalar` of length 1
1499    /// per column.  Matches pre-#283 semantics.
1500    ///
1501    /// When `sys.htbeta_matvec` is set and per-row `htbeta` slabs are absent,
1502    /// each column is probed via the matvec (one call per column per row).
1503    pub(crate) fn build_scalar_jacobi<B: BatchedBlockSolver + Sync>(
1504        sys: &ArrowSchurSystem,
1505        htt_factors: &ArrowFactorSlab,
1506        ridge_beta: f64,
1507        backend: &B,
1508    ) -> Result<Self, ArrowSchurError> {
1509        let k = sys.k;
1510        // Extract diagonal of H_ββ via penalty_diagonal_add (#296):
1511        // no Arc-clone; falls back to hbb_diag or hbb[[a,a]] inline.
1512        let mut diag = Array1::<f64>::zeros(k);
1513        {
1514            let diag_slice = diag.as_slice_mut().expect("diag must be contiguous");
1515            sys.penalty_diagonal_add(diag_slice);
1516        }
1517        for a in 0..k {
1518            diag[a] += ridge_beta;
1519        }
1520        // Per-row body: subtract this row's `Σ_a (H_tβ^(i)e_a)ᵀ(H_tt^(i))⁻¹
1521        // (H_tβ^(i)e_a)` contribution into a caller-provided length-`K` diagonal
1522        // accumulator (`-=`). For each column `a`, probe the cross-block (or read
1523        // the dense slab) and compute the scalar point-elimination quotient. The
1524        // `O(K)` solves per row are the build's whole cost; the row contributions
1525        // are independent length-`K` vectors, so a worker sums a chunk into a
1526        // private `diag_part` and the caller folds the partials back in chunk
1527        // order — bit-identical run-to-run (the #1017 preconditioner gate).
1528        let row_into = |i: usize, row: &ArrowRowBlock, diag_part: &mut Array1<f64>| {
1529            let di = sys.row_dims[i];
1530            // Dense-slab fast path (#1017): when the per-row cross-block is a
1531            // materialized `di × k` slab (no matrix-free operator), the entire
1532            // reduced-Schur diagonal contribution for this row is
1533            // `Σ_c H_tβ[c,a] · ((H_tt)⁻¹ H_tβ)[c,a]`. The generic loop below
1534            // re-solved `(H_tt)⁻¹` once PER COLUMN — `O(k)` block solves + `O(k)`
1535            // allocations per row, i.e. `O(n·k)` tiny solves per Newton step
1536            // (the dominant fixed per-solve cost at the SAE wide-border shape,
1537            // k in the tens of thousands). Solve all `k` columns in ONE batched
1538            // block solve instead, then take the column dots. Reassociates the
1539            // diagonal within the documented #1211 preconditioner margin (same as
1540            // the resident no-probe path), and the preconditioner only steers the
1541            // PCG iterate, which still terminates at the PCG tolerance.
1542            if sys.htbeta_matvec.is_none() && row.htbeta.dim() == (di, k) {
1543                let solved = backend.solve_block_matrix(htt_factors.factor(i), row.htbeta.view());
1544                for a in 0..k {
1545                    let mut acc = 0.0;
1546                    for c in 0..di {
1547                        acc += row.htbeta[[c, a]] * solved[[c, a]];
1548                    }
1549                    diag_part[a] -= acc;
1550                }
1551                return;
1552            }
1553            // Matrix-free path: probe column a. `e_a` stays all-zero between
1554            // columns — set the single active entry and reset it after the probe,
1555            // so we never pay the `O(k)` `e_a.fill(0.0)` per column (that fill was
1556            // `O(n·k²)`). `sys_htbeta_apply_row` zeroes `col_i` internally.
1557            let mut col_i = Array1::<f64>::zeros(di);
1558            let mut e_a = Array1::<f64>::zeros(k);
1559            for a in 0..k {
1560                e_a[a] = 1.0;
1561                sys_htbeta_apply_row(sys, i, row, e_a.view(), &mut col_i);
1562                e_a[a] = 0.0;
1563                let solved = backend.solve_block_vector(htt_factors.factor(i), col_i.view());
1564                let mut acc = 0.0;
1565                for c in 0..di {
1566                    acc += col_i[c] * solved[c];
1567                }
1568                diag_part[a] -= acc;
1569            }
1570        };
1571        let n = sys.rows.len();
1572        let parallel =
1573            n >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
1574        if parallel {
1575            use rayon::prelude::*;
1576            const CHUNK: usize = 64;
1577            let partials: Vec<Array1<f64>> = (0..n)
1578                .into_par_iter()
1579                .chunks(CHUNK)
1580                .map(|idxs| {
1581                    let mut diag_part = Array1::<f64>::zeros(k);
1582                    for i in idxs {
1583                        row_into(i, &sys.rows[i], &mut diag_part);
1584                    }
1585                    diag_part
1586                })
1587                .collect();
1588            // Deterministic ordered reduction: fold chunk partials left-to-right.
1589            for part in &partials {
1590                for a in 0..k {
1591                    diag[a] += part[a];
1592                }
1593            }
1594        } else {
1595            for (i, row) in sys.rows.iter().enumerate() {
1596                row_into(i, row, &mut diag);
1597            }
1598        }
1599        let mut blocks = Vec::with_capacity(k);
1600        for a in 0..k {
1601            let v = diag[a];
1602            if !v.is_finite() || v <= JACOBI_DIAGONAL_PD_FLOOR {
1603                return Err(ArrowSchurError::PcgFailed {
1604                    reason: format!(
1605                        "invalid Schur Jacobi diagonal at index {a}: {v}; \
1606                         operator regularization is required"
1607                    ),
1608                });
1609            }
1610            blocks.push(BlockFactor::Scalar {
1611                inv: Array1::from_elem(1, 1.0 / v),
1612                range: a..a + 1,
1613            });
1614        }
1615        Ok(Self { blocks })
1616    }
1617
1618    /// Build scalar-diagonal Jacobi from the pre-staged SAE residency factors
1619    /// `(L_i, Y_i)` (#1017).
1620    ///
1621    /// The generic [`Self::build_scalar_jacobi`] forms each reduced-Schur
1622    /// diagonal entry `S_aa = H_ββ,aa + ρ − Σ_i (H_tβ^(i) e_a)ᵀ(H_tt^(i))⁻¹(H_tβ^(i) e_a)`
1623    /// by probing the cross-block operator with the unit vector `e_a` and
1624    /// re-solving `(H_tt^(i))⁻¹` for every `(row, column)` pair — `O(n·K)`
1625    /// triangular solves per Newton step. For the SAE Kronecker cross-block the
1626    /// `a`-th column lives on exactly one active support entry: `a = beta_base + j`
1627    /// for some `(beta_base, φ) ∈ a_phi[i]` and output channel `j ∈ 0..p`, with
1628    /// `H_tβ^(i) e_a = φ · L_i[:, j]`. The point-elimination quotient is then
1629    ///
1630    /// ```text
1631    /// (H_tβ^(i) e_a)ᵀ (H_tt^(i))⁻¹ (H_tβ^(i) e_a)
1632    ///     = φ² · L_i[:, j]ᵀ (H_tt^(i))⁻¹ L_i[:, j]
1633    ///     = φ² · (L_i[:, j] · Y_i[:, j]),          Y_i := (H_tt^(i))⁻¹ L_i.
1634    /// ```
1635    ///
1636    /// so the whole diagonal is accumulated in ONE support-sparse pass over the
1637    /// resident factors — no probe, no per-column solve, the staged `Y_i` reused
1638    /// from the matvec residency. The result is the SAME quotient the generic
1639    /// path computes (up to float reassociation of the row sum), so the PCG
1640    /// preconditioner is unchanged up to that f64 margin. Since the preconditioner
1641    /// only steers the iterate (which still terminates at the PCG tolerance), the
1642    /// criterion ranking is stable except for candidates within that margin,
1643    /// where the near-tie winner can flip — not an exact no-move guarantee (#1211).
1644    pub(crate) fn build_scalar_jacobi_resident(
1645        sys: &ArrowSchurSystem,
1646        ridge_beta: f64,
1647        resident: &SaeResidentReducedSchur,
1648    ) -> Result<Self, ArrowSchurError> {
1649        let k = sys.k;
1650        let p = resident.p;
1651        let n = resident.rows.len();
1652        // Seed with diag(H_ββ) + ridge — same penalty source the generic path
1653        // reads, so the only difference is how the point-elimination term is
1654        // gathered.
1655        let mut diag = Array1::<f64>::zeros(k);
1656        {
1657            let diag_slice = diag.as_slice_mut().expect("diag must be contiguous");
1658            sys.penalty_diagonal_add(diag_slice);
1659        }
1660        for a in 0..k {
1661            diag[a] += ridge_beta;
1662        }
1663        // Per-row point-elimination diagonal: for each active support entry
1664        // `(beta_base, φ)` and channel `j`, subtract `φ² · L_i[:, j]·Y_i[:, j]`
1665        // into `diag[beta_base + j]`. `L_i`/`Y_i` are row-major `di × p`, so the
1666        // `j`-th column dot is `Σ_r L_i[r·p + j]·Y_i[r·p + j]`.
1667        //
1668        // The accumulation is into a SHARED `diag` (rows scatter into overlapping
1669        // `beta_base + j` columns), so — like the generic `build_scalar_jacobi`
1670        // and the `schur_matvec` row loop (#1017) — parallelism uses worker-private
1671        // length-`K` partials folded back in chunk order: each chunk is a
1672        // contiguous ascending row range and rows within it stay ascending, so the
1673        // chunk-ordered fold reproduces the serial `row = 0..n` subtraction order
1674        // bit-for-bit run-to-run (the #1017 determinism gate). Run-to-run
1675        // bit-identity does not extend to bit-identity with the in-place serial
1676        // accumulation, so the preconditioner — and any criterion ranking it
1677        // steers — is stable only up to the chunk-reassociation margin; a near-tie
1678        // winner inside that margin can flip (#1211).
1679        // This build runs once per inexact-PCG solve = O(inner-Newton-iters)
1680        // per fit; at the SAE LLM shape (thousands of rows, wide border `k`) the
1681        // per-row support sweep is the build's whole cost and was on one core.
1682        // The per-channel column dot `col_dot[j] = Σ_r L_i[r·p+j]·Y_i[r·p+j]`
1683        // (the diagonal of `G_i = L_iᵀ(H_tt)⁻¹L_i`) depends ONLY on the row `i`,
1684        // not on the support entry `(beta_base, φ)`. The previous loop recomputed
1685        // it once per support entry — a row with `m` active atoms paid `m·p`
1686        // column dots over `di`. Hoist it: compute the `p` column dots once per
1687        // row into reusable `col_dot` scratch, then each support entry is a pure
1688        // scatter `diag[beta_base+j] -= φ²·col_dot[j]`. Bit-for-bit identical:
1689        // each `col_dot[j]` is the same `r`-ascending sum, and `φ²·col_dot[j]`
1690        // yields identical bits whether `col_dot[j]` was just computed or cached.
1691        let row_into = |row: usize, diag_part: &mut [f64], col_dot: &mut [f64]| {
1692            let rf = &resident.rows[row];
1693            let di = rf.di;
1694            if di == 0 {
1695                return;
1696            }
1697            let support = &resident.a_phi[row];
1698            if support.is_empty() {
1699                return;
1700            }
1701            // `L_i` is the shared `local_jac[row]` slab (#1033) — byte-for-byte
1702            // the former per-row `rf.l` copy.
1703            let l_i = &resident.local_jac[row];
1704            for (j, slot) in col_dot.iter_mut().enumerate().take(p) {
1705                let mut acc = 0.0_f64;
1706                for r in 0..di {
1707                    let idx = r * p + j;
1708                    acc += l_i[idx] * rf.y[idx];
1709                }
1710                *slot = acc;
1711            }
1712            for &(beta_base, phi) in support {
1713                if phi == 0.0 {
1714                    continue;
1715                }
1716                let phi2 = phi * phi;
1717                for j in 0..p {
1718                    diag_part[beta_base + j] -= phi2 * col_dot[j];
1719                }
1720            }
1721        };
1722        let parallel =
1723            n >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
1724        if parallel {
1725            use rayon::prelude::*;
1726            const CHUNK: usize = 64;
1727            let partials: Vec<Array1<f64>> = (0..n)
1728                .into_par_iter()
1729                .chunks(CHUNK)
1730                .map(|idxs| {
1731                    let mut diag_part = Array1::<f64>::zeros(k);
1732                    let mut col_dot = vec![0.0_f64; p];
1733                    let slice = diag_part
1734                        .as_slice_mut()
1735                        .expect("diag_part must be contiguous");
1736                    for i in idxs {
1737                        row_into(i, slice, &mut col_dot);
1738                    }
1739                    diag_part
1740                })
1741                .collect();
1742            // Deterministic ordered reduction: fold chunk partials left-to-right
1743            // (each partial already holds the per-row terms subtracted, so add
1744            // them into `diag` in chunk order to mirror the serial subtraction).
1745            for part in &partials {
1746                for a in 0..k {
1747                    diag[a] += part[a];
1748                }
1749            }
1750        } else {
1751            let diag_slice = diag.as_slice_mut().expect("diag must be contiguous");
1752            let mut col_dot = vec![0.0_f64; p];
1753            for row in 0..n {
1754                row_into(row, diag_slice, &mut col_dot);
1755            }
1756        }
1757        let mut blocks = Vec::with_capacity(k);
1758        for a in 0..k {
1759            let v = diag[a];
1760            if !v.is_finite() || v <= JACOBI_DIAGONAL_PD_FLOOR {
1761                return Err(ArrowSchurError::PcgFailed {
1762                    reason: format!(
1763                        "invalid SAE-resident Schur Jacobi diagonal at index {a}: {v}; \
1764                         operator regularization is required"
1765                    ),
1766                });
1767            }
1768            blocks.push(BlockFactor::Scalar {
1769                inv: Array1::from_elem(1, 1.0 / v),
1770                range: a..a + 1,
1771            });
1772        }
1773        Ok(Self { blocks })
1774    }
1775
1776    /// Build block-Jacobi from the pre-staged SAE residency factors `(L_i, Y_i)`.
1777    ///
1778    /// This is the block analogue of [`Self::build_scalar_jacobi_resident`].
1779    /// When SAE block offsets are small enough to select BetaBlockJacobi (for
1780    /// example per-atom decoder blocks with `basis_size·p <= 256`), the generic
1781    /// block builder materializes every row's dense `(d_i × K)` `H_tβ` by probing
1782    /// the matrix-free operator, then re-solves `(H_tt)⁻¹` for each block column.
1783    /// The resident factors already carry `G_i = L_iᵀ(H_tt)⁻¹L_i`, so each block
1784    /// is assembled by scattering only the active support pairs inside that block:
1785    ///
1786    /// ```text
1787    /// S_block -= Σ_i Σ_(s,t in block support) φ_s φ_t · G_i[channel_s, channel_t]
1788    /// ```
1789    ///
1790    /// It computes the same block-diagonal restriction as the generic path, but
1791    /// avoids the full-row `H_tβ` materialization and per-column triangular solves.
1792    pub(crate) fn build_block_jacobi_resident(
1793        sys: &ArrowSchurSystem,
1794        ridge_beta: f64,
1795        resident: &SaeResidentReducedSchur,
1796    ) -> Result<Self, ArrowSchurError> {
1797        let block_offsets = &sys.block_offsets;
1798        let p = resident.p;
1799        let mut schur_blocks: Vec<Array2<f64>> = Vec::with_capacity(block_offsets.len());
1800        for (block_idx, range) in block_offsets.iter().enumerate() {
1801            let b = range.end - range.start;
1802            let mut schur_block = Array2::<f64>::zeros((b, b));
1803            sys.penalty_block_add(
1804                BetaBlockId(block_idx),
1805                block_offsets.as_ref(),
1806                &mut schur_block,
1807            );
1808            for bi in 0..b {
1809                schur_block[[bi, bi]] += ridge_beta;
1810            }
1811            schur_blocks.push(schur_block);
1812        }
1813
1814        let row_into = |row: usize, blocks: &mut [Array2<f64>]| {
1815            let rf = &resident.rows[row];
1816            let di = rf.di;
1817            if di == 0 {
1818                return;
1819            }
1820            let support = &resident.a_phi[row];
1821            if support.is_empty() {
1822                return;
1823            }
1824            // `L_i` is the shared `local_jac[row]` slab (#1033) — byte-for-byte
1825            // the former per-row `rf.l` copy.
1826            let l_i = &resident.local_jac[row];
1827            for (block_idx, range) in block_offsets.iter().enumerate() {
1828                let block = &mut blocks[block_idx];
1829                for &(base_left, phi_left) in support {
1830                    if phi_left == 0.0 {
1831                        continue;
1832                    }
1833                    let left_start = base_left.max(range.start);
1834                    let left_end = (base_left + p).min(range.end);
1835                    if left_start >= left_end {
1836                        continue;
1837                    }
1838                    for &(base_right, phi_right) in support {
1839                        if phi_right == 0.0 {
1840                            continue;
1841                        }
1842                        let right_start = base_right.max(range.start);
1843                        let right_end = (base_right + p).min(range.end);
1844                        if right_start >= right_end {
1845                            continue;
1846                        }
1847                        let phi = phi_left * phi_right;
1848                        for gi in left_start..left_end {
1849                            let li = gi - range.start;
1850                            let ch_i = gi - base_left;
1851                            for gj in right_start..right_end {
1852                                let lj = gj - range.start;
1853                                let ch_j = gj - base_right;
1854                                let mut gij = 0.0_f64;
1855                                for r in 0..di {
1856                                    gij += l_i[r * p + ch_i] * rf.y[r * p + ch_j];
1857                                }
1858                                block[[li, lj]] -= phi * gij;
1859                            }
1860                        }
1861                    }
1862                }
1863            }
1864        };
1865
1866        let n = resident.rows.len();
1867        let parallel =
1868            n >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
1869        if parallel {
1870            use rayon::prelude::*;
1871            const CHUNK: usize = 64;
1872            let n_blocks = block_offsets.len();
1873            let block_dims: Vec<usize> = block_offsets.iter().map(|r| r.end - r.start).collect();
1874            let partials: Vec<Vec<Array2<f64>>> = (0..n)
1875                .into_par_iter()
1876                .chunks(CHUNK)
1877                .map(|idxs| {
1878                    let mut local: Vec<Array2<f64>> = block_dims
1879                        .iter()
1880                        .map(|&b| Array2::<f64>::zeros((b, b)))
1881                        .collect();
1882                    for i in idxs {
1883                        row_into(i, &mut local);
1884                    }
1885                    local
1886                })
1887                .collect();
1888            for local in &partials {
1889                for bidx in 0..n_blocks {
1890                    schur_blocks[bidx] += &local[bidx];
1891                }
1892            }
1893        } else {
1894            for row in 0..n {
1895                row_into(row, &mut schur_blocks);
1896            }
1897        }
1898
1899        let mut blocks = Vec::with_capacity(block_offsets.len());
1900        for (block_idx, range) in block_offsets.iter().enumerate() {
1901            let b = range.end - range.start;
1902            let schur_block = &schur_blocks[block_idx];
1903            let factor_opt = {
1904                use faer::Side;
1905                let view = FaerArrayView::new(schur_block);
1906                FaerLlt::new(view.as_ref(), Side::Lower).ok()
1907            };
1908            if let Some(llt) = factor_opt {
1909                blocks.push(BlockFactor::Chol {
1910                    factor: llt,
1911                    range: range.clone(),
1912                });
1913            } else {
1914                let mut inv = Array1::<f64>::zeros(b);
1915                for bi in 0..b {
1916                    let v = schur_block[[bi, bi]];
1917                    if !v.is_finite() || v <= JACOBI_DIAGONAL_PD_FLOOR {
1918                        return Err(ArrowSchurError::PcgFailed {
1919                            reason: format!(
1920                                "SAE-resident block Jacobi scalar fallback: non-PD diagonal at \
1921                                 global index {}: {v}; regularization required",
1922                                range.start + bi
1923                            ),
1924                        });
1925                    }
1926                    inv[bi] = 1.0 / v;
1927                }
1928                blocks.push(BlockFactor::Scalar {
1929                    inv,
1930                    range: range.clone(),
1931                });
1932            }
1933        }
1934        Ok(Self { blocks })
1935    }
1936
1937    /// Build term-block Jacobi: one dense `b×b` Schur block per term in
1938    /// `sys.block_offsets`.
1939    pub(crate) fn build_block_jacobi<B: BatchedBlockSolver + Sync>(
1940        sys: &ArrowSchurSystem,
1941        htt_factors: &ArrowFactorSlab,
1942        ridge_beta: f64,
1943        backend: &B,
1944    ) -> Result<Self, ArrowSchurError> {
1945        let block_offsets = &sys.block_offsets;
1946
1947        // Initialise every b×b Schur sub-block from H_ββ + ridge·I via
1948        // penalty_block_add (#296): routes to penalty_op or falls back to
1949        // hbb / hbb_diag inline without Arc-clone per loop iteration. These are
1950        // the block-diagonal restrictions of the reduced Schur complement; the
1951        // per-row cross-block contributions are accumulated in the row sweep
1952        // below.
1953        let mut schur_blocks: Vec<Array2<f64>> = Vec::with_capacity(block_offsets.len());
1954        for (block_idx, range) in block_offsets.iter().enumerate() {
1955            let b = range.end - range.start;
1956            let mut schur_block = Array2::<f64>::zeros((b, b));
1957            sys.penalty_block_add(
1958                BetaBlockId(block_idx),
1959                block_offsets.as_ref(),
1960                &mut schur_block,
1961            );
1962            for bi in 0..b {
1963                schur_block[[bi, bi]] += ridge_beta;
1964            }
1965            schur_blocks.push(schur_block);
1966        }
1967
1968        // Subtract Schur contributions:
1969        // S_kk -= H_βt_k^(i) (H_tt^(i))^{-1} H_tβ_k^(i)
1970        //
1971        // Materialize each row's (d_i × K) cross-block ONCE and scatter its
1972        // contribution into every block-diagonal sub-block — mirroring the
1973        // row-outer structure of `build_dense_schur_direct`. The previous
1974        // block-outer form re-materialized every row for each β-block
1975        // (O(n_blocks · n · K) probes); for the matrix-free softmax cross-block
1976        // each materialize is itself O(K²), so that nesting made the
1977        // preconditioner build quadratically more expensive than the direct
1978        // dense Schur it preconditions. sys_htbeta_materialize_row handles the
1979        // Kronecker / htbeta_matvec path transparently.
1980        // Per-row body: materialize the row's `(d_i × K)` cross-block once and
1981        // subtract its `H_βt_k^(i)(H_tt^(i))⁻¹H_tβ_k^(i)` contribution into EACH
1982        // block-diagonal sub-block. Writes INTO a caller-provided `blocks`
1983        // accumulator (`-=`) so a rayon worker can subtract a chunk's rows into
1984        // a worker-private zero-seeded `Vec<Array2>` and the caller folds the
1985        // chunk partials back in chunk order — bit-identical run-to-run
1986        // regardless of thread scheduling (the #1017 verification gate). This
1987        // is deterministic and within the chunk-reassociation margin of serial,
1988        // so the preconditioner, hence the criterion ranking, is stable except
1989        // for near-tie candidates inside that f64 margin — not an exact no-move
1990        // guarantee (#1211).
1991        let row_into = |i: usize,
1992                        row: &ArrowRowBlock,
1993                        blocks: &mut [Array2<f64>]|
1994         -> Result<(), ArrowSchurError> {
1995            let di = sys.row_dims[i];
1996            let htbeta_full = sys_htbeta_materialize_row(sys, i, row)?;
1997            for (block_idx, range) in block_offsets.iter().enumerate() {
1998                let b = range.end - range.start;
1999                let mut solved_cols = Array2::<f64>::zeros((di, b));
2000                for bj in 0..b {
2001                    let gj = range.start + bj;
2002                    let rhs = htbeta_full.column(gj).to_owned();
2003                    let solved = backend.solve_block_vector(htt_factors.factor(i), rhs.view());
2004                    for c in 0..di {
2005                        solved_cols[[c, bj]] = solved[c];
2006                    }
2007                }
2008                let schur_block = &mut blocks[block_idx];
2009                for bi in 0..b {
2010                    let gi = range.start + bi;
2011                    for bj in 0..b {
2012                        let mut acc = 0.0;
2013                        for c in 0..di {
2014                            acc += htbeta_full[[c, gi]] * solved_cols[[c, bj]];
2015                        }
2016                        schur_block[[bi, bj]] -= acc;
2017                    }
2018                }
2019            }
2020            Ok(())
2021        };
2022        // Each row materializes an `O(K²)` cross-block (Kronecker) plus `Σ_k b_k`
2023        // triangular solves — the preconditioner build's whole per-row cost at
2024        // the SAE LLM shape (#1017), and the rows are independent. Fan over fixed
2025        // row chunks above the threshold, staying serial for the handful-of-rows
2026        // non-SAE callers and inside a rayon worker (topology-race nesting guard)
2027        // — the same gate `schur_matvec` uses.
2028        let n = sys.rows.len();
2029        let parallel =
2030            n >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
2031        if parallel {
2032            use rayon::prelude::*;
2033            const CHUNK: usize = 64;
2034            let n_blocks = block_offsets.len();
2035            let block_dims: Vec<usize> = block_offsets.iter().map(|r| r.end - r.start).collect();
2036            let partials: Vec<Vec<Array2<f64>>> = (0..n)
2037                .into_par_iter()
2038                .chunks(CHUNK)
2039                .map(|idxs| {
2040                    let mut local: Vec<Array2<f64>> = block_dims
2041                        .iter()
2042                        .map(|&b| Array2::<f64>::zeros((b, b)))
2043                        .collect();
2044                    for i in idxs {
2045                        row_into(i, &sys.rows[i], &mut local)?;
2046                    }
2047                    Ok::<_, ArrowSchurError>(local)
2048                })
2049                .collect::<Result<Vec<_>, _>>()?;
2050            // Deterministic ordered reduction: fold chunk partials left-to-right.
2051            for local in &partials {
2052                for bidx in 0..n_blocks {
2053                    schur_blocks[bidx] += &local[bidx];
2054                }
2055            }
2056        } else {
2057            for (i, row) in sys.rows.iter().enumerate() {
2058                row_into(i, row, &mut schur_blocks)?;
2059            }
2060        }
2061
2062        // Factor each accumulated block: LLT, with scalar-diagonal fallback for
2063        // a block that comes out non-PD at this ridge.
2064        let mut blocks = Vec::with_capacity(block_offsets.len());
2065        for (block_idx, range) in block_offsets.iter().enumerate() {
2066            let b = range.end - range.start;
2067            let schur_block = &schur_blocks[block_idx];
2068            let factor_opt = {
2069                use faer::Side;
2070                let view = FaerArrayView::new(schur_block);
2071                FaerLlt::new(view.as_ref(), Side::Lower).ok()
2072            };
2073            if let Some(llt) = factor_opt {
2074                blocks.push(BlockFactor::Chol {
2075                    factor: llt,
2076                    range: range.clone(),
2077                });
2078            } else {
2079                // Non-PD block: fall back to scalar diagonal for this block.
2080                let mut inv = Array1::<f64>::zeros(b);
2081                for bi in 0..b {
2082                    let v = schur_block[[bi, bi]];
2083                    if !v.is_finite() || v <= JACOBI_DIAGONAL_PD_FLOOR {
2084                        return Err(ArrowSchurError::PcgFailed {
2085                            reason: format!(
2086                                "block Jacobi scalar fallback: non-PD diagonal at \
2087                                 global index {}: {v}; regularization required",
2088                                range.start + bi
2089                            ),
2090                        });
2091                    }
2092                    inv[bi] = 1.0 / v;
2093                }
2094                blocks.push(BlockFactor::Scalar {
2095                    inv,
2096                    range: range.clone(),
2097                });
2098            }
2099        }
2100        Ok(Self { blocks })
2101    }
2102
2103    pub(crate) fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
2104        let mut out = Array1::<f64>::zeros(r.len());
2105        for block in &self.blocks {
2106            match block {
2107                BlockFactor::Scalar { inv, range } => {
2108                    for (local, gi) in range.clone().enumerate() {
2109                        out[gi] = inv[local] * r[gi];
2110                    }
2111                }
2112                BlockFactor::Chol { factor, range } => {
2113                    let b = range.end - range.start;
2114                    let mut rhs = Array1::<f64>::zeros(b);
2115                    for (local, gi) in range.clone().enumerate() {
2116                        rhs[local] = r[gi];
2117                    }
2118                    use faer::linalg::solvers::Solve;
2119                    let stride = rhs.strides()[0];
2120                    let len = rhs.len();
2121                    // SAFETY: rhs is a uniquely-borrowed contiguous Array1
2122                    // with positive stride (standard layout).
2123                    let rhs_mat =
2124                        unsafe { faer::MatRef::from_raw_parts(rhs.as_ptr(), len, 1, stride, 0) };
2125                    let solved = factor.solve(rhs_mat);
2126                    for (local, gi) in range.clone().enumerate() {
2127                        out[gi] = solved[(local, 0)];
2128                    }
2129                }
2130            }
2131        }
2132        out
2133    }
2134}
2135
2136// ---------------------------------------------------------------------------
2137// Preconditioner ladder: SchurPreconditionerKind, ClusterJacobi,
2138// AdditiveSchwarz  (issue #299)
2139// ---------------------------------------------------------------------------
2140
2141/// Which Schur preconditioner to use in the inexact-PCG path.
2142///
2143/// Ladder ordered by cost / effectiveness:
2144/// - `Diagonal`: scalar Jacobi (pre-#283 behaviour).
2145/// - `BetaBlockJacobi`: block-Jacobi per `block_offsets` term (#287).
2146/// - `ClusterJacobi`: one dense block per beta-graph connected component.
2147/// - `AdditiveSchwarz { overlap }`: component + `overlap`-hop expansion,
2148///   overlapping columns averaged by partition-of-unity weights (full dense
2149///   local-inverse apply per subdomain).
2150/// - `DiagAssembledSchwarz { overlap }`: the cheap Schwarz variant (#299) —
2151///   same overlapping decomposition, but each subdomain contributes only the
2152///   diagonal of its local inverse `(A_k⁻¹)_ii`, assembled additively with
2153///   partition-of-unity weights into a single `O(K)`-apply diagonal.
2154/// - `BlockIncompleteCholesky`: level-0 incomplete Cholesky (#299). Within each
2155///   connected component of the β-coupling graph the dense reduced-Schur block
2156///   `S[C,C]` is assembled once, its structural-nonzero pattern is taken as the
2157///   level-0 fill pattern, and a no-fill incomplete Cholesky `S ≈ L̃ L̃ᵀ` is
2158///   formed keeping ONLY that pattern (Saad, *Iterative Methods*, IC(0)). Apply
2159///   is a sparse triangular forward/back solve over `nnz(S[C,C])`, so for a
2160///   large component with internal sparsity it is far cheaper to build and apply
2161///   than `ClusterJacobi`'s full dense Cholesky (which fills the whole `b×b`
2162///   factor) while retaining the inter-block coupling that ClusterJacobi keeps
2163///   but the diagonal/Schwarz tiers discard. A non-PD incomplete pivot degrades
2164///   that component to the scalar reciprocal diagonal.
2165#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2166pub enum SchurPreconditionerKind {
2167    Diagonal,
2168    BetaBlockJacobi,
2169    ClusterJacobi,
2170    AdditiveSchwarz { overlap: usize },
2171    DiagAssembledSchwarz { overlap: usize },
2172    BlockIncompleteCholesky,
2173}
2174
2175/// Escalate beyond BetaBlockJacobi only when K exceeds this value and PCG
2176/// exhausted `max_iterations`.
2177pub(crate) const PRECOND_ESCALATE_K_THRESHOLD: usize = 100;
2178
2179/// #1026 matrix-free Schur curvature-floor (the unbounded-PCG analogue of the
2180/// dense `spectral_pd_floored_schur`). On `pᵀSp ≤ 0` in the unbounded SAE inner
2181/// PCG, the operator ridge is lifted by the minimal amount that restores
2182/// positive curvature along the offending direction, plus this fractional
2183/// margin (so the next CG iterate sits strictly inside the positive cone, not on
2184/// the `0` knife-edge).
2185pub(crate) const SCHUR_CURVATURE_FLOOR_MARGIN: f64 = 1.0e-2;
2186/// Lower bound on the curvature-floor ridge bump, relative to the rhs scale, so
2187/// a `pᵀSp` that rounds to exactly `0` still gets a strictly positive bump.
2188pub(crate) const SCHUR_CURVATURE_FLOOR_REL_FLOOR: f64 = 1.0e-12;
2189/// Ceiling on the accumulated curvature-floor ridge, relative to the rhs scale.
2190/// Beyond this the operator is treated as un-conditionable by a minimal floor
2191/// and the recoverable failure is handed to the outer LM loop (which re-forms
2192/// the whole system at a heavier ridge). Generous so that a large collapsed
2193/// over-subtraction `(H_tβ)²/H_tt` is still reachable.
2194pub(crate) const SCHUR_CURVATURE_FLOOR_REL_CEILING: f64 = 1.0e12;
2195/// Multiplicative growth for the DIAGONAL-refusal ridge escalation (no
2196/// `(curvature, ‖p‖²)` deficit is available there), matching the per-row
2197/// `factor_one_row_result` `RIDGE_GROWTH_FACTOR`.
2198pub(crate) const SCHUR_CURVATURE_FLOOR_DIAG_GROWTH: f64 = 10.0;
2199/// Max curvature-floor ridge-lift attempts before deferring to the outer LM
2200/// loop. The diagonal-refusal path grows ×10 per attempt, so this bounds the
2201/// reachable ridge at `rhs_scale · 10^(attempts)` — ample for any realistic
2202/// over-subtraction while still bounded.
2203pub(crate) const SCHUR_CURVATURE_FLOOR_MAX_ATTEMPTS: usize = 24;
2204
2205/// Cholesky or scalar factor for one cluster of the beta-coefficient graph.
2206#[derive(Clone)]
2207pub(crate) enum ClusterFactor {
2208    Chol {
2209        cols: Vec<usize>,
2210        factor: FaerLlt<f64>,
2211    },
2212    Scalar {
2213        cols: Vec<usize>,
2214        inv: Vec<f64>,
2215    },
2216}
2217
2218impl std::fmt::Debug for ClusterFactor {
2219    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2220        match self {
2221            ClusterFactor::Chol { cols, .. } => {
2222                write!(f, "ClusterFactor::Chol {{ cols.len: {} }}", cols.len())
2223            }
2224            ClusterFactor::Scalar { cols, inv } => write!(
2225                f,
2226                "ClusterFactor::Scalar {{ cols.len: {}, inv.len: {} }}",
2227                cols.len(),
2228                inv.len()
2229            ),
2230        }
2231    }
2232}
2233
2234/// Maximum columns per cluster before scalar fallback.
2235pub(crate) const CLUSTER_JACOBI_MAX_CLUSTER: usize = 512;
2236
2237/// Maximum columns in a single connected component for which the IC(0)
2238/// preconditioner assembles the dense `S[C,C]` to derive its sparsity pattern.
2239/// IC(0) is cheap to APPLY at any size, but the pattern is read from the dense
2240/// assembly, which is `O(b²)` memory; beyond this the component falls back to
2241/// the scalar reciprocal diagonal (the same ceiling concern as
2242/// `CLUSTER_JACOBI_MAX_CLUSTER`, lifted because the IC(0) FACTOR is sparse).
2243pub(crate) const IC0_MAX_COMPONENT: usize = 4096;
2244
2245/// Relative threshold below which an assembled `S[i,j]` is treated as a
2246/// structural zero when deriving the IC(0) level-0 pattern. Scaled by
2247/// `sqrt(|S_ii|·|S_jj|)` so it is invariant to column scaling; this prunes
2248/// entries that are pure FMA round-off (a genuinely decoupled `(i,j)` pair
2249/// assembles to ~0) so they do not enter the kept fill pattern.
2250pub(crate) const IC0_PATTERN_REL_DROP: f64 = 1.0e-13;
2251
2252/// Assemble the dense `b×b` reduced-Schur block for the column set `cols`:
2253/// `S[cols, cols] = H_ββ[cols, cols] + ridge·I − Σ_i H_tβ[cols]ᵀ (H_tt^i)⁻¹ H_tβ[cols]`.
2254///
2255/// Shared by `ClusterJacobiPreconditioner::build_from_column_groups` (which
2256/// Cholesky-factors the returned block) and `DiagAssembledSchwarzPreconditioner`
2257/// (which inverts each subdomain block and keeps only its diagonal). The result
2258/// is the LOWER triangle filled by the row reduction; callers that need the full
2259/// symmetric block must `symmetrize_upper_from_lower`.
2260///
2261/// The per-row Schur contribution is fanned over fixed 64-row chunks above
2262/// `SCHUR_MATVEC_PARALLEL_ROW_MIN` and folded left-to-right so the assembly is
2263/// bit-identical to the serial path (and run-to-run deterministic), exactly as
2264/// in `build_block_jacobi` (#1017).
2265pub(crate) fn assemble_local_schur_block<B: BatchedBlockSolver + Sync>(
2266    sys: &ArrowSchurSystem,
2267    htt_factors: &ArrowFactorSlab,
2268    ridge_beta: f64,
2269    backend: &B,
2270    cols: &[usize],
2271) -> Array2<f64> {
2272    let d = sys.d;
2273    let b = cols.len();
2274    let mut s_block = Array2::<f64>::zeros((b, b));
2275    // Initialise from H_ββ via penalty_subblock_add (#296): routes through
2276    // penalty_op or falls back to hbb / hbb_diag inline.
2277    sys.penalty_subblock_add(cols, &mut s_block);
2278    for bi in 0..b {
2279        s_block[[bi, bi]] += ridge_beta;
2280    }
2281    let cluster_row_into = |row_idx: usize, row: &ArrowRowBlock, acc: &mut Array2<f64>| {
2282        let mut col_vec = Array1::<f64>::zeros(d);
2283        let mut solved_cols = Array2::<f64>::zeros((d, b));
2284        for bj in 0..b {
2285            let gj = cols[bj];
2286            for c in 0..d {
2287                col_vec[c] = row.htbeta[[c, gj]];
2288            }
2289            let solved = backend.solve_block_vector(htt_factors.factor(row_idx), col_vec.view());
2290            for c in 0..d {
2291                solved_cols[[c, bj]] = solved[c];
2292            }
2293        }
2294        for bi in 0..b {
2295            let gi = cols[bi];
2296            for bj in 0..b {
2297                let mut dot = 0.0;
2298                for c in 0..d {
2299                    dot += row.htbeta[[c, gi]] * solved_cols[[c, bj]];
2300                }
2301                acc[[bi, bj]] -= dot;
2302            }
2303        }
2304    };
2305    let n = sys.rows.len();
2306    let parallel = n >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
2307    if parallel {
2308        use rayon::prelude::*;
2309        const CHUNK: usize = 64;
2310        let partials: Vec<Array2<f64>> = (0..n)
2311            .into_par_iter()
2312            .chunks(CHUNK)
2313            .map(|idxs| {
2314                let mut local = Array2::<f64>::zeros((b, b));
2315                for i in idxs {
2316                    cluster_row_into(i, &sys.rows[i], &mut local);
2317                }
2318                local
2319            })
2320            .collect();
2321        for local in &partials {
2322            s_block += local;
2323        }
2324    } else {
2325        for (row_idx, row) in sys.rows.iter().enumerate() {
2326            cluster_row_into(row_idx, row, &mut s_block);
2327        }
2328    }
2329    s_block
2330}
2331
2332/// Dense Schur block per connected component of the beta-coupling graph.
2333///
2334/// Nodes = beta blocks (`block_offsets`); edges = rows where two blocks
2335/// co-occur with nonzero `H_t_beta` entries. One Cholesky factor per
2336/// connected component; applied as a triangular solve.
2337#[derive(Debug, Clone)]
2338pub struct ClusterJacobiPreconditioner {
2339    pub(crate) clusters: Vec<ClusterFactor>,
2340}
2341
2342impl ClusterJacobiPreconditioner {
2343    pub fn from_arrow_schur<B: BatchedBlockSolver + Sync>(
2344        sys: &ArrowSchurSystem,
2345        htt_factors: &ArrowFactorSlab,
2346        ridge_beta: f64,
2347        backend: &B,
2348    ) -> Result<Self, ArrowSchurError> {
2349        if sys.block_offsets.is_empty() {
2350            let cols: Vec<usize> = (0..sys.k).collect();
2351            return Self::build_from_column_groups(sys, htt_factors, ridge_beta, backend, &[cols]);
2352        }
2353        let graph = BetaCouplingGraph::build(
2354            &sys.block_offsets,
2355            &sys.rows
2356                .iter()
2357                .map(|r| r.htbeta.clone())
2358                .collect::<Vec<_>>(),
2359        );
2360        let col_groups: Vec<Vec<usize>> = graph
2361            .component_partition()
2362            .iter()
2363            .map(|comp_blocks| {
2364                let mut cols: Vec<usize> = comp_blocks
2365                    .iter()
2366                    .flat_map(|&b| sys.block_offsets[b].clone())
2367                    .collect();
2368                cols.sort_unstable();
2369                cols
2370            })
2371            .collect();
2372        Self::build_from_column_groups(sys, htt_factors, ridge_beta, backend, &col_groups)
2373    }
2374
2375    pub(crate) fn build_from_column_groups<B: BatchedBlockSolver + Sync>(
2376        sys: &ArrowSchurSystem,
2377        htt_factors: &ArrowFactorSlab,
2378        ridge_beta: f64,
2379        backend: &B,
2380        col_groups: &[Vec<usize>],
2381    ) -> Result<Self, ArrowSchurError> {
2382        let mut clusters = Vec::with_capacity(col_groups.len());
2383        for cols in col_groups {
2384            let b = cols.len();
2385            if b == 0 {
2386                continue;
2387            }
2388            if b > CLUSTER_JACOBI_MAX_CLUSTER {
2389                let inv = build_schur_scalar_inv(sys, htt_factors, ridge_beta, backend, cols)?;
2390                clusters.push(ClusterFactor::Scalar {
2391                    cols: cols.clone(),
2392                    inv,
2393                });
2394                continue;
2395            }
2396            let mut s_block =
2397                assemble_local_schur_block(sys, htt_factors, ridge_beta, backend, cols);
2398            symmetrize_upper_from_lower(&mut s_block);
2399            let factor_opt = {
2400                use faer::Side;
2401                let view = FaerArrayView::new(&s_block);
2402                FaerLlt::new(view.as_ref(), Side::Lower).ok()
2403            };
2404            if let Some(llt) = factor_opt {
2405                clusters.push(ClusterFactor::Chol {
2406                    cols: cols.clone(),
2407                    factor: llt,
2408                });
2409            } else {
2410                let inv = build_schur_scalar_inv(sys, htt_factors, ridge_beta, backend, cols)?;
2411                clusters.push(ClusterFactor::Scalar {
2412                    cols: cols.clone(),
2413                    inv,
2414                });
2415            }
2416        }
2417        Ok(Self { clusters })
2418    }
2419
2420    pub(crate) fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
2421        let mut out = Array1::<f64>::zeros(r.len());
2422        for cluster in &self.clusters {
2423            apply_cluster(cluster, r, &mut out, &ClusterApplyMode::Overwrite);
2424        }
2425        out
2426    }
2427}
2428
2429/// Additive Schwarz: base components expanded by `overlap` graph-hops;
2430/// overlapping columns averaged by partition-of-unity weights.
2431#[derive(Debug, Clone)]
2432pub struct AdditiveSchwarzPreconditioner {
2433    pub(crate) clusters: Vec<ClusterFactor>,
2434    pub(crate) weights: Vec<f64>,
2435}
2436
2437impl AdditiveSchwarzPreconditioner {
2438    pub fn from_arrow_schur<B: BatchedBlockSolver + Sync>(
2439        sys: &ArrowSchurSystem,
2440        htt_factors: &ArrowFactorSlab,
2441        ridge_beta: f64,
2442        backend: &B,
2443        overlap: usize,
2444    ) -> Result<Self, ArrowSchurError> {
2445        if sys.block_offsets.is_empty() {
2446            let cols: Vec<usize> = (0..sys.k).collect();
2447            let inner = ClusterJacobiPreconditioner::build_from_column_groups(
2448                sys,
2449                htt_factors,
2450                ridge_beta,
2451                backend,
2452                &[cols],
2453            )?;
2454            return Ok(Self {
2455                clusters: inner.clusters,
2456                weights: vec![1.0f64; sys.k],
2457            });
2458        }
2459        let graph = BetaCouplingGraph::build(
2460            &sys.block_offsets,
2461            &sys.rows
2462                .iter()
2463                .map(|r| r.htbeta.clone())
2464                .collect::<Vec<_>>(),
2465        );
2466        let col_groups: Vec<Vec<usize>> = graph
2467            .component_partition()
2468            .iter()
2469            .map(|seed| {
2470                let mut current = seed.clone();
2471                for _ in 0..overlap {
2472                    current = graph.expand_one_hop(&current);
2473                }
2474                let mut cols: Vec<usize> = current
2475                    .iter()
2476                    .flat_map(|&b| sys.block_offsets[b].clone())
2477                    .collect();
2478                cols.sort_unstable();
2479                cols.dedup();
2480                cols
2481            })
2482            .collect();
2483        let mut counts = vec![0u32; sys.k];
2484        for cols in &col_groups {
2485            for &gi in cols {
2486                counts[gi] += 1;
2487            }
2488        }
2489        let weights: Vec<f64> = counts
2490            .iter()
2491            .map(|&c| if c == 0 { 1.0 } else { 1.0 / c as f64 })
2492            .collect();
2493        let inner = ClusterJacobiPreconditioner::build_from_column_groups(
2494            sys,
2495            htt_factors,
2496            ridge_beta,
2497            backend,
2498            &col_groups,
2499        )?;
2500        Ok(Self {
2501            clusters: inner.clusters,
2502            weights,
2503        })
2504    }
2505
2506    pub(crate) fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
2507        let mut out = Array1::<f64>::zeros(r.len());
2508        for cluster in &self.clusters {
2509            apply_cluster(
2510                cluster,
2511                r,
2512                &mut out,
2513                &ClusterApplyMode::Accumulate {
2514                    weights: &self.weights,
2515                },
2516            );
2517        }
2518        out
2519    }
2520}
2521
2522/// Diagonal-assembled additive Schwarz (#299).
2523///
2524/// The cheap Schwarz variant the domain-decomposition literature recommends as
2525/// the default for sparse-coupling β-graphs: instead of storing and applying a
2526/// dense Cholesky factor per overlapping subdomain (as
2527/// [`AdditiveSchwarzPreconditioner`] does), it inverts each overlapping
2528/// subdomain Schur block ONCE at build time and keeps only the **diagonal of the
2529/// local inverse** `(A_k⁻¹)_ii`. Those per-subdomain diagonal contributions are
2530/// then assembled additively across overlapping subdomains with partition-of-
2531/// unity weights into a single global diagonal `m`, applied as `out[i] = m[i]·r[i]`.
2532///
2533/// This is strictly richer than scalar Jacobi (`1/S_ii`): the local inverse
2534/// diagonal `(A_k⁻¹)_ii` folds in the off-diagonal coupling WITHIN the subdomain,
2535/// so a strongly-coupled column gets a smaller (better-damped) effective scale
2536/// than its bare reciprocal diagonal would give — while the apply stays `O(K)`
2537/// (one multiply per column), unlike the `O(Σ b_k²)` triangular solves of dense
2538/// Schwarz. For `overlap = 0` and one column per subdomain it reduces exactly to
2539/// scalar Jacobi.
2540#[derive(Debug, Clone)]
2541pub struct DiagAssembledSchwarzPreconditioner {
2542    /// Global per-column multiplier `m[i]`; `out[i] = m[i] · r[i]`.
2543    pub(crate) inv_diag: Vec<f64>,
2544}
2545
2546impl DiagAssembledSchwarzPreconditioner {
2547    pub fn from_arrow_schur<B: BatchedBlockSolver + Sync>(
2548        sys: &ArrowSchurSystem,
2549        htt_factors: &ArrowFactorSlab,
2550        ridge_beta: f64,
2551        backend: &B,
2552        overlap: usize,
2553    ) -> Result<Self, ArrowSchurError> {
2554        // Build the overlapping subdomain column groups exactly like
2555        // AdditiveSchwarz (component partition + `overlap` graph-hop expansion),
2556        // so the two Schwarz variants decompose the β space identically and
2557        // differ only in how each subdomain's local inverse is applied.
2558        let col_groups: Vec<Vec<usize>> = if sys.block_offsets.is_empty() {
2559            vec![(0..sys.k).collect()]
2560        } else {
2561            let graph = BetaCouplingGraph::build(
2562                &sys.block_offsets,
2563                &sys.rows
2564                    .iter()
2565                    .map(|r| r.htbeta.clone())
2566                    .collect::<Vec<_>>(),
2567            );
2568            graph
2569                .component_partition()
2570                .iter()
2571                .map(|seed| {
2572                    let mut current = seed.clone();
2573                    for _ in 0..overlap {
2574                        current = graph.expand_one_hop(&current);
2575                    }
2576                    let mut cols: Vec<usize> = current
2577                        .iter()
2578                        .flat_map(|&b| sys.block_offsets[b].clone())
2579                        .collect();
2580                    cols.sort_unstable();
2581                    cols.dedup();
2582                    cols
2583                })
2584                .collect()
2585        };
2586        Self::build_from_column_groups(sys, htt_factors, ridge_beta, backend, &col_groups)
2587    }
2588
2589    pub(crate) fn build_from_column_groups<B: BatchedBlockSolver + Sync>(
2590        sys: &ArrowSchurSystem,
2591        htt_factors: &ArrowFactorSlab,
2592        ridge_beta: f64,
2593        backend: &B,
2594        col_groups: &[Vec<usize>],
2595    ) -> Result<Self, ArrowSchurError> {
2596        // Partition-of-unity weights: a column shared by `c` subdomains gets each
2597        // of its `c` diagonal contributions scaled by `1/c`, so the assembled
2598        // diagonal is a convex combination (and reduces to a single contribution
2599        // for non-overlapping columns).
2600        let mut counts = vec![0u32; sys.k];
2601        for cols in col_groups {
2602            for &gi in cols {
2603                counts[gi] += 1;
2604            }
2605        }
2606        let mut accum = vec![0.0f64; sys.k];
2607        for cols in col_groups {
2608            let b = cols.len();
2609            if b == 0 {
2610                continue;
2611            }
2612            // For large subdomains, the dense inverse is too costly; fall back to
2613            // the global scalar Schur diagonal inverse `1/S_ii` for those columns
2614            // (the diag-assembled variant then coincides with scalar Jacobi over
2615            // that subdomain, which is exactly the intended cheap degradation).
2616            if b > CLUSTER_JACOBI_MAX_CLUSTER {
2617                let inv = build_schur_scalar_inv(sys, htt_factors, ridge_beta, backend, cols)?;
2618                for (local, &gi) in cols.iter().enumerate() {
2619                    let w = if counts[gi] == 0 {
2620                        1.0
2621                    } else {
2622                        1.0 / counts[gi] as f64
2623                    };
2624                    accum[gi] += w * inv[local];
2625                }
2626                continue;
2627            }
2628            let mut s_block =
2629                assemble_local_schur_block(sys, htt_factors, ridge_beta, backend, cols);
2630            symmetrize_upper_from_lower(&mut s_block);
2631            // Diagonal of the local inverse `(A_k⁻¹)_ii`, obtained by solving
2632            // `A_k X = I` through the same faer Cholesky used elsewhere; on a
2633            // non-PD local block, degrade to the scalar reciprocal diagonal.
2634            let local_inv_diag = match local_inverse_diagonal(&s_block) {
2635                Some(diag) => diag,
2636                None => {
2637                    let inv = build_schur_scalar_inv(sys, htt_factors, ridge_beta, backend, cols)?;
2638                    inv
2639                }
2640            };
2641            for (local, &gi) in cols.iter().enumerate() {
2642                let w = if counts[gi] == 0 {
2643                    1.0
2644                } else {
2645                    1.0 / counts[gi] as f64
2646                };
2647                accum[gi] += w * local_inv_diag[local];
2648            }
2649        }
2650        // A column never covered by any subdomain (only possible for `k` columns
2651        // with no block_offsets coverage) keeps a neutral unit scale.
2652        for (gi, &c) in counts.iter().enumerate() {
2653            if c == 0 {
2654                accum[gi] = 1.0;
2655            }
2656        }
2657        for (gi, m) in accum.iter().enumerate() {
2658            if !m.is_finite() || *m <= 0.0 {
2659                return Err(ArrowSchurError::PcgFailed {
2660                    reason: format!(
2661                        "diag-assembled Schwarz: non-positive assembled diagonal at index {gi}: {m}"
2662                    ),
2663                });
2664            }
2665        }
2666        Ok(Self { inv_diag: accum })
2667    }
2668
2669    pub(crate) fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
2670        let mut out = Array1::<f64>::zeros(r.len());
2671        for (gi, &m) in self.inv_diag.iter().enumerate() {
2672            out[gi] = m * r[gi];
2673        }
2674        out
2675    }
2676}
2677
2678/// Diagonal of `A⁻¹` for a small dense SPD block `A`, via the same faer
2679/// Cholesky used by the cluster/Schwarz factors. Returns `None` if `A` is not
2680/// positive-definite (caller degrades to the scalar reciprocal diagonal).
2681pub(crate) fn local_inverse_diagonal(a: &Array2<f64>) -> Option<Vec<f64>> {
2682    let b = a.nrows();
2683    let llt = {
2684        use faer::Side;
2685        let view = FaerArrayView::new(a);
2686        FaerLlt::new(view.as_ref(), Side::Lower).ok()?
2687    };
2688    use faer::linalg::solvers::Solve;
2689    let mut diag = Vec::with_capacity(b);
2690    for col in 0..b {
2691        // Solve `A x = e_col`; the `col`-th entry of `x` is `(A⁻¹)_{col,col}`.
2692        let mut rhs = Array1::<f64>::zeros(b);
2693        rhs[col] = 1.0;
2694        let stride = rhs.strides()[0];
2695        let len = rhs.len();
2696        // SAFETY: `rhs` is a uniquely-borrowed contiguous `Array1<f64>` of `len`
2697        // elements with positive row stride; a single column never dereferences
2698        // the column stride, so `0` is sound.
2699        let rhs_mat = unsafe { faer::MatRef::from_raw_parts(rhs.as_ptr(), len, 1, stride, 0) };
2700        let solved = llt.solve(rhs_mat);
2701        diag.push(solved[(col, 0)]);
2702    }
2703    Some(diag)
2704}
2705
2706/// How a cluster factor's contribution is written into the output vector.
2707///
2708/// `Overwrite` assigns `out[gi] = value` (non-overlapping clusters, each global
2709/// column touched by exactly one cluster). `Accumulate` adds the partition-of-unity
2710/// weighted contribution `out[gi] += weights[gi] * value` (overlapping Schwarz
2711/// clusters, where a column may belong to several clusters).
2712pub(crate) enum ClusterApplyMode<'w> {
2713    Overwrite,
2714    Accumulate { weights: &'w [f64] },
2715}
2716
2717impl ClusterApplyMode<'_> {
2718    #[inline]
2719    pub(crate) fn write(&self, out: &mut Array1<f64>, gi: usize, value: f64) {
2720        match self {
2721            ClusterApplyMode::Overwrite => out[gi] = value,
2722            ClusterApplyMode::Accumulate { weights } => out[gi] += weights[gi] * value,
2723        }
2724    }
2725}
2726
2727/// Apply a single cluster factor to the residual `r`, writing into `out`
2728/// according to `mode` (overwrite for non-overlapping clusters, weighted
2729/// accumulate for overlapping Schwarz clusters).
2730pub(crate) fn apply_cluster(
2731    cluster: &ClusterFactor,
2732    r: &Array1<f64>,
2733    out: &mut Array1<f64>,
2734    mode: &ClusterApplyMode<'_>,
2735) {
2736    match cluster {
2737        ClusterFactor::Scalar { cols, inv } => {
2738            for (local, &gi) in cols.iter().enumerate() {
2739                mode.write(out, gi, inv[local] * r[gi]);
2740            }
2741        }
2742        ClusterFactor::Chol { cols, factor } => {
2743            let b = cols.len();
2744            let mut rhs = Array1::<f64>::zeros(b);
2745            for (local, &gi) in cols.iter().enumerate() {
2746                rhs[local] = r[gi];
2747            }
2748            use faer::linalg::solvers::Solve;
2749            let stride = rhs.strides()[0];
2750            let len = rhs.len();
2751            // SAFETY: rhs is uniquely-borrowed contiguous Array1 with positive stride.
2752            let rhs_mat = unsafe { faer::MatRef::from_raw_parts(rhs.as_ptr(), len, 1, stride, 0) };
2753            let solved = factor.solve(rhs_mat);
2754            for (local, &gi) in cols.iter().enumerate() {
2755                mode.write(out, gi, solved[(local, 0)]);
2756            }
2757        }
2758    }
2759}
2760
2761/// One connected-component factor of the block IC(0) preconditioner.
2762///
2763/// `IncompleteChol` holds a sparse lower-triangular `L̃` in column-compressed
2764/// form over the component's local indices: `col_ptr[j]..col_ptr[j+1]` indexes
2765/// into `(row_idx, val)` for column `j` (rows `>= j`, diagonal first). `cols`
2766/// maps a local index back to its global β column. `Scalar` is the non-PD /
2767/// oversized degradation, identical in meaning to [`ClusterFactor::Scalar`].
2768#[derive(Clone)]
2769pub(crate) enum Ic0Factor {
2770    IncompleteChol {
2771        cols: Vec<usize>,
2772        col_ptr: Vec<usize>,
2773        row_idx: Vec<usize>,
2774        val: Vec<f64>,
2775    },
2776    Scalar {
2777        cols: Vec<usize>,
2778        inv: Vec<f64>,
2779    },
2780}
2781
2782impl std::fmt::Debug for Ic0Factor {
2783    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2784        match self {
2785            Ic0Factor::IncompleteChol { cols, val, .. } => write!(
2786                f,
2787                "Ic0Factor::IncompleteChol {{ cols.len: {}, nnz: {} }}",
2788                cols.len(),
2789                val.len()
2790            ),
2791            Ic0Factor::Scalar { cols, .. } => {
2792                write!(f, "Ic0Factor::Scalar {{ cols.len: {} }}", cols.len())
2793            }
2794        }
2795    }
2796}
2797
2798/// Level-0 incomplete-Cholesky Schur preconditioner (#299).
2799///
2800/// One sparse incomplete-Cholesky factor per connected component of the
2801/// β-coupling graph. Within a component the dense `S[C,C]` is assembled, its
2802/// structural-nonzero pattern `P = { (i,j) : |S_ij| > drop·sqrt(S_ii S_jj) }`
2803/// is taken as the level-0 fill set, and the no-fill incomplete Cholesky
2804/// `S ≈ L̃ L̃ᵀ` is formed keeping only `P` (drop any update landing outside it).
2805/// See [`SchurPreconditionerKind::BlockIncompleteCholesky`].
2806#[derive(Debug, Clone)]
2807pub struct BlockIncompleteCholeskyPreconditioner {
2808    pub(crate) components: Vec<Ic0Factor>,
2809}
2810
2811impl BlockIncompleteCholeskyPreconditioner {
2812    pub fn from_arrow_schur<B: BatchedBlockSolver + Sync>(
2813        sys: &ArrowSchurSystem,
2814        htt_factors: &ArrowFactorSlab,
2815        ridge_beta: f64,
2816        backend: &B,
2817    ) -> Result<Self, ArrowSchurError> {
2818        // Column grouping mirrors ClusterJacobi: one group per connected
2819        // component of the β-coupling graph (whole-K single group when no
2820        // block_offsets are registered), so IC(0) preconditions exactly the
2821        // coupling ClusterJacobi keeps, but with a sparse (no-fill) factor.
2822        let col_groups: Vec<Vec<usize>> = if sys.block_offsets.is_empty() {
2823            vec![(0..sys.k).collect()]
2824        } else {
2825            let graph = BetaCouplingGraph::build(
2826                &sys.block_offsets,
2827                &sys.rows
2828                    .iter()
2829                    .map(|r| r.htbeta.clone())
2830                    .collect::<Vec<_>>(),
2831            );
2832            graph
2833                .component_partition()
2834                .iter()
2835                .map(|comp| {
2836                    let mut cols: Vec<usize> = comp
2837                        .iter()
2838                        .flat_map(|&blk| sys.block_offsets[blk].clone())
2839                        .collect();
2840                    cols.sort_unstable();
2841                    cols.dedup();
2842                    cols
2843                })
2844                .collect()
2845        };
2846
2847        let mut components = Vec::with_capacity(col_groups.len());
2848        for cols in &col_groups {
2849            let b = cols.len();
2850            if b == 0 {
2851                continue;
2852            }
2853            if b > IC0_MAX_COMPONENT {
2854                let inv = build_schur_scalar_inv(sys, htt_factors, ridge_beta, backend, cols)?;
2855                components.push(Ic0Factor::Scalar {
2856                    cols: cols.clone(),
2857                    inv,
2858                });
2859                continue;
2860            }
2861            let mut s_block =
2862                assemble_local_schur_block(sys, htt_factors, ridge_beta, backend, cols);
2863            symmetrize_upper_from_lower(&mut s_block);
2864            match incomplete_cholesky_level0(&s_block) {
2865                Some((col_ptr, row_idx, val)) => components.push(Ic0Factor::IncompleteChol {
2866                    cols: cols.clone(),
2867                    col_ptr,
2868                    row_idx,
2869                    val,
2870                }),
2871                None => {
2872                    // Non-PD incomplete pivot: degrade this component to the
2873                    // scalar reciprocal diagonal (mirrors the ClusterJacobi
2874                    // non-PD fallback), which is always applicable for a
2875                    // PD-floored Schur diagonal.
2876                    let inv = build_schur_scalar_inv(sys, htt_factors, ridge_beta, backend, cols)?;
2877                    components.push(Ic0Factor::Scalar {
2878                        cols: cols.clone(),
2879                        inv,
2880                    });
2881                }
2882            }
2883        }
2884        Ok(Self { components })
2885    }
2886
2887    pub(crate) fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
2888        let mut out = Array1::<f64>::zeros(r.len());
2889        for comp in &self.components {
2890            match comp {
2891                Ic0Factor::Scalar { cols, inv } => {
2892                    for (local, &gi) in cols.iter().enumerate() {
2893                        out[gi] = inv[local] * r[gi];
2894                    }
2895                }
2896                Ic0Factor::IncompleteChol {
2897                    cols,
2898                    col_ptr,
2899                    row_idx,
2900                    val,
2901                } => {
2902                    let b = cols.len();
2903                    // Gather the local residual, solve `L̃ L̃ᵀ z = r_local` by a
2904                    // sparse forward solve (`L̃ y = r`) then a sparse back solve
2905                    // (`L̃ᵀ z = y`), then scatter `z` back to global columns.
2906                    let mut z = vec![0.0f64; b];
2907                    for (local, &gi) in cols.iter().enumerate() {
2908                        z[local] = r[gi];
2909                    }
2910                    // Forward solve `L̃ y = r` (overwrite z with y). Column-major
2911                    // CSC: row_idx[col_ptr[j]] == j (diagonal stored first).
2912                    for j in 0..b {
2913                        let dstart = col_ptr[j];
2914                        let diag = val[dstart];
2915                        z[j] /= diag;
2916                        let yj = z[j];
2917                        for k in (dstart + 1)..col_ptr[j + 1] {
2918                            z[row_idx[k]] -= val[k] * yj;
2919                        }
2920                    }
2921                    // Back solve `L̃ᵀ z = y` (overwrite z). Walk columns in
2922                    // reverse; the below-diagonal entries of column j are the
2923                    // off-diagonal entries of row j of L̃ᵀ.
2924                    for j in (0..b).rev() {
2925                        let dstart = col_ptr[j];
2926                        let mut acc = z[j];
2927                        for k in (dstart + 1)..col_ptr[j + 1] {
2928                            acc -= val[k] * z[row_idx[k]];
2929                        }
2930                        z[j] = acc / val[dstart];
2931                    }
2932                    for (local, &gi) in cols.iter().enumerate() {
2933                        out[gi] = z[local];
2934                    }
2935                }
2936            }
2937        }
2938        out
2939    }
2940}
2941
2942/// Level-0 incomplete Cholesky of a dense SPD-ish block `a` (`b×b`, symmetric).
2943///
2944/// Returns the lower factor `L̃` in column-compressed (CSC) form
2945/// `(col_ptr, row_idx, val)` where each column lists its diagonal entry FIRST
2946/// followed by the strictly-below-diagonal entries, in increasing row order.
2947/// The kept pattern is the level-0 set `P` = structural nonzeros of `a` (a
2948/// relative drop threshold prunes round-off). IC(0) computes the standard
2949/// Cholesky recurrence but DROPS any value at a position outside `P`, so the
2950/// factor has exactly `nnz(tril(P))` entries — no fill. Returns `None` on a
2951/// non-positive pivot (caller degrades to scalar diagonal).
2952///
2953/// Reference: Y. Saad, *Iterative Methods for Sparse Linear Systems*, 2nd ed.,
2954/// §10.3.2 (IC(0)). This is the left-looking, pattern-restricted variant.
2955pub(crate) fn incomplete_cholesky_level0(
2956    a: &Array2<f64>,
2957) -> Option<(Vec<usize>, Vec<usize>, Vec<f64>)> {
2958    let b = a.nrows();
2959    assert_eq!(a.ncols(), b, "incomplete Cholesky needs a square block");
2960
2961    // ---- derive the level-0 lower-triangular pattern from `a` --------------
2962    // Per column j, the kept below-or-on-diagonal rows i>=j with a structurally
2963    // nonzero a[i,j]. The diagonal is always kept.
2964    let mut col_ptr = vec![0usize; b + 1];
2965    let mut row_idx: Vec<usize> = Vec::new();
2966    // value buffer, parallel to row_idx, initialised from tril(a) on the pattern
2967    let mut val: Vec<f64> = Vec::new();
2968    // For O(1) "is (i,j) in pattern + where" lookups during the recurrence, keep
2969    // a per-column map from global row -> position in that column's value slice.
2970    let mut col_pos: Vec<std::collections::HashMap<usize, usize>> = Vec::with_capacity(b);
2971    for j in 0..b {
2972        let ajj = a[[j, j]];
2973        let scale_j = ajj.abs().max(0.0).sqrt();
2974        let mut map = std::collections::HashMap::new();
2975        // diagonal first
2976        map.insert(j, val.len());
2977        row_idx.push(j);
2978        val.push(ajj);
2979        for i in (j + 1)..b {
2980            let aij = a[[i, j]];
2981            let scale_i = a[[i, i]].abs().sqrt();
2982            let thresh = IC0_PATTERN_REL_DROP * scale_i * scale_j;
2983            if aij.abs() > thresh {
2984                map.insert(i, val.len());
2985                row_idx.push(i);
2986                val.push(aij);
2987            }
2988        }
2989        col_pos.push(map);
2990        col_ptr[j + 1] = val.len();
2991    }
2992
2993    // ---- IC(0) recurrence, left-looking over columns -----------------------
2994    // For column j: subtract the contributions of all prior columns k<j that
2995    // have BOTH a nonzero at row j (so they touch the diagonal/the column) — the
2996    // multiplier L[j,k] — and a nonzero at the rows i of column j's pattern.
2997    // Any update whose target (i,j) is OUTSIDE the kept pattern is dropped.
2998    for j in 0..b {
2999        // Diagonal: a[j,j] - Σ_{k<j} L[j,k]². Each prior column k<j contributes
3000        // its row-j entry L[j,k] (looked up by row, so the column index is not
3001        // needed); columns without a row-j entry contribute nothing.
3002        let dpos = col_ptr[j];
3003        let mut diag = val[dpos];
3004        for mapk in &col_pos[..j] {
3005            if let Some(&pjk) = mapk.get(&j) {
3006                let ljk = val[pjk];
3007                diag -= ljk * ljk;
3008            }
3009        }
3010        if !diag.is_finite() || diag <= JACOBI_DIAGONAL_PD_FLOOR {
3011            return None;
3012        }
3013        let ljj = diag.sqrt();
3014        val[dpos] = ljj;
3015        // Below-diagonal of column j: L[i,j] = (a[i,j] - Σ_{k<j} L[i,k] L[j,k]) / L[j,j]
3016        for p in (dpos + 1)..col_ptr[j + 1] {
3017            let i = row_idx[p];
3018            let mut s = val[p];
3019            for mapk in &col_pos[..j] {
3020                if let (Some(&pik), Some(&pjk)) = (mapk.get(&i), mapk.get(&j)) {
3021                    s -= val[pik] * val[pjk];
3022                }
3023            }
3024            val[p] = s / ljj;
3025        }
3026    }
3027    Some((col_ptr, row_idx, val))
3028}
3029
3030/// One row of the #299 preconditioner-ladder iteration study: the converged
3031/// PCG iteration count and stop reason for a single preconditioner tier.
3032#[derive(Debug, Clone, Copy)]
3033pub struct PrecondLadderRow {
3034    /// PCG iterations to convergence (or to the `MaxIter` cutoff).
3035    pub iterations: usize,
3036    /// Whether the PCG converged (vs hit `MaxIter` / negative curvature).
3037    pub converged: bool,
3038    /// Final relative residual reported by the PCG.
3039    pub final_relative_residual: f64,
3040}
3041
3042/// Full #299 ladder iteration study on one reduced-Schur system: run the SAME
3043/// preconditioned CG (same `rhs`, tolerances, trust radius) once per ladder tier
3044/// and report the iteration count of each. This is the public seam the
3045/// `tests/owed_299.rs` iteration-reduction gate drives — it keeps the internal
3046/// `run_pcg_with_preconditioner` / preconditioner constructors `pub(crate)`
3047/// while exposing exactly the per-tier measurement the issue asks for.
3048///
3049/// Tiers (in escalation order): scalar `Diagonal`, `BetaBlockJacobi`,
3050/// `ClusterJacobi`, `AdditiveSchwarz{overlap:1}`, `DiagAssembledSchwarz{1}`, and
3051/// `BlockIncompleteCholesky`. A tier whose build fails (e.g. non-PD reduced
3052/// Schur with no curvature floor) reports `None` for that entry; every healthy
3053/// SPD reduced system populates all six.
3054pub fn arrow_precond_ladder_iteration_study(
3055    sys: &ArrowSchurSystem,
3056    ridge_beta: f64,
3057    rhs: &Array1<f64>,
3058    pcg: &ArrowPcgOptions,
3059    trust: &ArrowTrustRegionOptions,
3060) -> Result<Vec<(SchurPreconditionerKind, Option<PrecondLadderRow>)>, ArrowSchurError> {
3061    let backend = CpuBatchedBlockSolver;
3062    let htt_factors = backend.factor_blocks(&sys.rows, 0.0, sys.d, false)?;
3063
3064    let run = |apply: &dyn Fn(&Array1<f64>) -> Array1<f64>| -> Option<PrecondLadderRow> {
3065        let (_sol, diag) = run_pcg_with_preconditioner(
3066            sys,
3067            &htt_factors,
3068            ridge_beta,
3069            rhs,
3070            |r| apply(r),
3071            pcg,
3072            trust,
3073            &backend,
3074            None,
3075            None,
3076            None,
3077        )
3078        .ok()?;
3079        Some(PrecondLadderRow {
3080            iterations: diag.iterations,
3081            converged: matches!(diag.stopping_reason, PcgStopReason::Converged),
3082            final_relative_residual: diag.final_relative_residual,
3083        })
3084    };
3085
3086    let mut out: Vec<(SchurPreconditionerKind, Option<PrecondLadderRow>)> = Vec::with_capacity(6);
3087
3088    // Scalar Diagonal Jacobi: force the scalar path by clearing block_offsets on
3089    // a clone so the build does not pick up the per-block dense Schur blocks.
3090    let diag_row = {
3091        let mut bare = sys.clone();
3092        bare.set_block_offsets(std::sync::Arc::from([] as [Range<usize>; 0]));
3093        let bare_factors = backend.factor_blocks(&bare.rows, 0.0, bare.d, false)?;
3094        JacobiPreconditioner::from_arrow_schur(&bare, &bare_factors, ridge_beta, &backend, None)
3095            .ok()
3096            .and_then(|p| {
3097                run_pcg_with_preconditioner(
3098                    &bare,
3099                    &bare_factors,
3100                    ridge_beta,
3101                    rhs,
3102                    |r| p.apply(r),
3103                    pcg,
3104                    trust,
3105                    &backend,
3106                    None,
3107                    None,
3108                    None,
3109                )
3110                .ok()
3111                .map(|(_s, diag)| PrecondLadderRow {
3112                    iterations: diag.iterations,
3113                    converged: matches!(diag.stopping_reason, PcgStopReason::Converged),
3114                    final_relative_residual: diag.final_relative_residual,
3115                })
3116            })
3117    };
3118    out.push((SchurPreconditionerKind::Diagonal, diag_row));
3119
3120    let block_row =
3121        JacobiPreconditioner::from_arrow_schur(sys, &htt_factors, ridge_beta, &backend, None)
3122            .ok()
3123            .and_then(|p| run(&|r| p.apply(r)));
3124    out.push((SchurPreconditionerKind::BetaBlockJacobi, block_row));
3125
3126    let cluster_row =
3127        ClusterJacobiPreconditioner::from_arrow_schur(sys, &htt_factors, ridge_beta, &backend)
3128            .ok()
3129            .and_then(|p| run(&|r| p.apply(r)));
3130    out.push((SchurPreconditionerKind::ClusterJacobi, cluster_row));
3131
3132    let schwarz_row =
3133        AdditiveSchwarzPreconditioner::from_arrow_schur(sys, &htt_factors, ridge_beta, &backend, 1)
3134            .ok()
3135            .and_then(|p| run(&|r| p.apply(r)));
3136    out.push((
3137        SchurPreconditionerKind::AdditiveSchwarz { overlap: 1 },
3138        schwarz_row,
3139    ));
3140
3141    let diag_schwarz_row = DiagAssembledSchwarzPreconditioner::from_arrow_schur(
3142        sys,
3143        &htt_factors,
3144        ridge_beta,
3145        &backend,
3146        1,
3147    )
3148    .ok()
3149    .and_then(|p| run(&|r| p.apply(r)));
3150    out.push((
3151        SchurPreconditionerKind::DiagAssembledSchwarz { overlap: 1 },
3152        diag_schwarz_row,
3153    ));
3154
3155    let ic0_row = BlockIncompleteCholeskyPreconditioner::from_arrow_schur(
3156        sys,
3157        &htt_factors,
3158        ridge_beta,
3159        &backend,
3160    )
3161    .ok()
3162    .and_then(|p| run(&|r| p.apply(r)));
3163    out.push((SchurPreconditionerKind::BlockIncompleteCholesky, ic0_row));
3164
3165    Ok(out)
3166}
3167
3168/// Build scalar diagonal inverses for a set of global column indices.
3169///
3170/// Used when a cluster is non-PD or exceeds `CLUSTER_JACOBI_MAX_CLUSTER`.
3171pub(crate) fn build_schur_scalar_inv<B: BatchedBlockSolver>(
3172    sys: &ArrowSchurSystem,
3173    htt_factors: &ArrowFactorSlab,
3174    ridge_beta: f64,
3175    backend: &B,
3176    cols: &[usize],
3177) -> Result<Vec<f64>, ArrowSchurError> {
3178    let d = sys.d;
3179    let mut result = Vec::with_capacity(cols.len());
3180    let mut col_vec = Array1::<f64>::zeros(d);
3181    // Extract the penalty diagonal for all K columns once, then index per-column.
3182    let mut full_diag = Array1::<f64>::zeros(sys.k);
3183    {
3184        let diag_slice = full_diag.as_slice_mut().expect("full_diag contiguous");
3185        sys.penalty_diagonal_add(diag_slice);
3186    }
3187    for &gi in cols {
3188        let mut s = full_diag[gi] + ridge_beta;
3189        for (row_idx, row) in sys.rows.iter().enumerate() {
3190            for c in 0..d {
3191                col_vec[c] = row.htbeta[[c, gi]];
3192            }
3193            let solved = backend.solve_block_vector(htt_factors.factor(row_idx), col_vec.view());
3194            let mut acc = 0.0;
3195            for c in 0..d {
3196                acc += col_vec[c] * solved[c];
3197            }
3198            s -= acc;
3199        }
3200        if !s.is_finite() || s <= JACOBI_DIAGONAL_PD_FLOOR {
3201            return Err(ArrowSchurError::PcgFailed {
3202                reason: format!(
3203                    "cluster Schur scalar fallback: non-PD diagonal at index {gi}: {s}"
3204                ),
3205            });
3206        }
3207        result.push(1.0 / s);
3208    }
3209    Ok(result)
3210}
3211
3212/// Inexact PCG with automatic preconditioner-ladder escalation.
3213///
3214/// Starts with `JacobiPreconditioner` (Diagonal or BetaBlockJacobi).
3215/// If PCG hits `MaxIter` and `k > PRECOND_ESCALATE_K_THRESHOLD`,
3216/// escalates to `ClusterJacobi`; if still `MaxIter`, escalates to
3217/// `AdditiveSchwarz { overlap: 1 }`.
3218pub(crate) fn steihaug_pcg_auto<B: BatchedBlockSolver + Sync>(
3219    sys: &ArrowSchurSystem,
3220    htt_factors: &ArrowFactorSlab,
3221    ridge_beta: f64,
3222    rhs: &Array1<f64>,
3223    pcg: &ArrowPcgOptions,
3224    trust: &ArrowTrustRegionOptions,
3225    backend: &B,
3226    gpu_matvec: Option<&GpuSchurMatvec>,
3227    metric_weights: Option<&MetricWeights>,
3228    curvature_floor: Option<f64>,
3229) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurError> {
3230    // #1017 CPU residency: stage the per-row reduced-Schur factors `(L_i, Y_i)`
3231    // (NOT the dense `p×p` block — `di ≪ p`, so the factored form is `O(n·di·p)`
3232    // memory and `2·support_i·p + 2·di·p` flops/row including the sparse
3233    // gather/scatter over the active support) once, up
3234    // front, when the SAE structure is installed and the matvec runs on host
3235    // (CPU). The GPU matvec carries its own residency, so skip when it is engaged.
3236    // The same staged operator is reused across the whole preconditioner ladder
3237    // (Jacobi → ClusterJacobi → AdditiveSchwarz) — built once, not per tier.
3238    let resident = if gpu_matvec.is_none() {
3239        SaeResidentReducedSchur::build(sys, htt_factors, backend)
3240    } else {
3241        None
3242    };
3243    // #1026 — curvature-floor retry on the Jacobi tier. The unbounded SAE inner
3244    // PCG (trust radius = ∞) fails on `pᵀSp ≤ 0` when the reduced Schur is
3245    // indefinite (K≥4 co-collapse: a near-singular per-row `H_tt` over-subtracts
3246    // `S`). Instead of letting that failure propagate to the outer LM loop —
3247    // which inflates `ridge_β` over EVERY β direction and makes the inner Newton
3248    // crawl — floor the OPERATOR by the minimal ridge `δ = |pᵀSp|/‖p‖² · (1+ε)`
3249    // that restores positive curvature along the offending direction, rebuild the
3250    // Jacobi preconditioner at the lifted ridge, and retry. This is the
3251    // matrix-free analogue of the dense `spectral_pd_floored_schur`: the healthy
3252    // β subspace (where curvature is already positive) is essentially untouched
3253    // by a tiny `δ`, while the collapsed direction gets exactly the stiffness it
3254    // needs to make a real descent step. A PD reduced Schur never hits `pᵀSp ≤ 0`,
3255    // so this loop is a strict no-op there (bit-for-bit unchanged). Bounded by a
3256    // small attempt cap and a relative ridge ceiling; on exhaustion the original
3257    // recoverable failure still reaches the outer LM loop.
3258    let mut effective_ridge = ridge_beta;
3259    let mut x0_diag0: Option<(Array1<f64>, PcgDiagnostics)> = None;
3260    let mut last_curvature_err: Option<ArrowSchurError> = None;
3261    let rhs_scale = metric_norm(rhs.view(), metric_weights).max(1.0);
3262    let ridge_ceiling = ridge_beta.max(SCHUR_CURVATURE_FLOOR_REL_CEILING * rhs_scale);
3263    for _attempt in 0..=SCHUR_CURVATURE_FLOOR_MAX_ATTEMPTS {
3264        // The Jacobi preconditioner build itself refuses a non-PD Schur diagonal
3265        // (`PcgFailed: invalid Schur Jacobi diagonal`) — the SAME co-collapse
3266        // signature reached BEFORE the CG loop, since `S_ii = H_ββ,ii − Σ …` goes
3267        // negative. Treat that build failure as a curvature deficit too: when the
3268        // floor is enabled, lift the ridge and retry; otherwise propagate.
3269        let jacobi = match JacobiPreconditioner::from_arrow_schur(
3270            sys,
3271            htt_factors,
3272            effective_ridge,
3273            backend,
3274            resident.as_ref(),
3275        ) {
3276            Ok(jacobi) => jacobi,
3277            Err(err @ ArrowSchurError::PcgFailed { .. }) => {
3278                if curvature_floor.is_none() {
3279                    return Err(err);
3280                }
3281                // A diagonal refusal carries no `(curvature, ‖p‖²)` deficit, and
3282                // the over-subtraction magnitude `Σ H_tβᵀ(H_tt)⁻¹H_tβ` is
3283                // unbounded relative to `rhs_scale`, so a small additive bump
3284                // would crawl. Escalate the ridge MULTIPLICATIVELY (×10, matching
3285                // the per-row `factor_one_row_result` RIDGE_GROWTH_FACTOR), seeded
3286                // at `rhs_scale`, so even a large deficit (the collapsed
3287                // `(H_tβ)²/H_tt` over-subtraction) is reached in a handful of
3288                // attempts. The ceiling + attempt cap still bound it; on
3289                // exhaustion the recoverable failure reaches the outer LM loop.
3290                // Jump straight to a meaningful scale on the FIRST refusal rather
3291                // than crawling ×10 from a tiny `ridge_beta`: each rebuild is a full
3292                // block-Jacobi factorization (the massive-K preconditioner hotspot),
3293                // and a large collapsed deficit (`Σ H_tβᵀ(H_tt)⁻¹H_tβ` over-subtraction,
3294                // O(1)-scale) otherwise costs ~log10(deficit / ridge_beta) rebuilds.
3295                // Seeding the first bump at `rhs_scale` covers it in one or two, then
3296                // escalates multiplicatively; the ceiling + attempt cap still bound it.
3297                let next = if effective_ridge > 0.0 {
3298                    (effective_ridge * SCHUR_CURVATURE_FLOOR_DIAG_GROWTH).max(rhs_scale)
3299                } else {
3300                    rhs_scale
3301                };
3302                last_curvature_err = Some(err);
3303                if !next.is_finite() || next > ridge_ceiling {
3304                    break;
3305                }
3306                effective_ridge = next;
3307                continue;
3308            }
3309            Err(other) => return Err(other),
3310        };
3311        match run_pcg_with_preconditioner(
3312            sys,
3313            htt_factors,
3314            effective_ridge,
3315            rhs,
3316            |r| jacobi.apply(r),
3317            pcg,
3318            trust,
3319            backend,
3320            gpu_matvec,
3321            metric_weights,
3322            resident.as_ref(),
3323        ) {
3324            Ok(result) => {
3325                x0_diag0 = Some(result);
3326                break;
3327            }
3328            Err(ArrowSchurError::UnboundedNegativeCurvature {
3329                curvature,
3330                direction_norm_sq,
3331            }) => {
3332                // Only floor when the caller opted in (SAE solve path); otherwise
3333                // propagate the raw negative-curvature signal so BA / non-SAE
3334                // unbounded solves keep their existing failure contract.
3335                let Some(relative_floor) = curvature_floor else {
3336                    return Err(ArrowSchurError::UnboundedNegativeCurvature {
3337                        curvature,
3338                        direction_norm_sq,
3339                    });
3340                };
3341                // Minimal ridge to make `pᵀ(S+δI)p = |curvature| + δ·‖p‖² > 0`,
3342                // with a margin so the next CG iterate has strictly positive
3343                // curvature rather than sitting on the `0` knife-edge.
3344                let deficit = if direction_norm_sq > 0.0 {
3345                    curvature.abs() / direction_norm_sq
3346                } else {
3347                    0.0
3348                };
3349                let bump = (deficit * (1.0 + SCHUR_CURVATURE_FLOOR_MARGIN))
3350                    .max(relative_floor.max(SCHUR_CURVATURE_FLOOR_REL_FLOOR) * rhs_scale);
3351                let next = (effective_ridge + bump).max(effective_ridge * 2.0);
3352                last_curvature_err = Some(ArrowSchurError::UnboundedNegativeCurvature {
3353                    curvature,
3354                    direction_norm_sq,
3355                });
3356                if !next.is_finite() || next > ridge_ceiling {
3357                    break;
3358                }
3359                effective_ridge = next;
3360            }
3361            Err(other) => return Err(other),
3362        }
3363    }
3364    let (x0, diag0) = match x0_diag0 {
3365        Some(result) => result,
3366        None => {
3367            // The curvature floor could not condition the operator within the
3368            // ceiling; hand the recoverable failure to the outer LM loop, which
3369            // re-forms the system at a heavier ridge.
3370            return Err(last_curvature_err.unwrap_or(ArrowSchurError::PcgFailed {
3371                reason: "unbounded Schur PCG negative curvature unresolved by curvature floor"
3372                    .to_string(),
3373            }));
3374        }
3375    };
3376    if sys.k <= PRECOND_ESCALATE_K_THRESHOLD || diag0.stopping_reason != PcgStopReason::MaxIter {
3377        return Ok((x0, diag0));
3378    }
3379    // Escalation tiers reuse the curvature-floored `effective_ridge` so the
3380    // operator they precondition is the SAME (PD-floored) one the Jacobi tier
3381    // settled on; a still-negative-curvature signal here is handed to the outer
3382    // LM loop (it only arises if the floored Jacobi tier merely ran out of
3383    // iterations yet a coarser preconditioner still finds an indefinite
3384    // direction — rare; the LM loop re-forms at a heavier ridge).
3385    let cluster =
3386        ClusterJacobiPreconditioner::from_arrow_schur(sys, htt_factors, effective_ridge, backend)?;
3387    let (x1, diag1) = run_pcg_with_preconditioner(
3388        sys,
3389        htt_factors,
3390        effective_ridge,
3391        rhs,
3392        |r| cluster.apply(r),
3393        pcg,
3394        trust,
3395        backend,
3396        gpu_matvec,
3397        metric_weights,
3398        resident.as_ref(),
3399    )?;
3400    if diag1.stopping_reason != PcgStopReason::MaxIter {
3401        return Ok((x1, diag1));
3402    }
3403    let schwarz = AdditiveSchwarzPreconditioner::from_arrow_schur(
3404        sys,
3405        htt_factors,
3406        effective_ridge,
3407        backend,
3408        1,
3409    )?;
3410    let (x2, diag2) = run_pcg_with_preconditioner(
3411        sys,
3412        htt_factors,
3413        effective_ridge,
3414        rhs,
3415        |r| schwarz.apply(r),
3416        pcg,
3417        trust,
3418        backend,
3419        gpu_matvec,
3420        metric_weights,
3421        resident.as_ref(),
3422    )?;
3423    if diag2.stopping_reason != PcgStopReason::MaxIter {
3424        return Ok((x2, diag2));
3425    }
3426    // Final tier — diagonal-assembled additive Schwarz (#299), the cheap-apply
3427    // Schwarz variant. When the dense-block AdditiveSchwarz still ran out of
3428    // iterations its O(Σ b_k²) apply may have throttled the iteration budget on
3429    // a wide subdomain; the diag-assembled variant keeps Schwarz's overlapping
3430    // local-inverse conditioning but applies in O(K), so it can take more CG
3431    // iterations within the same wall budget. Same overlap (1) and same
3432    // curvature-floored ridge as the dense-block tier.
3433    let diag_schwarz = DiagAssembledSchwarzPreconditioner::from_arrow_schur(
3434        sys,
3435        htt_factors,
3436        effective_ridge,
3437        backend,
3438        1,
3439    )?;
3440    let (x3, diag3) = run_pcg_with_preconditioner(
3441        sys,
3442        htt_factors,
3443        effective_ridge,
3444        rhs,
3445        |r| diag_schwarz.apply(r),
3446        pcg,
3447        trust,
3448        backend,
3449        gpu_matvec,
3450        metric_weights,
3451        resident.as_ref(),
3452    )?;
3453    if diag3.stopping_reason != PcgStopReason::MaxIter {
3454        return Ok((x3, diag3));
3455    }
3456    // Richest tier — level-0 incomplete Cholesky (#299). ClusterJacobi keeps the
3457    // full DENSE Cholesky of each component (so on a single large connected
3458    // component it fills the whole `b×b` factor and its `O(b²)` apply throttles
3459    // the CG iteration budget), while the diagonal/Schwarz tiers drop most
3460    // inter-block coupling. IC(0) keeps the component's full structural coupling
3461    // but only the level-0 (no-fill) pattern, so its sparse triangular apply is
3462    // `O(nnz(S[C,C]))` — it can take more CG iterations within the same wall
3463    // budget AND conditions the off-diagonal coupling the cheap tiers discard.
3464    // Last in the ladder so it is only paid when every cheaper tier stalled.
3465    let ic0 = BlockIncompleteCholeskyPreconditioner::from_arrow_schur(
3466        sys,
3467        htt_factors,
3468        effective_ridge,
3469        backend,
3470    )?;
3471    let (x4, diag4) = run_pcg_with_preconditioner(
3472        sys,
3473        htt_factors,
3474        effective_ridge,
3475        rhs,
3476        |r| ic0.apply(r),
3477        pcg,
3478        trust,
3479        backend,
3480        gpu_matvec,
3481        metric_weights,
3482        resident.as_ref(),
3483    )?;
3484    // All five preconditioner tiers (Jacobi -> ClusterJacobi -> AdditiveSchwarz
3485    // -> DiagAssembledSchwarz -> BlockIncompleteCholesky) exhausted their
3486    // iteration budget without driving the residual below tolerance. Returning a
3487    // truncated iterate as `Ok` would feed an arbitrarily-large-residual step
3488    // into the Newton driver, where the PCG diagnostics are discarded. Surface a
3489    // recoverable failure instead so `solve_with_lm_escalation_inner` escalates
3490    // the proximal ridge: better conditioning is precisely what a stalled PCG on
3491    // an ill-conditioned reduced system needs.
3492    if diag4.stopping_reason == PcgStopReason::MaxIter {
3493        return Err(ArrowSchurError::PcgFailed {
3494            reason: format!(
3495                "Schur PCG exhausted all preconditioner tiers (Jacobi, ClusterJacobi, \
3496                 AdditiveSchwarz, DiagAssembledSchwarz, BlockIncompleteCholesky) at MaxIter; \
3497                 final relative residual = {:e}",
3498                diag4.final_relative_residual
3499            ),
3500        });
3501    }
3502    Ok((x4, diag4))
3503}
3504
3505/// Run Steihaug-CG with a generic preconditioner closure.
3506/// Routes matvec through GPU when `gpu_matvec` is set.
3507pub(crate) fn run_pcg_with_preconditioner<ApplyPrec, B: BatchedBlockSolver + Sync>(
3508    sys: &ArrowSchurSystem,
3509    htt_factors: &ArrowFactorSlab,
3510    ridge_beta: f64,
3511    rhs: &Array1<f64>,
3512    apply_prec: ApplyPrec,
3513    pcg: &ArrowPcgOptions,
3514    trust: &ArrowTrustRegionOptions,
3515    backend: &B,
3516    gpu_matvec: Option<&GpuSchurMatvec>,
3517    metric_weights: Option<&MetricWeights>,
3518    resident: Option<&SaeResidentReducedSchur>,
3519) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurError>
3520where
3521    ApplyPrec: FnMut(&Array1<f64>) -> Array1<f64>,
3522{
3523    let max_iters = pcg.max_iterations.min(trust.max_iterations);
3524    let tol = pcg
3525        .relative_tolerance
3526        .max(trust.steihaug_relative_tolerance);
3527    if let Some(gpu_mv) = gpu_matvec {
3528        let gpu_mv = Arc::clone(gpu_mv);
3529        steihaug_cg(
3530            rhs,
3531            move |p, out| gpu_mv(p, out),
3532            apply_prec,
3533            max_iters,
3534            tol,
3535            trust.radius,
3536            metric_weights,
3537        )
3538    } else {
3539        steihaug_cg(
3540            rhs,
3541            |p, out| schur_matvec(sys, htt_factors, ridge_beta, p, out, backend, resident),
3542            apply_prec,
3543            max_iters,
3544            tol,
3545            trust.radius,
3546            metric_weights,
3547        )
3548    }
3549}
3550
3551#[derive(Debug, Clone, Copy)]
3552pub(crate) struct IdentityPreconditioner;
3553
3554impl IdentityPreconditioner {
3555    pub(crate) fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
3556        r.clone()
3557    }
3558}
3559
3560pub(crate) fn steihaug_dense_system(
3561    schur: &Array2<f64>,
3562    rhs: &Array1<f64>,
3563    preconditioner: &IdentityPreconditioner,
3564    pcg: &ArrowPcgOptions,
3565    trust: &ArrowTrustRegionOptions,
3566    metric_weights: Option<&MetricWeights>,
3567) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurError> {
3568    steihaug_cg(
3569        rhs,
3570        |p, out| dense_matvec(schur, p, out),
3571        |r| preconditioner.apply(r),
3572        pcg.max_iterations,
3573        pcg.relative_tolerance,
3574        trust.radius,
3575        metric_weights,
3576    )
3577}
3578
3579pub(crate) fn steihaug_cg<MatVec, ApplyPrec>(
3580    rhs: &Array1<f64>,
3581    mut matvec: MatVec,
3582    mut apply_preconditioner: ApplyPrec,
3583    max_iterations: usize,
3584    relative_tolerance: f64,
3585    trust_radius: f64,
3586    metric_weights: Option<&MetricWeights>,
3587) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurError>
3588where
3589    MatVec: FnMut(&Array1<f64>, &mut Array1<f64>),
3590    ApplyPrec: FnMut(&Array1<f64>) -> Array1<f64>,
3591{
3592    let n = rhs.len();
3593    if let Some(weights) = metric_weights {
3594        assert_eq!(
3595            weights.len(),
3596            n,
3597            "Steihaug-CG metric weight length must match solve dimension"
3598        );
3599    }
3600    let radius = if trust_radius.is_finite() && trust_radius > 0.0 {
3601        trust_radius
3602    } else {
3603        f64::INFINITY
3604    };
3605    let rhs_norm = metric_norm(rhs.view(), metric_weights);
3606    if rhs_norm == 0.0 {
3607        return Ok((Array1::<f64>::zeros(n), PcgDiagnostics::default()));
3608    }
3609    let tol = (relative_tolerance.max(0.0) * rhs_norm).max(PCG_ABSOLUTE_TOLERANCE_FLOOR);
3610    let mut x = Array1::<f64>::zeros(n);
3611    let mut r = rhs.clone();
3612    let mut z = apply_preconditioner(&r);
3613    let mut diag = PcgDiagnostics {
3614        precond_apply_calls: 1,
3615        ..PcgDiagnostics::default()
3616    };
3617    let mut p = z.clone();
3618    let mut rz = metric_dot(&r, &z, metric_weights);
3619    if rz <= 0.0 || !rz.is_finite() {
3620        if radius.is_finite() {
3621            diag.final_relative_residual = metric_norm(r.view(), metric_weights) / rhs_norm;
3622            diag.stopping_reason = PcgStopReason::TrustRegion;
3623            return Ok((step_to_trust_boundary(&x, &r, radius, metric_weights), diag));
3624        }
3625        // Unbounded (radius = ∞) non-positive preconditioned residual: the
3626        // reduced Schur is indefinite at the very first direction. Surface the
3627        // typed curvature-floor signal so `steihaug_pcg_auto` floors the
3628        // operator minimally and retries, instead of failing into a global
3629        // `ridge_β` ramp. `rz = rᵀM⁻¹r` is a preconditioner-metric curvature;
3630        // report it with the residual norm² as the direction scale.
3631        return Err(ArrowSchurError::UnboundedNegativeCurvature {
3632            curvature: rz,
3633            direction_norm_sq: metric_dot(&r, &r, metric_weights),
3634        });
3635    }
3636    if metric_norm(r.view(), metric_weights) <= tol {
3637        diag.final_relative_residual = 0.0;
3638        diag.stopping_reason = PcgStopReason::Converged;
3639        return Ok((x, diag));
3640    }
3641    let mut ap = Array1::<f64>::zeros(n);
3642    // Reused candidate scratch — avoid per-iteration clone of x.
3643    let mut candidate = Array1::<f64>::zeros(n);
3644    for _ in 0..max_iterations {
3645        matvec(&p, &mut ap);
3646        diag.matvec_calls += 1;
3647        diag.iterations += 1;
3648        let pap = metric_dot(&p, &ap, metric_weights);
3649        if pap <= 0.0 || !pap.is_finite() {
3650            if radius.is_finite() {
3651                diag.final_relative_residual = metric_norm(r.view(), metric_weights) / rhs_norm;
3652                diag.stopping_reason = PcgStopReason::TrustRegion;
3653                return Ok((step_to_trust_boundary(&x, &p, radius, metric_weights), diag));
3654            }
3655            // Unbounded negative curvature `pᵀSp ≤ 0`: the reduced Schur is
3656            // indefinite along `p` (the #1026 co-collapse direction). Surface
3657            // the typed signal carrying `pᵀSp` and `‖p‖²` so the caller floors
3658            // the operator by the minimal ridge `δ = |pᵀSp|/‖p‖²` (which makes
3659            // `pᵀ(S+δI)p = 0⁺`) plus a margin, and retries.
3660            return Err(ArrowSchurError::UnboundedNegativeCurvature {
3661                curvature: pap,
3662                direction_norm_sq: metric_dot(&p, &p, metric_weights),
3663            });
3664        }
3665        let alpha = rz / pap;
3666        for i in 0..n {
3667            candidate[i] = x[i] + alpha * p[i];
3668        }
3669        if radius.is_finite() && metric_norm(candidate.view(), metric_weights) >= radius {
3670            diag.final_relative_residual = metric_norm(r.view(), metric_weights) / rhs_norm;
3671            diag.stopping_reason = PcgStopReason::TrustRegion;
3672            return Ok((step_to_trust_boundary(&x, &p, radius, metric_weights), diag));
3673        }
3674        x.assign(&candidate);
3675        for i in 0..n {
3676            r[i] -= alpha * ap[i];
3677        }
3678        if metric_norm(r.view(), metric_weights) <= tol {
3679            diag.final_relative_residual = metric_norm(r.view(), metric_weights) / rhs_norm;
3680            diag.stopping_reason = PcgStopReason::Converged;
3681            return Ok((x, diag));
3682        }
3683        z = apply_preconditioner(&r);
3684        diag.precond_apply_calls += 1;
3685        let rz_next = metric_dot(&r, &z, metric_weights);
3686        if rz_next <= 0.0 || !rz_next.is_finite() {
3687            return Err(ArrowSchurError::PcgFailed {
3688                reason: "non-positive or non-finite PCG residual".to_string(),
3689            });
3690        }
3691        let beta = rz_next / rz;
3692        for i in 0..n {
3693            p[i] = z[i] + beta * p[i];
3694        }
3695        rz = rz_next;
3696    }
3697    diag.final_relative_residual = metric_norm(r.view(), metric_weights) / rhs_norm;
3698    diag.stopping_reason = PcgStopReason::MaxIter;
3699    Ok((x, diag))
3700}
3701
3702pub(crate) fn step_to_trust_boundary(
3703    x: &Array1<f64>,
3704    p: &Array1<f64>,
3705    radius: f64,
3706    metric_weights: Option<&MetricWeights>,
3707) -> Array1<f64> {
3708    let pp = metric_dot(p, p, metric_weights);
3709    if pp == 0.0 {
3710        return x.clone();
3711    }
3712    let xp = metric_dot(x, p, metric_weights);
3713    let xx = metric_dot(x, x, metric_weights);
3714    let disc = (xp * xp + pp * (radius * radius - xx)).max(0.0);
3715    let tau = (-xp + disc.sqrt()) / pp;
3716    let mut out = x.clone();
3717    for i in 0..out.len() {
3718        out[i] += tau * p[i];
3719    }
3720    out
3721}
3722
3723pub(crate) fn dense_matvec(a: &Array2<f64>, x: &Array1<f64>, out: &mut Array1<f64>) {
3724    let n = a.nrows();
3725    for i in 0..n {
3726        let mut acc = 0.0;
3727        for j in 0..n {
3728            acc += a[[i, j]] * x[j];
3729        }
3730        out[i] = acc;
3731    }
3732}
3733
3734pub(crate) fn dot(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
3735    let mut acc = 0.0;
3736    for i in 0..a.len() {
3737        acc += a[i] * b[i];
3738    }
3739    acc
3740}
3741
3742pub(crate) fn metric_dot(
3743    a: &Array1<f64>,
3744    b: &Array1<f64>,
3745    metric_weights: Option<&MetricWeights>,
3746) -> f64 {
3747    assert_eq!(a.len(), b.len());
3748    match metric_weights {
3749        Some(weights) => {
3750            assert_eq!(weights.len(), a.len());
3751            let mut acc = 0.0;
3752            for i in 0..a.len() {
3753                acc += weights[i] * a[i] * b[i];
3754            }
3755            acc
3756        }
3757        None => dot(a, b),
3758    }
3759}
3760
3761pub(crate) fn metric_norm(v: ArrayView1<'_, f64>, metric_weights: Option<&MetricWeights>) -> f64 {
3762    let mut acc = 0.0;
3763    match metric_weights {
3764        Some(weights) => {
3765            assert_eq!(weights.len(), v.len());
3766            for i in 0..v.len() {
3767                acc += weights[i] * v[i] * v[i];
3768            }
3769        }
3770        None => {
3771            for x in v.iter() {
3772                acc += x * x;
3773            }
3774        }
3775    }
3776    acc.sqrt()
3777}
3778
3779pub(crate) fn symmetrize_upper_from_lower(a: &mut Array2<f64>) {
3780    let n = a.nrows().min(a.ncols());
3781    for i in 0..n {
3782        for j in 0..i {
3783            let v = 0.5 * (a[[i, j]] + a[[j, i]]);
3784            a[[i, j]] = v;
3785            a[[j, i]] = v;
3786        }
3787    }
3788}
3789
3790/// Errors raised by [`ArrowSchurSystem::solve`].
3791#[derive(Debug, Clone)]
3792pub enum ArrowSchurError {
3793    /// A per-row `H_tt^(i)` block was not positive-definite at the
3794    /// supplied ridge. Indicates an under-regularized latent block —
3795    /// typically a gauge-free fit without an identifiability penalty.
3796    PerRowFactorFailed { row: usize, reason: String },
3797    /// A per-row `H_tt^(i)` block factored, but the Cholesky factor failed
3798    /// the safe-inversion guard for the Schur reduction. This can be either
3799    /// an excessive diagonal-ratio condition-number estimate or a numerically
3800    /// tiny pivot relative to the row block scale. Cholesky technically
3801    /// succeeded, but the inverse used in
3802    /// `S = H_ββ − Σ_i H_tβ^(i)ᵀ (H_tt^(i))⁻¹ H_tβ^(i)` is contaminated
3803    /// by spectral terms on the order of `κ_i`; functionally
3804    /// equivalent to a PSD-fail for Schur stability. The LM outer
3805    /// wrapper escalates `ridge_t` identically to `PerRowFactorFailed`.
3806    PerRowFactorIllConditioned { row: usize, kappa_estimate: f64 },
3807    /// The Schur complement was not positive-definite. Indicates a
3808    /// near-collinear decoder or a degenerate weighting; the LM outer
3809    /// wrapper should escalate `ridge_beta` and retry.
3810    SchurFactorFailed { reason: String },
3811    /// The BA inexact-step PCG solve failed before producing a usable
3812    /// Steihaug trust-region step.
3813    PcgFailed { reason: String },
3814    /// The UNBOUNDED (trust-radius = ∞) Schur PCG encountered negative
3815    /// curvature `pᵀSp ≤ 0` (or a non-positive preconditioned residual): the
3816    /// reduced Schur is indefinite, the #1026 K≥4 co-collapse signature where
3817    /// a near-singular per-row `H_tt` over-subtracts `S`. With no trust radius
3818    /// there is no boundary to step to, so CG cannot proceed. `curvature` is
3819    /// the offending `pᵀSp` and `direction_norm_sq` the `‖p‖²` of the
3820    /// negative-curvature direction; the caller floors the operator with the
3821    /// minimal ridge `δ = (|curvature|/‖p‖² )·(1+ε)` that restores positive
3822    /// curvature along `p` and retries (matrix-free analogue of the dense
3823    /// `spectral_pd_floored_schur`), rather than blindly inflating `ridge_β`.
3824    UnboundedNegativeCurvature {
3825        curvature: f64,
3826        direction_norm_sq: f64,
3827    },
3828    /// Adaptive proximal damping could not produce an Armijo-accepted
3829    /// nonlinear step.
3830    AdaptiveCorrectionFailed { reason: String },
3831}
3832
3833impl std::fmt::Display for ArrowSchurError {
3834    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
3835        match self {
3836            ArrowSchurError::PerRowFactorFailed { row, reason } => write!(
3837                f,
3838                "arrow-Schur: per-row H_tt^({row}) Cholesky failed: {reason}"
3839            ),
3840            ArrowSchurError::PerRowFactorIllConditioned {
3841                row,
3842                kappa_estimate,
3843            } => write!(
3844                f,
3845                "arrow-Schur: per-row H_tt^({row}) Cholesky succeeded but failed \
3846                 the safe-inversion guard (kappa_estimate={kappa_estimate:e}); \
3847                 Schur reduction would be numerically contaminated"
3848            ),
3849            ArrowSchurError::SchurFactorFailed { reason } => {
3850                write!(f, "arrow-Schur: Schur complement Cholesky failed: {reason}")
3851            }
3852            ArrowSchurError::PcgFailed { reason } => {
3853                write!(f, "arrow-Schur: Schur PCG failed: {reason}")
3854            }
3855            ArrowSchurError::UnboundedNegativeCurvature {
3856                curvature,
3857                direction_norm_sq,
3858            } => write!(
3859                f,
3860                "arrow-Schur: unbounded Schur PCG hit negative curvature pᵀSp={curvature:e} \
3861                 (‖p‖²={direction_norm_sq:e}); reduced Schur is indefinite (co-collapse), \
3862                 retry with a curvature-floor ridge"
3863            ),
3864            ArrowSchurError::AdaptiveCorrectionFailed { reason } => {
3865                write!(
3866                    f,
3867                    "arrow-Schur: adaptive proximal correction failed: {reason}"
3868                )
3869            }
3870        }
3871    }
3872}
3873
3874impl std::error::Error for ArrowSchurError {}
3875
3876// ---------------------------------------------------------------------------
3877// Cholesky helpers (kept local to avoid a new public-API dependency on the
3878// linalg crate. The systems here are tiny per-row (d × d, d ∈ {1..16}) and
3879// modest at the Schur level (K × K, K ∈ {basis size}). For production SAE
3880// scales the Schur factor should switch to faer; this module's `cholesky_lower`
3881// is the obvious replacement site.)
3882// ---------------------------------------------------------------------------
3883
3884pub(crate) fn cholesky_lower(a: &Array2<f64>) -> Result<Array2<f64>, String> {
3885    let n = a.nrows();
3886    if a.ncols() != n {
3887        return Err(format!("cholesky_lower: non-square {}×{}", n, a.ncols()));
3888    }
3889    if let Some((idx, _)) = a.iter().enumerate().find(|(_, v)| !v.is_finite()) {
3890        return Err(format!(
3891            "cholesky_lower: non-finite entry at linear index {idx}"
3892        ));
3893    }
3894
3895    let mut maybe_device = a.clone();
3896    if gam_gpu::try_cholesky_lower_inplace(&mut maybe_device).is_some() {
3897        return Ok(maybe_device);
3898    }
3899
3900    let mut l = Array2::<f64>::zeros((n, n));
3901    for i in 0..n {
3902        for j in 0..=i {
3903            let mut sum = a[[i, j]];
3904            for kk in 0..j {
3905                sum -= l[[i, kk]] * l[[j, kk]];
3906            }
3907            if i == j {
3908                if !sum.is_finite() || sum <= 0.0 {
3909                    return Err(format!(
3910                        "non-PD pivot {sum} at index {i} (matrix is not positive definite)"
3911                    ));
3912                }
3913                l[[i, j]] = sum.sqrt();
3914            } else {
3915                l[[i, j]] = sum / l[[j, j]];
3916            }
3917        }
3918    }
3919    Ok(l)
3920}