Skip to main content

gam_solve/arrow_schur/
reduced_solve.rs

1//! The reduced `K x K` shared-system solve: dense Schur assembly (direct and
2//! square-root BA), the Schur matvec, the Jacobi/cluster/Schwarz
3//! preconditioners, Steihaug-PCG, and the [`ArrowSchurError`] type.
4
5use super::*;
6
7/// Reduce one contiguous device tile's rows into a private `-Σ leftᵀ·right`
8/// partial (`k×k`).
9///
10/// The tile stacks its per-row `left_i` / `right_i` factors (each `d×k`) into
11/// two `(Σ_i d_i × k)` matrices and tries a single per-ordinal `AᵀB` device
12/// GEMM (`gam_gpu::try_fast_atb_on_ordinal`), which runs on the device this
13/// worker thread already bound — one big GPU GEMM per tile rather than `n` small
14/// CPU ones. When the device primitive declines (no GPU, shape below policy,
15/// transient failure) the tile reduces with the exact CPU `block_gemm_subtract`
16/// loop, so the result is unchanged. The partial is negated so the caller's
17/// `schur += partial` reproduces the serial `schur -= Σ contribution`.
18pub(crate) fn tile_schur_partial<B: BatchedBlockSolver>(
19    sys: &ArrowSchurSystem,
20    htt_factors: &ArrowFactorSlab,
21    backend: &B,
22    kind: SchurReductionKind,
23    ordinal: usize,
24    range: Range<usize>,
25) -> Result<Array2<f64>, ArrowSchurError> {
26    let k = sys.k;
27
28    // Build the per-row contribution factors once; both the GPU stacked-GEMM
29    // and the CPU fallback consume them.
30    let mut factors: Vec<(Array2<f64>, Array2<f64>)> = Vec::with_capacity(range.len());
31    let mut total_d = 0usize;
32    for i in range.clone() {
33        let (left, right) = row_schur_contribution_factors(
34            sys,
35            i,
36            &sys.rows[i],
37            htt_factors.factor(i),
38            backend,
39            kind,
40        )?;
41        total_d += left.nrows();
42        factors.push((left, right));
43    }
44
45    // Stack into (total_d × k) left/right matrices for one device AᵀB GEMM on
46    // this tile's bound ordinal. `try_fast_atb_on_ordinal` returns leftᵀ·right
47    // (k×k); negate into the partial. At an SAE-shaped whole-fit tile with
48    // n=2000 rows, k=2048 shared columns, M=12 local rows per observation, and
49    // K=8 candidate/atom batches, the stacked GEMM is
50    // 2*(n*M)*k^2 = 201_326_592_000 flops per batch, or
51    // 1_610_612_736_000 flops across K=8, so the policy work gate is cleared
52    // even though the observation count is far below the old row floor.
53    if total_d > 0 && k > 0 {
54        let mut left_stack = Array2::<f64>::zeros((total_d, k));
55        let mut right_stack = Array2::<f64>::zeros((total_d, k));
56        let mut base = 0usize;
57        for (left, right) in &factors {
58            let di = left.nrows();
59            left_stack
60                .slice_mut(ndarray::s![base..base + di, ..])
61                .assign(left);
62            right_stack
63                .slice_mut(ndarray::s![base..base + di, ..])
64                .assign(right);
65            base += di;
66        }
67        if let Some(product) =
68            gam_gpu::try_fast_atb_on_ordinal(ordinal, left_stack.view(), right_stack.view())
69        {
70            return Ok(product.mapv(|v| -v));
71        }
72    }
73
74    // CPU fallback: exact per-row block_gemm_subtract into a zero-seeded partial.
75    let mut partial = Array2::<f64>::zeros((k, k));
76    for (left, right) in &factors {
77        backend.block_gemm_subtract(&mut partial, left, right);
78    }
79    Ok(partial)
80}
81
82/// Reduce the per-row Schur contributions `Σ_i H_tβ^(i)ᵀ (H_tt^(i))⁻¹ H_tβ^(i)`
83/// out of `schur` (seeded with `H_ββ + ρ_β·I`).
84///
85/// The per-row contributions are independent — exactly the "sum over independent
86/// arrow-tip blocks" axis the device pool partitions. When more than one GPU is
87/// usable, [`gam_gpu::pool::balanced_partition`] splits the `0..n` rows into
88/// per-device contiguous tiles; each tile is reduced on its own scoped thread
89/// (binding that ordinal's context so the per-row GEMM-subtract offloads to its
90/// device) into a private `k×k` partial, and the partials are summed back into
91/// `schur` in tile order. The tiles are contiguous, ordered to cover `0..n`, and
92/// folded back in that same order, so within each tile the per-row accumulation
93/// order is preserved and the only departure from the serial loop is the
94/// inter-tile reassociation of the reduction sum — the established
95/// reduction-order equivalence the device pool already operates under, well
96/// inside the Newton solve's tolerance.
97///
98/// With a single device (or no GPU) the row loop runs serially in place, which
99/// is bit-for-bit the original behaviour.
100pub(crate) fn reduce_row_schur_contributions<B: BatchedBlockSolver + Sync>(
101    sys: &ArrowSchurSystem,
102    htt_factors: &ArrowFactorSlab,
103    backend: &B,
104    kind: SchurReductionKind,
105    schur: &mut Array2<f64>,
106) -> Result<(), ArrowSchurError> {
107    let n = sys.rows.len();
108    let k = sys.k;
109
110    let tiles = gam_gpu::device_runtime::GpuRuntime::global()
111        .map(|rt| gam_gpu::pool::balanced_partition(rt, n))
112        .filter(|tiles| tiles.len() > 1);
113
114    let Some(tiles) = tiles else {
115        // Single-device / CPU. The per-row contributions `-Σ_i leftᵀ·right` fold
116        // into the `k×k` `schur` independently — the same dense-assembly axis the
117        // multi-GPU tile path partitions, and the dense-Direct analog of the
118        // per-row matvec / streaming `accumulate_chunk` loops already parallelized
119        // for #1017. At the SAE Direct-solve shape (`n` in the thousands, wide
120        // border `k`) this O(n·d·k²) reduction is the dense assembly's whole cost
121        // and was the last serial CPU step on the dense-Schur build.
122        //
123        // Fan it across rayon over fixed row chunks: each chunk reduces its rows
124        // (in row order) into a private zero-seeded `k×k` partial, then the
125        // partials are folded into `schur` in CHUNK order. The per-chunk row order
126        // and the inter-chunk fold order are both fixed independent of thread
127        // scheduling, so the f64 reduction is **bit-identical run-to-run** (the
128        // #1017 determinism gate). NOTE: bit-identical run-to-run does NOT make
129        // it bit-identical to the in-place serial loop — the chunk-boundary
130        // reassociation of the reduction sum is a genuine f64 departure (the
131        // established equivalence `accumulate_chunk` / the per-row matvec operate
132        // under, well inside the Newton solve's tolerance). It bounds candidate-
133        // to-candidate drift to that reassociation margin, so the criterion
134        // ranking is stable EXCEPT for candidates tying within the margin, where
135        // the winner can flip; it is not an exact no-move guarantee (#1211). For
136        // an exact-order guarantee, take the serial path. Stay in-place serial
137        // below the row floor and when already inside a rayon worker (the topology
138        // race fans candidates with `run_topology_race_parallel`) to avoid
139        // nested-rayon oversubscription — the same guard the matvec uses.
140        let n_rows = sys.rows.len();
141        let parallel =
142            n_rows >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
143        if parallel {
144            use rayon::prelude::*;
145            const CHUNK: usize = 64;
146            let partials: Result<Vec<Array2<f64>>, ArrowSchurError> = (0..n_rows)
147                .into_par_iter()
148                .chunks(CHUNK)
149                .map(|idxs| {
150                    let mut partial = Array2::<f64>::zeros((k, k));
151                    for i in idxs {
152                        subtract_row_schur_contribution(
153                            sys,
154                            i,
155                            &sys.rows[i],
156                            htt_factors.factor(i),
157                            backend,
158                            kind,
159                            &mut partial,
160                        )?;
161                    }
162                    Ok(partial)
163                })
164                .collect();
165            // Deterministic ordered fold: chunk partials hold `-Σ contribution`
166            // over their rows, so `schur += partial` reproduces the serial
167            // `schur -= Σ contribution` in fixed (chunk, a, b) order.
168            for partial in &partials? {
169                for a in 0..k {
170                    for b in 0..k {
171                        schur[[a, b]] += partial[[a, b]];
172                    }
173                }
174            }
175            return Ok(());
176        }
177        // Serial in-place reduction (original order) — bit-for-bit reference.
178        for (i, row) in sys.rows.iter().enumerate() {
179            subtract_row_schur_contribution(
180                sys,
181                i,
182                row,
183                htt_factors.factor(i),
184                backend,
185                kind,
186                schur,
187            )?;
188        }
189        return Ok(());
190    };
191
192    // Multi-GPU: one private `-Σ leftᵀ·right` partial per contiguous device
193    // tile. Each tile runs on its own scoped worker thread that binds its
194    // ordinal's context and issues a single stacked AᵀB GEMM on that device, so
195    // the tiles' GEMMs overlap across the pool. Folding the partials back into
196    // the H_ββ-seeded `schur` reproduces the serial reduction (up to inter-tile
197    // reassociation).
198    let partials: Result<Vec<Array2<f64>>, ArrowSchurError> = std::thread::scope(|scope| {
199        let handles: Vec<_> = tiles
200            .iter()
201            .map(|(ordinal, range)| {
202                let ordinal = *ordinal;
203                let range = range.clone();
204                scope.spawn(move || {
205                    // Bind this ordinal's CUDA context on this worker thread so
206                    // the per-row GPU GEMM shims issued from `tile_schur_partial`
207                    // offload to that device. A missing context or bind failure
208                    // is intentionally consumed without escalation — the shims
209                    // no-op back to CPU and the math is unchanged. Off Linux
210                    // `GpuRuntime::global()` is always `None`, so this branch
211                    // is unreachable and the bind is omitted entirely.
212                    #[cfg(target_os = "linux")]
213                    {
214                        if let Some(ctx) = gam_gpu::device_runtime::cuda_context_for(ordinal) {
215                            if ctx.bind_to_thread().is_err() {
216                                // Fall through: this tile reduces on the CPU.
217                            }
218                        }
219                    }
220                    tile_schur_partial(sys, htt_factors, backend, kind, ordinal, range)
221                })
222            })
223            .collect();
224        handles
225            .into_iter()
226            .map(|handle| {
227                handle
228                    .join()
229                    .map_err(|_| ArrowSchurError::SchurFactorFailed {
230                        reason: "schur-reduction tile thread panicked".to_string(),
231                    })?
232            })
233            .collect()
234    });
235    let partials = partials?;
236
237    // Fold partials into `schur` in tile order (contiguous, covering 0..n) so
238    // the per-tile and inter-tile accumulation order is the row order; each
239    // partial holds `-Σ contribution` over its rows, so `schur += partial`
240    // reproduces `schur -= Σ contribution`.
241    for partial in &partials {
242        for a in 0..k {
243            for b in 0..k {
244                schur[[a, b]] += partial[[a, b]];
245            }
246        }
247    }
248    Ok(())
249}
250
251pub(crate) fn build_dense_schur_direct<B: BatchedBlockSolver + Sync>(
252    sys: &ArrowSchurSystem,
253    htt_factors: &ArrowFactorSlab,
254    ridge_beta: f64,
255    backend: &B,
256) -> Result<Array2<f64>, ArrowSchurError> {
257    let k = sys.k;
258    // Materialise H_ββ via the BetaPenaltyOp trait (#296): DensePenaltyOp
259    // for the legacy dense path, structured ops for SAE / Kronecker smooths.
260    let op = sys.effective_penalty_op();
261    if op.dim() != k {
262        return Err(ArrowSchurError::SchurFactorFailed {
263            reason: "Direct BA requires a K×K shared H_ββ penalty operator".to_string(),
264        });
265    }
266    let mut schur = op.to_dense();
267    for j in 0..k {
268        schur[[j, j]] += ridge_beta;
269    }
270    reduce_row_schur_contributions(
271        sys,
272        htt_factors,
273        backend,
274        SchurReductionKind::Direct,
275        &mut schur,
276    )?;
277    symmetrize_upper_from_lower(&mut schur);
278    Ok(schur)
279}
280
281pub(crate) fn build_dense_schur_sqrt_ba<B: BatchedBlockSolver + Sync>(
282    sys: &ArrowSchurSystem,
283    htt_factors: &ArrowFactorSlab,
284    ridge_beta: f64,
285    backend: &B,
286) -> Result<Array2<f64>, ArrowSchurError> {
287    let k = sys.k;
288    // Materialise H_ββ via the BetaPenaltyOp trait (#296).
289    let op = sys.effective_penalty_op();
290    if op.dim() != k {
291        return Err(ArrowSchurError::SchurFactorFailed {
292            reason: "Square-Root BA direct solve requires a K×K shared H_ββ penalty operator"
293                .to_string(),
294        });
295    }
296    let mut schur = op.to_dense();
297    for j in 0..k {
298        schur[[j, j]] += ridge_beta;
299    }
300    reduce_row_schur_contributions(
301        sys,
302        htt_factors,
303        backend,
304        SchurReductionKind::SqrtBa,
305        &mut schur,
306    )?;
307    symmetrize_upper_from_lower(&mut schur);
308    Ok(schur)
309}
310
311/// Certified Carson–Higham mixed-precision solve of the reduced dense Schur
312/// system `S Δβ = rhs` (#1014), specialized to the streaming/residency path.
313///
314/// Returns `Some(Δβ)` when certified mixed precision is enabled AND the κ gate
315/// admits the f32 factorization AND the f64 backward-error certificate closes;
316/// `None` in every other case so the caller falls back to the exact f64
317/// triangular solve. The f64 `factor` (whose diagonal carries the exact
318/// `log|S|`) is supplied by the caller and never re-derived here — the logdet
319/// the evidence path reads stays f64 by construction.
320///
321/// Method: store the f64 Cholesky factor as f32, solve in f32, then refine with
322/// residuals `r = rhs − S·x` computed in f64 against the f64 `S`. With
323/// `κ(S)·u_f32 < margin` the refinement contracts at rate `κ·u`, and the
324/// terminating certificate is the normwise backward error
325/// `‖r‖∞ / (‖S‖∞‖x‖∞ + ‖rhs‖∞) ≤ tol`. A non-decreasing residual or an
326/// unmet certificate after `max_refinement_steps` returns `None`.
327pub(crate) fn mixed_precision_reduced_beta(
328    schur: &Array2<f64>,
329    factor: &Array2<f64>,
330    rhs: &Array1<f64>,
331    options: &ArrowSolveOptions,
332) -> Option<Array1<f64>> {
333    let ArrowSolvePrecisionPolicy::CertifiedMixed {
334        max_refinement_steps,
335        residual_relative_tolerance,
336        kappa_unit_roundoff_margin,
337    } = options.solve_precision
338    else {
339        return None;
340    };
341    // The reduced-system mixed-precision path is the dense reduced solve only;
342    // a trust-region-truncated step takes the Steihaug branch below in f64.
343    if options.trust_region.radius.is_finite() {
344        return None;
345    }
346    let n = schur.nrows();
347    if n == 0 {
348        return None;
349    }
350
351    // κ gate: the f32 factorization is only admissible when κ(S)·u_f32 leaves
352    // the refinement contraction headroom the certificate needs.
353    let kappa = cholesky_factor_kappa_estimate(factor);
354    if !kappa.is_finite() || kappa * F32_UNIT_ROUNDOFF >= kappa_unit_roundoff_margin {
355        return None;
356    }
357
358    let factor_f32 = factor.mapv(|v| v as f32);
359    let s_inf = matrix_inf_norm(schur);
360    let rhs_inf = rhs.iter().fold(0.0_f64, |a, &b| a.max(b.abs()));
361    let certificate_tol = residual_relative_tolerance
362        .max(MIXED_PRECISION_CERTIFICATE_EPSILON_MULTIPLIER * f64::EPSILON);
363
364    // f32 solve of the seed system, then f64-residual refinement steps.
365    let mut x = cholesky_solve_lower_f32(&factor_f32, &rhs.mapv(|v| v as f32)).mapv(|v| v as f64);
366    let mut last_residual = f64::INFINITY;
367    for _ in 0..=max_refinement_steps {
368        // Residual r = rhs − S·x in f64 against the f64 model.
369        let sx = schur.dot(&x);
370        let mut r = rhs.clone();
371        r -= &sx;
372        let r_inf = r.iter().fold(0.0_f64, |a, &b| a.max(b.abs()));
373        let x_inf = x.iter().fold(0.0_f64, |a, &b| a.max(b.abs()));
374        let denom = s_inf * x_inf + rhs_inf;
375        let backward_error = if denom > 0.0 { r_inf / denom } else { 0.0 };
376        if backward_error <= certificate_tol {
377            return Some(x);
378        }
379        // Refinement must make monotone progress, else hand back to f64.
380        if !(r_inf < last_residual) {
381            return None;
382        }
383        last_residual = r_inf;
384        // Correction solve in f32 against the f32 factor: S·δ = r.
385        let delta = cholesky_solve_lower_f32(&factor_f32, &r.mapv(|v| v as f32)).mapv(|v| v as f64);
386        x += &delta;
387    }
388    None
389}
390
391/// Infinity norm (max absolute row sum) of a dense matrix.
392pub(crate) fn matrix_inf_norm(a: &Array2<f64>) -> f64 {
393    let mut max_row = 0.0_f64;
394    for row in a.rows() {
395        let s: f64 = row.iter().map(|v| v.abs()).sum();
396        if s > max_row {
397            max_row = s;
398        }
399    }
400    max_row
401}
402
403/// Spectral positive-definiteness floor for the reduced Schur complement
404/// `S` (#1026 SAE co-collapse SOLVE-path cure).
405///
406/// Reached only after the genuine Cholesky of `S` has REFUSED it (an indefinite
407/// reduced Schur: collapsed atoms drive a per-row `H_tt` near-singular, so the
408/// accumulated `Σ_i H_tβᵀ (H_tt)⁻¹ H_tβ` over-subtracts `H_ββ + ridge_β·I` into a
409/// matrix with a non-positive eigenvalue). Rather than reject and let the LM
410/// loop inflate `ridge_β` over EVERY β direction (the #1026 "crawl"), we
411/// symmetric-eigendecompose `S` and clamp every eigenvalue UP to
412/// `floor·max(λ)`. This is Levenberg–Marquardt restricted to exactly the
413/// indefinite/collapsed subspace: a well-separated positive direction
414/// (`λ ≫ floor·max λ`) keeps its EXACT eigenvalue (`λ.max(floor·max λ) = λ`), so
415/// the Newton step in the healthy β subspace is unchanged, while only the
416/// collapsed directions get the minimal positive stiffness needed for a PD
417/// solve. Returns the floored, symmetric, strictly-PD matrix, or `None` if `S`
418/// has no usable scale (non-finite / all-zero spectrum), in which case the
419/// caller keeps the strict refusal.
420///
421/// Mirrors the per-row evidence floor
422/// [`super::factorization::factor_spectral_deflated_evidence_row`]; the only
423/// difference is the floored VALUE — a small positive `floor·max λ` (Tikhonov,
424/// for an accurate solve) here, vs unit stiffness `+1` (`log 1 = 0`) there (for
425/// the quotient log-det).
426pub(crate) fn spectral_pd_floored_schur(
427    schur: &Array2<f64>,
428    relative_floor: f64,
429) -> Option<Array2<f64>> {
430    let n = schur.nrows();
431    if n == 0 || schur.ncols() != n || !(relative_floor.is_finite() && relative_floor > 0.0) {
432        return None;
433    }
434    // Symmetrise defensively (the assembled Schur is symmetric up to reduction
435    // order; the eig routine assumes exact symmetry).
436    let mut sym = Array2::<f64>::zeros((n, n));
437    for i in 0..n {
438        for j in 0..n {
439            let v = 0.5 * (schur[[i, j]] + schur[[j, i]]);
440            if !v.is_finite() {
441                return None;
442            }
443            sym[[i, j]] = v;
444        }
445    }
446    let (evals, evecs) = sym.eigh(Side::Lower).ok()?;
447    let max_abs = evals.iter().fold(
448        0.0_f64,
449        |acc, &v| if v.is_finite() { acc.max(v.abs()) } else { acc },
450    );
451    if !(max_abs.is_finite() && max_abs > 0.0) {
452        return None;
453    }
454    let floor = relative_floor * max_abs;
455    // Reconstruct `Σ_i max(λ_i, floor) v_i v_iᵀ`: clamp every eigenvalue UP to a
456    // strictly positive `floor`. Healthy positive directions (`λ ≫ floor`) are
457    // untouched; non-positive / tiny collapsed directions are lifted to exactly
458    // `floor`. The result is symmetric PD by construction.
459    let mut conditioned = Array2::<f64>::zeros((n, n));
460    for eig_idx in 0..evals.len() {
461        let lambda = evals[eig_idx];
462        let lambda_floored = if lambda.is_finite() {
463            lambda.max(floor)
464        } else {
465            floor
466        };
467        for i in 0..n {
468            let vi = evecs[[i, eig_idx]];
469            if vi == 0.0 {
470                continue;
471            }
472            for j in 0..n {
473                conditioned[[i, j]] += lambda_floored * vi * evecs[[j, eig_idx]];
474            }
475        }
476    }
477    Some(conditioned)
478}
479
480pub(crate) fn solve_dense_reduced_system(
481    schur: &Array2<f64>,
482    rhs_beta: &Array1<f64>,
483    options: &ArrowSolveOptions,
484    metric_weights: Option<&MetricWeights>,
485) -> Result<(Array1<f64>, Option<Array2<f64>>, PcgDiagnostics), ArrowSchurError> {
486    let factor = match cholesky_lower(schur) {
487        Ok(factor) => factor,
488        Err(e) => {
489            // #1026 — opt-in spectral PD-floor on the indefinite reduced Schur.
490            // When enabled (SAE solve path), condition ONLY the collapsed
491            // directions and re-factor, instead of erroring out and letting the
492            // outer LM loop inflate `ridge_β` over every β direction (the
493            // co-collapse "crawl"). Disabled (default `None`) keeps the strict
494            // refusal so BA / non-SAE callers are bit-for-bit unchanged.
495            match options.schur_pd_floor {
496                Some(relative_floor) => match spectral_pd_floored_schur(schur, relative_floor) {
497                    Some(floored) => match cholesky_lower(&floored) {
498                        Ok(factor) => {
499                            // Solve against the floored (PD) Schur. The healthy β
500                            // subspace keeps its exact eigenvalues, so its Δβ is
501                            // the exact Newton component; only the collapsed
502                            // subspace is minimally damped.
503                            let direct =
504                                mixed_precision_reduced_beta(&floored, &factor, rhs_beta, options)
505                                    .unwrap_or_else(|| cholesky_solve_vector(&factor, rhs_beta));
506                            if step_inside_trust_region(
507                                direct.view(),
508                                options.trust_region.radius,
509                                metric_weights,
510                            ) {
511                                return Ok((direct, Some(factor), PcgDiagnostics::default()));
512                            }
513                            let identity = IdentityPreconditioner;
514                            let (delta, diag) = steihaug_dense_system(
515                                &floored,
516                                rhs_beta,
517                                &identity,
518                                &ArrowPcgOptions {
519                                    max_iterations: options.trust_region.max_iterations,
520                                    relative_tolerance: options
521                                        .trust_region
522                                        .steihaug_relative_tolerance,
523                                },
524                                &options.trust_region,
525                                metric_weights,
526                            )?;
527                            return Ok((delta, Some(factor), diag));
528                        }
529                        Err(floored_err) => {
530                            return Err(ArrowSchurError::SchurFactorFailed {
531                                reason: format!(
532                                    "reduced Schur non-PD ({e}); spectral PD-floor \
533                                     reconstruction still non-PD: {floored_err}"
534                                ),
535                            });
536                        }
537                    },
538                    None => {
539                        return Err(ArrowSchurError::SchurFactorFailed {
540                            reason: format!(
541                                "reduced Schur non-PD ({e}); spectral PD-floor declined \
542                                 (no usable spectrum)"
543                            ),
544                        });
545                    }
546                },
547                None => return Err(ArrowSchurError::SchurFactorFailed { reason: e }),
548            }
549        }
550    };
551    // Ill-conditioned-but-PD Schur guard. The per-row factor checks reject
552    // any single barely-PD H_tt^(i) block, but the reduced Schur complement
553    //     S = H_ββ + ridge_β·I − Σ_i H_tβ^(i)ᵀ (H_tt^(i))⁻¹ H_tβ^(i)
554    // accumulates the (H_tt^(i))⁻¹ contributions of every row in finite
555    // precision. With many weak-but-admissible rows those terms can sum to a
556    // Schur matrix whose Cholesky succeeds yet whose condition number is far
557    // past the safe inversion regime, so `cholesky_solve_vector` yields an
558    // inaccurate Δβ that is silently propagated to the Newton step. Apply the
559    // same diagonal-ratio κ proxy used per-row to the reduced factor and treat
560    // an over-threshold estimate as a Schur-stability failure: `SchurFactorFailed`
561    // is already recoverable in `solve_with_lm_escalation_inner`, so this lifts
562    // `ridge_beta` and re-forms a better-conditioned Schur. This guard is
563    // exclusive to the dense Direct / SqrtBA path (the only caller of this
564    // function); the inexact-PCG path tolerates higher κ(S) and is unaffected.
565    // Evidence/log-det-only callers (`tolerate_ill_conditioning`) skip this
566    // rejection: the factor is genuinely PD (Cholesky above succeeded), so its
567    // diagonal still yields an exact `log|S|`, and an inaccurate Δβ is harmless
568    // because the step is discarded.
569    if !options.tolerate_ill_conditioning {
570        let schur_kappa = cholesky_factor_kappa_estimate(&factor);
571        if !schur_kappa.is_finite() || schur_kappa > safe_spd_kappa_max(schur.nrows()) {
572            // #1026 — over-complete SAE dictionaries park surplus atoms dead
573            // (β_k → 0), so the reduced Schur is PD (the Cholesky above succeeded)
574            // but ILL-CONDITIONED: the dead decoder subspace carries near-zero
575            // eigenvalues while the live subspace is healthy. The kappa gate's
576            // concern is an inaccurate Δβ from accumulated (H_tt)⁻¹ contamination —
577            // but on the dead subspace the correct Δβ IS ≈0 (those atoms have no
578            // signal), so the only "inaccuracy" is in directions whose true step is
579            // zero. When the spectral PD-floor is enabled (the SAE solve path),
580            // clamp exactly those collapsed directions up to `floor·max(λ)` and
581            // solve against the floored Schur: the live subspace keeps its EXACT
582            // Newton component, the dead subspace is damped to ≈0, and κ is bounded
583            // so Δβ is accurate where it matters. This is the same conditioning the
584            // non-PD branch above applies; here it also covers the PD-but-ill-
585            // conditioned case so the LM loop does not exhaust `ridge_β` trying to
586            // (futilely) lift a fundamentally rank-deficient dead-atom subspace.
587            // Without the floor (BA / non-SAE callers) the strict refusal stands.
588            if let Some(relative_floor) = options.schur_pd_floor
589                && let Some(floored) = spectral_pd_floored_schur(schur, relative_floor)
590                && let Ok(floored_factor) = cholesky_lower(&floored)
591            {
592                let direct =
593                    mixed_precision_reduced_beta(&floored, &floored_factor, rhs_beta, options)
594                        .unwrap_or_else(|| cholesky_solve_vector(&floored_factor, rhs_beta));
595                if step_inside_trust_region(
596                    direct.view(),
597                    options.trust_region.radius,
598                    metric_weights,
599                ) {
600                    return Ok((direct, Some(floored_factor), PcgDiagnostics::default()));
601                }
602                let identity = IdentityPreconditioner;
603                let (delta, diag) = steihaug_dense_system(
604                    &floored,
605                    rhs_beta,
606                    &identity,
607                    &ArrowPcgOptions {
608                        max_iterations: options.trust_region.max_iterations,
609                        relative_tolerance: options.trust_region.steihaug_relative_tolerance,
610                    },
611                    &options.trust_region,
612                    metric_weights,
613                )?;
614                return Ok((delta, Some(floored_factor), diag));
615            }
616            return Err(ArrowSchurError::SchurFactorFailed {
617                reason: format!(
618                    "reduced Schur complement Cholesky succeeded but is ill-conditioned \
619                     (kappa_estimate={schur_kappa:e}); accumulated per-row \
620                     (H_tt)⁻¹ contamination would yield an inaccurate Δβ"
621                ),
622            });
623        }
624    }
625    // Reduced-system solve. The f64 `factor` is always retained and returned —
626    // its diagonal is the EXACT `log|S|` the evidence path reads, so the logdet
627    // stays f64 regardless of how Δβ is computed (#1014 invariant). When the
628    // streaming/residency path enabled certified mixed precision, the Δβ solve
629    // itself runs f32-then-f64-refined (κ-gated, with the f64 triangular solve
630    // as the automatic fallback); the certificate is the f64 backward error.
631    let direct = mixed_precision_reduced_beta(schur, &factor, rhs_beta, options)
632        .unwrap_or_else(|| cholesky_solve_vector(&factor, rhs_beta));
633    if step_inside_trust_region(direct.view(), options.trust_region.radius, metric_weights) {
634        return Ok((direct, Some(factor), PcgDiagnostics::default()));
635    }
636
637    // Ceres-style trust-region correction: once the dense BA solve proposes a
638    // step outside the trust ball, Steihaug-CG returns the boundary point
639    // without requiring a second dense factorization.
640    let identity = IdentityPreconditioner;
641    let (delta, diag) = steihaug_dense_system(
642        schur,
643        rhs_beta,
644        &identity,
645        &ArrowPcgOptions {
646            max_iterations: options.trust_region.max_iterations,
647            relative_tolerance: options.trust_region.steihaug_relative_tolerance,
648        },
649        &options.trust_region,
650        metric_weights,
651    )?;
652    Ok((delta, Some(factor), diag))
653}
654
655/// Solve an externally accumulated dense reduced β system
656/// `S Δβ = rhs_β` with the same LM-style ridge escalation the full-batch
657/// driver applies: on a `SchurFactorFailed` (non-PD or ill-conditioned `S`),
658/// geometrically grow a proximal ridge on `S`'s diagonal and retry.
659///
660/// Used by the SAE streaming joint fit, which accumulates `S` and `rhs_β` over
661/// re-materialized row chunks (via [`StreamingArrowSchur::take_accumulators`])
662/// and must solve the single global reduced system without a per-row
663/// `ArrowSchurSystem`. `S` is symmetrized from its lower triangle before each
664/// factorization. `base_ridge_beta` is folded into the caller's `S` already;
665/// this routine only adds the *escalation* ridge on top.
666pub fn solve_streaming_reduced_beta(
667    s_acc: &Array2<f64>,
668    rhs_beta: &Array1<f64>,
669    options: &ArrowSolveOptions,
670) -> Result<Array1<f64>, ArrowSchurError> {
671    let mut proximal_ridge = 0.0_f64;
672    let mut last_err: Option<ArrowSchurError> = None;
673    for attempt in 0..=DEFAULT_PROXIMAL_MAX_ATTEMPTS {
674        let mut schur = s_acc.clone();
675        symmetrize_upper_from_lower(&mut schur);
676        if proximal_ridge > 0.0 {
677            for j in 0..schur.nrows() {
678                schur[[j, j]] += proximal_ridge;
679            }
680        }
681        // Reduced K-system on device: Jacobi-preconditioned CG over the dense
682        // symmetric `S`. The `O(K²)` `S·p` matvec runs device-side; only the
683        // K-vectors cross the boundary per CG iteration. This is the dominant
684        // cost of the streaming SAE joint fit at `K = 100K`. Any device-side
685        // failure (`Unavailable`, non-PD Jacobi diagonal) falls through to the
686        // CPU `solve_dense_reduced_system`, which then drives the same proximal
687        // ridge escalation. A genuine device PD failure is non-recoverable for
688        // this attempt's `schur`, so we let the CPU path re-confirm and escalate.
689        if gam_gpu::device_runtime::GpuRuntime::is_available() {
690            match crate::gpu_kernels::arrow_schur::solve_reduced_beta_pcg(
691                &schur,
692                rhs_beta,
693                options.trust_region.max_iterations,
694                options.trust_region.steihaug_relative_tolerance,
695            ) {
696                Ok(delta_beta) => return Ok(delta_beta),
697                Err(crate::gpu_kernels::arrow_schur::ArrowSchurGpuFailure::Unavailable) => {}
698                Err(_) => {
699                    // Device declined this `schur` (e.g. non-PD Jacobi diag);
700                    // let the CPU path confirm and escalate the proximal ridge.
701                }
702            }
703        }
704        match solve_dense_reduced_system(&schur, rhs_beta, options, None) {
705            Ok((delta_beta, _factor, _diag)) => return Ok(delta_beta),
706            Err(err) => {
707                let recoverable = matches!(
708                    err,
709                    ArrowSchurError::SchurFactorFailed { .. }
710                        | ArrowSchurError::PcgFailed { .. }
711                        | ArrowSchurError::UnboundedNegativeCurvature { .. }
712                );
713                last_err = Some(err);
714                if !recoverable || attempt == DEFAULT_PROXIMAL_MAX_ATTEMPTS {
715                    break;
716                }
717                proximal_ridge = if proximal_ridge == 0.0 {
718                    DEFAULT_PROXIMAL_INITIAL_RIDGE
719                } else {
720                    proximal_ridge * DEFAULT_PROXIMAL_RIDGE_GROWTH
721                };
722            }
723        }
724    }
725    Err(last_err.expect("escalation loop set last_err on failure"))
726}
727
728pub(crate) fn step_inside_trust_region(
729    step: ArrayView1<'_, f64>,
730    radius: f64,
731    metric_weights: Option<&MetricWeights>,
732) -> bool {
733    !radius.is_finite() || metric_norm(step, metric_weights) <= radius
734}
735
736/// Below this row count the per-row Schur loop stays sequential: the rayon
737/// fan-out (chunk dispatch + the deterministic per-chunk length-`K` reduction)
738/// costs more than it saves for the handful-of-rows arrow systems that dominate
739/// the non-SAE callers. Above it — the SAE LLM shape (`n` in the thousands,
740/// wide border `k`) that issue #1017 names — the per-row `H_βt (H_tt)⁻¹ H_tβ x`
741/// contributions are the matvec's whole cost and parallelize cleanly.
742pub(crate) const SCHUR_MATVEC_PARALLEL_ROW_MIN: usize = 256;
743
744/// Below this border width `k` the dense `H_ββ` penalty-prologue GEMV stays
745/// sequential: parallelizing a `k×k` matvec only pays once `k²` is large enough
746/// to dwarf the rayon fan-out, which for the arrow callers with narrow borders
747/// it never is. At the SAE LLM border (`k` in the low thousands) the `O(k²)`
748/// prologue is ≈4M flops/CG-iteration and was the serial Amdahl ceiling on the
749/// otherwise per-row-parallel matvec (#1017), so it crosses this threshold and
750/// fans out. 512 keeps the prologue serial for every non-SAE arrow system while
751/// engaging it for the wide SAE/Qwen borders the issue targets.
752pub(crate) const SCHUR_PROLOGUE_PARALLEL_K_MIN: usize = 512;
753
754/// Device-residency CPU analogue for the SAE reduced-Schur matvec (#1017).
755///
756/// In the production SAE joint fit the per-row cross-block factors as
757/// `H_tβ^(i) = L_i P_i`, where `L_i` (`q_i × p`) is the row's local
758/// assignment/coordinate Jacobian and `P_i` (`p × K`, sparse) gathers the
759/// active atoms' decoder blocks (`P_i x = Σ_s φ_s · x[base_s .. base_s+p]`).
760/// The reduced-Schur point-elimination contribution of one row is therefore
761///
762/// ```text
763/// S_i x = H_βt^(i) (H_tt^(i)+ρ_t I)⁻¹ H_tβ^(i) x
764///       = P_iᵀ · [ L_iᵀ (H_tt^(i)+ρ_t I)⁻¹ L_i ] · P_i x
765///       = P_iᵀ G_i (P_i x),      G_i := L_iᵀ (H_tt^(i)+ρ_t I)⁻¹ L_i   (p×p).
766/// ```
767///
768/// The block `G_i = L_iᵀ Y_i` depends only on the assembled per-row blocks and
769/// the (already-computed, solve-stable) `H_tt` factor — NOT on the CG iterate
770/// `x`. The generic [`schur_matvec`] re-walks `apply_jbeta → apply_l →
771/// solve(d×d) → apply_l_t → scatter` on every CG iteration; this object **stages
772/// the factors `(L_i, Y_i)` once per CG solve** (the "upload X once" residency
773/// mechanism, applied on CPU to the matvec rather than a dense factorization),
774/// turning each subsequent matvec into a sparse gather → two `di×p` GEMVs →
775/// sparse scatter, with no per-iteration triangular solve and no operator-closure
776/// re-walk. It never materialises the dense `p×p` product: `di ≪ p` for SAE
777/// rows, so the factored apply is `2·support_i·p + 2·di·p` flops/row — the two
778/// `di·p` GEMVs PLUS the `support_i·p` sparse gather (`P_i x`) and `support_i·p`
779/// sparse scatter (`P_iᵀ prod`) — versus the dense `p²` block apply, and
780/// `O(n·di·p)` memory (vs `O(n·p²)` ≈ 67 GB at the Qwen shape — the dense form
781/// is OOM). For dense/full active support `support_i` can scale with the active
782/// β-columns, so the gather/scatter term is NOT negligible and is counted here.
783///
784/// Numerically identical to the generic path up to floating-point reassociation
785/// (it differentiates and accumulates the SAME quotient). It is deterministic
786/// run-to-run and within the reassociation margin of the serial path, so the
787/// criterion ranking across topology candidates is stable except for candidates
788/// separated by less than that f64 margin, where reassociation can flip the
789/// near-tie winner — it is NOT an exact no-move guarantee (#1211).
790pub(crate) struct SaeResidentReducedSchur {
791    /// Decoder output dimension `p` (the side length of every `G_i = L_iᵀ Y_i`).
792    pub(crate) p: usize,
793    /// Per-row **factored** residency: `(L_i, Y_i)`, each stored row-major as a
794    /// `di × p` slab (`L_i` = local Jacobian, `Y_i = (H_tt^(i)+ρ_t I)⁻¹ L_i`).
795    /// The reduced block is `G_i = L_iᵀ Y_i` (`p×p`, symmetric PSD), but it has
796    /// rank ≤ `di` and `di ≪ p` for SAE rows (the per-row latent dim is 1–2
797    /// while `p` is the decoder block width, ~2048). Materialising the dense
798    /// `p×p` block would cost `O(n·p²)` memory (≈67 GB at the Qwen shape) and
799    /// `p²` flops per matvec/row; the factored form costs `O(n·di·p)` memory and
800    /// `2·support_i·p + 2·di·p` flops/row, applying `G_i v = L_iᵀ (Y_i v)`
801    /// (sparse gather over `support_i` atoms → `di`-length GEMV → `p`-length
802    /// GEMV → sparse scatter over `support_i` atoms). The `2·support_i·p`
803    /// gather/scatter term is part of the per-row cost — for dense/full support
804    /// `support_i` scales with active β-columns — and is not dropped. A row with
805    /// empty active support / degenerate dims gets `di = 0` and is skipped.
806    /// `(di, L_i, Y_i)` per row; `L_i`/`Y_i` are `di·p`-length row-major buffers.
807    pub(crate) rows: Vec<ResidentRowFactor>,
808    /// Per-row active atom support `(β-block base index, φ weight)`, shared with
809    /// the assembler's [`DeviceSaePcgData`] (no re-clone of the index lists).
810    pub(crate) a_phi: Arc<[Vec<(usize, f64)>]>,
811    /// #1033: per-row local Jacobian `L_i` (row-major `di × p`), SHARED via `Arc`
812    /// with the assembler's [`DeviceSaePcgData`] rather than copied into each
813    /// `ResidentRowFactor`. The staged factor previously held its own verbatim
814    /// row-major copy of `data.local_jac[row]` — a second full `O(n·di·p)` slab
815    /// for zero benefit (the bytes and the `di × p` layout are identical). The
816    /// matvec now reads `L_i = &self.local_jac[row]` directly; only the SOLVED
817    /// factor `Y_i = (H_tt+ρI)⁻¹ L_i` (genuinely new data) stays per-row. Reads
818    /// are byte-for-byte the former `rf.l` (same slab, same `r·p + c` indexing),
819    /// so the matvec/preconditioner output is bit-identical.
820    pub(crate) local_jac: Arc<[Vec<f64>]>,
821}
822
823/// Factored per-row residency block: `G_i = L_iᵀ Y_i` kept as its `di×p` factors
824/// so the matvec never materialises the dense `p×p` product. The local Jacobian
825/// factor `L_i` is NOT stored here — it is shared via
826/// [`SaeResidentReducedSchur::local_jac`] (`&local_jac[row]`); only the solved
827/// `Y_i` is per-row. See [`SaeResidentReducedSchur`].
828pub(crate) struct ResidentRowFactor {
829    /// Row latent dimension `di` (the inner contraction width). `0` ⇒ skipped.
830    pub(crate) di: usize,
831    /// `Y_i = (H_tt^(i)+ρ_t I)⁻¹ L_i` row-major `di × p`. Empty when `di == 0`.
832    pub(crate) y: Vec<f64>,
833}
834
835impl SaeResidentReducedSchur {
836    /// Stage the per-row `G_i = L_iᵀ (H_tt^(i)+ρ_t I)⁻¹ L_i` blocks once, from
837    /// the SAE structure (`DeviceSaePcgData`: `p`, per-row `a_phi`, per-row
838    /// row-major `local_jac` = `L_i`) and the already-factored `H_tt` slab.
839    ///
840    /// Returns `None` when the structure does not match (degenerate `p`, row
841    /// count mismatch) so the caller falls back to the generic matvec. Row
842    /// builds are independent and run under the same deterministic rayon
843    /// discipline as the matvec (each `G_i` is self-contained — no cross-row
844    /// reduction — so there is no ordering subtlety).
845    /// `ridge_t` is NOT a parameter: it is already folded into the factored
846    /// blocks `htt_factors` carry (they factor `H_tt^(i) + ridge_t·I` — see
847    /// `factor_blocks`), so solving against the factor yields `(H_tt^(i)+ρ_t I)⁻¹`
848    /// exactly. The residency block is a pure function of the factor and `L_i`.
849    pub(crate) fn build<B: BatchedBlockSolver + Sync>(
850        sys: &ArrowSchurSystem,
851        htt_factors: &ArrowFactorSlab,
852        backend: &B,
853    ) -> Option<Self> {
854        let data = sys.device_sae_pcg.as_ref()?;
855        let p = data.p;
856        let n = sys.rows.len();
857        if p == 0
858            || sys.htbeta_dense_supplement
859            || data.a_phi.len() != n
860            || data.local_jac.len() != n
861        {
862            return None;
863        }
864        let empty = || ResidentRowFactor {
865            di: 0,
866            y: Vec::new(),
867        };
868        let build_row = |row: usize| -> ResidentRowFactor {
869            let di = sys.row_dims[row];
870            let jac = &data.local_jac[row];
871            // q_i = len/p; must match the row's latent dimension di.
872            if p == 0 || jac.len() != di * p || di == 0 {
873                return empty();
874            }
875            // L_i as a (di × p) matrix (row-major in `local_jac`).
876            let l_i = match ArrayView2::from_shape((di, p), jac.as_slice()) {
877                Ok(v) => v.to_owned(),
878                Err(_) => return empty(),
879            };
880            // Solve (H_tt+ρ_t I) Y = L_i for Y (di × p): one batched back-solve
881            // over the p columns against the cached factor. Stage `(L_i, Y_i)`
882            // — NOT the dense `p×p` product `G_i = L_iᵀ Y_i` — so storage and the
883            // matvec stay `O(di·p)` instead of `O(p²)` (`di ≪ p` for SAE rows).
884            let y = backend.solve_block_matrix(htt_factors.factor(row), l_i.view());
885            // Flatten the SOLVED factor to a `di × p` row-major buffer (iteration
886            // over a standard-layout view is row-major regardless of the source
887            // strides, so the hot loop can index `r*p + c` directly). `L_i` is NOT
888            // copied — the matvec reads it from the shared `local_jac` slab (it is
889            // byte-for-byte `data.local_jac[row]`).
890            let y_flat: Vec<f64> = y.iter().copied().collect();
891            ResidentRowFactor { di, y: y_flat }
892        };
893        let rows: Vec<ResidentRowFactor> =
894            if n >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
895                use rayon::prelude::*;
896                (0..n).into_par_iter().map(build_row).collect()
897            } else {
898                (0..n).map(build_row).collect()
899            };
900        Some(Self {
901            p,
902            rows,
903            a_phi: data.a_phi_shared(),
904            local_jac: data.local_jac_shared(),
905        })
906    }
907
908    /// Accumulate one row's `S_i x = P_iᵀ G_i (P_i x) = P_iᵀ L_iᵀ Y_i (P_i x)`
909    /// into `acc` (length `K`). `gather`/`prod` are caller-owned length-`p`
910    /// buffers and `w` a caller-owned `≥ max_i di`-length buffer, all reused
911    /// across rows to keep the hot loop allocation-free. The matvec applies the
912    /// factored block in four steps: sparse gather `P_i x = Σ_s φ_s·x[base_s..]`
913    /// (`support_i·p` flops), `w = Y_i·(P_i x)` (`di`-length, `di·p` flops),
914    /// `prod = L_iᵀ·w` (`p`-length, `di·p` flops), and sparse scatter
915    /// `acc += P_iᵀ prod` (`support_i·p` flops) — `2·support_i·p + 2·di·p`
916    /// total, never the dense `p²` product. The gather/scatter `2·support_i·p`
917    /// term is counted: it is not dominated by the GEMVs when the active support
918    /// is wide.
919    #[inline]
920    pub(crate) fn row_into(
921        &self,
922        row: usize,
923        x: &Array1<f64>,
924        acc: &mut Array1<f64>,
925        gather: &mut [f64],
926        prod: &mut [f64],
927        w: &mut [f64],
928    ) {
929        let rf = &self.rows[row];
930        let di = rf.di;
931        if di == 0 {
932            return;
933        }
934        let p = self.p;
935        let support = &self.a_phi[row];
936        if support.is_empty() {
937            return;
938        }
939        // P_i x = Σ_s φ_s · x[base_s .. base_s+p]   (length p).
940        for v in gather.iter_mut() {
941            *v = 0.0;
942        }
943        for &(base, phi) in support {
944            if phi == 0.0 {
945                continue;
946            }
947            for j in 0..p {
948                gather[j] += phi * x[base + j];
949            }
950        }
951        // w = Y_i · (P_i x)   (di × p GEMV → length di).  Y_i row-major di×p.
952        for r in 0..di {
953            let yrow = &rf.y[r * p..r * p + p];
954            let mut s = 0.0_f64;
955            for c in 0..p {
956                s += yrow[c] * gather[c];
957            }
958            w[r] = s;
959        }
960        // prod = L_iᵀ · w   (p × di GEMV → length p).  L_i row-major di×p, so
961        // L_iᵀ[j,r] = L_i[r,j]; accumulate column-by-column over the di rows.
962        // `L_i` is the shared `local_jac[row]` slab (#1033) — byte-for-byte the
963        // former per-row `rf.l` copy.
964        let l_i = &self.local_jac[row];
965        for v in prod.iter_mut().take(p) {
966            *v = 0.0;
967        }
968        for r in 0..di {
969            let lrow = &l_i[r * p..r * p + p];
970            let wr = w[r];
971            for j in 0..p {
972                prod[j] += lrow[j] * wr;
973            }
974        }
975        // acc += P_iᵀ prod = scatter φ_s · prod into base_s blocks.
976        for &(base, phi) in support {
977            if phi == 0.0 {
978                continue;
979            }
980            for j in 0..p {
981                acc[base + j] += phi * prod[j];
982            }
983        }
984    }
985
986    /// Max row latent dim `di` across resident rows — the size of the `w`
987    /// scratch the matvec needs for the inner `Y_i·(P_i x)` GEMV.
988    pub(crate) fn max_di(&self) -> usize {
989        self.rows.iter().map(|r| r.di).max().unwrap_or(0)
990    }
991}
992
993/// Reduced-Schur matvec `out = S·x` with an optional pre-staged SAE residency
994/// operator. When `resident` is `Some`, the per-row point-elimination term is
995/// applied through the resident `p×p` blocks (#1017 CPU residency); otherwise it
996/// falls back to the generic per-row `apply → solve → transpose` path. Both
997/// routes accumulate the SAME reduced operator
998/// `S = H_ββ + ρ_β I − Σ_i H_βt^(i)(H_tt^(i))⁻¹H_tβ^(i)`.
999pub(crate) fn schur_matvec<B: BatchedBlockSolver + Sync>(
1000    sys: &ArrowSchurSystem,
1001    htt_factors: &ArrowFactorSlab,
1002    ridge_beta: f64,
1003    x: &Array1<f64>,
1004    out: &mut Array1<f64>,
1005    backend: &B,
1006    resident: Option<&SaeResidentReducedSchur>,
1007) {
1008    // `steihaug_cg` reuses one output buffer across iterations and requires
1009    // `matvec` to ASSIGN every entry of `out` (the contract `dense_matvec`
1010    // upholds). This routine builds `S·x` purely by accumulation
1011    // (`penalty_matvec_add`, `out[a] += ridge·x`, `out[a] -= neg_contrib`), so it
1012    // MUST clear `out` first. Without this, iteration n>0 returns `S·x` plus the
1013    // previous call's `S·p`, the PCG solves a corrupted reduced system, and the
1014    // resulting Newton step is inconsistent with the assembled gradient
1015    // (g·δ ≈ 0 — a non-descent direction that defeats the line search).
1016    out.fill(0.0);
1017    let k = sys.k;
1018    // Top-level (not nested in a rayon worker) and big enough to amortize the
1019    // fan-out: the single gate that authorizes BOTH the dense penalty-prologue
1020    // GEMV and the per-row point-elimination loop to go parallel. The topology
1021    // race fans candidates with `run_topology_race_parallel`, so inside a worker
1022    // both stay sequential (no nested-rayon oversubscription).
1023    let parallel =
1024        sys.rows.len() >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
1025    // Route the penalty-side (H_ββ + ridge·I) x product through the prologue:
1026    // no Arc-clone hot-path cost when penalty_op is None (falls back to hbb
1027    // inline); the dense fallback fans across cores at the wide SAE border (#1017).
1028    {
1029        let x_slice = x.as_slice().expect("x must be contiguous");
1030        let out_slice = out.as_slice_mut().expect("out must be contiguous");
1031        sys.penalty_ridge_prologue_into(x_slice, ridge_beta, out_slice, parallel);
1032    }
1033    // The reduced-Schur point-elimination term: `out -= Σ_i H_βt^(i) (H_tt^(i))⁻¹
1034    // H_tβ^(i) x`. Each row contributes an independent length-`K` vector, so for
1035    // the SAE LLM shape (#1017) this is the matvec's whole cost and is
1036    // embarrassingly parallel. Run it under rayon over fixed row chunks, summing
1037    // the per-chunk partials in chunk order so the f64 reduction is bit-identical
1038    // run-to-run regardless of thread scheduling (the #1017 verification gate).
1039    // This is deterministic and within the chunk-reassociation margin of serial,
1040    // so the criterion ranking is stable except for candidates that tie inside
1041    // that f64 margin — not an exact no-move guarantee (#1211). Stay
1042    // sequential when already inside a rayon worker (the topology race fans
1043    // candidates with `run_topology_race_parallel`) to avoid nested-rayon
1044    // oversubscription — the same guard `HyperOperator::mul_mat` uses. The
1045    // `parallel` gate above authorizes this loop too.
1046    let p = resident.map(|r| r.p).unwrap_or(0);
1047    if parallel {
1048        use rayon::prelude::*;
1049        const CHUNK: usize = 64;
1050        let n = sys.rows.len();
1051        let partials: Vec<Array1<f64>> = (0..n)
1052            .into_par_iter()
1053            .chunks(CHUNK)
1054            .map(|idxs| {
1055                let mut acc = Array1::<f64>::zeros(k);
1056                if let Some(res) = resident {
1057                    // Resident path: each matvec is gather → factored di×p GEMVs
1058                    // → scatter, reading only the pre-staged `(L_i, Y_i)` (no
1059                    // per-iteration solve, no dense p×p block).
1060                    let mut gather = vec![0.0_f64; p];
1061                    let mut prod = vec![0.0_f64; p];
1062                    let mut w = vec![0.0_f64; res.max_di()];
1063                    for i in idxs {
1064                        res.row_into(i, x, &mut acc, &mut gather, &mut prod, &mut w);
1065                    }
1066                } else {
1067                    let mut local = Array1::<f64>::zeros(sys.d);
1068                    for i in idxs {
1069                        schur_matvec_row_into(
1070                            sys,
1071                            htt_factors,
1072                            x,
1073                            backend,
1074                            i,
1075                            &mut local,
1076                            &mut acc,
1077                        );
1078                    }
1079                }
1080                acc
1081            })
1082            .collect();
1083        // Deterministic ordered reduction: fold chunk partials left-to-right.
1084        for acc in &partials {
1085            for a in 0..k {
1086                out[a] -= acc[a];
1087            }
1088        }
1089    } else if let Some(res) = resident {
1090        let mut acc = Array1::<f64>::zeros(k);
1091        let mut gather = vec![0.0_f64; p];
1092        let mut prod = vec![0.0_f64; p];
1093        let mut w = vec![0.0_f64; res.max_di()];
1094        for i in 0..sys.rows.len() {
1095            res.row_into(i, x, &mut acc, &mut gather, &mut prod, &mut w);
1096        }
1097        for a in 0..k {
1098            out[a] -= acc[a];
1099        }
1100    } else {
1101        // Allocate scratch at max_d; per-row slice is `..di`.
1102        let mut local = Array1::<f64>::zeros(sys.d);
1103        let mut neg_contrib = Array1::<f64>::zeros(k);
1104        for i in 0..sys.rows.len() {
1105            neg_contrib.fill(0.0);
1106            schur_matvec_row_into(
1107                sys,
1108                htt_factors,
1109                x,
1110                backend,
1111                i,
1112                &mut local,
1113                &mut neg_contrib,
1114            );
1115            for a in 0..k {
1116                out[a] -= neg_contrib[a];
1117            }
1118        }
1119    }
1120}
1121
1122/// Accumulate one row's reduced-Schur point-elimination contribution
1123/// `H_βt^(i) (H_tt^(i))⁻¹ H_tβ^(i) x` (length `K`) into `acc`.
1124///
1125/// `local` is caller-owned `≥ sys.d`-length scratch (reused across rows to keep
1126/// the hot loop allocation-free); only `..di` is touched. `acc` is **added to**,
1127/// never cleared, so the caller controls whether contributions sum into a chunk
1128/// partial (parallel path) or a per-row buffer (sequential path).
1129#[inline]
1130pub(crate) fn schur_matvec_row_into<B: BatchedBlockSolver>(
1131    sys: &ArrowSchurSystem,
1132    htt_factors: &ArrowFactorSlab,
1133    x: &Array1<f64>,
1134    backend: &B,
1135    i: usize,
1136    local: &mut Array1<f64>,
1137    acc: &mut Array1<f64>,
1138) {
1139    let row = &sys.rows[i];
1140    let di = sys.row_dims[i];
1141    // H_tβ^(i) · x → local[..di], routed through sys.htbeta_matvec
1142    // when the dense block is absent.
1143    let mut local_i = local.slice_mut(ndarray::s![..di]).to_owned();
1144    local_i.fill(0.0);
1145    sys_htbeta_apply_row(sys, i, row, x.view(), &mut local_i);
1146    let solved = backend.solve_block_vector(htt_factors.factor(i), local_i.view());
1147    // H_βt^(i) · solved accumulates into acc (length k).  Routed through
1148    // sys.htbeta_matvec when needed.
1149    sys_htbeta_accumulate_transpose(sys, i, row, solved.view(), acc);
1150}
1151
1152/// One per-term block factor for the block-Jacobi Schur preconditioner.
1153///
1154/// Carries either a dense Cholesky factor (for PD blocks ≤ 256 columns) or
1155/// the scalar inverses for that block's diagonal as a fallback.
1156#[derive(Clone)]
1157pub(crate) enum BlockFactor {
1158    /// Cholesky L stored column-major via faer. `range` identifies the
1159    /// columns in the full K-vector this block covers.
1160    Chol {
1161        factor: FaerLlt<f64>,
1162        range: Range<usize>,
1163    },
1164    /// Scalar fallback: per-element `1/s_aa` for each column in `range`.
1165    Scalar {
1166        inv: Array1<f64>,
1167        range: Range<usize>,
1168    },
1169}
1170
1171impl std::fmt::Debug for BlockFactor {
1172    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1173        match self {
1174            BlockFactor::Chol { range, .. } => {
1175                write!(f, "BlockFactor::Chol {{ range: {:?} }}", range)
1176            }
1177            BlockFactor::Scalar { inv, range } => {
1178                write!(
1179                    f,
1180                    "BlockFactor::Scalar {{ inv.len: {}, range: {:?} }}",
1181                    inv.len(),
1182                    range
1183                )
1184            }
1185        }
1186    }
1187}
1188
1189/// Block-Jacobi Schur preconditioner for BA's inexact reduced-system PCG.
1190///
1191/// When [`ArrowSchurSystem::block_offsets`] is populated (via
1192/// [`ArrowSchurSystem::set_block_offsets`]) and the largest block has ≤ 256
1193/// columns, builds one small dense Schur block per term, factors it with
1194/// Cholesky (faer LLT), and applies the preconditioner as per-block
1195/// triangular solves.  Non-PD blocks fall back to scalar diagonal inversion
1196/// for that block only.  When `block_offsets` is empty or the largest block
1197/// exceeds 256 columns the preconditioner reduces to pure scalar-diagonal
1198/// Jacobi (pre-#283 behaviour), so callers that have not called
1199/// `set_block_offsets` are unaffected.
1200///
1201/// The `block_offsets` plumbing is compatible with issue #287 (custom
1202/// `ParameterBlockSpec` families): those callers supply ranges derived from
1203/// their own block layout.
1204#[derive(Debug, Clone)]
1205pub struct JacobiPreconditioner {
1206    pub(crate) blocks: Vec<BlockFactor>,
1207}
1208
1209/// Maximum block size for which we attempt dense block-Jacobi factorization.
1210pub(crate) const BLOCK_JACOBI_MAX_BLOCK: usize = 256;
1211
1212/// Positive-definiteness floor on a Schur-complement Jacobi diagonal entry.
1213/// A diagonal at or below this value (or non-finite) signals a non-PD reduced
1214/// system: the preconditioner cannot invert it, so the PCG solve fails loudly
1215/// and demands operator regularization rather than returning a garbage scale.
1216pub(crate) const JACOBI_DIAGONAL_PD_FLOOR: f64 = 1e-18;
1217
1218impl JacobiPreconditioner {
1219    /// Build the block-Jacobi (or scalar fallback) preconditioner from the
1220    /// Arrow-Schur system without materializing the full dense Schur
1221    /// complement.
1222    ///
1223    /// When `sys.block_offsets` is non-empty and `max(block_size) ≤ 256`,
1224    /// each block gets a dense `b×b` Schur sub-matrix formed, factored, and
1225    /// stored.  Otherwise every column gets its own scalar entry.
1226    pub(crate) fn from_arrow_schur<B: BatchedBlockSolver + Sync>(
1227        sys: &ArrowSchurSystem,
1228        htt_factors: &ArrowFactorSlab,
1229        ridge_beta: f64,
1230        backend: &B,
1231        resident: Option<&SaeResidentReducedSchur>,
1232    ) -> Result<Self, ArrowSchurError> {
1233        let use_block = !sys.block_offsets.is_empty()
1234            && sys
1235                .block_offsets
1236                .iter()
1237                .map(|r| r.end.saturating_sub(r.start))
1238                .max()
1239                .unwrap_or(0)
1240                <= BLOCK_JACOBI_MAX_BLOCK;
1241        if use_block {
1242            if let Some(res) = resident {
1243                Self::build_block_jacobi_resident(sys, ridge_beta, res)
1244            } else {
1245                Self::build_block_jacobi(sys, htt_factors, ridge_beta, backend)
1246            }
1247        } else if let Some(res) = resident {
1248            // #1017 — SAE residency scalar Jacobi. The generic scalar build
1249            // probes `H_tβ^(i) e_a` and re-solves `(H_tt^(i))⁻¹` once for EVERY
1250            // (row, β-column) pair: `O(n·K)` triangular solves and `O(n·K·p)`
1251            // operator-probe work per Newton step, with `K = K_atoms·p` in the
1252            // tens of thousands at LLM shapes. The reduced-Schur diagonal is the
1253            // same quotient the resident `(L_i, Y_i)` factors already carry, so
1254            // read the diagonal straight off them in one support-sparse pass —
1255            // no probe, no per-column solve.
1256            Self::build_scalar_jacobi_resident(sys, ridge_beta, res)
1257        } else {
1258            Self::build_scalar_jacobi(sys, htt_factors, ridge_beta, backend)
1259        }
1260    }
1261
1262    /// Build scalar-diagonal Jacobi: one `BlockFactor::Scalar` of length 1
1263    /// per column.  Matches pre-#283 semantics.
1264    ///
1265    /// When `sys.htbeta_matvec` is set and per-row `htbeta` slabs are absent,
1266    /// each column is probed via the matvec (one call per column per row).
1267    pub(crate) fn build_scalar_jacobi<B: BatchedBlockSolver + Sync>(
1268        sys: &ArrowSchurSystem,
1269        htt_factors: &ArrowFactorSlab,
1270        ridge_beta: f64,
1271        backend: &B,
1272    ) -> Result<Self, ArrowSchurError> {
1273        let k = sys.k;
1274        // Extract diagonal of H_ββ via penalty_diagonal_add (#296):
1275        // no Arc-clone; falls back to hbb_diag or hbb[[a,a]] inline.
1276        let mut diag = Array1::<f64>::zeros(k);
1277        {
1278            let diag_slice = diag.as_slice_mut().expect("diag must be contiguous");
1279            sys.penalty_diagonal_add(diag_slice);
1280        }
1281        for a in 0..k {
1282            diag[a] += ridge_beta;
1283        }
1284        // Per-row body: subtract this row's `Σ_a (H_tβ^(i)e_a)ᵀ(H_tt^(i))⁻¹
1285        // (H_tβ^(i)e_a)` contribution into a caller-provided length-`K` diagonal
1286        // accumulator (`-=`). For each column `a`, probe the cross-block (or read
1287        // the dense slab) and compute the scalar point-elimination quotient. The
1288        // `O(K)` solves per row are the build's whole cost; the row contributions
1289        // are independent length-`K` vectors, so a worker sums a chunk into a
1290        // private `diag_part` and the caller folds the partials back in chunk
1291        // order — bit-identical run-to-run (the #1017 preconditioner gate).
1292        let row_into = |i: usize, row: &ArrowRowBlock, diag_part: &mut Array1<f64>| {
1293            let di = sys.row_dims[i];
1294            // Dense-slab fast path (#1017): when the per-row cross-block is a
1295            // materialized `di × k` slab (no matrix-free operator), the entire
1296            // reduced-Schur diagonal contribution for this row is
1297            // `Σ_c H_tβ[c,a] · ((H_tt)⁻¹ H_tβ)[c,a]`. The generic loop below
1298            // re-solved `(H_tt)⁻¹` once PER COLUMN — `O(k)` block solves + `O(k)`
1299            // allocations per row, i.e. `O(n·k)` tiny solves per Newton step
1300            // (the dominant fixed per-solve cost at the SAE wide-border shape,
1301            // k in the tens of thousands). Solve all `k` columns in ONE batched
1302            // block solve instead, then take the column dots. Reassociates the
1303            // diagonal within the documented #1211 preconditioner margin (same as
1304            // the resident no-probe path), and the preconditioner only steers the
1305            // PCG iterate, which still terminates at the PCG tolerance.
1306            if sys.htbeta_matvec.is_none() && row.htbeta.dim() == (di, k) {
1307                let solved = backend.solve_block_matrix(htt_factors.factor(i), row.htbeta.view());
1308                for a in 0..k {
1309                    let mut acc = 0.0;
1310                    for c in 0..di {
1311                        acc += row.htbeta[[c, a]] * solved[[c, a]];
1312                    }
1313                    diag_part[a] -= acc;
1314                }
1315                return;
1316            }
1317            // Matrix-free path: probe column a. `e_a` stays all-zero between
1318            // columns — set the single active entry and reset it after the probe,
1319            // so we never pay the `O(k)` `e_a.fill(0.0)` per column (that fill was
1320            // `O(n·k²)`). `sys_htbeta_apply_row` zeroes `col_i` internally.
1321            let mut col_i = Array1::<f64>::zeros(di);
1322            let mut e_a = Array1::<f64>::zeros(k);
1323            for a in 0..k {
1324                e_a[a] = 1.0;
1325                sys_htbeta_apply_row(sys, i, row, e_a.view(), &mut col_i);
1326                e_a[a] = 0.0;
1327                let solved = backend.solve_block_vector(htt_factors.factor(i), col_i.view());
1328                let mut acc = 0.0;
1329                for c in 0..di {
1330                    acc += col_i[c] * solved[c];
1331                }
1332                diag_part[a] -= acc;
1333            }
1334        };
1335        let n = sys.rows.len();
1336        let parallel =
1337            n >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
1338        if parallel {
1339            use rayon::prelude::*;
1340            const CHUNK: usize = 64;
1341            let partials: Vec<Array1<f64>> = (0..n)
1342                .into_par_iter()
1343                .chunks(CHUNK)
1344                .map(|idxs| {
1345                    let mut diag_part = Array1::<f64>::zeros(k);
1346                    for i in idxs {
1347                        row_into(i, &sys.rows[i], &mut diag_part);
1348                    }
1349                    diag_part
1350                })
1351                .collect();
1352            // Deterministic ordered reduction: fold chunk partials left-to-right.
1353            for part in &partials {
1354                for a in 0..k {
1355                    diag[a] += part[a];
1356                }
1357            }
1358        } else {
1359            for (i, row) in sys.rows.iter().enumerate() {
1360                row_into(i, row, &mut diag);
1361            }
1362        }
1363        let mut blocks = Vec::with_capacity(k);
1364        for a in 0..k {
1365            let v = diag[a];
1366            if !v.is_finite() || v <= JACOBI_DIAGONAL_PD_FLOOR {
1367                return Err(ArrowSchurError::PcgFailed {
1368                    reason: format!(
1369                        "invalid Schur Jacobi diagonal at index {a}: {v}; \
1370                         operator regularization is required"
1371                    ),
1372                });
1373            }
1374            blocks.push(BlockFactor::Scalar {
1375                inv: Array1::from_elem(1, 1.0 / v),
1376                range: a..a + 1,
1377            });
1378        }
1379        Ok(Self { blocks })
1380    }
1381
1382    /// Build scalar-diagonal Jacobi from the pre-staged SAE residency factors
1383    /// `(L_i, Y_i)` (#1017).
1384    ///
1385    /// The generic [`Self::build_scalar_jacobi`] forms each reduced-Schur
1386    /// diagonal entry `S_aa = H_ββ,aa + ρ − Σ_i (H_tβ^(i) e_a)ᵀ(H_tt^(i))⁻¹(H_tβ^(i) e_a)`
1387    /// by probing the cross-block operator with the unit vector `e_a` and
1388    /// re-solving `(H_tt^(i))⁻¹` for every `(row, column)` pair — `O(n·K)`
1389    /// triangular solves per Newton step. For the SAE Kronecker cross-block the
1390    /// `a`-th column lives on exactly one active support entry: `a = beta_base + j`
1391    /// for some `(beta_base, φ) ∈ a_phi[i]` and output channel `j ∈ 0..p`, with
1392    /// `H_tβ^(i) e_a = φ · L_i[:, j]`. The point-elimination quotient is then
1393    ///
1394    /// ```text
1395    /// (H_tβ^(i) e_a)ᵀ (H_tt^(i))⁻¹ (H_tβ^(i) e_a)
1396    ///     = φ² · L_i[:, j]ᵀ (H_tt^(i))⁻¹ L_i[:, j]
1397    ///     = φ² · (L_i[:, j] · Y_i[:, j]),          Y_i := (H_tt^(i))⁻¹ L_i.
1398    /// ```
1399    ///
1400    /// so the whole diagonal is accumulated in ONE support-sparse pass over the
1401    /// resident factors — no probe, no per-column solve, the staged `Y_i` reused
1402    /// from the matvec residency. The result is the SAME quotient the generic
1403    /// path computes (up to float reassociation of the row sum), so the PCG
1404    /// preconditioner is unchanged up to that f64 margin. Since the preconditioner
1405    /// only steers the iterate (which still terminates at the PCG tolerance), the
1406    /// criterion ranking is stable except for candidates within that margin,
1407    /// where the near-tie winner can flip — not an exact no-move guarantee (#1211).
1408    pub(crate) fn build_scalar_jacobi_resident(
1409        sys: &ArrowSchurSystem,
1410        ridge_beta: f64,
1411        resident: &SaeResidentReducedSchur,
1412    ) -> Result<Self, ArrowSchurError> {
1413        let k = sys.k;
1414        let p = resident.p;
1415        let n = resident.rows.len();
1416        // Seed with diag(H_ββ) + ridge — same penalty source the generic path
1417        // reads, so the only difference is how the point-elimination term is
1418        // gathered.
1419        let mut diag = Array1::<f64>::zeros(k);
1420        {
1421            let diag_slice = diag.as_slice_mut().expect("diag must be contiguous");
1422            sys.penalty_diagonal_add(diag_slice);
1423        }
1424        for a in 0..k {
1425            diag[a] += ridge_beta;
1426        }
1427        // Per-row point-elimination diagonal: for each active support entry
1428        // `(beta_base, φ)` and channel `j`, subtract `φ² · L_i[:, j]·Y_i[:, j]`
1429        // into `diag[beta_base + j]`. `L_i`/`Y_i` are row-major `di × p`, so the
1430        // `j`-th column dot is `Σ_r L_i[r·p + j]·Y_i[r·p + j]`.
1431        //
1432        // The accumulation is into a SHARED `diag` (rows scatter into overlapping
1433        // `beta_base + j` columns), so — like the generic `build_scalar_jacobi`
1434        // and the `schur_matvec` row loop (#1017) — parallelism uses worker-private
1435        // length-`K` partials folded back in chunk order: each chunk is a
1436        // contiguous ascending row range and rows within it stay ascending, so the
1437        // chunk-ordered fold reproduces the serial `row = 0..n` subtraction order
1438        // bit-for-bit run-to-run (the #1017 determinism gate). Run-to-run
1439        // bit-identity does not extend to bit-identity with the in-place serial
1440        // accumulation, so the preconditioner — and any criterion ranking it
1441        // steers — is stable only up to the chunk-reassociation margin; a near-tie
1442        // winner inside that margin can flip (#1211).
1443        // This build runs once per inexact-PCG solve = O(inner-Newton-iters)
1444        // per fit; at the SAE LLM shape (thousands of rows, wide border `k`) the
1445        // per-row support sweep is the build's whole cost and was on one core.
1446        // The per-channel column dot `col_dot[j] = Σ_r L_i[r·p+j]·Y_i[r·p+j]`
1447        // (the diagonal of `G_i = L_iᵀ(H_tt)⁻¹L_i`) depends ONLY on the row `i`,
1448        // not on the support entry `(beta_base, φ)`. The previous loop recomputed
1449        // it once per support entry — a row with `m` active atoms paid `m·p`
1450        // column dots over `di`. Hoist it: compute the `p` column dots once per
1451        // row into reusable `col_dot` scratch, then each support entry is a pure
1452        // scatter `diag[beta_base+j] -= φ²·col_dot[j]`. Bit-for-bit identical:
1453        // each `col_dot[j]` is the same `r`-ascending sum, and `φ²·col_dot[j]`
1454        // yields identical bits whether `col_dot[j]` was just computed or cached.
1455        let row_into = |row: usize, diag_part: &mut [f64], col_dot: &mut [f64]| {
1456            let rf = &resident.rows[row];
1457            let di = rf.di;
1458            if di == 0 {
1459                return;
1460            }
1461            let support = &resident.a_phi[row];
1462            if support.is_empty() {
1463                return;
1464            }
1465            // `L_i` is the shared `local_jac[row]` slab (#1033) — byte-for-byte
1466            // the former per-row `rf.l` copy.
1467            let l_i = &resident.local_jac[row];
1468            for (j, slot) in col_dot.iter_mut().enumerate().take(p) {
1469                let mut acc = 0.0_f64;
1470                for r in 0..di {
1471                    let idx = r * p + j;
1472                    acc += l_i[idx] * rf.y[idx];
1473                }
1474                *slot = acc;
1475            }
1476            for &(beta_base, phi) in support {
1477                if phi == 0.0 {
1478                    continue;
1479                }
1480                let phi2 = phi * phi;
1481                for j in 0..p {
1482                    diag_part[beta_base + j] -= phi2 * col_dot[j];
1483                }
1484            }
1485        };
1486        let parallel =
1487            n >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
1488        if parallel {
1489            use rayon::prelude::*;
1490            const CHUNK: usize = 64;
1491            let partials: Vec<Array1<f64>> = (0..n)
1492                .into_par_iter()
1493                .chunks(CHUNK)
1494                .map(|idxs| {
1495                    let mut diag_part = Array1::<f64>::zeros(k);
1496                    let mut col_dot = vec![0.0_f64; p];
1497                    let slice = diag_part
1498                        .as_slice_mut()
1499                        .expect("diag_part must be contiguous");
1500                    for i in idxs {
1501                        row_into(i, slice, &mut col_dot);
1502                    }
1503                    diag_part
1504                })
1505                .collect();
1506            // Deterministic ordered reduction: fold chunk partials left-to-right
1507            // (each partial already holds the per-row terms subtracted, so add
1508            // them into `diag` in chunk order to mirror the serial subtraction).
1509            for part in &partials {
1510                for a in 0..k {
1511                    diag[a] += part[a];
1512                }
1513            }
1514        } else {
1515            let diag_slice = diag.as_slice_mut().expect("diag must be contiguous");
1516            let mut col_dot = vec![0.0_f64; p];
1517            for row in 0..n {
1518                row_into(row, diag_slice, &mut col_dot);
1519            }
1520        }
1521        let mut blocks = Vec::with_capacity(k);
1522        for a in 0..k {
1523            let v = diag[a];
1524            if !v.is_finite() || v <= JACOBI_DIAGONAL_PD_FLOOR {
1525                return Err(ArrowSchurError::PcgFailed {
1526                    reason: format!(
1527                        "invalid SAE-resident Schur Jacobi diagonal at index {a}: {v}; \
1528                         operator regularization is required"
1529                    ),
1530                });
1531            }
1532            blocks.push(BlockFactor::Scalar {
1533                inv: Array1::from_elem(1, 1.0 / v),
1534                range: a..a + 1,
1535            });
1536        }
1537        Ok(Self { blocks })
1538    }
1539
1540    /// Build block-Jacobi from the pre-staged SAE residency factors `(L_i, Y_i)`.
1541    ///
1542    /// This is the block analogue of [`Self::build_scalar_jacobi_resident`].
1543    /// When SAE block offsets are small enough to select BetaBlockJacobi (for
1544    /// example per-atom decoder blocks with `basis_size·p <= 256`), the generic
1545    /// block builder materializes every row's dense `(d_i × K)` `H_tβ` by probing
1546    /// the matrix-free operator, then re-solves `(H_tt)⁻¹` for each block column.
1547    /// The resident factors already carry `G_i = L_iᵀ(H_tt)⁻¹L_i`, so each block
1548    /// is assembled by scattering only the active support pairs inside that block:
1549    ///
1550    /// ```text
1551    /// S_block -= Σ_i Σ_(s,t in block support) φ_s φ_t · G_i[channel_s, channel_t]
1552    /// ```
1553    ///
1554    /// It computes the same block-diagonal restriction as the generic path, but
1555    /// avoids the full-row `H_tβ` materialization and per-column triangular solves.
1556    pub(crate) fn build_block_jacobi_resident(
1557        sys: &ArrowSchurSystem,
1558        ridge_beta: f64,
1559        resident: &SaeResidentReducedSchur,
1560    ) -> Result<Self, ArrowSchurError> {
1561        let block_offsets = &sys.block_offsets;
1562        let p = resident.p;
1563        let mut schur_blocks: Vec<Array2<f64>> = Vec::with_capacity(block_offsets.len());
1564        for (block_idx, range) in block_offsets.iter().enumerate() {
1565            let b = range.end - range.start;
1566            let mut schur_block = Array2::<f64>::zeros((b, b));
1567            sys.penalty_block_add(
1568                BetaBlockId(block_idx),
1569                block_offsets.as_ref(),
1570                &mut schur_block,
1571            );
1572            for bi in 0..b {
1573                schur_block[[bi, bi]] += ridge_beta;
1574            }
1575            schur_blocks.push(schur_block);
1576        }
1577
1578        let row_into = |row: usize, blocks: &mut [Array2<f64>]| {
1579            let rf = &resident.rows[row];
1580            let di = rf.di;
1581            if di == 0 {
1582                return;
1583            }
1584            let support = &resident.a_phi[row];
1585            if support.is_empty() {
1586                return;
1587            }
1588            // `L_i` is the shared `local_jac[row]` slab (#1033) — byte-for-byte
1589            // the former per-row `rf.l` copy.
1590            let l_i = &resident.local_jac[row];
1591            for (block_idx, range) in block_offsets.iter().enumerate() {
1592                let block = &mut blocks[block_idx];
1593                for &(base_left, phi_left) in support {
1594                    if phi_left == 0.0 {
1595                        continue;
1596                    }
1597                    let left_start = base_left.max(range.start);
1598                    let left_end = (base_left + p).min(range.end);
1599                    if left_start >= left_end {
1600                        continue;
1601                    }
1602                    for &(base_right, phi_right) in support {
1603                        if phi_right == 0.0 {
1604                            continue;
1605                        }
1606                        let right_start = base_right.max(range.start);
1607                        let right_end = (base_right + p).min(range.end);
1608                        if right_start >= right_end {
1609                            continue;
1610                        }
1611                        let phi = phi_left * phi_right;
1612                        for gi in left_start..left_end {
1613                            let li = gi - range.start;
1614                            let ch_i = gi - base_left;
1615                            for gj in right_start..right_end {
1616                                let lj = gj - range.start;
1617                                let ch_j = gj - base_right;
1618                                let mut gij = 0.0_f64;
1619                                for r in 0..di {
1620                                    gij += l_i[r * p + ch_i] * rf.y[r * p + ch_j];
1621                                }
1622                                block[[li, lj]] -= phi * gij;
1623                            }
1624                        }
1625                    }
1626                }
1627            }
1628        };
1629
1630        let n = resident.rows.len();
1631        let parallel =
1632            n >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
1633        if parallel {
1634            use rayon::prelude::*;
1635            const CHUNK: usize = 64;
1636            let n_blocks = block_offsets.len();
1637            let block_dims: Vec<usize> = block_offsets.iter().map(|r| r.end - r.start).collect();
1638            let partials: Vec<Vec<Array2<f64>>> = (0..n)
1639                .into_par_iter()
1640                .chunks(CHUNK)
1641                .map(|idxs| {
1642                    let mut local: Vec<Array2<f64>> = block_dims
1643                        .iter()
1644                        .map(|&b| Array2::<f64>::zeros((b, b)))
1645                        .collect();
1646                    for i in idxs {
1647                        row_into(i, &mut local);
1648                    }
1649                    local
1650                })
1651                .collect();
1652            for local in &partials {
1653                for bidx in 0..n_blocks {
1654                    schur_blocks[bidx] += &local[bidx];
1655                }
1656            }
1657        } else {
1658            for row in 0..n {
1659                row_into(row, &mut schur_blocks);
1660            }
1661        }
1662
1663        let mut blocks = Vec::with_capacity(block_offsets.len());
1664        for (block_idx, range) in block_offsets.iter().enumerate() {
1665            let b = range.end - range.start;
1666            let schur_block = &schur_blocks[block_idx];
1667            let factor_opt = {
1668                use faer::Side;
1669                let view = FaerArrayView::new(schur_block);
1670                FaerLlt::new(view.as_ref(), Side::Lower).ok()
1671            };
1672            if let Some(llt) = factor_opt {
1673                blocks.push(BlockFactor::Chol {
1674                    factor: llt,
1675                    range: range.clone(),
1676                });
1677            } else {
1678                let mut inv = Array1::<f64>::zeros(b);
1679                for bi in 0..b {
1680                    let v = schur_block[[bi, bi]];
1681                    if !v.is_finite() || v <= JACOBI_DIAGONAL_PD_FLOOR {
1682                        return Err(ArrowSchurError::PcgFailed {
1683                            reason: format!(
1684                                "SAE-resident block Jacobi scalar fallback: non-PD diagonal at \
1685                                 global index {}: {v}; regularization required",
1686                                range.start + bi
1687                            ),
1688                        });
1689                    }
1690                    inv[bi] = 1.0 / v;
1691                }
1692                blocks.push(BlockFactor::Scalar {
1693                    inv,
1694                    range: range.clone(),
1695                });
1696            }
1697        }
1698        Ok(Self { blocks })
1699    }
1700
1701    /// Build term-block Jacobi: one dense `b×b` Schur block per term in
1702    /// `sys.block_offsets`.
1703    pub(crate) fn build_block_jacobi<B: BatchedBlockSolver + Sync>(
1704        sys: &ArrowSchurSystem,
1705        htt_factors: &ArrowFactorSlab,
1706        ridge_beta: f64,
1707        backend: &B,
1708    ) -> Result<Self, ArrowSchurError> {
1709        let block_offsets = &sys.block_offsets;
1710
1711        // Initialise every b×b Schur sub-block from H_ββ + ridge·I via
1712        // penalty_block_add (#296): routes to penalty_op or falls back to
1713        // hbb / hbb_diag inline without Arc-clone per loop iteration. These are
1714        // the block-diagonal restrictions of the reduced Schur complement; the
1715        // per-row cross-block contributions are accumulated in the row sweep
1716        // below.
1717        let mut schur_blocks: Vec<Array2<f64>> = Vec::with_capacity(block_offsets.len());
1718        for (block_idx, range) in block_offsets.iter().enumerate() {
1719            let b = range.end - range.start;
1720            let mut schur_block = Array2::<f64>::zeros((b, b));
1721            sys.penalty_block_add(
1722                BetaBlockId(block_idx),
1723                block_offsets.as_ref(),
1724                &mut schur_block,
1725            );
1726            for bi in 0..b {
1727                schur_block[[bi, bi]] += ridge_beta;
1728            }
1729            schur_blocks.push(schur_block);
1730        }
1731
1732        // Subtract Schur contributions:
1733        // S_kk -= H_βt_k^(i) (H_tt^(i))^{-1} H_tβ_k^(i)
1734        //
1735        // Materialize each row's (d_i × K) cross-block ONCE and scatter its
1736        // contribution into every block-diagonal sub-block — mirroring the
1737        // row-outer structure of `build_dense_schur_direct`. The previous
1738        // block-outer form re-materialized every row for each β-block
1739        // (O(n_blocks · n · K) probes); for the matrix-free softmax cross-block
1740        // each materialize is itself O(K²), so that nesting made the
1741        // preconditioner build quadratically more expensive than the direct
1742        // dense Schur it preconditions. sys_htbeta_materialize_row handles the
1743        // Kronecker / htbeta_matvec path transparently.
1744        // Per-row body: materialize the row's `(d_i × K)` cross-block once and
1745        // subtract its `H_βt_k^(i)(H_tt^(i))⁻¹H_tβ_k^(i)` contribution into EACH
1746        // block-diagonal sub-block. Writes INTO a caller-provided `blocks`
1747        // accumulator (`-=`) so a rayon worker can subtract a chunk's rows into
1748        // a worker-private zero-seeded `Vec<Array2>` and the caller folds the
1749        // chunk partials back in chunk order — bit-identical run-to-run
1750        // regardless of thread scheduling (the #1017 verification gate). This
1751        // is deterministic and within the chunk-reassociation margin of serial,
1752        // so the preconditioner, hence the criterion ranking, is stable except
1753        // for near-tie candidates inside that f64 margin — not an exact no-move
1754        // guarantee (#1211).
1755        let row_into = |i: usize,
1756                        row: &ArrowRowBlock,
1757                        blocks: &mut [Array2<f64>]|
1758         -> Result<(), ArrowSchurError> {
1759            let di = sys.row_dims[i];
1760            let htbeta_full = sys_htbeta_materialize_row(sys, i, row)?;
1761            for (block_idx, range) in block_offsets.iter().enumerate() {
1762                let b = range.end - range.start;
1763                let mut solved_cols = Array2::<f64>::zeros((di, b));
1764                for bj in 0..b {
1765                    let gj = range.start + bj;
1766                    let rhs = htbeta_full.column(gj).to_owned();
1767                    let solved = backend.solve_block_vector(htt_factors.factor(i), rhs.view());
1768                    for c in 0..di {
1769                        solved_cols[[c, bj]] = solved[c];
1770                    }
1771                }
1772                let schur_block = &mut blocks[block_idx];
1773                for bi in 0..b {
1774                    let gi = range.start + bi;
1775                    for bj in 0..b {
1776                        let mut acc = 0.0;
1777                        for c in 0..di {
1778                            acc += htbeta_full[[c, gi]] * solved_cols[[c, bj]];
1779                        }
1780                        schur_block[[bi, bj]] -= acc;
1781                    }
1782                }
1783            }
1784            Ok(())
1785        };
1786        // Each row materializes an `O(K²)` cross-block (Kronecker) plus `Σ_k b_k`
1787        // triangular solves — the preconditioner build's whole per-row cost at
1788        // the SAE LLM shape (#1017), and the rows are independent. Fan over fixed
1789        // row chunks above the threshold, staying serial for the handful-of-rows
1790        // non-SAE callers and inside a rayon worker (topology-race nesting guard)
1791        // — the same gate `schur_matvec` uses.
1792        let n = sys.rows.len();
1793        let parallel =
1794            n >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
1795        if parallel {
1796            use rayon::prelude::*;
1797            const CHUNK: usize = 64;
1798            let n_blocks = block_offsets.len();
1799            let block_dims: Vec<usize> = block_offsets.iter().map(|r| r.end - r.start).collect();
1800            let partials: Vec<Vec<Array2<f64>>> = (0..n)
1801                .into_par_iter()
1802                .chunks(CHUNK)
1803                .map(|idxs| {
1804                    let mut local: Vec<Array2<f64>> = block_dims
1805                        .iter()
1806                        .map(|&b| Array2::<f64>::zeros((b, b)))
1807                        .collect();
1808                    for i in idxs {
1809                        row_into(i, &sys.rows[i], &mut local)?;
1810                    }
1811                    Ok::<_, ArrowSchurError>(local)
1812                })
1813                .collect::<Result<Vec<_>, _>>()?;
1814            // Deterministic ordered reduction: fold chunk partials left-to-right.
1815            for local in &partials {
1816                for bidx in 0..n_blocks {
1817                    schur_blocks[bidx] += &local[bidx];
1818                }
1819            }
1820        } else {
1821            for (i, row) in sys.rows.iter().enumerate() {
1822                row_into(i, row, &mut schur_blocks)?;
1823            }
1824        }
1825
1826        // Factor each accumulated block: LLT, with scalar-diagonal fallback for
1827        // a block that comes out non-PD at this ridge.
1828        let mut blocks = Vec::with_capacity(block_offsets.len());
1829        for (block_idx, range) in block_offsets.iter().enumerate() {
1830            let b = range.end - range.start;
1831            let schur_block = &schur_blocks[block_idx];
1832            let factor_opt = {
1833                use faer::Side;
1834                let view = FaerArrayView::new(schur_block);
1835                FaerLlt::new(view.as_ref(), Side::Lower).ok()
1836            };
1837            if let Some(llt) = factor_opt {
1838                blocks.push(BlockFactor::Chol {
1839                    factor: llt,
1840                    range: range.clone(),
1841                });
1842            } else {
1843                // Non-PD block: fall back to scalar diagonal for this block.
1844                let mut inv = Array1::<f64>::zeros(b);
1845                for bi in 0..b {
1846                    let v = schur_block[[bi, bi]];
1847                    if !v.is_finite() || v <= JACOBI_DIAGONAL_PD_FLOOR {
1848                        return Err(ArrowSchurError::PcgFailed {
1849                            reason: format!(
1850                                "block Jacobi scalar fallback: non-PD diagonal at \
1851                                 global index {}: {v}; regularization required",
1852                                range.start + bi
1853                            ),
1854                        });
1855                    }
1856                    inv[bi] = 1.0 / v;
1857                }
1858                blocks.push(BlockFactor::Scalar {
1859                    inv,
1860                    range: range.clone(),
1861                });
1862            }
1863        }
1864        Ok(Self { blocks })
1865    }
1866
1867    pub(crate) fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
1868        let mut out = Array1::<f64>::zeros(r.len());
1869        for block in &self.blocks {
1870            match block {
1871                BlockFactor::Scalar { inv, range } => {
1872                    for (local, gi) in range.clone().enumerate() {
1873                        out[gi] = inv[local] * r[gi];
1874                    }
1875                }
1876                BlockFactor::Chol { factor, range } => {
1877                    let b = range.end - range.start;
1878                    let mut rhs = Array1::<f64>::zeros(b);
1879                    for (local, gi) in range.clone().enumerate() {
1880                        rhs[local] = r[gi];
1881                    }
1882                    use faer::linalg::solvers::Solve;
1883                    let stride = rhs.strides()[0];
1884                    let len = rhs.len();
1885                    // SAFETY: rhs is a uniquely-borrowed contiguous Array1
1886                    // with positive stride (standard layout).
1887                    let rhs_mat =
1888                        unsafe { faer::MatRef::from_raw_parts(rhs.as_ptr(), len, 1, stride, 0) };
1889                    let solved = factor.solve(rhs_mat);
1890                    for (local, gi) in range.clone().enumerate() {
1891                        out[gi] = solved[(local, 0)];
1892                    }
1893                }
1894            }
1895        }
1896        out
1897    }
1898}
1899
1900// ---------------------------------------------------------------------------
1901// Preconditioner ladder: SchurPreconditionerKind, ClusterJacobi,
1902// AdditiveSchwarz  (issue #299)
1903// ---------------------------------------------------------------------------
1904
1905/// Which Schur preconditioner to use in the inexact-PCG path.
1906///
1907/// Ladder ordered by cost / effectiveness:
1908/// - `Diagonal`: scalar Jacobi (pre-#283 behaviour).
1909/// - `BetaBlockJacobi`: block-Jacobi per `block_offsets` term (#287).
1910/// - `ClusterJacobi`: one dense block per beta-graph connected component.
1911/// - `AdditiveSchwarz { overlap }`: component + `overlap`-hop expansion,
1912///   overlapping columns averaged by partition-of-unity weights (full dense
1913///   local-inverse apply per subdomain).
1914/// - `DiagAssembledSchwarz { overlap }`: the cheap Schwarz variant (#299) —
1915///   same overlapping decomposition, but each subdomain contributes only the
1916///   diagonal of its local inverse `(A_k⁻¹)_ii`, assembled additively with
1917///   partition-of-unity weights into a single `O(K)`-apply diagonal.
1918/// - `BlockIncompleteCholesky`: level-0 incomplete Cholesky (#299). Within each
1919///   connected component of the β-coupling graph the dense reduced-Schur block
1920///   `S[C,C]` is assembled once, its structural-nonzero pattern is taken as the
1921///   level-0 fill pattern, and a no-fill incomplete Cholesky `S ≈ L̃ L̃ᵀ` is
1922///   formed keeping ONLY that pattern (Saad, *Iterative Methods*, IC(0)). Apply
1923///   is a sparse triangular forward/back solve over `nnz(S[C,C])`, so for a
1924///   large component with internal sparsity it is far cheaper to build and apply
1925///   than `ClusterJacobi`'s full dense Cholesky (which fills the whole `b×b`
1926///   factor) while retaining the inter-block coupling that ClusterJacobi keeps
1927///   but the diagonal/Schwarz tiers discard. A non-PD incomplete pivot degrades
1928///   that component to the scalar reciprocal diagonal.
1929#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1930pub enum SchurPreconditionerKind {
1931    Diagonal,
1932    BetaBlockJacobi,
1933    ClusterJacobi,
1934    AdditiveSchwarz { overlap: usize },
1935    DiagAssembledSchwarz { overlap: usize },
1936    BlockIncompleteCholesky,
1937}
1938
1939/// Escalate beyond BetaBlockJacobi only when K exceeds this value and PCG
1940/// exhausted `max_iterations`.
1941pub(crate) const PRECOND_ESCALATE_K_THRESHOLD: usize = 100;
1942
1943/// #1026 matrix-free Schur curvature-floor (the unbounded-PCG analogue of the
1944/// dense `spectral_pd_floored_schur`). On `pᵀSp ≤ 0` in the unbounded SAE inner
1945/// PCG, the operator ridge is lifted by the minimal amount that restores
1946/// positive curvature along the offending direction, plus this fractional
1947/// margin (so the next CG iterate sits strictly inside the positive cone, not on
1948/// the `0` knife-edge).
1949pub(crate) const SCHUR_CURVATURE_FLOOR_MARGIN: f64 = 1.0e-2;
1950/// Lower bound on the curvature-floor ridge bump, relative to the rhs scale, so
1951/// a `pᵀSp` that rounds to exactly `0` still gets a strictly positive bump.
1952pub(crate) const SCHUR_CURVATURE_FLOOR_REL_FLOOR: f64 = 1.0e-12;
1953/// Ceiling on the accumulated curvature-floor ridge, relative to the rhs scale.
1954/// Beyond this the operator is treated as un-conditionable by a minimal floor
1955/// and the recoverable failure is handed to the outer LM loop (which re-forms
1956/// the whole system at a heavier ridge). Generous so that a large collapsed
1957/// over-subtraction `(H_tβ)²/H_tt` is still reachable.
1958pub(crate) const SCHUR_CURVATURE_FLOOR_REL_CEILING: f64 = 1.0e12;
1959/// Multiplicative growth for the DIAGONAL-refusal ridge escalation (no
1960/// `(curvature, ‖p‖²)` deficit is available there), matching the per-row
1961/// `factor_one_row_result` `RIDGE_GROWTH_FACTOR`.
1962pub(crate) const SCHUR_CURVATURE_FLOOR_DIAG_GROWTH: f64 = 10.0;
1963/// Max curvature-floor ridge-lift attempts before deferring to the outer LM
1964/// loop. The diagonal-refusal path grows ×10 per attempt, so this bounds the
1965/// reachable ridge at `rhs_scale · 10^(attempts)` — ample for any realistic
1966/// over-subtraction while still bounded.
1967pub(crate) const SCHUR_CURVATURE_FLOOR_MAX_ATTEMPTS: usize = 24;
1968
1969/// Cholesky or scalar factor for one cluster of the beta-coefficient graph.
1970#[derive(Clone)]
1971pub(crate) enum ClusterFactor {
1972    Chol {
1973        cols: Vec<usize>,
1974        factor: FaerLlt<f64>,
1975    },
1976    Scalar {
1977        cols: Vec<usize>,
1978        inv: Vec<f64>,
1979    },
1980}
1981
1982impl std::fmt::Debug for ClusterFactor {
1983    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1984        match self {
1985            ClusterFactor::Chol { cols, .. } => {
1986                write!(f, "ClusterFactor::Chol {{ cols.len: {} }}", cols.len())
1987            }
1988            ClusterFactor::Scalar { cols, inv } => write!(
1989                f,
1990                "ClusterFactor::Scalar {{ cols.len: {}, inv.len: {} }}",
1991                cols.len(),
1992                inv.len()
1993            ),
1994        }
1995    }
1996}
1997
1998/// Maximum columns per cluster before scalar fallback.
1999pub(crate) const CLUSTER_JACOBI_MAX_CLUSTER: usize = 512;
2000
2001/// Maximum columns in a single connected component for which the IC(0)
2002/// preconditioner assembles the dense `S[C,C]` to derive its sparsity pattern.
2003/// IC(0) is cheap to APPLY at any size, but the pattern is read from the dense
2004/// assembly, which is `O(b²)` memory; beyond this the component falls back to
2005/// the scalar reciprocal diagonal (the same ceiling concern as
2006/// `CLUSTER_JACOBI_MAX_CLUSTER`, lifted because the IC(0) FACTOR is sparse).
2007pub(crate) const IC0_MAX_COMPONENT: usize = 4096;
2008
2009/// Relative threshold below which an assembled `S[i,j]` is treated as a
2010/// structural zero when deriving the IC(0) level-0 pattern. Scaled by
2011/// `sqrt(|S_ii|·|S_jj|)` so it is invariant to column scaling; this prunes
2012/// entries that are pure FMA round-off (a genuinely decoupled `(i,j)` pair
2013/// assembles to ~0) so they do not enter the kept fill pattern.
2014pub(crate) const IC0_PATTERN_REL_DROP: f64 = 1.0e-13;
2015
2016/// Assemble the dense `b×b` reduced-Schur block for the column set `cols`:
2017/// `S[cols, cols] = H_ββ[cols, cols] + ridge·I − Σ_i H_tβ[cols]ᵀ (H_tt^i)⁻¹ H_tβ[cols]`.
2018///
2019/// Shared by `ClusterJacobiPreconditioner::build_from_column_groups` (which
2020/// Cholesky-factors the returned block) and `DiagAssembledSchwarzPreconditioner`
2021/// (which inverts each subdomain block and keeps only its diagonal). The result
2022/// is the LOWER triangle filled by the row reduction; callers that need the full
2023/// symmetric block must `symmetrize_upper_from_lower`.
2024///
2025/// The per-row Schur contribution is fanned over fixed 64-row chunks above
2026/// `SCHUR_MATVEC_PARALLEL_ROW_MIN` and folded left-to-right so the assembly is
2027/// bit-identical to the serial path (and run-to-run deterministic), exactly as
2028/// in `build_block_jacobi` (#1017).
2029pub(crate) fn assemble_local_schur_block<B: BatchedBlockSolver + Sync>(
2030    sys: &ArrowSchurSystem,
2031    htt_factors: &ArrowFactorSlab,
2032    ridge_beta: f64,
2033    backend: &B,
2034    cols: &[usize],
2035) -> Array2<f64> {
2036    let d = sys.d;
2037    let b = cols.len();
2038    let mut s_block = Array2::<f64>::zeros((b, b));
2039    // Initialise from H_ββ via penalty_subblock_add (#296): routes through
2040    // penalty_op or falls back to hbb / hbb_diag inline.
2041    sys.penalty_subblock_add(cols, &mut s_block);
2042    for bi in 0..b {
2043        s_block[[bi, bi]] += ridge_beta;
2044    }
2045    let cluster_row_into = |row_idx: usize, row: &ArrowRowBlock, acc: &mut Array2<f64>| {
2046        let mut col_vec = Array1::<f64>::zeros(d);
2047        let mut solved_cols = Array2::<f64>::zeros((d, b));
2048        for bj in 0..b {
2049            let gj = cols[bj];
2050            for c in 0..d {
2051                col_vec[c] = row.htbeta[[c, gj]];
2052            }
2053            let solved = backend.solve_block_vector(htt_factors.factor(row_idx), col_vec.view());
2054            for c in 0..d {
2055                solved_cols[[c, bj]] = solved[c];
2056            }
2057        }
2058        for bi in 0..b {
2059            let gi = cols[bi];
2060            for bj in 0..b {
2061                let mut dot = 0.0;
2062                for c in 0..d {
2063                    dot += row.htbeta[[c, gi]] * solved_cols[[c, bj]];
2064                }
2065                acc[[bi, bj]] -= dot;
2066            }
2067        }
2068    };
2069    let n = sys.rows.len();
2070    let parallel = n >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
2071    if parallel {
2072        use rayon::prelude::*;
2073        const CHUNK: usize = 64;
2074        let partials: Vec<Array2<f64>> = (0..n)
2075            .into_par_iter()
2076            .chunks(CHUNK)
2077            .map(|idxs| {
2078                let mut local = Array2::<f64>::zeros((b, b));
2079                for i in idxs {
2080                    cluster_row_into(i, &sys.rows[i], &mut local);
2081                }
2082                local
2083            })
2084            .collect();
2085        for local in &partials {
2086            s_block += local;
2087        }
2088    } else {
2089        for (row_idx, row) in sys.rows.iter().enumerate() {
2090            cluster_row_into(row_idx, row, &mut s_block);
2091        }
2092    }
2093    s_block
2094}
2095
2096/// Dense Schur block per connected component of the beta-coupling graph.
2097///
2098/// Nodes = beta blocks (`block_offsets`); edges = rows where two blocks
2099/// co-occur with nonzero `H_t_beta` entries. One Cholesky factor per
2100/// connected component; applied as a triangular solve.
2101#[derive(Debug, Clone)]
2102pub struct ClusterJacobiPreconditioner {
2103    pub(crate) clusters: Vec<ClusterFactor>,
2104}
2105
2106impl ClusterJacobiPreconditioner {
2107    pub fn from_arrow_schur<B: BatchedBlockSolver + Sync>(
2108        sys: &ArrowSchurSystem,
2109        htt_factors: &ArrowFactorSlab,
2110        ridge_beta: f64,
2111        backend: &B,
2112    ) -> Result<Self, ArrowSchurError> {
2113        if sys.block_offsets.is_empty() {
2114            let cols: Vec<usize> = (0..sys.k).collect();
2115            return Self::build_from_column_groups(sys, htt_factors, ridge_beta, backend, &[cols]);
2116        }
2117        let graph = BetaCouplingGraph::build(
2118            &sys.block_offsets,
2119            &sys.rows
2120                .iter()
2121                .map(|r| r.htbeta.clone())
2122                .collect::<Vec<_>>(),
2123        );
2124        let col_groups: Vec<Vec<usize>> = graph
2125            .component_partition()
2126            .iter()
2127            .map(|comp_blocks| {
2128                let mut cols: Vec<usize> = comp_blocks
2129                    .iter()
2130                    .flat_map(|&b| sys.block_offsets[b].clone())
2131                    .collect();
2132                cols.sort_unstable();
2133                cols
2134            })
2135            .collect();
2136        Self::build_from_column_groups(sys, htt_factors, ridge_beta, backend, &col_groups)
2137    }
2138
2139    pub(crate) fn build_from_column_groups<B: BatchedBlockSolver + Sync>(
2140        sys: &ArrowSchurSystem,
2141        htt_factors: &ArrowFactorSlab,
2142        ridge_beta: f64,
2143        backend: &B,
2144        col_groups: &[Vec<usize>],
2145    ) -> Result<Self, ArrowSchurError> {
2146        let mut clusters = Vec::with_capacity(col_groups.len());
2147        for cols in col_groups {
2148            let b = cols.len();
2149            if b == 0 {
2150                continue;
2151            }
2152            if b > CLUSTER_JACOBI_MAX_CLUSTER {
2153                let inv = build_schur_scalar_inv(sys, htt_factors, ridge_beta, backend, cols)?;
2154                clusters.push(ClusterFactor::Scalar {
2155                    cols: cols.clone(),
2156                    inv,
2157                });
2158                continue;
2159            }
2160            let mut s_block =
2161                assemble_local_schur_block(sys, htt_factors, ridge_beta, backend, cols);
2162            symmetrize_upper_from_lower(&mut s_block);
2163            let factor_opt = {
2164                use faer::Side;
2165                let view = FaerArrayView::new(&s_block);
2166                FaerLlt::new(view.as_ref(), Side::Lower).ok()
2167            };
2168            if let Some(llt) = factor_opt {
2169                clusters.push(ClusterFactor::Chol {
2170                    cols: cols.clone(),
2171                    factor: llt,
2172                });
2173            } else {
2174                let inv = build_schur_scalar_inv(sys, htt_factors, ridge_beta, backend, cols)?;
2175                clusters.push(ClusterFactor::Scalar {
2176                    cols: cols.clone(),
2177                    inv,
2178                });
2179            }
2180        }
2181        Ok(Self { clusters })
2182    }
2183
2184    pub(crate) fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
2185        let mut out = Array1::<f64>::zeros(r.len());
2186        for cluster in &self.clusters {
2187            apply_cluster(cluster, r, &mut out, &ClusterApplyMode::Overwrite);
2188        }
2189        out
2190    }
2191}
2192
2193/// Additive Schwarz: base components expanded by `overlap` graph-hops;
2194/// overlapping columns averaged by partition-of-unity weights.
2195#[derive(Debug, Clone)]
2196pub struct AdditiveSchwarzPreconditioner {
2197    pub(crate) clusters: Vec<ClusterFactor>,
2198    pub(crate) weights: Vec<f64>,
2199}
2200
2201impl AdditiveSchwarzPreconditioner {
2202    pub fn from_arrow_schur<B: BatchedBlockSolver + Sync>(
2203        sys: &ArrowSchurSystem,
2204        htt_factors: &ArrowFactorSlab,
2205        ridge_beta: f64,
2206        backend: &B,
2207        overlap: usize,
2208    ) -> Result<Self, ArrowSchurError> {
2209        if sys.block_offsets.is_empty() {
2210            let cols: Vec<usize> = (0..sys.k).collect();
2211            let inner = ClusterJacobiPreconditioner::build_from_column_groups(
2212                sys,
2213                htt_factors,
2214                ridge_beta,
2215                backend,
2216                &[cols],
2217            )?;
2218            return Ok(Self {
2219                clusters: inner.clusters,
2220                weights: vec![1.0f64; sys.k],
2221            });
2222        }
2223        let graph = BetaCouplingGraph::build(
2224            &sys.block_offsets,
2225            &sys.rows
2226                .iter()
2227                .map(|r| r.htbeta.clone())
2228                .collect::<Vec<_>>(),
2229        );
2230        let col_groups: Vec<Vec<usize>> = graph
2231            .component_partition()
2232            .iter()
2233            .map(|seed| {
2234                let mut current = seed.clone();
2235                for _ in 0..overlap {
2236                    current = graph.expand_one_hop(&current);
2237                }
2238                let mut cols: Vec<usize> = current
2239                    .iter()
2240                    .flat_map(|&b| sys.block_offsets[b].clone())
2241                    .collect();
2242                cols.sort_unstable();
2243                cols.dedup();
2244                cols
2245            })
2246            .collect();
2247        let mut counts = vec![0u32; sys.k];
2248        for cols in &col_groups {
2249            for &gi in cols {
2250                counts[gi] += 1;
2251            }
2252        }
2253        let weights: Vec<f64> = counts
2254            .iter()
2255            .map(|&c| if c == 0 { 1.0 } else { 1.0 / c as f64 })
2256            .collect();
2257        let inner = ClusterJacobiPreconditioner::build_from_column_groups(
2258            sys,
2259            htt_factors,
2260            ridge_beta,
2261            backend,
2262            &col_groups,
2263        )?;
2264        Ok(Self {
2265            clusters: inner.clusters,
2266            weights,
2267        })
2268    }
2269
2270    pub(crate) fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
2271        let mut out = Array1::<f64>::zeros(r.len());
2272        for cluster in &self.clusters {
2273            apply_cluster(
2274                cluster,
2275                r,
2276                &mut out,
2277                &ClusterApplyMode::Accumulate {
2278                    weights: &self.weights,
2279                },
2280            );
2281        }
2282        out
2283    }
2284}
2285
2286/// Diagonal-assembled additive Schwarz (#299).
2287///
2288/// The cheap Schwarz variant the domain-decomposition literature recommends as
2289/// the default for sparse-coupling β-graphs: instead of storing and applying a
2290/// dense Cholesky factor per overlapping subdomain (as
2291/// [`AdditiveSchwarzPreconditioner`] does), it inverts each overlapping
2292/// subdomain Schur block ONCE at build time and keeps only the **diagonal of the
2293/// local inverse** `(A_k⁻¹)_ii`. Those per-subdomain diagonal contributions are
2294/// then assembled additively across overlapping subdomains with partition-of-
2295/// unity weights into a single global diagonal `m`, applied as `out[i] = m[i]·r[i]`.
2296///
2297/// This is strictly richer than scalar Jacobi (`1/S_ii`): the local inverse
2298/// diagonal `(A_k⁻¹)_ii` folds in the off-diagonal coupling WITHIN the subdomain,
2299/// so a strongly-coupled column gets a smaller (better-damped) effective scale
2300/// than its bare reciprocal diagonal would give — while the apply stays `O(K)`
2301/// (one multiply per column), unlike the `O(Σ b_k²)` triangular solves of dense
2302/// Schwarz. For `overlap = 0` and one column per subdomain it reduces exactly to
2303/// scalar Jacobi.
2304#[derive(Debug, Clone)]
2305pub struct DiagAssembledSchwarzPreconditioner {
2306    /// Global per-column multiplier `m[i]`; `out[i] = m[i] · r[i]`.
2307    pub(crate) inv_diag: Vec<f64>,
2308}
2309
2310impl DiagAssembledSchwarzPreconditioner {
2311    pub fn from_arrow_schur<B: BatchedBlockSolver + Sync>(
2312        sys: &ArrowSchurSystem,
2313        htt_factors: &ArrowFactorSlab,
2314        ridge_beta: f64,
2315        backend: &B,
2316        overlap: usize,
2317    ) -> Result<Self, ArrowSchurError> {
2318        // Build the overlapping subdomain column groups exactly like
2319        // AdditiveSchwarz (component partition + `overlap` graph-hop expansion),
2320        // so the two Schwarz variants decompose the β space identically and
2321        // differ only in how each subdomain's local inverse is applied.
2322        let col_groups: Vec<Vec<usize>> = if sys.block_offsets.is_empty() {
2323            vec![(0..sys.k).collect()]
2324        } else {
2325            let graph = BetaCouplingGraph::build(
2326                &sys.block_offsets,
2327                &sys.rows
2328                    .iter()
2329                    .map(|r| r.htbeta.clone())
2330                    .collect::<Vec<_>>(),
2331            );
2332            graph
2333                .component_partition()
2334                .iter()
2335                .map(|seed| {
2336                    let mut current = seed.clone();
2337                    for _ in 0..overlap {
2338                        current = graph.expand_one_hop(&current);
2339                    }
2340                    let mut cols: Vec<usize> = current
2341                        .iter()
2342                        .flat_map(|&b| sys.block_offsets[b].clone())
2343                        .collect();
2344                    cols.sort_unstable();
2345                    cols.dedup();
2346                    cols
2347                })
2348                .collect()
2349        };
2350        Self::build_from_column_groups(sys, htt_factors, ridge_beta, backend, &col_groups)
2351    }
2352
2353    pub(crate) fn build_from_column_groups<B: BatchedBlockSolver + Sync>(
2354        sys: &ArrowSchurSystem,
2355        htt_factors: &ArrowFactorSlab,
2356        ridge_beta: f64,
2357        backend: &B,
2358        col_groups: &[Vec<usize>],
2359    ) -> Result<Self, ArrowSchurError> {
2360        // Partition-of-unity weights: a column shared by `c` subdomains gets each
2361        // of its `c` diagonal contributions scaled by `1/c`, so the assembled
2362        // diagonal is a convex combination (and reduces to a single contribution
2363        // for non-overlapping columns).
2364        let mut counts = vec![0u32; sys.k];
2365        for cols in col_groups {
2366            for &gi in cols {
2367                counts[gi] += 1;
2368            }
2369        }
2370        let mut accum = vec![0.0f64; sys.k];
2371        for cols in col_groups {
2372            let b = cols.len();
2373            if b == 0 {
2374                continue;
2375            }
2376            // For large subdomains, the dense inverse is too costly; fall back to
2377            // the global scalar Schur diagonal inverse `1/S_ii` for those columns
2378            // (the diag-assembled variant then coincides with scalar Jacobi over
2379            // that subdomain, which is exactly the intended cheap degradation).
2380            if b > CLUSTER_JACOBI_MAX_CLUSTER {
2381                let inv = build_schur_scalar_inv(sys, htt_factors, ridge_beta, backend, cols)?;
2382                for (local, &gi) in cols.iter().enumerate() {
2383                    let w = if counts[gi] == 0 {
2384                        1.0
2385                    } else {
2386                        1.0 / counts[gi] as f64
2387                    };
2388                    accum[gi] += w * inv[local];
2389                }
2390                continue;
2391            }
2392            let mut s_block =
2393                assemble_local_schur_block(sys, htt_factors, ridge_beta, backend, cols);
2394            symmetrize_upper_from_lower(&mut s_block);
2395            // Diagonal of the local inverse `(A_k⁻¹)_ii`, obtained by solving
2396            // `A_k X = I` through the same faer Cholesky used elsewhere; on a
2397            // non-PD local block, degrade to the scalar reciprocal diagonal.
2398            let local_inv_diag = match local_inverse_diagonal(&s_block) {
2399                Some(diag) => diag,
2400                None => {
2401                    let inv = build_schur_scalar_inv(sys, htt_factors, ridge_beta, backend, cols)?;
2402                    inv
2403                }
2404            };
2405            for (local, &gi) in cols.iter().enumerate() {
2406                let w = if counts[gi] == 0 {
2407                    1.0
2408                } else {
2409                    1.0 / counts[gi] as f64
2410                };
2411                accum[gi] += w * local_inv_diag[local];
2412            }
2413        }
2414        // A column never covered by any subdomain (only possible for `k` columns
2415        // with no block_offsets coverage) keeps a neutral unit scale.
2416        for (gi, &c) in counts.iter().enumerate() {
2417            if c == 0 {
2418                accum[gi] = 1.0;
2419            }
2420        }
2421        for (gi, m) in accum.iter().enumerate() {
2422            if !m.is_finite() || *m <= 0.0 {
2423                return Err(ArrowSchurError::PcgFailed {
2424                    reason: format!(
2425                        "diag-assembled Schwarz: non-positive assembled diagonal at index {gi}: {m}"
2426                    ),
2427                });
2428            }
2429        }
2430        Ok(Self { inv_diag: accum })
2431    }
2432
2433    pub(crate) fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
2434        let mut out = Array1::<f64>::zeros(r.len());
2435        for (gi, &m) in self.inv_diag.iter().enumerate() {
2436            out[gi] = m * r[gi];
2437        }
2438        out
2439    }
2440}
2441
2442/// Diagonal of `A⁻¹` for a small dense SPD block `A`, via the same faer
2443/// Cholesky used by the cluster/Schwarz factors. Returns `None` if `A` is not
2444/// positive-definite (caller degrades to the scalar reciprocal diagonal).
2445pub(crate) fn local_inverse_diagonal(a: &Array2<f64>) -> Option<Vec<f64>> {
2446    let b = a.nrows();
2447    let llt = {
2448        use faer::Side;
2449        let view = FaerArrayView::new(a);
2450        FaerLlt::new(view.as_ref(), Side::Lower).ok()?
2451    };
2452    use faer::linalg::solvers::Solve;
2453    let mut diag = Vec::with_capacity(b);
2454    for col in 0..b {
2455        // Solve `A x = e_col`; the `col`-th entry of `x` is `(A⁻¹)_{col,col}`.
2456        let mut rhs = Array1::<f64>::zeros(b);
2457        rhs[col] = 1.0;
2458        let stride = rhs.strides()[0];
2459        let len = rhs.len();
2460        // SAFETY: `rhs` is a uniquely-borrowed contiguous `Array1<f64>` of `len`
2461        // elements with positive row stride; a single column never dereferences
2462        // the column stride, so `0` is sound.
2463        let rhs_mat = unsafe { faer::MatRef::from_raw_parts(rhs.as_ptr(), len, 1, stride, 0) };
2464        let solved = llt.solve(rhs_mat);
2465        diag.push(solved[(col, 0)]);
2466    }
2467    Some(diag)
2468}
2469
2470/// How a cluster factor's contribution is written into the output vector.
2471///
2472/// `Overwrite` assigns `out[gi] = value` (non-overlapping clusters, each global
2473/// column touched by exactly one cluster). `Accumulate` adds the partition-of-unity
2474/// weighted contribution `out[gi] += weights[gi] * value` (overlapping Schwarz
2475/// clusters, where a column may belong to several clusters).
2476pub(crate) enum ClusterApplyMode<'w> {
2477    Overwrite,
2478    Accumulate { weights: &'w [f64] },
2479}
2480
2481impl ClusterApplyMode<'_> {
2482    #[inline]
2483    pub(crate) fn write(&self, out: &mut Array1<f64>, gi: usize, value: f64) {
2484        match self {
2485            ClusterApplyMode::Overwrite => out[gi] = value,
2486            ClusterApplyMode::Accumulate { weights } => out[gi] += weights[gi] * value,
2487        }
2488    }
2489}
2490
2491/// Apply a single cluster factor to the residual `r`, writing into `out`
2492/// according to `mode` (overwrite for non-overlapping clusters, weighted
2493/// accumulate for overlapping Schwarz clusters).
2494pub(crate) fn apply_cluster(
2495    cluster: &ClusterFactor,
2496    r: &Array1<f64>,
2497    out: &mut Array1<f64>,
2498    mode: &ClusterApplyMode<'_>,
2499) {
2500    match cluster {
2501        ClusterFactor::Scalar { cols, inv } => {
2502            for (local, &gi) in cols.iter().enumerate() {
2503                mode.write(out, gi, inv[local] * r[gi]);
2504            }
2505        }
2506        ClusterFactor::Chol { cols, factor } => {
2507            let b = cols.len();
2508            let mut rhs = Array1::<f64>::zeros(b);
2509            for (local, &gi) in cols.iter().enumerate() {
2510                rhs[local] = r[gi];
2511            }
2512            use faer::linalg::solvers::Solve;
2513            let stride = rhs.strides()[0];
2514            let len = rhs.len();
2515            // SAFETY: rhs is uniquely-borrowed contiguous Array1 with positive stride.
2516            let rhs_mat = unsafe { faer::MatRef::from_raw_parts(rhs.as_ptr(), len, 1, stride, 0) };
2517            let solved = factor.solve(rhs_mat);
2518            for (local, &gi) in cols.iter().enumerate() {
2519                mode.write(out, gi, solved[(local, 0)]);
2520            }
2521        }
2522    }
2523}
2524
2525/// One connected-component factor of the block IC(0) preconditioner.
2526///
2527/// `IncompleteChol` holds a sparse lower-triangular `L̃` in column-compressed
2528/// form over the component's local indices: `col_ptr[j]..col_ptr[j+1]` indexes
2529/// into `(row_idx, val)` for column `j` (rows `>= j`, diagonal first). `cols`
2530/// maps a local index back to its global β column. `Scalar` is the non-PD /
2531/// oversized degradation, identical in meaning to [`ClusterFactor::Scalar`].
2532#[derive(Clone)]
2533pub(crate) enum Ic0Factor {
2534    IncompleteChol {
2535        cols: Vec<usize>,
2536        col_ptr: Vec<usize>,
2537        row_idx: Vec<usize>,
2538        val: Vec<f64>,
2539    },
2540    Scalar {
2541        cols: Vec<usize>,
2542        inv: Vec<f64>,
2543    },
2544}
2545
2546impl std::fmt::Debug for Ic0Factor {
2547    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2548        match self {
2549            Ic0Factor::IncompleteChol { cols, val, .. } => write!(
2550                f,
2551                "Ic0Factor::IncompleteChol {{ cols.len: {}, nnz: {} }}",
2552                cols.len(),
2553                val.len()
2554            ),
2555            Ic0Factor::Scalar { cols, .. } => {
2556                write!(f, "Ic0Factor::Scalar {{ cols.len: {} }}", cols.len())
2557            }
2558        }
2559    }
2560}
2561
2562/// Level-0 incomplete-Cholesky Schur preconditioner (#299).
2563///
2564/// One sparse incomplete-Cholesky factor per connected component of the
2565/// β-coupling graph. Within a component the dense `S[C,C]` is assembled, its
2566/// structural-nonzero pattern `P = { (i,j) : |S_ij| > drop·sqrt(S_ii S_jj) }`
2567/// is taken as the level-0 fill set, and the no-fill incomplete Cholesky
2568/// `S ≈ L̃ L̃ᵀ` is formed keeping only `P` (drop any update landing outside it).
2569/// See [`SchurPreconditionerKind::BlockIncompleteCholesky`].
2570#[derive(Debug, Clone)]
2571pub struct BlockIncompleteCholeskyPreconditioner {
2572    pub(crate) components: Vec<Ic0Factor>,
2573}
2574
2575impl BlockIncompleteCholeskyPreconditioner {
2576    pub fn from_arrow_schur<B: BatchedBlockSolver + Sync>(
2577        sys: &ArrowSchurSystem,
2578        htt_factors: &ArrowFactorSlab,
2579        ridge_beta: f64,
2580        backend: &B,
2581    ) -> Result<Self, ArrowSchurError> {
2582        // Column grouping mirrors ClusterJacobi: one group per connected
2583        // component of the β-coupling graph (whole-K single group when no
2584        // block_offsets are registered), so IC(0) preconditions exactly the
2585        // coupling ClusterJacobi keeps, but with a sparse (no-fill) factor.
2586        let col_groups: Vec<Vec<usize>> = if sys.block_offsets.is_empty() {
2587            vec![(0..sys.k).collect()]
2588        } else {
2589            let graph = BetaCouplingGraph::build(
2590                &sys.block_offsets,
2591                &sys.rows
2592                    .iter()
2593                    .map(|r| r.htbeta.clone())
2594                    .collect::<Vec<_>>(),
2595            );
2596            graph
2597                .component_partition()
2598                .iter()
2599                .map(|comp| {
2600                    let mut cols: Vec<usize> = comp
2601                        .iter()
2602                        .flat_map(|&blk| sys.block_offsets[blk].clone())
2603                        .collect();
2604                    cols.sort_unstable();
2605                    cols.dedup();
2606                    cols
2607                })
2608                .collect()
2609        };
2610
2611        let mut components = Vec::with_capacity(col_groups.len());
2612        for cols in &col_groups {
2613            let b = cols.len();
2614            if b == 0 {
2615                continue;
2616            }
2617            if b > IC0_MAX_COMPONENT {
2618                let inv = build_schur_scalar_inv(sys, htt_factors, ridge_beta, backend, cols)?;
2619                components.push(Ic0Factor::Scalar {
2620                    cols: cols.clone(),
2621                    inv,
2622                });
2623                continue;
2624            }
2625            let mut s_block =
2626                assemble_local_schur_block(sys, htt_factors, ridge_beta, backend, cols);
2627            symmetrize_upper_from_lower(&mut s_block);
2628            match incomplete_cholesky_level0(&s_block) {
2629                Some((col_ptr, row_idx, val)) => components.push(Ic0Factor::IncompleteChol {
2630                    cols: cols.clone(),
2631                    col_ptr,
2632                    row_idx,
2633                    val,
2634                }),
2635                None => {
2636                    // Non-PD incomplete pivot: degrade this component to the
2637                    // scalar reciprocal diagonal (mirrors the ClusterJacobi
2638                    // non-PD fallback), which is always applicable for a
2639                    // PD-floored Schur diagonal.
2640                    let inv = build_schur_scalar_inv(sys, htt_factors, ridge_beta, backend, cols)?;
2641                    components.push(Ic0Factor::Scalar {
2642                        cols: cols.clone(),
2643                        inv,
2644                    });
2645                }
2646            }
2647        }
2648        Ok(Self { components })
2649    }
2650
2651    pub(crate) fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
2652        let mut out = Array1::<f64>::zeros(r.len());
2653        for comp in &self.components {
2654            match comp {
2655                Ic0Factor::Scalar { cols, inv } => {
2656                    for (local, &gi) in cols.iter().enumerate() {
2657                        out[gi] = inv[local] * r[gi];
2658                    }
2659                }
2660                Ic0Factor::IncompleteChol {
2661                    cols,
2662                    col_ptr,
2663                    row_idx,
2664                    val,
2665                } => {
2666                    let b = cols.len();
2667                    // Gather the local residual, solve `L̃ L̃ᵀ z = r_local` by a
2668                    // sparse forward solve (`L̃ y = r`) then a sparse back solve
2669                    // (`L̃ᵀ z = y`), then scatter `z` back to global columns.
2670                    let mut z = vec![0.0f64; b];
2671                    for (local, &gi) in cols.iter().enumerate() {
2672                        z[local] = r[gi];
2673                    }
2674                    // Forward solve `L̃ y = r` (overwrite z with y). Column-major
2675                    // CSC: row_idx[col_ptr[j]] == j (diagonal stored first).
2676                    for j in 0..b {
2677                        let dstart = col_ptr[j];
2678                        let diag = val[dstart];
2679                        z[j] /= diag;
2680                        let yj = z[j];
2681                        for k in (dstart + 1)..col_ptr[j + 1] {
2682                            z[row_idx[k]] -= val[k] * yj;
2683                        }
2684                    }
2685                    // Back solve `L̃ᵀ z = y` (overwrite z). Walk columns in
2686                    // reverse; the below-diagonal entries of column j are the
2687                    // off-diagonal entries of row j of L̃ᵀ.
2688                    for j in (0..b).rev() {
2689                        let dstart = col_ptr[j];
2690                        let mut acc = z[j];
2691                        for k in (dstart + 1)..col_ptr[j + 1] {
2692                            acc -= val[k] * z[row_idx[k]];
2693                        }
2694                        z[j] = acc / val[dstart];
2695                    }
2696                    for (local, &gi) in cols.iter().enumerate() {
2697                        out[gi] = z[local];
2698                    }
2699                }
2700            }
2701        }
2702        out
2703    }
2704}
2705
2706/// Level-0 incomplete Cholesky of a dense SPD-ish block `a` (`b×b`, symmetric).
2707///
2708/// Returns the lower factor `L̃` in column-compressed (CSC) form
2709/// `(col_ptr, row_idx, val)` where each column lists its diagonal entry FIRST
2710/// followed by the strictly-below-diagonal entries, in increasing row order.
2711/// The kept pattern is the level-0 set `P` = structural nonzeros of `a` (a
2712/// relative drop threshold prunes round-off). IC(0) computes the standard
2713/// Cholesky recurrence but DROPS any value at a position outside `P`, so the
2714/// factor has exactly `nnz(tril(P))` entries — no fill. Returns `None` on a
2715/// non-positive pivot (caller degrades to scalar diagonal).
2716///
2717/// Reference: Y. Saad, *Iterative Methods for Sparse Linear Systems*, 2nd ed.,
2718/// §10.3.2 (IC(0)). This is the left-looking, pattern-restricted variant.
2719pub(crate) fn incomplete_cholesky_level0(
2720    a: &Array2<f64>,
2721) -> Option<(Vec<usize>, Vec<usize>, Vec<f64>)> {
2722    let b = a.nrows();
2723    assert_eq!(a.ncols(), b, "incomplete Cholesky needs a square block");
2724
2725    // ---- derive the level-0 lower-triangular pattern from `a` --------------
2726    // Per column j, the kept below-or-on-diagonal rows i>=j with a structurally
2727    // nonzero a[i,j]. The diagonal is always kept.
2728    let mut col_ptr = vec![0usize; b + 1];
2729    let mut row_idx: Vec<usize> = Vec::new();
2730    // value buffer, parallel to row_idx, initialised from tril(a) on the pattern
2731    let mut val: Vec<f64> = Vec::new();
2732    // For O(1) "is (i,j) in pattern + where" lookups during the recurrence, keep
2733    // a per-column map from global row -> position in that column's value slice.
2734    let mut col_pos: Vec<std::collections::HashMap<usize, usize>> = Vec::with_capacity(b);
2735    for j in 0..b {
2736        let ajj = a[[j, j]];
2737        let scale_j = ajj.abs().max(0.0).sqrt();
2738        let mut map = std::collections::HashMap::new();
2739        // diagonal first
2740        map.insert(j, val.len());
2741        row_idx.push(j);
2742        val.push(ajj);
2743        for i in (j + 1)..b {
2744            let aij = a[[i, j]];
2745            let scale_i = a[[i, i]].abs().sqrt();
2746            let thresh = IC0_PATTERN_REL_DROP * scale_i * scale_j;
2747            if aij.abs() > thresh {
2748                map.insert(i, val.len());
2749                row_idx.push(i);
2750                val.push(aij);
2751            }
2752        }
2753        col_pos.push(map);
2754        col_ptr[j + 1] = val.len();
2755    }
2756
2757    // ---- IC(0) recurrence, left-looking over columns -----------------------
2758    // For column j: subtract the contributions of all prior columns k<j that
2759    // have BOTH a nonzero at row j (so they touch the diagonal/the column) — the
2760    // multiplier L[j,k] — and a nonzero at the rows i of column j's pattern.
2761    // Any update whose target (i,j) is OUTSIDE the kept pattern is dropped.
2762    for j in 0..b {
2763        // Diagonal: a[j,j] - Σ_{k<j} L[j,k]². Each prior column k<j contributes
2764        // its row-j entry L[j,k] (looked up by row, so the column index is not
2765        // needed); columns without a row-j entry contribute nothing.
2766        let dpos = col_ptr[j];
2767        let mut diag = val[dpos];
2768        for mapk in &col_pos[..j] {
2769            if let Some(&pjk) = mapk.get(&j) {
2770                let ljk = val[pjk];
2771                diag -= ljk * ljk;
2772            }
2773        }
2774        if !diag.is_finite() || diag <= JACOBI_DIAGONAL_PD_FLOOR {
2775            return None;
2776        }
2777        let ljj = diag.sqrt();
2778        val[dpos] = ljj;
2779        // Below-diagonal of column j: L[i,j] = (a[i,j] - Σ_{k<j} L[i,k] L[j,k]) / L[j,j]
2780        for p in (dpos + 1)..col_ptr[j + 1] {
2781            let i = row_idx[p];
2782            let mut s = val[p];
2783            for mapk in &col_pos[..j] {
2784                if let (Some(&pik), Some(&pjk)) = (mapk.get(&i), mapk.get(&j)) {
2785                    s -= val[pik] * val[pjk];
2786                }
2787            }
2788            val[p] = s / ljj;
2789        }
2790    }
2791    Some((col_ptr, row_idx, val))
2792}
2793
2794/// One row of the #299 preconditioner-ladder iteration study: the converged
2795/// PCG iteration count and stop reason for a single preconditioner tier.
2796#[derive(Debug, Clone, Copy)]
2797pub struct PrecondLadderRow {
2798    /// PCG iterations to convergence (or to the `MaxIter` cutoff).
2799    pub iterations: usize,
2800    /// Whether the PCG converged (vs hit `MaxIter` / negative curvature).
2801    pub converged: bool,
2802    /// Final relative residual reported by the PCG.
2803    pub final_relative_residual: f64,
2804}
2805
2806/// Full #299 ladder iteration study on one reduced-Schur system: run the SAME
2807/// preconditioned CG (same `rhs`, tolerances, trust radius) once per ladder tier
2808/// and report the iteration count of each. This is the public seam the
2809/// `tests/owed_299.rs` iteration-reduction gate drives — it keeps the internal
2810/// `run_pcg_with_preconditioner` / preconditioner constructors `pub(crate)`
2811/// while exposing exactly the per-tier measurement the issue asks for.
2812///
2813/// Tiers (in escalation order): scalar `Diagonal`, `BetaBlockJacobi`,
2814/// `ClusterJacobi`, `AdditiveSchwarz{overlap:1}`, `DiagAssembledSchwarz{1}`, and
2815/// `BlockIncompleteCholesky`. A tier whose build fails (e.g. non-PD reduced
2816/// Schur with no curvature floor) reports `None` for that entry; every healthy
2817/// SPD reduced system populates all six.
2818pub fn arrow_precond_ladder_iteration_study(
2819    sys: &ArrowSchurSystem,
2820    ridge_beta: f64,
2821    rhs: &Array1<f64>,
2822    pcg: &ArrowPcgOptions,
2823    trust: &ArrowTrustRegionOptions,
2824) -> Result<Vec<(SchurPreconditionerKind, Option<PrecondLadderRow>)>, ArrowSchurError> {
2825    let backend = CpuBatchedBlockSolver;
2826    let htt_factors = backend.factor_blocks(&sys.rows, 0.0, sys.d, false)?;
2827
2828    let run = |apply: &dyn Fn(&Array1<f64>) -> Array1<f64>| -> Option<PrecondLadderRow> {
2829        let (_sol, diag) = run_pcg_with_preconditioner(
2830            sys,
2831            &htt_factors,
2832            ridge_beta,
2833            rhs,
2834            |r| apply(r),
2835            pcg,
2836            trust,
2837            &backend,
2838            None,
2839            None,
2840            None,
2841        )
2842        .ok()?;
2843        Some(PrecondLadderRow {
2844            iterations: diag.iterations,
2845            converged: matches!(diag.stopping_reason, PcgStopReason::Converged),
2846            final_relative_residual: diag.final_relative_residual,
2847        })
2848    };
2849
2850    let mut out: Vec<(SchurPreconditionerKind, Option<PrecondLadderRow>)> = Vec::with_capacity(6);
2851
2852    // Scalar Diagonal Jacobi: force the scalar path by clearing block_offsets on
2853    // a clone so the build does not pick up the per-block dense Schur blocks.
2854    let diag_row = {
2855        let mut bare = sys.clone();
2856        bare.set_block_offsets(std::sync::Arc::from([] as [Range<usize>; 0]));
2857        let bare_factors = backend.factor_blocks(&bare.rows, 0.0, bare.d, false)?;
2858        JacobiPreconditioner::from_arrow_schur(&bare, &bare_factors, ridge_beta, &backend, None)
2859            .ok()
2860            .and_then(|p| {
2861                run_pcg_with_preconditioner(
2862                    &bare,
2863                    &bare_factors,
2864                    ridge_beta,
2865                    rhs,
2866                    |r| p.apply(r),
2867                    pcg,
2868                    trust,
2869                    &backend,
2870                    None,
2871                    None,
2872                    None,
2873                )
2874                .ok()
2875                .map(|(_s, diag)| PrecondLadderRow {
2876                    iterations: diag.iterations,
2877                    converged: matches!(diag.stopping_reason, PcgStopReason::Converged),
2878                    final_relative_residual: diag.final_relative_residual,
2879                })
2880            })
2881    };
2882    out.push((SchurPreconditionerKind::Diagonal, diag_row));
2883
2884    let block_row =
2885        JacobiPreconditioner::from_arrow_schur(sys, &htt_factors, ridge_beta, &backend, None)
2886            .ok()
2887            .and_then(|p| run(&|r| p.apply(r)));
2888    out.push((SchurPreconditionerKind::BetaBlockJacobi, block_row));
2889
2890    let cluster_row =
2891        ClusterJacobiPreconditioner::from_arrow_schur(sys, &htt_factors, ridge_beta, &backend)
2892            .ok()
2893            .and_then(|p| run(&|r| p.apply(r)));
2894    out.push((SchurPreconditionerKind::ClusterJacobi, cluster_row));
2895
2896    let schwarz_row =
2897        AdditiveSchwarzPreconditioner::from_arrow_schur(sys, &htt_factors, ridge_beta, &backend, 1)
2898            .ok()
2899            .and_then(|p| run(&|r| p.apply(r)));
2900    out.push((
2901        SchurPreconditionerKind::AdditiveSchwarz { overlap: 1 },
2902        schwarz_row,
2903    ));
2904
2905    let diag_schwarz_row = DiagAssembledSchwarzPreconditioner::from_arrow_schur(
2906        sys,
2907        &htt_factors,
2908        ridge_beta,
2909        &backend,
2910        1,
2911    )
2912    .ok()
2913    .and_then(|p| run(&|r| p.apply(r)));
2914    out.push((
2915        SchurPreconditionerKind::DiagAssembledSchwarz { overlap: 1 },
2916        diag_schwarz_row,
2917    ));
2918
2919    let ic0_row = BlockIncompleteCholeskyPreconditioner::from_arrow_schur(
2920        sys,
2921        &htt_factors,
2922        ridge_beta,
2923        &backend,
2924    )
2925    .ok()
2926    .and_then(|p| run(&|r| p.apply(r)));
2927    out.push((SchurPreconditionerKind::BlockIncompleteCholesky, ic0_row));
2928
2929    Ok(out)
2930}
2931
2932/// Build scalar diagonal inverses for a set of global column indices.
2933///
2934/// Used when a cluster is non-PD or exceeds `CLUSTER_JACOBI_MAX_CLUSTER`.
2935pub(crate) fn build_schur_scalar_inv<B: BatchedBlockSolver>(
2936    sys: &ArrowSchurSystem,
2937    htt_factors: &ArrowFactorSlab,
2938    ridge_beta: f64,
2939    backend: &B,
2940    cols: &[usize],
2941) -> Result<Vec<f64>, ArrowSchurError> {
2942    let d = sys.d;
2943    let mut result = Vec::with_capacity(cols.len());
2944    let mut col_vec = Array1::<f64>::zeros(d);
2945    // Extract the penalty diagonal for all K columns once, then index per-column.
2946    let mut full_diag = Array1::<f64>::zeros(sys.k);
2947    {
2948        let diag_slice = full_diag.as_slice_mut().expect("full_diag contiguous");
2949        sys.penalty_diagonal_add(diag_slice);
2950    }
2951    for &gi in cols {
2952        let mut s = full_diag[gi] + ridge_beta;
2953        for (row_idx, row) in sys.rows.iter().enumerate() {
2954            for c in 0..d {
2955                col_vec[c] = row.htbeta[[c, gi]];
2956            }
2957            let solved = backend.solve_block_vector(htt_factors.factor(row_idx), col_vec.view());
2958            let mut acc = 0.0;
2959            for c in 0..d {
2960                acc += col_vec[c] * solved[c];
2961            }
2962            s -= acc;
2963        }
2964        if !s.is_finite() || s <= JACOBI_DIAGONAL_PD_FLOOR {
2965            return Err(ArrowSchurError::PcgFailed {
2966                reason: format!(
2967                    "cluster Schur scalar fallback: non-PD diagonal at index {gi}: {s}"
2968                ),
2969            });
2970        }
2971        result.push(1.0 / s);
2972    }
2973    Ok(result)
2974}
2975
2976/// Inexact PCG with automatic preconditioner-ladder escalation.
2977///
2978/// Starts with `JacobiPreconditioner` (Diagonal or BetaBlockJacobi).
2979/// If PCG hits `MaxIter` and `k > PRECOND_ESCALATE_K_THRESHOLD`,
2980/// escalates to `ClusterJacobi`; if still `MaxIter`, escalates to
2981/// `AdditiveSchwarz { overlap: 1 }`.
2982pub(crate) fn steihaug_pcg_auto<B: BatchedBlockSolver + Sync>(
2983    sys: &ArrowSchurSystem,
2984    htt_factors: &ArrowFactorSlab,
2985    ridge_beta: f64,
2986    rhs: &Array1<f64>,
2987    pcg: &ArrowPcgOptions,
2988    trust: &ArrowTrustRegionOptions,
2989    backend: &B,
2990    gpu_matvec: Option<&GpuSchurMatvec>,
2991    metric_weights: Option<&MetricWeights>,
2992    curvature_floor: Option<f64>,
2993) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurError> {
2994    // #1017 CPU residency: stage the per-row reduced-Schur factors `(L_i, Y_i)`
2995    // (NOT the dense `p×p` block — `di ≪ p`, so the factored form is `O(n·di·p)`
2996    // memory and `2·support_i·p + 2·di·p` flops/row including the sparse
2997    // gather/scatter over the active support) once, up
2998    // front, when the SAE structure is installed and the matvec runs on host
2999    // (CPU). The GPU matvec carries its own residency, so skip when it is engaged.
3000    // The same staged operator is reused across the whole preconditioner ladder
3001    // (Jacobi → ClusterJacobi → AdditiveSchwarz) — built once, not per tier.
3002    let resident = if gpu_matvec.is_none() {
3003        SaeResidentReducedSchur::build(sys, htt_factors, backend)
3004    } else {
3005        None
3006    };
3007    // #1026 — curvature-floor retry on the Jacobi tier. The unbounded SAE inner
3008    // PCG (trust radius = ∞) fails on `pᵀSp ≤ 0` when the reduced Schur is
3009    // indefinite (K≥4 co-collapse: a near-singular per-row `H_tt` over-subtracts
3010    // `S`). Instead of letting that failure propagate to the outer LM loop —
3011    // which inflates `ridge_β` over EVERY β direction and makes the inner Newton
3012    // crawl — floor the OPERATOR by the minimal ridge `δ = |pᵀSp|/‖p‖² · (1+ε)`
3013    // that restores positive curvature along the offending direction, rebuild the
3014    // Jacobi preconditioner at the lifted ridge, and retry. This is the
3015    // matrix-free analogue of the dense `spectral_pd_floored_schur`: the healthy
3016    // β subspace (where curvature is already positive) is essentially untouched
3017    // by a tiny `δ`, while the collapsed direction gets exactly the stiffness it
3018    // needs to make a real descent step. A PD reduced Schur never hits `pᵀSp ≤ 0`,
3019    // so this loop is a strict no-op there (bit-for-bit unchanged). Bounded by a
3020    // small attempt cap and a relative ridge ceiling; on exhaustion the original
3021    // recoverable failure still reaches the outer LM loop.
3022    let mut effective_ridge = ridge_beta;
3023    let mut x0_diag0: Option<(Array1<f64>, PcgDiagnostics)> = None;
3024    let mut last_curvature_err: Option<ArrowSchurError> = None;
3025    let rhs_scale = metric_norm(rhs.view(), metric_weights).max(1.0);
3026    let ridge_ceiling = ridge_beta.max(SCHUR_CURVATURE_FLOOR_REL_CEILING * rhs_scale);
3027    for _attempt in 0..=SCHUR_CURVATURE_FLOOR_MAX_ATTEMPTS {
3028        // The Jacobi preconditioner build itself refuses a non-PD Schur diagonal
3029        // (`PcgFailed: invalid Schur Jacobi diagonal`) — the SAME co-collapse
3030        // signature reached BEFORE the CG loop, since `S_ii = H_ββ,ii − Σ …` goes
3031        // negative. Treat that build failure as a curvature deficit too: when the
3032        // floor is enabled, lift the ridge and retry; otherwise propagate.
3033        let jacobi = match JacobiPreconditioner::from_arrow_schur(
3034            sys,
3035            htt_factors,
3036            effective_ridge,
3037            backend,
3038            resident.as_ref(),
3039        ) {
3040            Ok(jacobi) => jacobi,
3041            Err(err @ ArrowSchurError::PcgFailed { .. }) => {
3042                if curvature_floor.is_none() {
3043                    return Err(err);
3044                }
3045                // A diagonal refusal carries no `(curvature, ‖p‖²)` deficit, and
3046                // the over-subtraction magnitude `Σ H_tβᵀ(H_tt)⁻¹H_tβ` is
3047                // unbounded relative to `rhs_scale`, so a small additive bump
3048                // would crawl. Escalate the ridge MULTIPLICATIVELY (×10, matching
3049                // the per-row `factor_one_row_result` RIDGE_GROWTH_FACTOR), seeded
3050                // at `rhs_scale`, so even a large deficit (the collapsed
3051                // `(H_tβ)²/H_tt` over-subtraction) is reached in a handful of
3052                // attempts. The ceiling + attempt cap still bound it; on
3053                // exhaustion the recoverable failure reaches the outer LM loop.
3054                let next = if effective_ridge > 0.0 {
3055                    effective_ridge * SCHUR_CURVATURE_FLOOR_DIAG_GROWTH
3056                } else {
3057                    rhs_scale
3058                };
3059                last_curvature_err = Some(err);
3060                if !next.is_finite() || next > ridge_ceiling {
3061                    break;
3062                }
3063                effective_ridge = next;
3064                continue;
3065            }
3066            Err(other) => return Err(other),
3067        };
3068        match run_pcg_with_preconditioner(
3069            sys,
3070            htt_factors,
3071            effective_ridge,
3072            rhs,
3073            |r| jacobi.apply(r),
3074            pcg,
3075            trust,
3076            backend,
3077            gpu_matvec,
3078            metric_weights,
3079            resident.as_ref(),
3080        ) {
3081            Ok(result) => {
3082                x0_diag0 = Some(result);
3083                break;
3084            }
3085            Err(ArrowSchurError::UnboundedNegativeCurvature {
3086                curvature,
3087                direction_norm_sq,
3088            }) => {
3089                // Only floor when the caller opted in (SAE solve path); otherwise
3090                // propagate the raw negative-curvature signal so BA / non-SAE
3091                // unbounded solves keep their existing failure contract.
3092                let Some(relative_floor) = curvature_floor else {
3093                    return Err(ArrowSchurError::UnboundedNegativeCurvature {
3094                        curvature,
3095                        direction_norm_sq,
3096                    });
3097                };
3098                // Minimal ridge to make `pᵀ(S+δI)p = |curvature| + δ·‖p‖² > 0`,
3099                // with a margin so the next CG iterate has strictly positive
3100                // curvature rather than sitting on the `0` knife-edge.
3101                let deficit = if direction_norm_sq > 0.0 {
3102                    curvature.abs() / direction_norm_sq
3103                } else {
3104                    0.0
3105                };
3106                let bump = (deficit * (1.0 + SCHUR_CURVATURE_FLOOR_MARGIN))
3107                    .max(relative_floor.max(SCHUR_CURVATURE_FLOOR_REL_FLOOR) * rhs_scale);
3108                let next = (effective_ridge + bump).max(effective_ridge * 2.0);
3109                last_curvature_err = Some(ArrowSchurError::UnboundedNegativeCurvature {
3110                    curvature,
3111                    direction_norm_sq,
3112                });
3113                if !next.is_finite() || next > ridge_ceiling {
3114                    break;
3115                }
3116                effective_ridge = next;
3117            }
3118            Err(other) => return Err(other),
3119        }
3120    }
3121    let (x0, diag0) = match x0_diag0 {
3122        Some(result) => result,
3123        None => {
3124            // The curvature floor could not condition the operator within the
3125            // ceiling; hand the recoverable failure to the outer LM loop, which
3126            // re-forms the system at a heavier ridge.
3127            return Err(last_curvature_err.unwrap_or(ArrowSchurError::PcgFailed {
3128                reason: "unbounded Schur PCG negative curvature unresolved by curvature floor"
3129                    .to_string(),
3130            }));
3131        }
3132    };
3133    if sys.k <= PRECOND_ESCALATE_K_THRESHOLD || diag0.stopping_reason != PcgStopReason::MaxIter {
3134        return Ok((x0, diag0));
3135    }
3136    // Escalation tiers reuse the curvature-floored `effective_ridge` so the
3137    // operator they precondition is the SAME (PD-floored) one the Jacobi tier
3138    // settled on; a still-negative-curvature signal here is handed to the outer
3139    // LM loop (it only arises if the floored Jacobi tier merely ran out of
3140    // iterations yet a coarser preconditioner still finds an indefinite
3141    // direction — rare; the LM loop re-forms at a heavier ridge).
3142    let cluster =
3143        ClusterJacobiPreconditioner::from_arrow_schur(sys, htt_factors, effective_ridge, backend)?;
3144    let (x1, diag1) = run_pcg_with_preconditioner(
3145        sys,
3146        htt_factors,
3147        effective_ridge,
3148        rhs,
3149        |r| cluster.apply(r),
3150        pcg,
3151        trust,
3152        backend,
3153        gpu_matvec,
3154        metric_weights,
3155        resident.as_ref(),
3156    )?;
3157    if diag1.stopping_reason != PcgStopReason::MaxIter {
3158        return Ok((x1, diag1));
3159    }
3160    let schwarz = AdditiveSchwarzPreconditioner::from_arrow_schur(
3161        sys,
3162        htt_factors,
3163        effective_ridge,
3164        backend,
3165        1,
3166    )?;
3167    let (x2, diag2) = run_pcg_with_preconditioner(
3168        sys,
3169        htt_factors,
3170        effective_ridge,
3171        rhs,
3172        |r| schwarz.apply(r),
3173        pcg,
3174        trust,
3175        backend,
3176        gpu_matvec,
3177        metric_weights,
3178        resident.as_ref(),
3179    )?;
3180    if diag2.stopping_reason != PcgStopReason::MaxIter {
3181        return Ok((x2, diag2));
3182    }
3183    // Final tier — diagonal-assembled additive Schwarz (#299), the cheap-apply
3184    // Schwarz variant. When the dense-block AdditiveSchwarz still ran out of
3185    // iterations its O(Σ b_k²) apply may have throttled the iteration budget on
3186    // a wide subdomain; the diag-assembled variant keeps Schwarz's overlapping
3187    // local-inverse conditioning but applies in O(K), so it can take more CG
3188    // iterations within the same wall budget. Same overlap (1) and same
3189    // curvature-floored ridge as the dense-block tier.
3190    let diag_schwarz = DiagAssembledSchwarzPreconditioner::from_arrow_schur(
3191        sys,
3192        htt_factors,
3193        effective_ridge,
3194        backend,
3195        1,
3196    )?;
3197    let (x3, diag3) = run_pcg_with_preconditioner(
3198        sys,
3199        htt_factors,
3200        effective_ridge,
3201        rhs,
3202        |r| diag_schwarz.apply(r),
3203        pcg,
3204        trust,
3205        backend,
3206        gpu_matvec,
3207        metric_weights,
3208        resident.as_ref(),
3209    )?;
3210    if diag3.stopping_reason != PcgStopReason::MaxIter {
3211        return Ok((x3, diag3));
3212    }
3213    // Richest tier — level-0 incomplete Cholesky (#299). ClusterJacobi keeps the
3214    // full DENSE Cholesky of each component (so on a single large connected
3215    // component it fills the whole `b×b` factor and its `O(b²)` apply throttles
3216    // the CG iteration budget), while the diagonal/Schwarz tiers drop most
3217    // inter-block coupling. IC(0) keeps the component's full structural coupling
3218    // but only the level-0 (no-fill) pattern, so its sparse triangular apply is
3219    // `O(nnz(S[C,C]))` — it can take more CG iterations within the same wall
3220    // budget AND conditions the off-diagonal coupling the cheap tiers discard.
3221    // Last in the ladder so it is only paid when every cheaper tier stalled.
3222    let ic0 = BlockIncompleteCholeskyPreconditioner::from_arrow_schur(
3223        sys,
3224        htt_factors,
3225        effective_ridge,
3226        backend,
3227    )?;
3228    let (x4, diag4) = run_pcg_with_preconditioner(
3229        sys,
3230        htt_factors,
3231        effective_ridge,
3232        rhs,
3233        |r| ic0.apply(r),
3234        pcg,
3235        trust,
3236        backend,
3237        gpu_matvec,
3238        metric_weights,
3239        resident.as_ref(),
3240    )?;
3241    // All five preconditioner tiers (Jacobi -> ClusterJacobi -> AdditiveSchwarz
3242    // -> DiagAssembledSchwarz -> BlockIncompleteCholesky) exhausted their
3243    // iteration budget without driving the residual below tolerance. Returning a
3244    // truncated iterate as `Ok` would feed an arbitrarily-large-residual step
3245    // into the Newton driver, where the PCG diagnostics are discarded. Surface a
3246    // recoverable failure instead so `solve_with_lm_escalation_inner` escalates
3247    // the proximal ridge: better conditioning is precisely what a stalled PCG on
3248    // an ill-conditioned reduced system needs.
3249    if diag4.stopping_reason == PcgStopReason::MaxIter {
3250        return Err(ArrowSchurError::PcgFailed {
3251            reason: format!(
3252                "Schur PCG exhausted all preconditioner tiers (Jacobi, ClusterJacobi, \
3253                 AdditiveSchwarz, DiagAssembledSchwarz, BlockIncompleteCholesky) at MaxIter; \
3254                 final relative residual = {:e}",
3255                diag4.final_relative_residual
3256            ),
3257        });
3258    }
3259    Ok((x4, diag4))
3260}
3261
3262/// Run Steihaug-CG with a generic preconditioner closure.
3263/// Routes matvec through GPU when `gpu_matvec` is set.
3264pub(crate) fn run_pcg_with_preconditioner<ApplyPrec, B: BatchedBlockSolver + Sync>(
3265    sys: &ArrowSchurSystem,
3266    htt_factors: &ArrowFactorSlab,
3267    ridge_beta: f64,
3268    rhs: &Array1<f64>,
3269    apply_prec: ApplyPrec,
3270    pcg: &ArrowPcgOptions,
3271    trust: &ArrowTrustRegionOptions,
3272    backend: &B,
3273    gpu_matvec: Option<&GpuSchurMatvec>,
3274    metric_weights: Option<&MetricWeights>,
3275    resident: Option<&SaeResidentReducedSchur>,
3276) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurError>
3277where
3278    ApplyPrec: FnMut(&Array1<f64>) -> Array1<f64>,
3279{
3280    let max_iters = pcg.max_iterations.min(trust.max_iterations);
3281    let tol = pcg
3282        .relative_tolerance
3283        .max(trust.steihaug_relative_tolerance);
3284    if let Some(gpu_mv) = gpu_matvec {
3285        let gpu_mv = Arc::clone(gpu_mv);
3286        steihaug_cg(
3287            rhs,
3288            move |p, out| gpu_mv(p, out),
3289            apply_prec,
3290            max_iters,
3291            tol,
3292            trust.radius,
3293            metric_weights,
3294        )
3295    } else {
3296        steihaug_cg(
3297            rhs,
3298            |p, out| schur_matvec(sys, htt_factors, ridge_beta, p, out, backend, resident),
3299            apply_prec,
3300            max_iters,
3301            tol,
3302            trust.radius,
3303            metric_weights,
3304        )
3305    }
3306}
3307
3308#[derive(Debug, Clone, Copy)]
3309pub(crate) struct IdentityPreconditioner;
3310
3311impl IdentityPreconditioner {
3312    pub(crate) fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
3313        r.clone()
3314    }
3315}
3316
3317pub(crate) fn steihaug_dense_system(
3318    schur: &Array2<f64>,
3319    rhs: &Array1<f64>,
3320    preconditioner: &IdentityPreconditioner,
3321    pcg: &ArrowPcgOptions,
3322    trust: &ArrowTrustRegionOptions,
3323    metric_weights: Option<&MetricWeights>,
3324) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurError> {
3325    steihaug_cg(
3326        rhs,
3327        |p, out| dense_matvec(schur, p, out),
3328        |r| preconditioner.apply(r),
3329        pcg.max_iterations,
3330        pcg.relative_tolerance,
3331        trust.radius,
3332        metric_weights,
3333    )
3334}
3335
3336pub(crate) fn steihaug_cg<MatVec, ApplyPrec>(
3337    rhs: &Array1<f64>,
3338    mut matvec: MatVec,
3339    mut apply_preconditioner: ApplyPrec,
3340    max_iterations: usize,
3341    relative_tolerance: f64,
3342    trust_radius: f64,
3343    metric_weights: Option<&MetricWeights>,
3344) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurError>
3345where
3346    MatVec: FnMut(&Array1<f64>, &mut Array1<f64>),
3347    ApplyPrec: FnMut(&Array1<f64>) -> Array1<f64>,
3348{
3349    let n = rhs.len();
3350    if let Some(weights) = metric_weights {
3351        assert_eq!(
3352            weights.len(),
3353            n,
3354            "Steihaug-CG metric weight length must match solve dimension"
3355        );
3356    }
3357    let radius = if trust_radius.is_finite() && trust_radius > 0.0 {
3358        trust_radius
3359    } else {
3360        f64::INFINITY
3361    };
3362    let rhs_norm = metric_norm(rhs.view(), metric_weights);
3363    if rhs_norm == 0.0 {
3364        return Ok((Array1::<f64>::zeros(n), PcgDiagnostics::default()));
3365    }
3366    let tol = (relative_tolerance.max(0.0) * rhs_norm).max(PCG_ABSOLUTE_TOLERANCE_FLOOR);
3367    let mut x = Array1::<f64>::zeros(n);
3368    let mut r = rhs.clone();
3369    let mut z = apply_preconditioner(&r);
3370    let mut diag = PcgDiagnostics {
3371        precond_apply_calls: 1,
3372        ..PcgDiagnostics::default()
3373    };
3374    let mut p = z.clone();
3375    let mut rz = metric_dot(&r, &z, metric_weights);
3376    if rz <= 0.0 || !rz.is_finite() {
3377        if radius.is_finite() {
3378            diag.final_relative_residual = metric_norm(r.view(), metric_weights) / rhs_norm;
3379            diag.stopping_reason = PcgStopReason::TrustRegion;
3380            return Ok((step_to_trust_boundary(&x, &r, radius, metric_weights), diag));
3381        }
3382        // Unbounded (radius = ∞) non-positive preconditioned residual: the
3383        // reduced Schur is indefinite at the very first direction. Surface the
3384        // typed curvature-floor signal so `steihaug_pcg_auto` floors the
3385        // operator minimally and retries, instead of failing into a global
3386        // `ridge_β` ramp. `rz = rᵀM⁻¹r` is a preconditioner-metric curvature;
3387        // report it with the residual norm² as the direction scale.
3388        return Err(ArrowSchurError::UnboundedNegativeCurvature {
3389            curvature: rz,
3390            direction_norm_sq: metric_dot(&r, &r, metric_weights),
3391        });
3392    }
3393    if metric_norm(r.view(), metric_weights) <= tol {
3394        diag.final_relative_residual = 0.0;
3395        diag.stopping_reason = PcgStopReason::Converged;
3396        return Ok((x, diag));
3397    }
3398    let mut ap = Array1::<f64>::zeros(n);
3399    // Reused candidate scratch — avoid per-iteration clone of x.
3400    let mut candidate = Array1::<f64>::zeros(n);
3401    for _ in 0..max_iterations {
3402        matvec(&p, &mut ap);
3403        diag.matvec_calls += 1;
3404        diag.iterations += 1;
3405        let pap = metric_dot(&p, &ap, metric_weights);
3406        if pap <= 0.0 || !pap.is_finite() {
3407            if radius.is_finite() {
3408                diag.final_relative_residual = metric_norm(r.view(), metric_weights) / rhs_norm;
3409                diag.stopping_reason = PcgStopReason::TrustRegion;
3410                return Ok((step_to_trust_boundary(&x, &p, radius, metric_weights), diag));
3411            }
3412            // Unbounded negative curvature `pᵀSp ≤ 0`: the reduced Schur is
3413            // indefinite along `p` (the #1026 co-collapse direction). Surface
3414            // the typed signal carrying `pᵀSp` and `‖p‖²` so the caller floors
3415            // the operator by the minimal ridge `δ = |pᵀSp|/‖p‖²` (which makes
3416            // `pᵀ(S+δI)p = 0⁺`) plus a margin, and retries.
3417            return Err(ArrowSchurError::UnboundedNegativeCurvature {
3418                curvature: pap,
3419                direction_norm_sq: metric_dot(&p, &p, metric_weights),
3420            });
3421        }
3422        let alpha = rz / pap;
3423        for i in 0..n {
3424            candidate[i] = x[i] + alpha * p[i];
3425        }
3426        if radius.is_finite() && metric_norm(candidate.view(), metric_weights) >= radius {
3427            diag.final_relative_residual = metric_norm(r.view(), metric_weights) / rhs_norm;
3428            diag.stopping_reason = PcgStopReason::TrustRegion;
3429            return Ok((step_to_trust_boundary(&x, &p, radius, metric_weights), diag));
3430        }
3431        x.assign(&candidate);
3432        for i in 0..n {
3433            r[i] -= alpha * ap[i];
3434        }
3435        if metric_norm(r.view(), metric_weights) <= tol {
3436            diag.final_relative_residual = metric_norm(r.view(), metric_weights) / rhs_norm;
3437            diag.stopping_reason = PcgStopReason::Converged;
3438            return Ok((x, diag));
3439        }
3440        z = apply_preconditioner(&r);
3441        diag.precond_apply_calls += 1;
3442        let rz_next = metric_dot(&r, &z, metric_weights);
3443        if rz_next <= 0.0 || !rz_next.is_finite() {
3444            return Err(ArrowSchurError::PcgFailed {
3445                reason: "non-positive or non-finite PCG residual".to_string(),
3446            });
3447        }
3448        let beta = rz_next / rz;
3449        for i in 0..n {
3450            p[i] = z[i] + beta * p[i];
3451        }
3452        rz = rz_next;
3453    }
3454    diag.final_relative_residual = metric_norm(r.view(), metric_weights) / rhs_norm;
3455    diag.stopping_reason = PcgStopReason::MaxIter;
3456    Ok((x, diag))
3457}
3458
3459pub(crate) fn step_to_trust_boundary(
3460    x: &Array1<f64>,
3461    p: &Array1<f64>,
3462    radius: f64,
3463    metric_weights: Option<&MetricWeights>,
3464) -> Array1<f64> {
3465    let pp = metric_dot(p, p, metric_weights);
3466    if pp == 0.0 {
3467        return x.clone();
3468    }
3469    let xp = metric_dot(x, p, metric_weights);
3470    let xx = metric_dot(x, x, metric_weights);
3471    let disc = (xp * xp + pp * (radius * radius - xx)).max(0.0);
3472    let tau = (-xp + disc.sqrt()) / pp;
3473    let mut out = x.clone();
3474    for i in 0..out.len() {
3475        out[i] += tau * p[i];
3476    }
3477    out
3478}
3479
3480pub(crate) fn dense_matvec(a: &Array2<f64>, x: &Array1<f64>, out: &mut Array1<f64>) {
3481    let n = a.nrows();
3482    for i in 0..n {
3483        let mut acc = 0.0;
3484        for j in 0..n {
3485            acc += a[[i, j]] * x[j];
3486        }
3487        out[i] = acc;
3488    }
3489}
3490
3491pub(crate) fn dot(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
3492    let mut acc = 0.0;
3493    for i in 0..a.len() {
3494        acc += a[i] * b[i];
3495    }
3496    acc
3497}
3498
3499pub(crate) fn metric_dot(
3500    a: &Array1<f64>,
3501    b: &Array1<f64>,
3502    metric_weights: Option<&MetricWeights>,
3503) -> f64 {
3504    assert_eq!(a.len(), b.len());
3505    match metric_weights {
3506        Some(weights) => {
3507            assert_eq!(weights.len(), a.len());
3508            let mut acc = 0.0;
3509            for i in 0..a.len() {
3510                acc += weights[i] * a[i] * b[i];
3511            }
3512            acc
3513        }
3514        None => dot(a, b),
3515    }
3516}
3517
3518pub(crate) fn metric_norm(v: ArrayView1<'_, f64>, metric_weights: Option<&MetricWeights>) -> f64 {
3519    let mut acc = 0.0;
3520    match metric_weights {
3521        Some(weights) => {
3522            assert_eq!(weights.len(), v.len());
3523            for i in 0..v.len() {
3524                acc += weights[i] * v[i] * v[i];
3525            }
3526        }
3527        None => {
3528            for x in v.iter() {
3529                acc += x * x;
3530            }
3531        }
3532    }
3533    acc.sqrt()
3534}
3535
3536pub(crate) fn symmetrize_upper_from_lower(a: &mut Array2<f64>) {
3537    let n = a.nrows().min(a.ncols());
3538    for i in 0..n {
3539        for j in 0..i {
3540            let v = 0.5 * (a[[i, j]] + a[[j, i]]);
3541            a[[i, j]] = v;
3542            a[[j, i]] = v;
3543        }
3544    }
3545}
3546
3547/// Errors raised by [`ArrowSchurSystem::solve`].
3548#[derive(Debug, Clone)]
3549pub enum ArrowSchurError {
3550    /// A per-row `H_tt^(i)` block was not positive-definite at the
3551    /// supplied ridge. Indicates an under-regularized latent block —
3552    /// typically a gauge-free fit without an identifiability penalty.
3553    PerRowFactorFailed { row: usize, reason: String },
3554    /// A per-row `H_tt^(i)` block factored, but the Cholesky factor failed
3555    /// the safe-inversion guard for the Schur reduction. This can be either
3556    /// an excessive diagonal-ratio condition-number estimate or a numerically
3557    /// tiny pivot relative to the row block scale. Cholesky technically
3558    /// succeeded, but the inverse used in
3559    /// `S = H_ββ − Σ_i H_tβ^(i)ᵀ (H_tt^(i))⁻¹ H_tβ^(i)` is contaminated
3560    /// by spectral terms on the order of `κ_i`; functionally
3561    /// equivalent to a PSD-fail for Schur stability. The LM outer
3562    /// wrapper escalates `ridge_t` identically to `PerRowFactorFailed`.
3563    PerRowFactorIllConditioned { row: usize, kappa_estimate: f64 },
3564    /// The Schur complement was not positive-definite. Indicates a
3565    /// near-collinear decoder or a degenerate weighting; the LM outer
3566    /// wrapper should escalate `ridge_beta` and retry.
3567    SchurFactorFailed { reason: String },
3568    /// The BA inexact-step PCG solve failed before producing a usable
3569    /// Steihaug trust-region step.
3570    PcgFailed { reason: String },
3571    /// The UNBOUNDED (trust-radius = ∞) Schur PCG encountered negative
3572    /// curvature `pᵀSp ≤ 0` (or a non-positive preconditioned residual): the
3573    /// reduced Schur is indefinite, the #1026 K≥4 co-collapse signature where
3574    /// a near-singular per-row `H_tt` over-subtracts `S`. With no trust radius
3575    /// there is no boundary to step to, so CG cannot proceed. `curvature` is
3576    /// the offending `pᵀSp` and `direction_norm_sq` the `‖p‖²` of the
3577    /// negative-curvature direction; the caller floors the operator with the
3578    /// minimal ridge `δ = (|curvature|/‖p‖² )·(1+ε)` that restores positive
3579    /// curvature along `p` and retries (matrix-free analogue of the dense
3580    /// `spectral_pd_floored_schur`), rather than blindly inflating `ridge_β`.
3581    UnboundedNegativeCurvature {
3582        curvature: f64,
3583        direction_norm_sq: f64,
3584    },
3585    /// Adaptive proximal damping could not produce an Armijo-accepted
3586    /// nonlinear step.
3587    AdaptiveCorrectionFailed { reason: String },
3588}
3589
3590impl std::fmt::Display for ArrowSchurError {
3591    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
3592        match self {
3593            ArrowSchurError::PerRowFactorFailed { row, reason } => write!(
3594                f,
3595                "arrow-Schur: per-row H_tt^({row}) Cholesky failed: {reason}"
3596            ),
3597            ArrowSchurError::PerRowFactorIllConditioned {
3598                row,
3599                kappa_estimate,
3600            } => write!(
3601                f,
3602                "arrow-Schur: per-row H_tt^({row}) Cholesky succeeded but failed \
3603                 the safe-inversion guard (kappa_estimate={kappa_estimate:e}); \
3604                 Schur reduction would be numerically contaminated"
3605            ),
3606            ArrowSchurError::SchurFactorFailed { reason } => {
3607                write!(f, "arrow-Schur: Schur complement Cholesky failed: {reason}")
3608            }
3609            ArrowSchurError::PcgFailed { reason } => {
3610                write!(f, "arrow-Schur: Schur PCG failed: {reason}")
3611            }
3612            ArrowSchurError::UnboundedNegativeCurvature {
3613                curvature,
3614                direction_norm_sq,
3615            } => write!(
3616                f,
3617                "arrow-Schur: unbounded Schur PCG hit negative curvature pᵀSp={curvature:e} \
3618                 (‖p‖²={direction_norm_sq:e}); reduced Schur is indefinite (co-collapse), \
3619                 retry with a curvature-floor ridge"
3620            ),
3621            ArrowSchurError::AdaptiveCorrectionFailed { reason } => {
3622                write!(
3623                    f,
3624                    "arrow-Schur: adaptive proximal correction failed: {reason}"
3625                )
3626            }
3627        }
3628    }
3629}
3630
3631impl std::error::Error for ArrowSchurError {}
3632
3633// ---------------------------------------------------------------------------
3634// Cholesky helpers (kept local to avoid a new public-API dependency on the
3635// linalg crate. The systems here are tiny per-row (d × d, d ∈ {1..16}) and
3636// modest at the Schur level (K × K, K ∈ {basis size}). For production SAE
3637// scales the Schur factor should switch to faer; this module's `cholesky_lower`
3638// is the obvious replacement site.)
3639// ---------------------------------------------------------------------------
3640
3641pub(crate) fn cholesky_lower(a: &Array2<f64>) -> Result<Array2<f64>, String> {
3642    let n = a.nrows();
3643    if a.ncols() != n {
3644        return Err(format!("cholesky_lower: non-square {}×{}", n, a.ncols()));
3645    }
3646    if let Some((idx, _)) = a.iter().enumerate().find(|(_, v)| !v.is_finite()) {
3647        return Err(format!(
3648            "cholesky_lower: non-finite entry at linear index {idx}"
3649        ));
3650    }
3651
3652    let mut maybe_device = a.clone();
3653    if gam_gpu::try_cholesky_lower_inplace(&mut maybe_device).is_some() {
3654        return Ok(maybe_device);
3655    }
3656
3657    let mut l = Array2::<f64>::zeros((n, n));
3658    for i in 0..n {
3659        for j in 0..=i {
3660            let mut sum = a[[i, j]];
3661            for kk in 0..j {
3662                sum -= l[[i, kk]] * l[[j, kk]];
3663            }
3664            if i == j {
3665                if !sum.is_finite() || sum <= 0.0 {
3666                    return Err(format!(
3667                        "non-PD pivot {sum} at index {i} (matrix is not positive definite)"
3668                    ));
3669                }
3670                l[[i, j]] = sum.sqrt();
3671            } else {
3672                l[[i, j]] = sum / l[[j, j]];
3673            }
3674        }
3675    }
3676    Ok(l)
3677}