gam_solve/arrow_schur/reduced_solve.rs
1//! The reduced `K x K` shared-system solve: dense Schur assembly (direct and
2//! square-root BA), the Schur matvec, the Jacobi/cluster/Schwarz
3//! preconditioners, Steihaug-PCG, and the [`ArrowSchurError`] type.
4
5use super::*;
6
7/// Host budget for a dense reduced Schur `k × k` f64 matrix (#1017). Above this
8/// the dense assembly is refused with a loud `SchurFactorFailed` rather than
9/// OOM-killing the host. 8 GiB ⇒ `k ≈ 32768`; every currently-feasible SAE border
10/// (k ≤ 5120 ⇒ 0.2 GiB) is well under it, while the qwen LLM border (k = 98304 ⇒
11/// 77 GiB) is correctly rejected as matrix-free-only.
12pub(crate) const DENSE_SCHUR_BYTES_BUDGET: u128 = 8 * 1024 * 1024 * 1024;
13
14/// Reduce one contiguous device tile's rows into a private `-Σ leftᵀ·right`
15/// partial (`k×k`).
16///
17/// The tile stacks its per-row `left_i` / `right_i` factors (each `d×k`) into
18/// two `(Σ_i d_i × k)` matrices and tries a single per-ordinal `AᵀB` device
19/// GEMM (`gam_gpu::try_fast_atb_on_ordinal`), which runs on the device this
20/// worker thread already bound — one big GPU GEMM per tile rather than `n` small
21/// CPU ones. When the device primitive declines (no GPU, shape below policy,
22/// transient failure) the tile reduces with the exact CPU `block_gemm_subtract`
23/// loop, so the result is unchanged. The partial is negated so the caller's
24/// `schur += partial` reproduces the serial `schur -= Σ contribution`.
25pub(crate) fn tile_schur_partial<B: BatchedBlockSolver>(
26 sys: &ArrowSchurSystem,
27 htt_factors: &ArrowFactorSlab,
28 backend: &B,
29 kind: SchurReductionKind,
30 ordinal: usize,
31 range: Range<usize>,
32) -> Result<Array2<f64>, ArrowSchurError> {
33 let k = sys.k;
34
35 // Build the per-row contribution factors once; both the GPU stacked-GEMM
36 // and the CPU fallback consume them.
37 let mut factors: Vec<(Array2<f64>, Array2<f64>)> = Vec::with_capacity(range.len());
38 let mut total_d = 0usize;
39 for i in range.clone() {
40 let (left, right) = row_schur_contribution_factors(
41 sys,
42 i,
43 &sys.rows[i],
44 htt_factors.factor(i),
45 backend,
46 kind,
47 )?;
48 total_d += left.nrows();
49 factors.push((left, right));
50 }
51
52 // Stack into (total_d × k) left/right matrices for one device AᵀB GEMM on
53 // this tile's bound ordinal. `try_fast_atb_on_ordinal` returns leftᵀ·right
54 // (k×k); negate into the partial. At an SAE-shaped whole-fit tile with
55 // n=2000 rows, k=2048 shared columns, M=12 local rows per observation, and
56 // K=8 candidate/atom batches, the stacked GEMM is
57 // 2*(n*M)*k^2 = 201_326_592_000 flops per batch, or
58 // 1_610_612_736_000 flops across K=8, so the policy work gate is cleared
59 // even though the observation count is far below the old row floor.
60 if total_d > 0 && k > 0 {
61 let mut left_stack = Array2::<f64>::zeros((total_d, k));
62 let mut right_stack = Array2::<f64>::zeros((total_d, k));
63 let mut base = 0usize;
64 for (left, right) in &factors {
65 let di = left.nrows();
66 left_stack
67 .slice_mut(ndarray::s![base..base + di, ..])
68 .assign(left);
69 right_stack
70 .slice_mut(ndarray::s![base..base + di, ..])
71 .assign(right);
72 base += di;
73 }
74 if let Some(product) =
75 gam_gpu::try_fast_atb_on_ordinal(ordinal, left_stack.view(), right_stack.view())
76 {
77 return Ok(product.mapv(|v| -v));
78 }
79 }
80
81 // CPU fallback: exact per-row block_gemm_subtract into a zero-seeded partial.
82 let mut partial = Array2::<f64>::zeros((k, k));
83 for (left, right) in &factors {
84 backend.block_gemm_subtract(&mut partial, left, right);
85 }
86 Ok(partial)
87}
88
89/// Reduce the per-row Schur contributions `Σ_i H_tβ^(i)ᵀ (H_tt^(i))⁻¹ H_tβ^(i)`
90/// out of `schur` (seeded with `H_ββ + ρ_β·I`).
91///
92/// The per-row contributions are independent — exactly the "sum over independent
93/// arrow-tip blocks" axis the device pool partitions. When more than one GPU is
94/// usable, [`gam_gpu::pool::balanced_partition`] splits the `0..n` rows into
95/// per-device contiguous tiles; each tile is reduced on its own scoped thread
96/// (binding that ordinal's context so the per-row GEMM-subtract offloads to its
97/// device) into a private `k×k` partial, and the partials are summed back into
98/// `schur` in tile order. The tiles are contiguous, ordered to cover `0..n`, and
99/// folded back in that same order, so within each tile the per-row accumulation
100/// order is preserved and the only departure from the serial loop is the
101/// inter-tile reassociation of the reduction sum — the established
102/// reduction-order equivalence the device pool already operates under, well
103/// inside the Newton solve's tolerance.
104///
105/// With a single device (or no GPU) the row loop runs serially in place, which
106/// is bit-for-bit the original behaviour.
107pub(crate) fn reduce_row_schur_contributions<B: BatchedBlockSolver + Sync>(
108 sys: &ArrowSchurSystem,
109 htt_factors: &ArrowFactorSlab,
110 backend: &B,
111 kind: SchurReductionKind,
112 schur: &mut Array2<f64>,
113) -> Result<(), ArrowSchurError> {
114 let n = sys.rows.len();
115 let k = sys.k;
116
117 let tiles = gam_gpu::device_runtime::GpuRuntime::global()
118 .map(|rt| gam_gpu::pool::balanced_partition(rt, n))
119 .filter(|tiles| tiles.len() > 1);
120
121 let Some(tiles) = tiles else {
122 // Single-device / CPU. The per-row contributions `-Σ_i leftᵀ·right` fold
123 // into the `k×k` `schur` independently — the same dense-assembly axis the
124 // multi-GPU tile path partitions, and the dense-Direct analog of the
125 // per-row matvec / streaming `accumulate_chunk` loops already parallelized
126 // for #1017. At the SAE Direct-solve shape (`n` in the thousands, wide
127 // border `k`) this O(n·d·k²) reduction is the dense assembly's whole cost
128 // and was the last serial CPU step on the dense-Schur build.
129 //
130 // Fan it across rayon over fixed row chunks: each chunk reduces its rows
131 // (in row order) into a private zero-seeded `k×k` partial, then the
132 // partials are folded into `schur` in CHUNK order. The per-chunk row order
133 // and the inter-chunk fold order are both fixed independent of thread
134 // scheduling, so the f64 reduction is **bit-identical run-to-run** (the
135 // #1017 determinism gate). NOTE: bit-identical run-to-run does NOT make
136 // it bit-identical to the in-place serial loop — the chunk-boundary
137 // reassociation of the reduction sum is a genuine f64 departure (the
138 // established equivalence `accumulate_chunk` / the per-row matvec operate
139 // under, well inside the Newton solve's tolerance). It bounds candidate-
140 // to-candidate drift to that reassociation margin, so the criterion
141 // ranking is stable EXCEPT for candidates tying within the margin, where
142 // the winner can flip; it is not an exact no-move guarantee (#1211). For
143 // an exact-order guarantee, take the serial path. Stay in-place serial
144 // below the row floor and when already inside a rayon worker (the topology
145 // race fans candidates with `run_topology_race_parallel`) to avoid
146 // nested-rayon oversubscription — the same guard the matvec uses.
147 let n_rows = sys.rows.len();
148 let parallel =
149 n_rows >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
150 if parallel {
151 use rayon::prelude::*;
152 const CHUNK: usize = 64;
153 let partials: Result<Vec<Array2<f64>>, ArrowSchurError> = (0..n_rows)
154 .into_par_iter()
155 .chunks(CHUNK)
156 .map(|idxs| {
157 let mut partial = Array2::<f64>::zeros((k, k));
158 for i in idxs {
159 subtract_row_schur_contribution(
160 sys,
161 i,
162 &sys.rows[i],
163 htt_factors.factor(i),
164 backend,
165 kind,
166 &mut partial,
167 )?;
168 }
169 Ok(partial)
170 })
171 .collect();
172 // Deterministic ordered fold: chunk partials hold `-Σ contribution`
173 // over their rows, so `schur += partial` reproduces the serial
174 // `schur -= Σ contribution` in fixed (chunk, a, b) order.
175 for partial in &partials? {
176 for a in 0..k {
177 for b in 0..k {
178 schur[[a, b]] += partial[[a, b]];
179 }
180 }
181 }
182 return Ok(());
183 }
184 // Serial in-place reduction (original order) — bit-for-bit reference.
185 for (i, row) in sys.rows.iter().enumerate() {
186 subtract_row_schur_contribution(
187 sys,
188 i,
189 row,
190 htt_factors.factor(i),
191 backend,
192 kind,
193 schur,
194 )?;
195 }
196 return Ok(());
197 };
198
199 // Multi-GPU: one private `-Σ leftᵀ·right` partial per contiguous device
200 // tile. Each tile runs on its own scoped worker thread that binds its
201 // ordinal's context and issues a single stacked AᵀB GEMM on that device, so
202 // the tiles' GEMMs overlap across the pool. Folding the partials back into
203 // the H_ββ-seeded `schur` reproduces the serial reduction (up to inter-tile
204 // reassociation).
205 let partials: Result<Vec<Array2<f64>>, ArrowSchurError> = std::thread::scope(|scope| {
206 let handles: Vec<_> = tiles
207 .iter()
208 .map(|(ordinal, range)| {
209 let ordinal = *ordinal;
210 let range = range.clone();
211 scope.spawn(move || {
212 // Bind this ordinal's CUDA context on this worker thread so
213 // the per-row GPU GEMM shims issued from `tile_schur_partial`
214 // offload to that device. A missing context or bind failure
215 // is intentionally consumed without escalation — the shims
216 // no-op back to CPU and the math is unchanged. Off Linux
217 // `GpuRuntime::global()` is always `None`, so this branch
218 // is unreachable and the bind is omitted entirely.
219 #[cfg(target_os = "linux")]
220 {
221 if let Some(ctx) = gam_gpu::device_runtime::cuda_context_for(ordinal) {
222 if ctx.bind_to_thread().is_err() {
223 // Fall through: this tile reduces on the CPU.
224 }
225 }
226 }
227 tile_schur_partial(sys, htt_factors, backend, kind, ordinal, range)
228 })
229 })
230 .collect();
231 handles
232 .into_iter()
233 .map(|handle| {
234 handle
235 .join()
236 .map_err(|_| ArrowSchurError::SchurFactorFailed {
237 reason: "schur-reduction tile thread panicked".to_string(),
238 })?
239 })
240 .collect()
241 });
242 let partials = partials?;
243
244 // Fold partials into `schur` in tile order (contiguous, covering 0..n) so
245 // the per-tile and inter-tile accumulation order is the row order; each
246 // partial holds `-Σ contribution` over its rows, so `schur += partial`
247 // reproduces `schur -= Σ contribution`.
248 for partial in &partials {
249 for a in 0..k {
250 for b in 0..k {
251 schur[[a, b]] += partial[[a, b]];
252 }
253 }
254 }
255 Ok(())
256}
257
258pub(crate) fn build_dense_schur_direct<B: BatchedBlockSolver + Sync>(
259 sys: &ArrowSchurSystem,
260 htt_factors: &ArrowFactorSlab,
261 ridge_beta: f64,
262 backend: &B,
263) -> Result<Array2<f64>, ArrowSchurError> {
264 let k = sys.k;
265 // Materialise H_ββ via the BetaPenaltyOp trait (#296): DensePenaltyOp
266 // for the legacy dense path, structured ops for SAE / Kronecker smooths.
267 let op = sys.effective_penalty_op();
268 if op.dim() != k {
269 return Err(ArrowSchurError::SchurFactorFailed {
270 reason: "Direct BA requires a K×K shared H_ββ penalty operator".to_string(),
271 });
272 }
273 // Fail LOUD, never OOM-kill (#1017): the dense reduced Schur is `k × k` f64.
274 // At SAE LLM borders (qwen `k = 98304` ⇒ 77 GiB) materialising it would crash
275 // the host. The matrix-free device PCG already solves the *step* without it
276 // (`try_device_arrow_direct_sae_pcg`); only the joint-Hessian log-det still
277 // routes here. A matrix-free determinant-lemma log-det (the proper follow-up)
278 // is not yet wired, so refuse the allocation with an actionable error rather
279 // than degrading silently into an OOM. The budget is generous so every
280 // currently-feasible border (k ≤ 5120 ⇒ 0.2 GiB) is unaffected.
281 let dense_bytes = (k as u128).saturating_mul(k as u128).saturating_mul(8);
282 if dense_bytes > DENSE_SCHUR_BYTES_BUDGET {
283 return Err(ArrowSchurError::SchurFactorFailed {
284 reason: format!(
285 "dense reduced Schur is {k}×{k} f64 = {} MiB, exceeding the {} MiB host budget; \
286 this border is matrix-free-only (the device PCG solves the step without the dense \
287 Schur) and a matrix-free determinant-lemma log-det is the required follow-up",
288 dense_bytes / (1024 * 1024),
289 DENSE_SCHUR_BYTES_BUDGET / (1024 * 1024),
290 ),
291 });
292 }
293 let mut schur = op.to_dense();
294 for j in 0..k {
295 schur[[j, j]] += ridge_beta;
296 }
297 reduce_row_schur_contributions(
298 sys,
299 htt_factors,
300 backend,
301 SchurReductionKind::Direct,
302 &mut schur,
303 )?;
304 symmetrize_upper_from_lower(&mut schur);
305 Ok(schur)
306}
307
308pub(crate) fn build_dense_schur_sqrt_ba<B: BatchedBlockSolver + Sync>(
309 sys: &ArrowSchurSystem,
310 htt_factors: &ArrowFactorSlab,
311 ridge_beta: f64,
312 backend: &B,
313) -> Result<Array2<f64>, ArrowSchurError> {
314 let k = sys.k;
315 // Materialise H_ββ via the BetaPenaltyOp trait (#296).
316 let op = sys.effective_penalty_op();
317 if op.dim() != k {
318 return Err(ArrowSchurError::SchurFactorFailed {
319 reason: "Square-Root BA direct solve requires a K×K shared H_ββ penalty operator"
320 .to_string(),
321 });
322 }
323 let mut schur = op.to_dense();
324 for j in 0..k {
325 schur[[j, j]] += ridge_beta;
326 }
327 reduce_row_schur_contributions(
328 sys,
329 htt_factors,
330 backend,
331 SchurReductionKind::SqrtBa,
332 &mut schur,
333 )?;
334 symmetrize_upper_from_lower(&mut schur);
335 Ok(schur)
336}
337
338/// Certified Carson–Higham mixed-precision solve of the reduced dense Schur
339/// system `S Δβ = rhs` (#1014), specialized to the streaming/residency path.
340///
341/// Returns `Some(Δβ)` when certified mixed precision is enabled AND the κ gate
342/// admits the f32 factorization AND the f64 backward-error certificate closes;
343/// `None` in every other case so the caller falls back to the exact f64
344/// triangular solve. The f64 `factor` (whose diagonal carries the exact
345/// `log|S|`) is supplied by the caller and never re-derived here — the logdet
346/// the evidence path reads stays f64 by construction.
347///
348/// Method: store the f64 Cholesky factor as f32, solve in f32, then refine with
349/// residuals `r = rhs − S·x` computed in f64 against the f64 `S`. With
350/// `κ(S)·u_f32 < margin` the refinement contracts at rate `κ·u`, and the
351/// terminating certificate is the normwise backward error
352/// `‖r‖∞ / (‖S‖∞‖x‖∞ + ‖rhs‖∞) ≤ tol`. A non-decreasing residual or an
353/// unmet certificate after `max_refinement_steps` returns `None`.
354pub(crate) fn mixed_precision_reduced_beta(
355 schur: &Array2<f64>,
356 factor: &Array2<f64>,
357 rhs: &Array1<f64>,
358 options: &ArrowSolveOptions,
359) -> Option<Array1<f64>> {
360 let ArrowSolvePrecisionPolicy::CertifiedMixed {
361 max_refinement_steps,
362 residual_relative_tolerance,
363 kappa_unit_roundoff_margin,
364 } = options.solve_precision
365 else {
366 return None;
367 };
368 // The reduced-system mixed-precision path is the dense reduced solve only;
369 // a trust-region-truncated step takes the Steihaug branch below in f64.
370 if options.trust_region.radius.is_finite() {
371 return None;
372 }
373 let n = schur.nrows();
374 if n == 0 {
375 return None;
376 }
377
378 // κ gate: the f32 factorization is only admissible when κ(S)·u_f32 leaves
379 // the refinement contraction headroom the certificate needs.
380 let kappa = cholesky_factor_kappa_estimate(factor);
381 if !kappa.is_finite() || kappa * F32_UNIT_ROUNDOFF >= kappa_unit_roundoff_margin {
382 return None;
383 }
384
385 let factor_f32 = factor.mapv(|v| v as f32);
386 let s_inf = matrix_inf_norm(schur);
387 let rhs_inf = rhs.iter().fold(0.0_f64, |a, &b| a.max(b.abs()));
388 let certificate_tol = residual_relative_tolerance
389 .max(MIXED_PRECISION_CERTIFICATE_EPSILON_MULTIPLIER * f64::EPSILON);
390
391 // f32 solve of the seed system, then f64-residual refinement steps.
392 let mut x = cholesky_solve_lower_f32(&factor_f32, &rhs.mapv(|v| v as f32)).mapv(|v| v as f64);
393 let mut last_residual = f64::INFINITY;
394 for _ in 0..=max_refinement_steps {
395 // Residual r = rhs − S·x in f64 against the f64 model.
396 let sx = schur.dot(&x);
397 let mut r = rhs.clone();
398 r -= &sx;
399 let r_inf = r.iter().fold(0.0_f64, |a, &b| a.max(b.abs()));
400 let x_inf = x.iter().fold(0.0_f64, |a, &b| a.max(b.abs()));
401 let denom = s_inf * x_inf + rhs_inf;
402 let backward_error = if denom > 0.0 { r_inf / denom } else { 0.0 };
403 if backward_error <= certificate_tol {
404 return Some(x);
405 }
406 // Refinement must make monotone progress, else hand back to f64.
407 if !(r_inf < last_residual) {
408 return None;
409 }
410 last_residual = r_inf;
411 // Correction solve in f32 against the f32 factor: S·δ = r.
412 let delta = cholesky_solve_lower_f32(&factor_f32, &r.mapv(|v| v as f32)).mapv(|v| v as f64);
413 x += δ
414 }
415 None
416}
417
418/// Infinity norm (max absolute row sum) of a dense matrix.
419pub(crate) fn matrix_inf_norm(a: &Array2<f64>) -> f64 {
420 let mut max_row = 0.0_f64;
421 for row in a.rows() {
422 let s: f64 = row.iter().map(|v| v.abs()).sum();
423 if s > max_row {
424 max_row = s;
425 }
426 }
427 max_row
428}
429
430/// Spectral positive-definiteness floor for the reduced Schur complement
431/// `S` (#1026 SAE co-collapse SOLVE-path cure).
432///
433/// Reached only after the genuine Cholesky of `S` has REFUSED it (an indefinite
434/// reduced Schur: collapsed atoms drive a per-row `H_tt` near-singular, so the
435/// accumulated `Σ_i H_tβᵀ (H_tt)⁻¹ H_tβ` over-subtracts `H_ββ + ridge_β·I` into a
436/// matrix with a non-positive eigenvalue). Rather than reject and let the LM
437/// loop inflate `ridge_β` over EVERY β direction (the #1026 "crawl"), we
438/// symmetric-eigendecompose `S` and clamp every eigenvalue UP to
439/// `floor·max(λ)`. This is Levenberg–Marquardt restricted to exactly the
440/// indefinite/collapsed subspace: a well-separated positive direction
441/// (`λ ≫ floor·max λ`) keeps its EXACT eigenvalue (`λ.max(floor·max λ) = λ`), so
442/// the Newton step in the healthy β subspace is unchanged, while only the
443/// collapsed directions get the minimal positive stiffness needed for a PD
444/// solve. Returns the floored, symmetric, strictly-PD matrix, or `None` if `S`
445/// has no usable scale (non-finite / all-zero spectrum), in which case the
446/// caller keeps the strict refusal.
447///
448/// Mirrors the per-row evidence floor
449/// [`super::factorization::factor_spectral_deflated_evidence_row`]; the only
450/// difference is the floored VALUE — a small positive `floor·max λ` (Tikhonov,
451/// for an accurate solve) here, vs unit stiffness `+1` (`log 1 = 0`) there (for
452/// the quotient log-det).
453pub(crate) fn spectral_pd_floored_schur(
454 schur: &Array2<f64>,
455 relative_floor: f64,
456) -> Option<Array2<f64>> {
457 let n = schur.nrows();
458 if n == 0 || schur.ncols() != n || !(relative_floor.is_finite() && relative_floor > 0.0) {
459 return None;
460 }
461 // Symmetrise defensively (the assembled Schur is symmetric up to reduction
462 // order; the eig routine assumes exact symmetry).
463 let mut sym = Array2::<f64>::zeros((n, n));
464 for i in 0..n {
465 for j in 0..n {
466 let v = 0.5 * (schur[[i, j]] + schur[[j, i]]);
467 if !v.is_finite() {
468 return None;
469 }
470 sym[[i, j]] = v;
471 }
472 }
473 let (evals, evecs) = sym.eigh(Side::Lower).ok()?;
474 let max_abs = evals.iter().fold(
475 0.0_f64,
476 |acc, &v| if v.is_finite() { acc.max(v.abs()) } else { acc },
477 );
478 if !(max_abs.is_finite() && max_abs > 0.0) {
479 return None;
480 }
481 let floor = relative_floor * max_abs;
482 // Reconstruct `Σ_i max(λ_i, floor) v_i v_iᵀ`: clamp every eigenvalue UP to a
483 // strictly positive `floor`. Healthy positive directions (`λ ≫ floor`) are
484 // untouched; non-positive / tiny collapsed directions are lifted to exactly
485 // `floor`. The result is symmetric PD by construction.
486 let mut conditioned = Array2::<f64>::zeros((n, n));
487 for eig_idx in 0..evals.len() {
488 let lambda = evals[eig_idx];
489 let lambda_floored = if lambda.is_finite() {
490 lambda.max(floor)
491 } else {
492 floor
493 };
494 for i in 0..n {
495 let vi = evecs[[i, eig_idx]];
496 if vi == 0.0 {
497 continue;
498 }
499 for j in 0..n {
500 conditioned[[i, j]] += lambda_floored * vi * evecs[[j, eig_idx]];
501 }
502 }
503 }
504 Some(conditioned)
505}
506
507pub(crate) fn solve_dense_reduced_system(
508 schur: &Array2<f64>,
509 rhs_beta: &Array1<f64>,
510 options: &ArrowSolveOptions,
511 metric_weights: Option<&MetricWeights>,
512) -> Result<(Array1<f64>, Option<Array2<f64>>, PcgDiagnostics), ArrowSchurError> {
513 let factor = match cholesky_lower(schur) {
514 Ok(factor) => factor,
515 Err(e) => {
516 // #1026 — opt-in spectral PD-floor on the indefinite reduced Schur.
517 // When enabled (SAE solve path), condition ONLY the collapsed
518 // directions and re-factor, instead of erroring out and letting the
519 // outer LM loop inflate `ridge_β` over every β direction (the
520 // co-collapse "crawl"). Disabled (default `None`) keeps the strict
521 // refusal so BA / non-SAE callers are bit-for-bit unchanged.
522 match options.schur_pd_floor {
523 Some(relative_floor) => match spectral_pd_floored_schur(schur, relative_floor) {
524 Some(floored) => match cholesky_lower(&floored) {
525 Ok(factor) => {
526 // Solve against the floored (PD) Schur. The healthy β
527 // subspace keeps its exact eigenvalues, so its Δβ is
528 // the exact Newton component; only the collapsed
529 // subspace is minimally damped.
530 let direct =
531 mixed_precision_reduced_beta(&floored, &factor, rhs_beta, options)
532 .unwrap_or_else(|| cholesky_solve_vector(&factor, rhs_beta));
533 if step_inside_trust_region(
534 direct.view(),
535 options.trust_region.radius,
536 metric_weights,
537 ) {
538 return Ok((direct, Some(factor), PcgDiagnostics::default()));
539 }
540 let identity = IdentityPreconditioner;
541 let (delta, diag) = steihaug_dense_system(
542 &floored,
543 rhs_beta,
544 &identity,
545 &ArrowPcgOptions {
546 max_iterations: options.trust_region.max_iterations,
547 relative_tolerance: options
548 .trust_region
549 .steihaug_relative_tolerance,
550 },
551 &options.trust_region,
552 metric_weights,
553 )?;
554 return Ok((delta, Some(factor), diag));
555 }
556 Err(floored_err) => {
557 return Err(ArrowSchurError::SchurFactorFailed {
558 reason: format!(
559 "reduced Schur non-PD ({e}); spectral PD-floor \
560 reconstruction still non-PD: {floored_err}"
561 ),
562 });
563 }
564 },
565 None => {
566 return Err(ArrowSchurError::SchurFactorFailed {
567 reason: format!(
568 "reduced Schur non-PD ({e}); spectral PD-floor declined \
569 (no usable spectrum)"
570 ),
571 });
572 }
573 },
574 None => return Err(ArrowSchurError::SchurFactorFailed { reason: e }),
575 }
576 }
577 };
578 // Ill-conditioned-but-PD Schur guard. The per-row factor checks reject
579 // any single barely-PD H_tt^(i) block, but the reduced Schur complement
580 // S = H_ββ + ridge_β·I − Σ_i H_tβ^(i)ᵀ (H_tt^(i))⁻¹ H_tβ^(i)
581 // accumulates the (H_tt^(i))⁻¹ contributions of every row in finite
582 // precision. With many weak-but-admissible rows those terms can sum to a
583 // Schur matrix whose Cholesky succeeds yet whose condition number is far
584 // past the safe inversion regime, so `cholesky_solve_vector` yields an
585 // inaccurate Δβ that is silently propagated to the Newton step. Apply the
586 // same diagonal-ratio κ proxy used per-row to the reduced factor and treat
587 // an over-threshold estimate as a Schur-stability failure: `SchurFactorFailed`
588 // is already recoverable in `solve_with_lm_escalation_inner`, so this lifts
589 // `ridge_beta` and re-forms a better-conditioned Schur. This guard is
590 // exclusive to the dense Direct / SqrtBA path (the only caller of this
591 // function); the inexact-PCG path tolerates higher κ(S) and is unaffected.
592 // Evidence/log-det-only callers (`tolerate_ill_conditioning`) skip this
593 // rejection: the factor is genuinely PD (Cholesky above succeeded), so its
594 // diagonal still yields an exact `log|S|`, and an inaccurate Δβ is harmless
595 // because the step is discarded.
596 if !options.tolerate_ill_conditioning {
597 let schur_kappa = cholesky_factor_kappa_estimate(&factor);
598 if !schur_kappa.is_finite() || schur_kappa > safe_spd_kappa_max(schur.nrows()) {
599 // #1026 — over-complete SAE dictionaries park surplus atoms dead
600 // (β_k → 0), so the reduced Schur is PD (the Cholesky above succeeded)
601 // but ILL-CONDITIONED: the dead decoder subspace carries near-zero
602 // eigenvalues while the live subspace is healthy. The kappa gate's
603 // concern is an inaccurate Δβ from accumulated (H_tt)⁻¹ contamination —
604 // but on the dead subspace the correct Δβ IS ≈0 (those atoms have no
605 // signal), so the only "inaccuracy" is in directions whose true step is
606 // zero. When the spectral PD-floor is enabled (the SAE solve path),
607 // clamp exactly those collapsed directions up to `floor·max(λ)` and
608 // solve against the floored Schur: the live subspace keeps its EXACT
609 // Newton component, the dead subspace is damped to ≈0, and κ is bounded
610 // so Δβ is accurate where it matters. This is the same conditioning the
611 // non-PD branch above applies; here it also covers the PD-but-ill-
612 // conditioned case so the LM loop does not exhaust `ridge_β` trying to
613 // (futilely) lift a fundamentally rank-deficient dead-atom subspace.
614 // Without the floor (BA / non-SAE callers) the strict refusal stands.
615 if let Some(relative_floor) = options.schur_pd_floor
616 && let Some(floored) = spectral_pd_floored_schur(schur, relative_floor)
617 && let Ok(floored_factor) = cholesky_lower(&floored)
618 {
619 let direct =
620 mixed_precision_reduced_beta(&floored, &floored_factor, rhs_beta, options)
621 .unwrap_or_else(|| cholesky_solve_vector(&floored_factor, rhs_beta));
622 if step_inside_trust_region(
623 direct.view(),
624 options.trust_region.radius,
625 metric_weights,
626 ) {
627 return Ok((direct, Some(floored_factor), PcgDiagnostics::default()));
628 }
629 let identity = IdentityPreconditioner;
630 let (delta, diag) = steihaug_dense_system(
631 &floored,
632 rhs_beta,
633 &identity,
634 &ArrowPcgOptions {
635 max_iterations: options.trust_region.max_iterations,
636 relative_tolerance: options.trust_region.steihaug_relative_tolerance,
637 },
638 &options.trust_region,
639 metric_weights,
640 )?;
641 return Ok((delta, Some(floored_factor), diag));
642 }
643 return Err(ArrowSchurError::SchurFactorFailed {
644 reason: format!(
645 "reduced Schur complement Cholesky succeeded but is ill-conditioned \
646 (kappa_estimate={schur_kappa:e}); accumulated per-row \
647 (H_tt)⁻¹ contamination would yield an inaccurate Δβ"
648 ),
649 });
650 }
651 }
652 // Reduced-system solve. The f64 `factor` is always retained and returned —
653 // its diagonal is the EXACT `log|S|` the evidence path reads, so the logdet
654 // stays f64 regardless of how Δβ is computed (#1014 invariant). When the
655 // streaming/residency path enabled certified mixed precision, the Δβ solve
656 // itself runs f32-then-f64-refined (κ-gated, with the f64 triangular solve
657 // as the automatic fallback); the certificate is the f64 backward error.
658 let direct = mixed_precision_reduced_beta(schur, &factor, rhs_beta, options)
659 .unwrap_or_else(|| cholesky_solve_vector(&factor, rhs_beta));
660 if step_inside_trust_region(direct.view(), options.trust_region.radius, metric_weights) {
661 return Ok((direct, Some(factor), PcgDiagnostics::default()));
662 }
663
664 // Ceres-style trust-region correction: once the dense BA solve proposes a
665 // step outside the trust ball, Steihaug-CG returns the boundary point
666 // without requiring a second dense factorization.
667 let identity = IdentityPreconditioner;
668 let (delta, diag) = steihaug_dense_system(
669 schur,
670 rhs_beta,
671 &identity,
672 &ArrowPcgOptions {
673 max_iterations: options.trust_region.max_iterations,
674 relative_tolerance: options.trust_region.steihaug_relative_tolerance,
675 },
676 &options.trust_region,
677 metric_weights,
678 )?;
679 Ok((delta, Some(factor), diag))
680}
681
682/// Solve an externally accumulated dense reduced β system
683/// `S Δβ = rhs_β` with the same LM-style ridge escalation the full-batch
684/// driver applies: on a `SchurFactorFailed` (non-PD or ill-conditioned `S`),
685/// geometrically grow a proximal ridge on `S`'s diagonal and retry.
686///
687/// Used by the SAE streaming joint fit, which accumulates `S` and `rhs_β` over
688/// re-materialized row chunks (via [`StreamingArrowSchur::take_accumulators`])
689/// and must solve the single global reduced system without a per-row
690/// `ArrowSchurSystem`. `S` is symmetrized from its lower triangle before each
691/// factorization. `base_ridge_beta` is folded into the caller's `S` already;
692/// this routine only adds the *escalation* ridge on top.
693pub fn solve_streaming_reduced_beta(
694 s_acc: &Array2<f64>,
695 rhs_beta: &Array1<f64>,
696 options: &ArrowSolveOptions,
697) -> Result<Array1<f64>, ArrowSchurError> {
698 let mut proximal_ridge = 0.0_f64;
699 let mut last_err: Option<ArrowSchurError> = None;
700 for attempt in 0..=DEFAULT_PROXIMAL_MAX_ATTEMPTS {
701 let mut schur = s_acc.clone();
702 symmetrize_upper_from_lower(&mut schur);
703 if proximal_ridge > 0.0 {
704 for j in 0..schur.nrows() {
705 schur[[j, j]] += proximal_ridge;
706 }
707 }
708 // Reduced K-system on device: Jacobi-preconditioned CG over the dense
709 // symmetric `S`. The `O(K²)` `S·p` matvec runs device-side; only the
710 // K-vectors cross the boundary per CG iteration. This is the dominant
711 // cost of the streaming SAE joint fit at `K = 100K`. Any device-side
712 // failure (`Unavailable`, non-PD Jacobi diagonal) falls through to the
713 // CPU `solve_dense_reduced_system`, which then drives the same proximal
714 // ridge escalation. A genuine device PD failure is non-recoverable for
715 // this attempt's `schur`, so we let the CPU path re-confirm and escalate.
716 if gam_gpu::device_runtime::GpuRuntime::is_available() {
717 match crate::gpu_kernels::arrow_schur::solve_reduced_beta_pcg(
718 &schur,
719 rhs_beta,
720 options.trust_region.max_iterations,
721 options.trust_region.steihaug_relative_tolerance,
722 ) {
723 Ok(delta_beta) => return Ok(delta_beta),
724 Err(crate::gpu_kernels::arrow_schur::ArrowSchurGpuFailure::Unavailable) => {}
725 Err(_) => {
726 // Device declined this `schur` (e.g. non-PD Jacobi diag);
727 // let the CPU path confirm and escalate the proximal ridge.
728 }
729 }
730 }
731 match solve_dense_reduced_system(&schur, rhs_beta, options, None) {
732 Ok((delta_beta, _factor, _diag)) => return Ok(delta_beta),
733 Err(err) => {
734 let recoverable = matches!(
735 err,
736 ArrowSchurError::SchurFactorFailed { .. }
737 | ArrowSchurError::PcgFailed { .. }
738 | ArrowSchurError::UnboundedNegativeCurvature { .. }
739 );
740 last_err = Some(err);
741 if !recoverable || attempt == DEFAULT_PROXIMAL_MAX_ATTEMPTS {
742 break;
743 }
744 proximal_ridge = if proximal_ridge == 0.0 {
745 DEFAULT_PROXIMAL_INITIAL_RIDGE
746 } else {
747 proximal_ridge * DEFAULT_PROXIMAL_RIDGE_GROWTH
748 };
749 }
750 }
751 }
752 Err(last_err.expect("escalation loop set last_err on failure"))
753}
754
755pub(crate) fn step_inside_trust_region(
756 step: ArrayView1<'_, f64>,
757 radius: f64,
758 metric_weights: Option<&MetricWeights>,
759) -> bool {
760 !radius.is_finite() || metric_norm(step, metric_weights) <= radius
761}
762
763/// Below this row count the per-row Schur loop stays sequential: the rayon
764/// fan-out (chunk dispatch + the deterministic per-chunk length-`K` reduction)
765/// costs more than it saves for the handful-of-rows arrow systems that dominate
766/// the non-SAE callers. Above it — the SAE LLM shape (`n` in the thousands,
767/// wide border `k`) that issue #1017 names — the per-row `H_βt (H_tt)⁻¹ H_tβ x`
768/// contributions are the matvec's whole cost and parallelize cleanly.
769pub(crate) const SCHUR_MATVEC_PARALLEL_ROW_MIN: usize = 256;
770
771/// Below this border width `k` the dense `H_ββ` penalty-prologue GEMV stays
772/// sequential: parallelizing a `k×k` matvec only pays once `k²` is large enough
773/// to dwarf the rayon fan-out, which for the arrow callers with narrow borders
774/// it never is. At the SAE LLM border (`k` in the low thousands) the `O(k²)`
775/// prologue is ≈4M flops/CG-iteration and was the serial Amdahl ceiling on the
776/// otherwise per-row-parallel matvec (#1017), so it crosses this threshold and
777/// fans out. 512 keeps the prologue serial for every non-SAE arrow system while
778/// engaging it for the wide SAE/Qwen borders the issue targets.
779pub(crate) const SCHUR_PROLOGUE_PARALLEL_K_MIN: usize = 512;
780
781/// Device-residency CPU analogue for the SAE reduced-Schur matvec (#1017).
782///
783/// In the production SAE joint fit the per-row cross-block factors as
784/// `H_tβ^(i) = L_i P_i`, where `L_i` (`q_i × p`) is the row's local
785/// assignment/coordinate Jacobian and `P_i` (`p × K`, sparse) gathers the
786/// active atoms' decoder blocks (`P_i x = Σ_s φ_s · x[base_s .. base_s+p]`).
787/// The reduced-Schur point-elimination contribution of one row is therefore
788///
789/// ```text
790/// S_i x = H_βt^(i) (H_tt^(i)+ρ_t I)⁻¹ H_tβ^(i) x
791/// = P_iᵀ · [ L_iᵀ (H_tt^(i)+ρ_t I)⁻¹ L_i ] · P_i x
792/// = P_iᵀ G_i (P_i x), G_i := L_iᵀ (H_tt^(i)+ρ_t I)⁻¹ L_i (p×p).
793/// ```
794///
795/// The block `G_i = L_iᵀ Y_i` depends only on the assembled per-row blocks and
796/// the (already-computed, solve-stable) `H_tt` factor — NOT on the CG iterate
797/// `x`. The generic [`schur_matvec`] re-walks `apply_jbeta → apply_l →
798/// solve(d×d) → apply_l_t → scatter` on every CG iteration; this object **stages
799/// the factors `(L_i, Y_i)` once per CG solve** (the "upload X once" residency
800/// mechanism, applied on CPU to the matvec rather than a dense factorization),
801/// turning each subsequent matvec into a sparse gather → two `di×p` GEMVs →
802/// sparse scatter, with no per-iteration triangular solve and no operator-closure
803/// re-walk. It never materialises the dense `p×p` product: `di ≪ p` for SAE
804/// rows, so the factored apply is `2·support_i·p + 2·di·p` flops/row — the two
805/// `di·p` GEMVs PLUS the `support_i·p` sparse gather (`P_i x`) and `support_i·p`
806/// sparse scatter (`P_iᵀ prod`) — versus the dense `p²` block apply, and
807/// `O(n·di·p)` memory (vs `O(n·p²)` ≈ 67 GB at the Qwen shape — the dense form
808/// is OOM). For dense/full active support `support_i` can scale with the active
809/// β-columns, so the gather/scatter term is NOT negligible and is counted here.
810///
811/// Numerically identical to the generic path up to floating-point reassociation
812/// (it differentiates and accumulates the SAME quotient). It is deterministic
813/// run-to-run and within the reassociation margin of the serial path, so the
814/// criterion ranking across topology candidates is stable except for candidates
815/// separated by less than that f64 margin, where reassociation can flip the
816/// near-tie winner — it is NOT an exact no-move guarantee (#1211).
817pub(crate) struct SaeResidentReducedSchur {
818 /// Decoder output dimension `p` (the side length of every `G_i = L_iᵀ Y_i`).
819 pub(crate) p: usize,
820 /// Per-row **factored** residency: `(L_i, Y_i)`, each stored row-major as a
821 /// `di × p` slab (`L_i` = local Jacobian, `Y_i = (H_tt^(i)+ρ_t I)⁻¹ L_i`).
822 /// The reduced block is `G_i = L_iᵀ Y_i` (`p×p`, symmetric PSD), but it has
823 /// rank ≤ `di` and `di ≪ p` for SAE rows (the per-row latent dim is 1–2
824 /// while `p` is the decoder block width, ~2048). Materialising the dense
825 /// `p×p` block would cost `O(n·p²)` memory (≈67 GB at the Qwen shape) and
826 /// `p²` flops per matvec/row; the factored form costs `O(n·di·p)` memory and
827 /// `2·support_i·p + 2·di·p` flops/row, applying `G_i v = L_iᵀ (Y_i v)`
828 /// (sparse gather over `support_i` atoms → `di`-length GEMV → `p`-length
829 /// GEMV → sparse scatter over `support_i` atoms). The `2·support_i·p`
830 /// gather/scatter term is part of the per-row cost — for dense/full support
831 /// `support_i` scales with active β-columns — and is not dropped. A row with
832 /// empty active support / degenerate dims gets `di = 0` and is skipped.
833 /// `(di, L_i, Y_i)` per row; `L_i`/`Y_i` are `di·p`-length row-major buffers.
834 pub(crate) rows: Vec<ResidentRowFactor>,
835 /// Per-row active atom support `(β-block base index, φ weight)`, shared with
836 /// the assembler's [`DeviceSaePcgData`] (no re-clone of the index lists).
837 pub(crate) a_phi: Arc<[Vec<(usize, f64)>]>,
838 /// #1033: per-row local Jacobian `L_i` (row-major `di × p`), SHARED via `Arc`
839 /// with the assembler's [`DeviceSaePcgData`] rather than copied into each
840 /// `ResidentRowFactor`. The staged factor previously held its own verbatim
841 /// row-major copy of `data.local_jac[row]` — a second full `O(n·di·p)` slab
842 /// for zero benefit (the bytes and the `di × p` layout are identical). The
843 /// matvec now reads `L_i = &self.local_jac[row]` directly; only the SOLVED
844 /// factor `Y_i = (H_tt+ρI)⁻¹ L_i` (genuinely new data) stays per-row. Reads
845 /// are byte-for-byte the former `rf.l` (same slab, same `r·p + c` indexing),
846 /// so the matvec/preconditioner output is bit-identical.
847 pub(crate) local_jac: Arc<[Vec<f64>]>,
848}
849
850/// Factored per-row residency block: `G_i = L_iᵀ Y_i` kept as its `di×p` factors
851/// so the matvec never materialises the dense `p×p` product. The local Jacobian
852/// factor `L_i` is NOT stored here — it is shared via
853/// [`SaeResidentReducedSchur::local_jac`] (`&local_jac[row]`); only the solved
854/// `Y_i` is per-row. See [`SaeResidentReducedSchur`].
855pub(crate) struct ResidentRowFactor {
856 /// Row latent dimension `di` (the inner contraction width). `0` ⇒ skipped.
857 pub(crate) di: usize,
858 /// `Y_i = (H_tt^(i)+ρ_t I)⁻¹ L_i` row-major `di × p`. Empty when `di == 0`.
859 pub(crate) y: Vec<f64>,
860}
861
862impl SaeResidentReducedSchur {
863 /// Stage the per-row `G_i = L_iᵀ (H_tt^(i)+ρ_t I)⁻¹ L_i` blocks once, from
864 /// the SAE structure (`DeviceSaePcgData`: `p`, per-row `a_phi`, per-row
865 /// row-major `local_jac` = `L_i`) and the already-factored `H_tt` slab.
866 ///
867 /// Returns `None` when the structure does not match (degenerate `p`, row
868 /// count mismatch) so the caller falls back to the generic matvec. Row
869 /// builds are independent and run under the same deterministic rayon
870 /// discipline as the matvec (each `G_i` is self-contained — no cross-row
871 /// reduction — so there is no ordering subtlety).
872 /// `ridge_t` is NOT a parameter: it is already folded into the factored
873 /// blocks `htt_factors` carry (they factor `H_tt^(i) + ridge_t·I` — see
874 /// `factor_blocks`), so solving against the factor yields `(H_tt^(i)+ρ_t I)⁻¹`
875 /// exactly. The residency block is a pure function of the factor and `L_i`.
876 pub(crate) fn build<B: BatchedBlockSolver + Sync>(
877 sys: &ArrowSchurSystem,
878 htt_factors: &ArrowFactorSlab,
879 backend: &B,
880 ) -> Option<Self> {
881 let data = sys.device_sae_pcg.as_ref()?;
882 let p = data.p;
883 let n = sys.rows.len();
884 if p == 0
885 || sys.htbeta_dense_supplement
886 || data.a_phi.len() != n
887 || data.local_jac.len() != n
888 {
889 return None;
890 }
891 let empty = || ResidentRowFactor {
892 di: 0,
893 y: Vec::new(),
894 };
895 let build_row = |row: usize| -> ResidentRowFactor {
896 let di = sys.row_dims[row];
897 let jac = &data.local_jac[row];
898 // q_i = len/p; must match the row's latent dimension di.
899 if p == 0 || jac.len() != di * p || di == 0 {
900 return empty();
901 }
902 // L_i as a (di × p) matrix (row-major in `local_jac`).
903 let l_i = match ArrayView2::from_shape((di, p), jac.as_slice()) {
904 Ok(v) => v.to_owned(),
905 Err(_) => return empty(),
906 };
907 // Solve (H_tt+ρ_t I) Y = L_i for Y (di × p): one batched back-solve
908 // over the p columns against the cached factor. Stage `(L_i, Y_i)`
909 // — NOT the dense `p×p` product `G_i = L_iᵀ Y_i` — so storage and the
910 // matvec stay `O(di·p)` instead of `O(p²)` (`di ≪ p` for SAE rows).
911 let y = backend.solve_block_matrix(htt_factors.factor(row), l_i.view());
912 // Flatten the SOLVED factor to a `di × p` row-major buffer (iteration
913 // over a standard-layout view is row-major regardless of the source
914 // strides, so the hot loop can index `r*p + c` directly). `L_i` is NOT
915 // copied — the matvec reads it from the shared `local_jac` slab (it is
916 // byte-for-byte `data.local_jac[row]`).
917 let y_flat: Vec<f64> = y.iter().copied().collect();
918 ResidentRowFactor { di, y: y_flat }
919 };
920 let rows: Vec<ResidentRowFactor> =
921 if n >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
922 use rayon::prelude::*;
923 (0..n).into_par_iter().map(build_row).collect()
924 } else {
925 (0..n).map(build_row).collect()
926 };
927 Some(Self {
928 p,
929 rows,
930 a_phi: data.a_phi_shared(),
931 local_jac: data.local_jac_shared(),
932 })
933 }
934
935 /// Accumulate one row's `S_i x = P_iᵀ G_i (P_i x) = P_iᵀ L_iᵀ Y_i (P_i x)`
936 /// into `acc` (length `K`). `gather`/`prod` are caller-owned length-`p`
937 /// buffers and `w` a caller-owned `≥ max_i di`-length buffer, all reused
938 /// across rows to keep the hot loop allocation-free. The matvec applies the
939 /// factored block in four steps: sparse gather `P_i x = Σ_s φ_s·x[base_s..]`
940 /// (`support_i·p` flops), `w = Y_i·(P_i x)` (`di`-length, `di·p` flops),
941 /// `prod = L_iᵀ·w` (`p`-length, `di·p` flops), and sparse scatter
942 /// `acc += P_iᵀ prod` (`support_i·p` flops) — `2·support_i·p + 2·di·p`
943 /// total, never the dense `p²` product. The gather/scatter `2·support_i·p`
944 /// term is counted: it is not dominated by the GEMVs when the active support
945 /// is wide.
946 #[inline]
947 pub(crate) fn row_into(
948 &self,
949 row: usize,
950 x: &Array1<f64>,
951 acc: &mut Array1<f64>,
952 gather: &mut [f64],
953 prod: &mut [f64],
954 w: &mut [f64],
955 ) {
956 let rf = &self.rows[row];
957 let di = rf.di;
958 if di == 0 {
959 return;
960 }
961 let p = self.p;
962 let support = &self.a_phi[row];
963 if support.is_empty() {
964 return;
965 }
966 // Slice `x`/`acc` ONCE so the per-support gather/scatter (the dominant
967 // `support·p` terms for wide active support) run over contiguous `f64`
968 // slices — the compiler can prove unit stride and emit vectorized FMA,
969 // where the former `x[base+j]`/`acc[base+j]` ndarray element indexing
970 // forced a per-element strided lookup + bounds check that blocked
971 // autovectorization. Every accumulation order is unchanged, so the
972 // result is bit-identical to the ndarray-indexed form.
973 let x_slice = x.as_slice().expect("resident matvec x must be contiguous");
974 // P_i x = Σ_s φ_s · x[base_s .. base_s+p] (length p).
975 let gather = &mut gather[..p];
976 for v in gather.iter_mut() {
977 *v = 0.0;
978 }
979 for &(base, phi) in support {
980 if phi == 0.0 {
981 continue;
982 }
983 let xrow = &x_slice[base..base + p];
984 for (g, &xv) in gather.iter_mut().zip(xrow) {
985 *g += phi * xv;
986 }
987 }
988 // w = Y_i · (P_i x) (di × p GEMV → length di). Y_i row-major di×p.
989 for r in 0..di {
990 let yrow = &rf.y[r * p..r * p + p];
991 let mut s = 0.0_f64;
992 for (&yv, &gv) in yrow.iter().zip(gather.iter()) {
993 s += yv * gv;
994 }
995 w[r] = s;
996 }
997 // prod = L_iᵀ · w (p × di GEMV → length p). L_i row-major di×p, so
998 // L_iᵀ[j,r] = L_i[r,j]; accumulate column-by-column over the di rows.
999 // `L_i` is the shared `local_jac[row]` slab (#1033) — byte-for-byte the
1000 // former per-row `rf.l` copy.
1001 let l_i = &self.local_jac[row];
1002 let prod = &mut prod[..p];
1003 for v in prod.iter_mut() {
1004 *v = 0.0;
1005 }
1006 for r in 0..di {
1007 let lrow = &l_i[r * p..r * p + p];
1008 let wr = w[r];
1009 for (pj, &lj) in prod.iter_mut().zip(lrow) {
1010 *pj += lj * wr;
1011 }
1012 }
1013 // acc += P_iᵀ prod = scatter φ_s · prod into base_s blocks.
1014 let acc_slice = acc.as_slice_mut().expect("resident matvec acc must be contiguous");
1015 for &(base, phi) in support {
1016 if phi == 0.0 {
1017 continue;
1018 }
1019 let arow = &mut acc_slice[base..base + p];
1020 for (a, &pv) in arow.iter_mut().zip(prod.iter()) {
1021 *a += phi * pv;
1022 }
1023 }
1024 }
1025
1026 /// Max row latent dim `di` across resident rows — the size of the `w`
1027 /// scratch the matvec needs for the inner `Y_i·(P_i x)` GEMV.
1028 pub(crate) fn max_di(&self) -> usize {
1029 self.rows.iter().map(|r| r.di).max().unwrap_or(0)
1030 }
1031}
1032
1033/// Reduced-Schur matvec `out = S·x` with an optional pre-staged SAE residency
1034/// operator. When `resident` is `Some`, the per-row point-elimination term is
1035/// applied through the resident `p×p` blocks (#1017 CPU residency); otherwise it
1036/// falls back to the generic per-row `apply → solve → transpose` path. Both
1037/// routes accumulate the SAME reduced operator
1038/// `S = H_ββ + ρ_β I − Σ_i H_βt^(i)(H_tt^(i))⁻¹H_tβ^(i)`.
1039pub(crate) fn schur_matvec<B: BatchedBlockSolver + Sync>(
1040 sys: &ArrowSchurSystem,
1041 htt_factors: &ArrowFactorSlab,
1042 ridge_beta: f64,
1043 x: &Array1<f64>,
1044 out: &mut Array1<f64>,
1045 backend: &B,
1046 resident: Option<&SaeResidentReducedSchur>,
1047) {
1048 // `steihaug_cg` reuses one output buffer across iterations and requires
1049 // `matvec` to ASSIGN every entry of `out` (the contract `dense_matvec`
1050 // upholds). This routine builds `S·x` purely by accumulation
1051 // (`penalty_matvec_add`, `out[a] += ridge·x`, `out[a] -= neg_contrib`), so it
1052 // MUST clear `out` first. Without this, iteration n>0 returns `S·x` plus the
1053 // previous call's `S·p`, the PCG solves a corrupted reduced system, and the
1054 // resulting Newton step is inconsistent with the assembled gradient
1055 // (g·δ ≈ 0 — a non-descent direction that defeats the line search).
1056 out.fill(0.0);
1057 let k = sys.k;
1058 // Top-level (not nested in a rayon worker) and big enough to amortize the
1059 // fan-out: the single gate that authorizes BOTH the dense penalty-prologue
1060 // GEMV and the per-row point-elimination loop to go parallel. The topology
1061 // race fans candidates with `run_topology_race_parallel`, so inside a worker
1062 // both stay sequential (no nested-rayon oversubscription).
1063 let parallel =
1064 sys.rows.len() >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
1065 // Route the penalty-side (H_ββ + ridge·I) x product through the prologue:
1066 // no Arc-clone hot-path cost when penalty_op is None (falls back to hbb
1067 // inline); the dense fallback fans across cores at the wide SAE border (#1017).
1068 {
1069 let x_slice = x.as_slice().expect("x must be contiguous");
1070 let out_slice = out.as_slice_mut().expect("out must be contiguous");
1071 sys.penalty_ridge_prologue_into(x_slice, ridge_beta, out_slice, parallel);
1072 }
1073 // The reduced-Schur point-elimination term: `out -= Σ_i H_βt^(i) (H_tt^(i))⁻¹
1074 // H_tβ^(i) x`. Each row contributes an independent length-`K` vector, so for
1075 // the SAE LLM shape (#1017) this is the matvec's whole cost and is
1076 // embarrassingly parallel. Run it under rayon over fixed row chunks, summing
1077 // the per-chunk partials in chunk order so the f64 reduction is bit-identical
1078 // run-to-run regardless of thread scheduling (the #1017 verification gate).
1079 // This is deterministic and within the chunk-reassociation margin of serial,
1080 // so the criterion ranking is stable except for candidates that tie inside
1081 // that f64 margin — not an exact no-move guarantee (#1211). Stay
1082 // sequential when already inside a rayon worker (the topology race fans
1083 // candidates with `run_topology_race_parallel`) to avoid nested-rayon
1084 // oversubscription — the same guard `HyperOperator::mul_mat` uses. The
1085 // `parallel` gate above authorizes this loop too.
1086 let p = resident.map(|r| r.p).unwrap_or(0);
1087 if parallel {
1088 use rayon::prelude::*;
1089 const CHUNK: usize = 64;
1090 let n = sys.rows.len();
1091 let partials: Vec<Array1<f64>> = (0..n)
1092 .into_par_iter()
1093 .chunks(CHUNK)
1094 .map(|idxs| {
1095 let mut acc = Array1::<f64>::zeros(k);
1096 if let Some(res) = resident {
1097 // Resident path: each matvec is gather → factored di×p GEMVs
1098 // → scatter, reading only the pre-staged `(L_i, Y_i)` (no
1099 // per-iteration solve, no dense p×p block).
1100 let mut gather = vec![0.0_f64; p];
1101 let mut prod = vec![0.0_f64; p];
1102 let mut w = vec![0.0_f64; res.max_di()];
1103 for i in idxs {
1104 res.row_into(i, x, &mut acc, &mut gather, &mut prod, &mut w);
1105 }
1106 } else {
1107 let mut local = Array1::<f64>::zeros(sys.d);
1108 for i in idxs {
1109 schur_matvec_row_into(
1110 sys,
1111 htt_factors,
1112 x,
1113 backend,
1114 i,
1115 &mut local,
1116 &mut acc,
1117 );
1118 }
1119 }
1120 acc
1121 })
1122 .collect();
1123 // Deterministic ordered reduction: fold chunk partials left-to-right.
1124 for acc in &partials {
1125 for a in 0..k {
1126 out[a] -= acc[a];
1127 }
1128 }
1129 } else if let Some(res) = resident {
1130 let mut acc = Array1::<f64>::zeros(k);
1131 let mut gather = vec![0.0_f64; p];
1132 let mut prod = vec![0.0_f64; p];
1133 let mut w = vec![0.0_f64; res.max_di()];
1134 for i in 0..sys.rows.len() {
1135 res.row_into(i, x, &mut acc, &mut gather, &mut prod, &mut w);
1136 }
1137 for a in 0..k {
1138 out[a] -= acc[a];
1139 }
1140 } else {
1141 // Allocate scratch at max_d; per-row slice is `..di`.
1142 let mut local = Array1::<f64>::zeros(sys.d);
1143 let mut neg_contrib = Array1::<f64>::zeros(k);
1144 for i in 0..sys.rows.len() {
1145 neg_contrib.fill(0.0);
1146 schur_matvec_row_into(
1147 sys,
1148 htt_factors,
1149 x,
1150 backend,
1151 i,
1152 &mut local,
1153 &mut neg_contrib,
1154 );
1155 for a in 0..k {
1156 out[a] -= neg_contrib[a];
1157 }
1158 }
1159 }
1160}
1161
1162/// Matrix-free reduced-Schur log-determinant `log|S|` via Stochastic Lanczos
1163/// Quadrature on the exact `schur_matvec` apply `v ↦ S·v`, where
1164/// `S = (H_ββ + ρ_β I) − Σ_i H_βt^(i)(H_tt^(i)+ρ_t I)⁻¹H_tβ^(i)` is the SPD
1165/// reduced Schur. **The dense `k×k` `S` is NEVER formed.**
1166///
1167/// This is the memory-matrix-free evidence path for the massive-K manifold SAE.
1168/// The dense evidence routes assemble `S` explicitly (`O(k²)` ≈ 8 GB at the
1169/// K=32k border) and Cholesky-factor it (`O(k³/3)`) purely to read `Σ 2·log Lᵢᵢ`;
1170/// that dense assembly + factor is the massive-K wall (both dense evidence
1171/// routes REFUSE above the in-core budget). Here peak memory is `O(k)` — the SLQ
1172/// Rademacher probe and Lanczos basis vectors — and the cost is
1173/// `O(num_probes·lanczos_steps · matvec)`, each matvec the same `O(n·d·k)`
1174/// reduced-Schur apply the PCG hot loop already runs. Deterministic for a fixed
1175/// `(sys, htt_factors, ρ_β, resident, num_probes, lanczos_steps, seed)` so the
1176/// REML evidence outer loop stays reproducible.
1177///
1178/// `htt_factors` are the per-row `(H_tt^(i)+ρ_t I)` Cholesky factors; `resident`
1179/// is the optional pre-staged SAE residency operator (`None` for the framed /
1180/// closure `H_tβ` path). SLQ is an ESTIMATE — the same accuracy contract the
1181/// device seam already accepts for `k ≥ SCHUR_SLQ_LOGDET_MIN_DIM`; callers that
1182/// need the exact dense log-det at small `k` must stay on the dense route.
1183///
1184/// Crate-internal because the `resident` parameter carries the `pub(crate)`
1185/// [`SaeResidentReducedSchur`] operator; cross-crate callers use the
1186/// [`matrix_free_arrow_evidence_log_det`] convenience, which stages residency
1187/// internally and exposes no crate-private type.
1188pub(crate) fn slq_reduced_schur_log_det<B: BatchedBlockSolver + Sync>(
1189 sys: &ArrowSchurSystem,
1190 htt_factors: &ArrowFactorSlab,
1191 ridge_beta: f64,
1192 backend: &B,
1193 resident: Option<&SaeResidentReducedSchur>,
1194 num_probes: usize,
1195 lanczos_steps: usize,
1196 seed: u64,
1197) -> SlqLogDet {
1198 let k = sys.k;
1199 slq_logdet(
1200 k,
1201 |v| {
1202 // `schur_matvec` clears and fully assigns `out`, so a fresh zeroed
1203 // buffer per apply is correct; the probes fan across rayon workers
1204 // (in `slq_logdet`), and `schur_matvec`'s own row parallelism is
1205 // guarded off inside a worker, so there is no nested oversubscription.
1206 let x = v.to_owned();
1207 let mut out = Array1::<f64>::zeros(k);
1208 schur_matvec(sys, htt_factors, ridge_beta, &x, &mut out, backend, resident);
1209 out
1210 },
1211 num_probes,
1212 lanczos_steps,
1213 seed,
1214 )
1215}
1216
1217/// One-call matrix-free arrow evidence log-determinant for an assembled system.
1218///
1219/// Factors the per-row `H_tt^(i)+ρ_t I` blocks (accumulating
1220/// `log_det_tt = Σ_i Σ_axis 2·log Lᵢᵢ` from the Cholesky diagonals — the cheap
1221/// `O(n·d³)` t-tier term), stages the SAE residency operator when the system
1222/// carries `device_sae_pcg` full-`B` data, and estimates `log|S|` via
1223/// [`slq_reduced_schur_log_det`] with NO dense `k×k` Schur formed at any point.
1224///
1225/// Returns `(log_det_tt, log|S| SLQ estimate)`; the undamped joint evidence
1226/// log-det the Laplace normaliser needs is their sum. Uses the identical
1227/// [`factor_blocks_for_system`] the dense Direct evidence path uses (same gauge
1228/// deflation), so `log_det_tt` matches the dense convention exactly and only the
1229/// `k×k` Schur term is replaced by its matrix-free SLQ estimate.
1230pub fn matrix_free_arrow_evidence_log_det(
1231 sys: &ArrowSchurSystem,
1232 ridge_t: f64,
1233 ridge_beta: f64,
1234 options: &ArrowSolveOptions,
1235 num_probes: usize,
1236 lanczos_steps: usize,
1237 seed: u64,
1238) -> Result<(f64, SlqLogDet), ArrowSchurError> {
1239 let backend = CpuBatchedBlockSolver;
1240 let factorization = factor_blocks_for_system(sys, ridge_t, options, &backend)?;
1241 let htt_factors = factorization.factors;
1242 let mut log_det_tt = 0.0_f64;
1243 for row in 0..htt_factors.len() {
1244 let factor = htt_factors.factor(row);
1245 for axis in 0..factor.nrows() {
1246 log_det_tt += 2.0 * factor[[axis, axis]].ln();
1247 }
1248 }
1249 let resident = SaeResidentReducedSchur::build(sys, &htt_factors, &backend);
1250 let slq = slq_reduced_schur_log_det(
1251 sys,
1252 &htt_factors,
1253 ridge_beta,
1254 &backend,
1255 resident.as_ref(),
1256 num_probes,
1257 lanczos_steps,
1258 seed,
1259 );
1260 Ok((log_det_tt, slq))
1261}
1262
1263/// Accumulate one row's reduced-Schur point-elimination contribution
1264/// `H_βt^(i) (H_tt^(i))⁻¹ H_tβ^(i) x` (length `K`) into `acc`.
1265///
1266/// `local` is caller-owned `≥ sys.d`-length scratch (reused across rows to keep
1267/// the hot loop allocation-free); only `..di` is touched. `acc` is **added to**,
1268/// never cleared, so the caller controls whether contributions sum into a chunk
1269/// partial (parallel path) or a per-row buffer (sequential path).
1270#[inline]
1271pub(crate) fn schur_matvec_row_into<B: BatchedBlockSolver>(
1272 sys: &ArrowSchurSystem,
1273 htt_factors: &ArrowFactorSlab,
1274 x: &Array1<f64>,
1275 backend: &B,
1276 i: usize,
1277 local: &mut Array1<f64>,
1278 acc: &mut Array1<f64>,
1279) {
1280 let row = &sys.rows[i];
1281 let di = sys.row_dims[i];
1282 // H_tβ^(i) · x → local[..di], routed through sys.htbeta_matvec
1283 // when the dense block is absent.
1284 let mut local_i = local.slice_mut(ndarray::s![..di]).to_owned();
1285 local_i.fill(0.0);
1286 sys_htbeta_apply_row(sys, i, row, x.view(), &mut local_i);
1287 let solved = backend.solve_block_vector(htt_factors.factor(i), local_i.view());
1288 // H_βt^(i) · solved accumulates into acc (length k). Routed through
1289 // sys.htbeta_matvec when needed.
1290 sys_htbeta_accumulate_transpose(sys, i, row, solved.view(), acc);
1291}
1292
1293/// One per-term block factor for the block-Jacobi Schur preconditioner.
1294///
1295/// Carries either a dense Cholesky factor (for PD blocks ≤ 256 columns) or
1296/// the scalar inverses for that block's diagonal as a fallback.
1297#[derive(Clone)]
1298pub(crate) enum BlockFactor {
1299 /// Cholesky L stored column-major via faer. `range` identifies the
1300 /// columns in the full K-vector this block covers.
1301 Chol {
1302 factor: FaerLlt<f64>,
1303 range: Range<usize>,
1304 },
1305 /// Scalar fallback: per-element `1/s_aa` for each column in `range`.
1306 Scalar {
1307 inv: Array1<f64>,
1308 range: Range<usize>,
1309 },
1310}
1311
1312impl std::fmt::Debug for BlockFactor {
1313 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1314 match self {
1315 BlockFactor::Chol { range, .. } => {
1316 write!(f, "BlockFactor::Chol {{ range: {:?} }}", range)
1317 }
1318 BlockFactor::Scalar { inv, range } => {
1319 write!(
1320 f,
1321 "BlockFactor::Scalar {{ inv.len: {}, range: {:?} }}",
1322 inv.len(),
1323 range
1324 )
1325 }
1326 }
1327 }
1328}
1329
1330/// Block-Jacobi Schur preconditioner for BA's inexact reduced-system PCG.
1331///
1332/// When [`ArrowSchurSystem::block_offsets`] is populated (via
1333/// [`ArrowSchurSystem::set_block_offsets`]) and the largest block has ≤ 256
1334/// columns, builds one small dense Schur block per term, factors it with
1335/// Cholesky (faer LLT), and applies the preconditioner as per-block
1336/// triangular solves. Non-PD blocks fall back to scalar diagonal inversion
1337/// for that block only. When `block_offsets` is empty or the largest block
1338/// exceeds 256 columns the preconditioner reduces to pure scalar-diagonal
1339/// Jacobi (pre-#283 behaviour), so callers that have not called
1340/// `set_block_offsets` are unaffected.
1341///
1342/// The `block_offsets` plumbing is compatible with issue #287 (custom
1343/// `ParameterBlockSpec` families): those callers supply ranges derived from
1344/// their own block layout.
1345#[derive(Debug, Clone)]
1346pub struct JacobiPreconditioner {
1347 pub(crate) blocks: Vec<BlockFactor>,
1348}
1349
1350/// Maximum block size for which we attempt dense block-Jacobi factorization.
1351pub(crate) const BLOCK_JACOBI_MAX_BLOCK: usize = 256;
1352
1353/// Positive-definiteness floor on a Schur-complement Jacobi diagonal entry.
1354/// A diagonal at or below this value (or non-finite) signals a non-PD reduced
1355/// system: the preconditioner cannot invert it, so the PCG solve fails loudly
1356/// and demands operator regularization rather than returning a garbage scale.
1357pub(crate) const JACOBI_DIAGONAL_PD_FLOOR: f64 = 1e-18;
1358
1359impl JacobiPreconditioner {
1360 /// Build the block-Jacobi (or scalar fallback) preconditioner from the
1361 /// Arrow-Schur system without materializing the full dense Schur
1362 /// complement.
1363 ///
1364 /// When `sys.block_offsets` is non-empty and `max(block_size) ≤ 256`,
1365 /// each block gets a dense `b×b` Schur sub-matrix formed, factored, and
1366 /// stored. Otherwise every column gets its own scalar entry.
1367 pub(crate) fn from_arrow_schur<B: BatchedBlockSolver + Sync>(
1368 sys: &ArrowSchurSystem,
1369 htt_factors: &ArrowFactorSlab,
1370 ridge_beta: f64,
1371 backend: &B,
1372 resident: Option<&SaeResidentReducedSchur>,
1373 ) -> Result<Self, ArrowSchurError> {
1374 let use_block = !sys.block_offsets.is_empty()
1375 && sys
1376 .block_offsets
1377 .iter()
1378 .map(|r| r.end.saturating_sub(r.start))
1379 .max()
1380 .unwrap_or(0)
1381 <= BLOCK_JACOBI_MAX_BLOCK;
1382 if use_block {
1383 if let Some(res) = resident {
1384 Self::build_block_jacobi_resident(sys, ridge_beta, res)
1385 } else {
1386 Self::build_block_jacobi(sys, htt_factors, ridge_beta, backend)
1387 }
1388 } else if let Some(res) = resident {
1389 // #1017 — SAE residency scalar Jacobi. The generic scalar build
1390 // probes `H_tβ^(i) e_a` and re-solves `(H_tt^(i))⁻¹` once for EVERY
1391 // (row, β-column) pair: `O(n·K)` triangular solves and `O(n·K·p)`
1392 // operator-probe work per Newton step, with `K = K_atoms·p` in the
1393 // tens of thousands at LLM shapes. The reduced-Schur diagonal is the
1394 // same quotient the resident `(L_i, Y_i)` factors already carry, so
1395 // read the diagonal straight off them in one support-sparse pass —
1396 // no probe, no per-column solve.
1397 Self::build_scalar_jacobi_resident(sys, ridge_beta, res)
1398 } else {
1399 Self::build_scalar_jacobi(sys, htt_factors, ridge_beta, backend)
1400 }
1401 }
1402
1403 /// Build scalar-diagonal Jacobi: one `BlockFactor::Scalar` of length 1
1404 /// per column. Matches pre-#283 semantics.
1405 ///
1406 /// When `sys.htbeta_matvec` is set and per-row `htbeta` slabs are absent,
1407 /// each column is probed via the matvec (one call per column per row).
1408 pub(crate) fn build_scalar_jacobi<B: BatchedBlockSolver + Sync>(
1409 sys: &ArrowSchurSystem,
1410 htt_factors: &ArrowFactorSlab,
1411 ridge_beta: f64,
1412 backend: &B,
1413 ) -> Result<Self, ArrowSchurError> {
1414 let k = sys.k;
1415 // Extract diagonal of H_ββ via penalty_diagonal_add (#296):
1416 // no Arc-clone; falls back to hbb_diag or hbb[[a,a]] inline.
1417 let mut diag = Array1::<f64>::zeros(k);
1418 {
1419 let diag_slice = diag.as_slice_mut().expect("diag must be contiguous");
1420 sys.penalty_diagonal_add(diag_slice);
1421 }
1422 for a in 0..k {
1423 diag[a] += ridge_beta;
1424 }
1425 // Per-row body: subtract this row's `Σ_a (H_tβ^(i)e_a)ᵀ(H_tt^(i))⁻¹
1426 // (H_tβ^(i)e_a)` contribution into a caller-provided length-`K` diagonal
1427 // accumulator (`-=`). For each column `a`, probe the cross-block (or read
1428 // the dense slab) and compute the scalar point-elimination quotient. The
1429 // `O(K)` solves per row are the build's whole cost; the row contributions
1430 // are independent length-`K` vectors, so a worker sums a chunk into a
1431 // private `diag_part` and the caller folds the partials back in chunk
1432 // order — bit-identical run-to-run (the #1017 preconditioner gate).
1433 let row_into = |i: usize, row: &ArrowRowBlock, diag_part: &mut Array1<f64>| {
1434 let di = sys.row_dims[i];
1435 // Dense-slab fast path (#1017): when the per-row cross-block is a
1436 // materialized `di × k` slab (no matrix-free operator), the entire
1437 // reduced-Schur diagonal contribution for this row is
1438 // `Σ_c H_tβ[c,a] · ((H_tt)⁻¹ H_tβ)[c,a]`. The generic loop below
1439 // re-solved `(H_tt)⁻¹` once PER COLUMN — `O(k)` block solves + `O(k)`
1440 // allocations per row, i.e. `O(n·k)` tiny solves per Newton step
1441 // (the dominant fixed per-solve cost at the SAE wide-border shape,
1442 // k in the tens of thousands). Solve all `k` columns in ONE batched
1443 // block solve instead, then take the column dots. Reassociates the
1444 // diagonal within the documented #1211 preconditioner margin (same as
1445 // the resident no-probe path), and the preconditioner only steers the
1446 // PCG iterate, which still terminates at the PCG tolerance.
1447 if sys.htbeta_matvec.is_none() && row.htbeta.dim() == (di, k) {
1448 let solved = backend.solve_block_matrix(htt_factors.factor(i), row.htbeta.view());
1449 for a in 0..k {
1450 let mut acc = 0.0;
1451 for c in 0..di {
1452 acc += row.htbeta[[c, a]] * solved[[c, a]];
1453 }
1454 diag_part[a] -= acc;
1455 }
1456 return;
1457 }
1458 // Matrix-free path: probe column a. `e_a` stays all-zero between
1459 // columns — set the single active entry and reset it after the probe,
1460 // so we never pay the `O(k)` `e_a.fill(0.0)` per column (that fill was
1461 // `O(n·k²)`). `sys_htbeta_apply_row` zeroes `col_i` internally.
1462 let mut col_i = Array1::<f64>::zeros(di);
1463 let mut e_a = Array1::<f64>::zeros(k);
1464 for a in 0..k {
1465 e_a[a] = 1.0;
1466 sys_htbeta_apply_row(sys, i, row, e_a.view(), &mut col_i);
1467 e_a[a] = 0.0;
1468 let solved = backend.solve_block_vector(htt_factors.factor(i), col_i.view());
1469 let mut acc = 0.0;
1470 for c in 0..di {
1471 acc += col_i[c] * solved[c];
1472 }
1473 diag_part[a] -= acc;
1474 }
1475 };
1476 let n = sys.rows.len();
1477 let parallel =
1478 n >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
1479 if parallel {
1480 use rayon::prelude::*;
1481 const CHUNK: usize = 64;
1482 let partials: Vec<Array1<f64>> = (0..n)
1483 .into_par_iter()
1484 .chunks(CHUNK)
1485 .map(|idxs| {
1486 let mut diag_part = Array1::<f64>::zeros(k);
1487 for i in idxs {
1488 row_into(i, &sys.rows[i], &mut diag_part);
1489 }
1490 diag_part
1491 })
1492 .collect();
1493 // Deterministic ordered reduction: fold chunk partials left-to-right.
1494 for part in &partials {
1495 for a in 0..k {
1496 diag[a] += part[a];
1497 }
1498 }
1499 } else {
1500 for (i, row) in sys.rows.iter().enumerate() {
1501 row_into(i, row, &mut diag);
1502 }
1503 }
1504 let mut blocks = Vec::with_capacity(k);
1505 for a in 0..k {
1506 let v = diag[a];
1507 if !v.is_finite() || v <= JACOBI_DIAGONAL_PD_FLOOR {
1508 return Err(ArrowSchurError::PcgFailed {
1509 reason: format!(
1510 "invalid Schur Jacobi diagonal at index {a}: {v}; \
1511 operator regularization is required"
1512 ),
1513 });
1514 }
1515 blocks.push(BlockFactor::Scalar {
1516 inv: Array1::from_elem(1, 1.0 / v),
1517 range: a..a + 1,
1518 });
1519 }
1520 Ok(Self { blocks })
1521 }
1522
1523 /// Build scalar-diagonal Jacobi from the pre-staged SAE residency factors
1524 /// `(L_i, Y_i)` (#1017).
1525 ///
1526 /// The generic [`Self::build_scalar_jacobi`] forms each reduced-Schur
1527 /// diagonal entry `S_aa = H_ββ,aa + ρ − Σ_i (H_tβ^(i) e_a)ᵀ(H_tt^(i))⁻¹(H_tβ^(i) e_a)`
1528 /// by probing the cross-block operator with the unit vector `e_a` and
1529 /// re-solving `(H_tt^(i))⁻¹` for every `(row, column)` pair — `O(n·K)`
1530 /// triangular solves per Newton step. For the SAE Kronecker cross-block the
1531 /// `a`-th column lives on exactly one active support entry: `a = beta_base + j`
1532 /// for some `(beta_base, φ) ∈ a_phi[i]` and output channel `j ∈ 0..p`, with
1533 /// `H_tβ^(i) e_a = φ · L_i[:, j]`. The point-elimination quotient is then
1534 ///
1535 /// ```text
1536 /// (H_tβ^(i) e_a)ᵀ (H_tt^(i))⁻¹ (H_tβ^(i) e_a)
1537 /// = φ² · L_i[:, j]ᵀ (H_tt^(i))⁻¹ L_i[:, j]
1538 /// = φ² · (L_i[:, j] · Y_i[:, j]), Y_i := (H_tt^(i))⁻¹ L_i.
1539 /// ```
1540 ///
1541 /// so the whole diagonal is accumulated in ONE support-sparse pass over the
1542 /// resident factors — no probe, no per-column solve, the staged `Y_i` reused
1543 /// from the matvec residency. The result is the SAME quotient the generic
1544 /// path computes (up to float reassociation of the row sum), so the PCG
1545 /// preconditioner is unchanged up to that f64 margin. Since the preconditioner
1546 /// only steers the iterate (which still terminates at the PCG tolerance), the
1547 /// criterion ranking is stable except for candidates within that margin,
1548 /// where the near-tie winner can flip — not an exact no-move guarantee (#1211).
1549 pub(crate) fn build_scalar_jacobi_resident(
1550 sys: &ArrowSchurSystem,
1551 ridge_beta: f64,
1552 resident: &SaeResidentReducedSchur,
1553 ) -> Result<Self, ArrowSchurError> {
1554 let k = sys.k;
1555 let p = resident.p;
1556 let n = resident.rows.len();
1557 // Seed with diag(H_ββ) + ridge — same penalty source the generic path
1558 // reads, so the only difference is how the point-elimination term is
1559 // gathered.
1560 let mut diag = Array1::<f64>::zeros(k);
1561 {
1562 let diag_slice = diag.as_slice_mut().expect("diag must be contiguous");
1563 sys.penalty_diagonal_add(diag_slice);
1564 }
1565 for a in 0..k {
1566 diag[a] += ridge_beta;
1567 }
1568 // Per-row point-elimination diagonal: for each active support entry
1569 // `(beta_base, φ)` and channel `j`, subtract `φ² · L_i[:, j]·Y_i[:, j]`
1570 // into `diag[beta_base + j]`. `L_i`/`Y_i` are row-major `di × p`, so the
1571 // `j`-th column dot is `Σ_r L_i[r·p + j]·Y_i[r·p + j]`.
1572 //
1573 // The accumulation is into a SHARED `diag` (rows scatter into overlapping
1574 // `beta_base + j` columns), so — like the generic `build_scalar_jacobi`
1575 // and the `schur_matvec` row loop (#1017) — parallelism uses worker-private
1576 // length-`K` partials folded back in chunk order: each chunk is a
1577 // contiguous ascending row range and rows within it stay ascending, so the
1578 // chunk-ordered fold reproduces the serial `row = 0..n` subtraction order
1579 // bit-for-bit run-to-run (the #1017 determinism gate). Run-to-run
1580 // bit-identity does not extend to bit-identity with the in-place serial
1581 // accumulation, so the preconditioner — and any criterion ranking it
1582 // steers — is stable only up to the chunk-reassociation margin; a near-tie
1583 // winner inside that margin can flip (#1211).
1584 // This build runs once per inexact-PCG solve = O(inner-Newton-iters)
1585 // per fit; at the SAE LLM shape (thousands of rows, wide border `k`) the
1586 // per-row support sweep is the build's whole cost and was on one core.
1587 // The per-channel column dot `col_dot[j] = Σ_r L_i[r·p+j]·Y_i[r·p+j]`
1588 // (the diagonal of `G_i = L_iᵀ(H_tt)⁻¹L_i`) depends ONLY on the row `i`,
1589 // not on the support entry `(beta_base, φ)`. The previous loop recomputed
1590 // it once per support entry — a row with `m` active atoms paid `m·p`
1591 // column dots over `di`. Hoist it: compute the `p` column dots once per
1592 // row into reusable `col_dot` scratch, then each support entry is a pure
1593 // scatter `diag[beta_base+j] -= φ²·col_dot[j]`. Bit-for-bit identical:
1594 // each `col_dot[j]` is the same `r`-ascending sum, and `φ²·col_dot[j]`
1595 // yields identical bits whether `col_dot[j]` was just computed or cached.
1596 let row_into = |row: usize, diag_part: &mut [f64], col_dot: &mut [f64]| {
1597 let rf = &resident.rows[row];
1598 let di = rf.di;
1599 if di == 0 {
1600 return;
1601 }
1602 let support = &resident.a_phi[row];
1603 if support.is_empty() {
1604 return;
1605 }
1606 // `L_i` is the shared `local_jac[row]` slab (#1033) — byte-for-byte
1607 // the former per-row `rf.l` copy.
1608 let l_i = &resident.local_jac[row];
1609 for (j, slot) in col_dot.iter_mut().enumerate().take(p) {
1610 let mut acc = 0.0_f64;
1611 for r in 0..di {
1612 let idx = r * p + j;
1613 acc += l_i[idx] * rf.y[idx];
1614 }
1615 *slot = acc;
1616 }
1617 for &(beta_base, phi) in support {
1618 if phi == 0.0 {
1619 continue;
1620 }
1621 let phi2 = phi * phi;
1622 for j in 0..p {
1623 diag_part[beta_base + j] -= phi2 * col_dot[j];
1624 }
1625 }
1626 };
1627 let parallel =
1628 n >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
1629 if parallel {
1630 use rayon::prelude::*;
1631 const CHUNK: usize = 64;
1632 let partials: Vec<Array1<f64>> = (0..n)
1633 .into_par_iter()
1634 .chunks(CHUNK)
1635 .map(|idxs| {
1636 let mut diag_part = Array1::<f64>::zeros(k);
1637 let mut col_dot = vec![0.0_f64; p];
1638 let slice = diag_part
1639 .as_slice_mut()
1640 .expect("diag_part must be contiguous");
1641 for i in idxs {
1642 row_into(i, slice, &mut col_dot);
1643 }
1644 diag_part
1645 })
1646 .collect();
1647 // Deterministic ordered reduction: fold chunk partials left-to-right
1648 // (each partial already holds the per-row terms subtracted, so add
1649 // them into `diag` in chunk order to mirror the serial subtraction).
1650 for part in &partials {
1651 for a in 0..k {
1652 diag[a] += part[a];
1653 }
1654 }
1655 } else {
1656 let diag_slice = diag.as_slice_mut().expect("diag must be contiguous");
1657 let mut col_dot = vec![0.0_f64; p];
1658 for row in 0..n {
1659 row_into(row, diag_slice, &mut col_dot);
1660 }
1661 }
1662 let mut blocks = Vec::with_capacity(k);
1663 for a in 0..k {
1664 let v = diag[a];
1665 if !v.is_finite() || v <= JACOBI_DIAGONAL_PD_FLOOR {
1666 return Err(ArrowSchurError::PcgFailed {
1667 reason: format!(
1668 "invalid SAE-resident Schur Jacobi diagonal at index {a}: {v}; \
1669 operator regularization is required"
1670 ),
1671 });
1672 }
1673 blocks.push(BlockFactor::Scalar {
1674 inv: Array1::from_elem(1, 1.0 / v),
1675 range: a..a + 1,
1676 });
1677 }
1678 Ok(Self { blocks })
1679 }
1680
1681 /// Build block-Jacobi from the pre-staged SAE residency factors `(L_i, Y_i)`.
1682 ///
1683 /// This is the block analogue of [`Self::build_scalar_jacobi_resident`].
1684 /// When SAE block offsets are small enough to select BetaBlockJacobi (for
1685 /// example per-atom decoder blocks with `basis_size·p <= 256`), the generic
1686 /// block builder materializes every row's dense `(d_i × K)` `H_tβ` by probing
1687 /// the matrix-free operator, then re-solves `(H_tt)⁻¹` for each block column.
1688 /// The resident factors already carry `G_i = L_iᵀ(H_tt)⁻¹L_i`, so each block
1689 /// is assembled by scattering only the active support pairs inside that block:
1690 ///
1691 /// ```text
1692 /// S_block -= Σ_i Σ_(s,t in block support) φ_s φ_t · G_i[channel_s, channel_t]
1693 /// ```
1694 ///
1695 /// It computes the same block-diagonal restriction as the generic path, but
1696 /// avoids the full-row `H_tβ` materialization and per-column triangular solves.
1697 pub(crate) fn build_block_jacobi_resident(
1698 sys: &ArrowSchurSystem,
1699 ridge_beta: f64,
1700 resident: &SaeResidentReducedSchur,
1701 ) -> Result<Self, ArrowSchurError> {
1702 let block_offsets = &sys.block_offsets;
1703 let p = resident.p;
1704 let mut schur_blocks: Vec<Array2<f64>> = Vec::with_capacity(block_offsets.len());
1705 for (block_idx, range) in block_offsets.iter().enumerate() {
1706 let b = range.end - range.start;
1707 let mut schur_block = Array2::<f64>::zeros((b, b));
1708 sys.penalty_block_add(
1709 BetaBlockId(block_idx),
1710 block_offsets.as_ref(),
1711 &mut schur_block,
1712 );
1713 for bi in 0..b {
1714 schur_block[[bi, bi]] += ridge_beta;
1715 }
1716 schur_blocks.push(schur_block);
1717 }
1718
1719 let row_into = |row: usize, blocks: &mut [Array2<f64>]| {
1720 let rf = &resident.rows[row];
1721 let di = rf.di;
1722 if di == 0 {
1723 return;
1724 }
1725 let support = &resident.a_phi[row];
1726 if support.is_empty() {
1727 return;
1728 }
1729 // `L_i` is the shared `local_jac[row]` slab (#1033) — byte-for-byte
1730 // the former per-row `rf.l` copy.
1731 let l_i = &resident.local_jac[row];
1732 for (block_idx, range) in block_offsets.iter().enumerate() {
1733 let block = &mut blocks[block_idx];
1734 for &(base_left, phi_left) in support {
1735 if phi_left == 0.0 {
1736 continue;
1737 }
1738 let left_start = base_left.max(range.start);
1739 let left_end = (base_left + p).min(range.end);
1740 if left_start >= left_end {
1741 continue;
1742 }
1743 for &(base_right, phi_right) in support {
1744 if phi_right == 0.0 {
1745 continue;
1746 }
1747 let right_start = base_right.max(range.start);
1748 let right_end = (base_right + p).min(range.end);
1749 if right_start >= right_end {
1750 continue;
1751 }
1752 let phi = phi_left * phi_right;
1753 for gi in left_start..left_end {
1754 let li = gi - range.start;
1755 let ch_i = gi - base_left;
1756 for gj in right_start..right_end {
1757 let lj = gj - range.start;
1758 let ch_j = gj - base_right;
1759 let mut gij = 0.0_f64;
1760 for r in 0..di {
1761 gij += l_i[r * p + ch_i] * rf.y[r * p + ch_j];
1762 }
1763 block[[li, lj]] -= phi * gij;
1764 }
1765 }
1766 }
1767 }
1768 }
1769 };
1770
1771 let n = resident.rows.len();
1772 let parallel =
1773 n >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
1774 if parallel {
1775 use rayon::prelude::*;
1776 const CHUNK: usize = 64;
1777 let n_blocks = block_offsets.len();
1778 let block_dims: Vec<usize> = block_offsets.iter().map(|r| r.end - r.start).collect();
1779 let partials: Vec<Vec<Array2<f64>>> = (0..n)
1780 .into_par_iter()
1781 .chunks(CHUNK)
1782 .map(|idxs| {
1783 let mut local: Vec<Array2<f64>> = block_dims
1784 .iter()
1785 .map(|&b| Array2::<f64>::zeros((b, b)))
1786 .collect();
1787 for i in idxs {
1788 row_into(i, &mut local);
1789 }
1790 local
1791 })
1792 .collect();
1793 for local in &partials {
1794 for bidx in 0..n_blocks {
1795 schur_blocks[bidx] += &local[bidx];
1796 }
1797 }
1798 } else {
1799 for row in 0..n {
1800 row_into(row, &mut schur_blocks);
1801 }
1802 }
1803
1804 let mut blocks = Vec::with_capacity(block_offsets.len());
1805 for (block_idx, range) in block_offsets.iter().enumerate() {
1806 let b = range.end - range.start;
1807 let schur_block = &schur_blocks[block_idx];
1808 let factor_opt = {
1809 use faer::Side;
1810 let view = FaerArrayView::new(schur_block);
1811 FaerLlt::new(view.as_ref(), Side::Lower).ok()
1812 };
1813 if let Some(llt) = factor_opt {
1814 blocks.push(BlockFactor::Chol {
1815 factor: llt,
1816 range: range.clone(),
1817 });
1818 } else {
1819 let mut inv = Array1::<f64>::zeros(b);
1820 for bi in 0..b {
1821 let v = schur_block[[bi, bi]];
1822 if !v.is_finite() || v <= JACOBI_DIAGONAL_PD_FLOOR {
1823 return Err(ArrowSchurError::PcgFailed {
1824 reason: format!(
1825 "SAE-resident block Jacobi scalar fallback: non-PD diagonal at \
1826 global index {}: {v}; regularization required",
1827 range.start + bi
1828 ),
1829 });
1830 }
1831 inv[bi] = 1.0 / v;
1832 }
1833 blocks.push(BlockFactor::Scalar {
1834 inv,
1835 range: range.clone(),
1836 });
1837 }
1838 }
1839 Ok(Self { blocks })
1840 }
1841
1842 /// Build term-block Jacobi: one dense `b×b` Schur block per term in
1843 /// `sys.block_offsets`.
1844 pub(crate) fn build_block_jacobi<B: BatchedBlockSolver + Sync>(
1845 sys: &ArrowSchurSystem,
1846 htt_factors: &ArrowFactorSlab,
1847 ridge_beta: f64,
1848 backend: &B,
1849 ) -> Result<Self, ArrowSchurError> {
1850 let block_offsets = &sys.block_offsets;
1851
1852 // Initialise every b×b Schur sub-block from H_ββ + ridge·I via
1853 // penalty_block_add (#296): routes to penalty_op or falls back to
1854 // hbb / hbb_diag inline without Arc-clone per loop iteration. These are
1855 // the block-diagonal restrictions of the reduced Schur complement; the
1856 // per-row cross-block contributions are accumulated in the row sweep
1857 // below.
1858 let mut schur_blocks: Vec<Array2<f64>> = Vec::with_capacity(block_offsets.len());
1859 for (block_idx, range) in block_offsets.iter().enumerate() {
1860 let b = range.end - range.start;
1861 let mut schur_block = Array2::<f64>::zeros((b, b));
1862 sys.penalty_block_add(
1863 BetaBlockId(block_idx),
1864 block_offsets.as_ref(),
1865 &mut schur_block,
1866 );
1867 for bi in 0..b {
1868 schur_block[[bi, bi]] += ridge_beta;
1869 }
1870 schur_blocks.push(schur_block);
1871 }
1872
1873 // Subtract Schur contributions:
1874 // S_kk -= H_βt_k^(i) (H_tt^(i))^{-1} H_tβ_k^(i)
1875 //
1876 // Materialize each row's (d_i × K) cross-block ONCE and scatter its
1877 // contribution into every block-diagonal sub-block — mirroring the
1878 // row-outer structure of `build_dense_schur_direct`. The previous
1879 // block-outer form re-materialized every row for each β-block
1880 // (O(n_blocks · n · K) probes); for the matrix-free softmax cross-block
1881 // each materialize is itself O(K²), so that nesting made the
1882 // preconditioner build quadratically more expensive than the direct
1883 // dense Schur it preconditions. sys_htbeta_materialize_row handles the
1884 // Kronecker / htbeta_matvec path transparently.
1885 // Per-row body: materialize the row's `(d_i × K)` cross-block once and
1886 // subtract its `H_βt_k^(i)(H_tt^(i))⁻¹H_tβ_k^(i)` contribution into EACH
1887 // block-diagonal sub-block. Writes INTO a caller-provided `blocks`
1888 // accumulator (`-=`) so a rayon worker can subtract a chunk's rows into
1889 // a worker-private zero-seeded `Vec<Array2>` and the caller folds the
1890 // chunk partials back in chunk order — bit-identical run-to-run
1891 // regardless of thread scheduling (the #1017 verification gate). This
1892 // is deterministic and within the chunk-reassociation margin of serial,
1893 // so the preconditioner, hence the criterion ranking, is stable except
1894 // for near-tie candidates inside that f64 margin — not an exact no-move
1895 // guarantee (#1211).
1896 let row_into = |i: usize,
1897 row: &ArrowRowBlock,
1898 blocks: &mut [Array2<f64>]|
1899 -> Result<(), ArrowSchurError> {
1900 let di = sys.row_dims[i];
1901 let htbeta_full = sys_htbeta_materialize_row(sys, i, row)?;
1902 for (block_idx, range) in block_offsets.iter().enumerate() {
1903 let b = range.end - range.start;
1904 let mut solved_cols = Array2::<f64>::zeros((di, b));
1905 for bj in 0..b {
1906 let gj = range.start + bj;
1907 let rhs = htbeta_full.column(gj).to_owned();
1908 let solved = backend.solve_block_vector(htt_factors.factor(i), rhs.view());
1909 for c in 0..di {
1910 solved_cols[[c, bj]] = solved[c];
1911 }
1912 }
1913 let schur_block = &mut blocks[block_idx];
1914 for bi in 0..b {
1915 let gi = range.start + bi;
1916 for bj in 0..b {
1917 let mut acc = 0.0;
1918 for c in 0..di {
1919 acc += htbeta_full[[c, gi]] * solved_cols[[c, bj]];
1920 }
1921 schur_block[[bi, bj]] -= acc;
1922 }
1923 }
1924 }
1925 Ok(())
1926 };
1927 // Each row materializes an `O(K²)` cross-block (Kronecker) plus `Σ_k b_k`
1928 // triangular solves — the preconditioner build's whole per-row cost at
1929 // the SAE LLM shape (#1017), and the rows are independent. Fan over fixed
1930 // row chunks above the threshold, staying serial for the handful-of-rows
1931 // non-SAE callers and inside a rayon worker (topology-race nesting guard)
1932 // — the same gate `schur_matvec` uses.
1933 let n = sys.rows.len();
1934 let parallel =
1935 n >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
1936 if parallel {
1937 use rayon::prelude::*;
1938 const CHUNK: usize = 64;
1939 let n_blocks = block_offsets.len();
1940 let block_dims: Vec<usize> = block_offsets.iter().map(|r| r.end - r.start).collect();
1941 let partials: Vec<Vec<Array2<f64>>> = (0..n)
1942 .into_par_iter()
1943 .chunks(CHUNK)
1944 .map(|idxs| {
1945 let mut local: Vec<Array2<f64>> = block_dims
1946 .iter()
1947 .map(|&b| Array2::<f64>::zeros((b, b)))
1948 .collect();
1949 for i in idxs {
1950 row_into(i, &sys.rows[i], &mut local)?;
1951 }
1952 Ok::<_, ArrowSchurError>(local)
1953 })
1954 .collect::<Result<Vec<_>, _>>()?;
1955 // Deterministic ordered reduction: fold chunk partials left-to-right.
1956 for local in &partials {
1957 for bidx in 0..n_blocks {
1958 schur_blocks[bidx] += &local[bidx];
1959 }
1960 }
1961 } else {
1962 for (i, row) in sys.rows.iter().enumerate() {
1963 row_into(i, row, &mut schur_blocks)?;
1964 }
1965 }
1966
1967 // Factor each accumulated block: LLT, with scalar-diagonal fallback for
1968 // a block that comes out non-PD at this ridge.
1969 let mut blocks = Vec::with_capacity(block_offsets.len());
1970 for (block_idx, range) in block_offsets.iter().enumerate() {
1971 let b = range.end - range.start;
1972 let schur_block = &schur_blocks[block_idx];
1973 let factor_opt = {
1974 use faer::Side;
1975 let view = FaerArrayView::new(schur_block);
1976 FaerLlt::new(view.as_ref(), Side::Lower).ok()
1977 };
1978 if let Some(llt) = factor_opt {
1979 blocks.push(BlockFactor::Chol {
1980 factor: llt,
1981 range: range.clone(),
1982 });
1983 } else {
1984 // Non-PD block: fall back to scalar diagonal for this block.
1985 let mut inv = Array1::<f64>::zeros(b);
1986 for bi in 0..b {
1987 let v = schur_block[[bi, bi]];
1988 if !v.is_finite() || v <= JACOBI_DIAGONAL_PD_FLOOR {
1989 return Err(ArrowSchurError::PcgFailed {
1990 reason: format!(
1991 "block Jacobi scalar fallback: non-PD diagonal at \
1992 global index {}: {v}; regularization required",
1993 range.start + bi
1994 ),
1995 });
1996 }
1997 inv[bi] = 1.0 / v;
1998 }
1999 blocks.push(BlockFactor::Scalar {
2000 inv,
2001 range: range.clone(),
2002 });
2003 }
2004 }
2005 Ok(Self { blocks })
2006 }
2007
2008 pub(crate) fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
2009 let mut out = Array1::<f64>::zeros(r.len());
2010 for block in &self.blocks {
2011 match block {
2012 BlockFactor::Scalar { inv, range } => {
2013 for (local, gi) in range.clone().enumerate() {
2014 out[gi] = inv[local] * r[gi];
2015 }
2016 }
2017 BlockFactor::Chol { factor, range } => {
2018 let b = range.end - range.start;
2019 let mut rhs = Array1::<f64>::zeros(b);
2020 for (local, gi) in range.clone().enumerate() {
2021 rhs[local] = r[gi];
2022 }
2023 use faer::linalg::solvers::Solve;
2024 let stride = rhs.strides()[0];
2025 let len = rhs.len();
2026 // SAFETY: rhs is a uniquely-borrowed contiguous Array1
2027 // with positive stride (standard layout).
2028 let rhs_mat =
2029 unsafe { faer::MatRef::from_raw_parts(rhs.as_ptr(), len, 1, stride, 0) };
2030 let solved = factor.solve(rhs_mat);
2031 for (local, gi) in range.clone().enumerate() {
2032 out[gi] = solved[(local, 0)];
2033 }
2034 }
2035 }
2036 }
2037 out
2038 }
2039}
2040
2041// ---------------------------------------------------------------------------
2042// Preconditioner ladder: SchurPreconditionerKind, ClusterJacobi,
2043// AdditiveSchwarz (issue #299)
2044// ---------------------------------------------------------------------------
2045
2046/// Which Schur preconditioner to use in the inexact-PCG path.
2047///
2048/// Ladder ordered by cost / effectiveness:
2049/// - `Diagonal`: scalar Jacobi (pre-#283 behaviour).
2050/// - `BetaBlockJacobi`: block-Jacobi per `block_offsets` term (#287).
2051/// - `ClusterJacobi`: one dense block per beta-graph connected component.
2052/// - `AdditiveSchwarz { overlap }`: component + `overlap`-hop expansion,
2053/// overlapping columns averaged by partition-of-unity weights (full dense
2054/// local-inverse apply per subdomain).
2055/// - `DiagAssembledSchwarz { overlap }`: the cheap Schwarz variant (#299) —
2056/// same overlapping decomposition, but each subdomain contributes only the
2057/// diagonal of its local inverse `(A_k⁻¹)_ii`, assembled additively with
2058/// partition-of-unity weights into a single `O(K)`-apply diagonal.
2059/// - `BlockIncompleteCholesky`: level-0 incomplete Cholesky (#299). Within each
2060/// connected component of the β-coupling graph the dense reduced-Schur block
2061/// `S[C,C]` is assembled once, its structural-nonzero pattern is taken as the
2062/// level-0 fill pattern, and a no-fill incomplete Cholesky `S ≈ L̃ L̃ᵀ` is
2063/// formed keeping ONLY that pattern (Saad, *Iterative Methods*, IC(0)). Apply
2064/// is a sparse triangular forward/back solve over `nnz(S[C,C])`, so for a
2065/// large component with internal sparsity it is far cheaper to build and apply
2066/// than `ClusterJacobi`'s full dense Cholesky (which fills the whole `b×b`
2067/// factor) while retaining the inter-block coupling that ClusterJacobi keeps
2068/// but the diagonal/Schwarz tiers discard. A non-PD incomplete pivot degrades
2069/// that component to the scalar reciprocal diagonal.
2070#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2071pub enum SchurPreconditionerKind {
2072 Diagonal,
2073 BetaBlockJacobi,
2074 ClusterJacobi,
2075 AdditiveSchwarz { overlap: usize },
2076 DiagAssembledSchwarz { overlap: usize },
2077 BlockIncompleteCholesky,
2078}
2079
2080/// Escalate beyond BetaBlockJacobi only when K exceeds this value and PCG
2081/// exhausted `max_iterations`.
2082pub(crate) const PRECOND_ESCALATE_K_THRESHOLD: usize = 100;
2083
2084/// #1026 matrix-free Schur curvature-floor (the unbounded-PCG analogue of the
2085/// dense `spectral_pd_floored_schur`). On `pᵀSp ≤ 0` in the unbounded SAE inner
2086/// PCG, the operator ridge is lifted by the minimal amount that restores
2087/// positive curvature along the offending direction, plus this fractional
2088/// margin (so the next CG iterate sits strictly inside the positive cone, not on
2089/// the `0` knife-edge).
2090pub(crate) const SCHUR_CURVATURE_FLOOR_MARGIN: f64 = 1.0e-2;
2091/// Lower bound on the curvature-floor ridge bump, relative to the rhs scale, so
2092/// a `pᵀSp` that rounds to exactly `0` still gets a strictly positive bump.
2093pub(crate) const SCHUR_CURVATURE_FLOOR_REL_FLOOR: f64 = 1.0e-12;
2094/// Ceiling on the accumulated curvature-floor ridge, relative to the rhs scale.
2095/// Beyond this the operator is treated as un-conditionable by a minimal floor
2096/// and the recoverable failure is handed to the outer LM loop (which re-forms
2097/// the whole system at a heavier ridge). Generous so that a large collapsed
2098/// over-subtraction `(H_tβ)²/H_tt` is still reachable.
2099pub(crate) const SCHUR_CURVATURE_FLOOR_REL_CEILING: f64 = 1.0e12;
2100/// Multiplicative growth for the DIAGONAL-refusal ridge escalation (no
2101/// `(curvature, ‖p‖²)` deficit is available there), matching the per-row
2102/// `factor_one_row_result` `RIDGE_GROWTH_FACTOR`.
2103pub(crate) const SCHUR_CURVATURE_FLOOR_DIAG_GROWTH: f64 = 10.0;
2104/// Max curvature-floor ridge-lift attempts before deferring to the outer LM
2105/// loop. The diagonal-refusal path grows ×10 per attempt, so this bounds the
2106/// reachable ridge at `rhs_scale · 10^(attempts)` — ample for any realistic
2107/// over-subtraction while still bounded.
2108pub(crate) const SCHUR_CURVATURE_FLOOR_MAX_ATTEMPTS: usize = 24;
2109
2110/// Cholesky or scalar factor for one cluster of the beta-coefficient graph.
2111#[derive(Clone)]
2112pub(crate) enum ClusterFactor {
2113 Chol {
2114 cols: Vec<usize>,
2115 factor: FaerLlt<f64>,
2116 },
2117 Scalar {
2118 cols: Vec<usize>,
2119 inv: Vec<f64>,
2120 },
2121}
2122
2123impl std::fmt::Debug for ClusterFactor {
2124 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2125 match self {
2126 ClusterFactor::Chol { cols, .. } => {
2127 write!(f, "ClusterFactor::Chol {{ cols.len: {} }}", cols.len())
2128 }
2129 ClusterFactor::Scalar { cols, inv } => write!(
2130 f,
2131 "ClusterFactor::Scalar {{ cols.len: {}, inv.len: {} }}",
2132 cols.len(),
2133 inv.len()
2134 ),
2135 }
2136 }
2137}
2138
2139/// Maximum columns per cluster before scalar fallback.
2140pub(crate) const CLUSTER_JACOBI_MAX_CLUSTER: usize = 512;
2141
2142/// Maximum columns in a single connected component for which the IC(0)
2143/// preconditioner assembles the dense `S[C,C]` to derive its sparsity pattern.
2144/// IC(0) is cheap to APPLY at any size, but the pattern is read from the dense
2145/// assembly, which is `O(b²)` memory; beyond this the component falls back to
2146/// the scalar reciprocal diagonal (the same ceiling concern as
2147/// `CLUSTER_JACOBI_MAX_CLUSTER`, lifted because the IC(0) FACTOR is sparse).
2148pub(crate) const IC0_MAX_COMPONENT: usize = 4096;
2149
2150/// Relative threshold below which an assembled `S[i,j]` is treated as a
2151/// structural zero when deriving the IC(0) level-0 pattern. Scaled by
2152/// `sqrt(|S_ii|·|S_jj|)` so it is invariant to column scaling; this prunes
2153/// entries that are pure FMA round-off (a genuinely decoupled `(i,j)` pair
2154/// assembles to ~0) so they do not enter the kept fill pattern.
2155pub(crate) const IC0_PATTERN_REL_DROP: f64 = 1.0e-13;
2156
2157/// Assemble the dense `b×b` reduced-Schur block for the column set `cols`:
2158/// `S[cols, cols] = H_ββ[cols, cols] + ridge·I − Σ_i H_tβ[cols]ᵀ (H_tt^i)⁻¹ H_tβ[cols]`.
2159///
2160/// Shared by `ClusterJacobiPreconditioner::build_from_column_groups` (which
2161/// Cholesky-factors the returned block) and `DiagAssembledSchwarzPreconditioner`
2162/// (which inverts each subdomain block and keeps only its diagonal). The result
2163/// is the LOWER triangle filled by the row reduction; callers that need the full
2164/// symmetric block must `symmetrize_upper_from_lower`.
2165///
2166/// The per-row Schur contribution is fanned over fixed 64-row chunks above
2167/// `SCHUR_MATVEC_PARALLEL_ROW_MIN` and folded left-to-right so the assembly is
2168/// bit-identical to the serial path (and run-to-run deterministic), exactly as
2169/// in `build_block_jacobi` (#1017).
2170pub(crate) fn assemble_local_schur_block<B: BatchedBlockSolver + Sync>(
2171 sys: &ArrowSchurSystem,
2172 htt_factors: &ArrowFactorSlab,
2173 ridge_beta: f64,
2174 backend: &B,
2175 cols: &[usize],
2176) -> Array2<f64> {
2177 let d = sys.d;
2178 let b = cols.len();
2179 let mut s_block = Array2::<f64>::zeros((b, b));
2180 // Initialise from H_ββ via penalty_subblock_add (#296): routes through
2181 // penalty_op or falls back to hbb / hbb_diag inline.
2182 sys.penalty_subblock_add(cols, &mut s_block);
2183 for bi in 0..b {
2184 s_block[[bi, bi]] += ridge_beta;
2185 }
2186 let cluster_row_into = |row_idx: usize, row: &ArrowRowBlock, acc: &mut Array2<f64>| {
2187 let mut col_vec = Array1::<f64>::zeros(d);
2188 let mut solved_cols = Array2::<f64>::zeros((d, b));
2189 for bj in 0..b {
2190 let gj = cols[bj];
2191 for c in 0..d {
2192 col_vec[c] = row.htbeta[[c, gj]];
2193 }
2194 let solved = backend.solve_block_vector(htt_factors.factor(row_idx), col_vec.view());
2195 for c in 0..d {
2196 solved_cols[[c, bj]] = solved[c];
2197 }
2198 }
2199 for bi in 0..b {
2200 let gi = cols[bi];
2201 for bj in 0..b {
2202 let mut dot = 0.0;
2203 for c in 0..d {
2204 dot += row.htbeta[[c, gi]] * solved_cols[[c, bj]];
2205 }
2206 acc[[bi, bj]] -= dot;
2207 }
2208 }
2209 };
2210 let n = sys.rows.len();
2211 let parallel = n >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
2212 if parallel {
2213 use rayon::prelude::*;
2214 const CHUNK: usize = 64;
2215 let partials: Vec<Array2<f64>> = (0..n)
2216 .into_par_iter()
2217 .chunks(CHUNK)
2218 .map(|idxs| {
2219 let mut local = Array2::<f64>::zeros((b, b));
2220 for i in idxs {
2221 cluster_row_into(i, &sys.rows[i], &mut local);
2222 }
2223 local
2224 })
2225 .collect();
2226 for local in &partials {
2227 s_block += local;
2228 }
2229 } else {
2230 for (row_idx, row) in sys.rows.iter().enumerate() {
2231 cluster_row_into(row_idx, row, &mut s_block);
2232 }
2233 }
2234 s_block
2235}
2236
2237/// Dense Schur block per connected component of the beta-coupling graph.
2238///
2239/// Nodes = beta blocks (`block_offsets`); edges = rows where two blocks
2240/// co-occur with nonzero `H_t_beta` entries. One Cholesky factor per
2241/// connected component; applied as a triangular solve.
2242#[derive(Debug, Clone)]
2243pub struct ClusterJacobiPreconditioner {
2244 pub(crate) clusters: Vec<ClusterFactor>,
2245}
2246
2247impl ClusterJacobiPreconditioner {
2248 pub fn from_arrow_schur<B: BatchedBlockSolver + Sync>(
2249 sys: &ArrowSchurSystem,
2250 htt_factors: &ArrowFactorSlab,
2251 ridge_beta: f64,
2252 backend: &B,
2253 ) -> Result<Self, ArrowSchurError> {
2254 if sys.block_offsets.is_empty() {
2255 let cols: Vec<usize> = (0..sys.k).collect();
2256 return Self::build_from_column_groups(sys, htt_factors, ridge_beta, backend, &[cols]);
2257 }
2258 let graph = BetaCouplingGraph::build(
2259 &sys.block_offsets,
2260 &sys.rows
2261 .iter()
2262 .map(|r| r.htbeta.clone())
2263 .collect::<Vec<_>>(),
2264 );
2265 let col_groups: Vec<Vec<usize>> = graph
2266 .component_partition()
2267 .iter()
2268 .map(|comp_blocks| {
2269 let mut cols: Vec<usize> = comp_blocks
2270 .iter()
2271 .flat_map(|&b| sys.block_offsets[b].clone())
2272 .collect();
2273 cols.sort_unstable();
2274 cols
2275 })
2276 .collect();
2277 Self::build_from_column_groups(sys, htt_factors, ridge_beta, backend, &col_groups)
2278 }
2279
2280 pub(crate) fn build_from_column_groups<B: BatchedBlockSolver + Sync>(
2281 sys: &ArrowSchurSystem,
2282 htt_factors: &ArrowFactorSlab,
2283 ridge_beta: f64,
2284 backend: &B,
2285 col_groups: &[Vec<usize>],
2286 ) -> Result<Self, ArrowSchurError> {
2287 let mut clusters = Vec::with_capacity(col_groups.len());
2288 for cols in col_groups {
2289 let b = cols.len();
2290 if b == 0 {
2291 continue;
2292 }
2293 if b > CLUSTER_JACOBI_MAX_CLUSTER {
2294 let inv = build_schur_scalar_inv(sys, htt_factors, ridge_beta, backend, cols)?;
2295 clusters.push(ClusterFactor::Scalar {
2296 cols: cols.clone(),
2297 inv,
2298 });
2299 continue;
2300 }
2301 let mut s_block =
2302 assemble_local_schur_block(sys, htt_factors, ridge_beta, backend, cols);
2303 symmetrize_upper_from_lower(&mut s_block);
2304 let factor_opt = {
2305 use faer::Side;
2306 let view = FaerArrayView::new(&s_block);
2307 FaerLlt::new(view.as_ref(), Side::Lower).ok()
2308 };
2309 if let Some(llt) = factor_opt {
2310 clusters.push(ClusterFactor::Chol {
2311 cols: cols.clone(),
2312 factor: llt,
2313 });
2314 } else {
2315 let inv = build_schur_scalar_inv(sys, htt_factors, ridge_beta, backend, cols)?;
2316 clusters.push(ClusterFactor::Scalar {
2317 cols: cols.clone(),
2318 inv,
2319 });
2320 }
2321 }
2322 Ok(Self { clusters })
2323 }
2324
2325 pub(crate) fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
2326 let mut out = Array1::<f64>::zeros(r.len());
2327 for cluster in &self.clusters {
2328 apply_cluster(cluster, r, &mut out, &ClusterApplyMode::Overwrite);
2329 }
2330 out
2331 }
2332}
2333
2334/// Additive Schwarz: base components expanded by `overlap` graph-hops;
2335/// overlapping columns averaged by partition-of-unity weights.
2336#[derive(Debug, Clone)]
2337pub struct AdditiveSchwarzPreconditioner {
2338 pub(crate) clusters: Vec<ClusterFactor>,
2339 pub(crate) weights: Vec<f64>,
2340}
2341
2342impl AdditiveSchwarzPreconditioner {
2343 pub fn from_arrow_schur<B: BatchedBlockSolver + Sync>(
2344 sys: &ArrowSchurSystem,
2345 htt_factors: &ArrowFactorSlab,
2346 ridge_beta: f64,
2347 backend: &B,
2348 overlap: usize,
2349 ) -> Result<Self, ArrowSchurError> {
2350 if sys.block_offsets.is_empty() {
2351 let cols: Vec<usize> = (0..sys.k).collect();
2352 let inner = ClusterJacobiPreconditioner::build_from_column_groups(
2353 sys,
2354 htt_factors,
2355 ridge_beta,
2356 backend,
2357 &[cols],
2358 )?;
2359 return Ok(Self {
2360 clusters: inner.clusters,
2361 weights: vec![1.0f64; sys.k],
2362 });
2363 }
2364 let graph = BetaCouplingGraph::build(
2365 &sys.block_offsets,
2366 &sys.rows
2367 .iter()
2368 .map(|r| r.htbeta.clone())
2369 .collect::<Vec<_>>(),
2370 );
2371 let col_groups: Vec<Vec<usize>> = graph
2372 .component_partition()
2373 .iter()
2374 .map(|seed| {
2375 let mut current = seed.clone();
2376 for _ in 0..overlap {
2377 current = graph.expand_one_hop(¤t);
2378 }
2379 let mut cols: Vec<usize> = current
2380 .iter()
2381 .flat_map(|&b| sys.block_offsets[b].clone())
2382 .collect();
2383 cols.sort_unstable();
2384 cols.dedup();
2385 cols
2386 })
2387 .collect();
2388 let mut counts = vec![0u32; sys.k];
2389 for cols in &col_groups {
2390 for &gi in cols {
2391 counts[gi] += 1;
2392 }
2393 }
2394 let weights: Vec<f64> = counts
2395 .iter()
2396 .map(|&c| if c == 0 { 1.0 } else { 1.0 / c as f64 })
2397 .collect();
2398 let inner = ClusterJacobiPreconditioner::build_from_column_groups(
2399 sys,
2400 htt_factors,
2401 ridge_beta,
2402 backend,
2403 &col_groups,
2404 )?;
2405 Ok(Self {
2406 clusters: inner.clusters,
2407 weights,
2408 })
2409 }
2410
2411 pub(crate) fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
2412 let mut out = Array1::<f64>::zeros(r.len());
2413 for cluster in &self.clusters {
2414 apply_cluster(
2415 cluster,
2416 r,
2417 &mut out,
2418 &ClusterApplyMode::Accumulate {
2419 weights: &self.weights,
2420 },
2421 );
2422 }
2423 out
2424 }
2425}
2426
2427/// Diagonal-assembled additive Schwarz (#299).
2428///
2429/// The cheap Schwarz variant the domain-decomposition literature recommends as
2430/// the default for sparse-coupling β-graphs: instead of storing and applying a
2431/// dense Cholesky factor per overlapping subdomain (as
2432/// [`AdditiveSchwarzPreconditioner`] does), it inverts each overlapping
2433/// subdomain Schur block ONCE at build time and keeps only the **diagonal of the
2434/// local inverse** `(A_k⁻¹)_ii`. Those per-subdomain diagonal contributions are
2435/// then assembled additively across overlapping subdomains with partition-of-
2436/// unity weights into a single global diagonal `m`, applied as `out[i] = m[i]·r[i]`.
2437///
2438/// This is strictly richer than scalar Jacobi (`1/S_ii`): the local inverse
2439/// diagonal `(A_k⁻¹)_ii` folds in the off-diagonal coupling WITHIN the subdomain,
2440/// so a strongly-coupled column gets a smaller (better-damped) effective scale
2441/// than its bare reciprocal diagonal would give — while the apply stays `O(K)`
2442/// (one multiply per column), unlike the `O(Σ b_k²)` triangular solves of dense
2443/// Schwarz. For `overlap = 0` and one column per subdomain it reduces exactly to
2444/// scalar Jacobi.
2445#[derive(Debug, Clone)]
2446pub struct DiagAssembledSchwarzPreconditioner {
2447 /// Global per-column multiplier `m[i]`; `out[i] = m[i] · r[i]`.
2448 pub(crate) inv_diag: Vec<f64>,
2449}
2450
2451impl DiagAssembledSchwarzPreconditioner {
2452 pub fn from_arrow_schur<B: BatchedBlockSolver + Sync>(
2453 sys: &ArrowSchurSystem,
2454 htt_factors: &ArrowFactorSlab,
2455 ridge_beta: f64,
2456 backend: &B,
2457 overlap: usize,
2458 ) -> Result<Self, ArrowSchurError> {
2459 // Build the overlapping subdomain column groups exactly like
2460 // AdditiveSchwarz (component partition + `overlap` graph-hop expansion),
2461 // so the two Schwarz variants decompose the β space identically and
2462 // differ only in how each subdomain's local inverse is applied.
2463 let col_groups: Vec<Vec<usize>> = if sys.block_offsets.is_empty() {
2464 vec![(0..sys.k).collect()]
2465 } else {
2466 let graph = BetaCouplingGraph::build(
2467 &sys.block_offsets,
2468 &sys.rows
2469 .iter()
2470 .map(|r| r.htbeta.clone())
2471 .collect::<Vec<_>>(),
2472 );
2473 graph
2474 .component_partition()
2475 .iter()
2476 .map(|seed| {
2477 let mut current = seed.clone();
2478 for _ in 0..overlap {
2479 current = graph.expand_one_hop(¤t);
2480 }
2481 let mut cols: Vec<usize> = current
2482 .iter()
2483 .flat_map(|&b| sys.block_offsets[b].clone())
2484 .collect();
2485 cols.sort_unstable();
2486 cols.dedup();
2487 cols
2488 })
2489 .collect()
2490 };
2491 Self::build_from_column_groups(sys, htt_factors, ridge_beta, backend, &col_groups)
2492 }
2493
2494 pub(crate) fn build_from_column_groups<B: BatchedBlockSolver + Sync>(
2495 sys: &ArrowSchurSystem,
2496 htt_factors: &ArrowFactorSlab,
2497 ridge_beta: f64,
2498 backend: &B,
2499 col_groups: &[Vec<usize>],
2500 ) -> Result<Self, ArrowSchurError> {
2501 // Partition-of-unity weights: a column shared by `c` subdomains gets each
2502 // of its `c` diagonal contributions scaled by `1/c`, so the assembled
2503 // diagonal is a convex combination (and reduces to a single contribution
2504 // for non-overlapping columns).
2505 let mut counts = vec![0u32; sys.k];
2506 for cols in col_groups {
2507 for &gi in cols {
2508 counts[gi] += 1;
2509 }
2510 }
2511 let mut accum = vec![0.0f64; sys.k];
2512 for cols in col_groups {
2513 let b = cols.len();
2514 if b == 0 {
2515 continue;
2516 }
2517 // For large subdomains, the dense inverse is too costly; fall back to
2518 // the global scalar Schur diagonal inverse `1/S_ii` for those columns
2519 // (the diag-assembled variant then coincides with scalar Jacobi over
2520 // that subdomain, which is exactly the intended cheap degradation).
2521 if b > CLUSTER_JACOBI_MAX_CLUSTER {
2522 let inv = build_schur_scalar_inv(sys, htt_factors, ridge_beta, backend, cols)?;
2523 for (local, &gi) in cols.iter().enumerate() {
2524 let w = if counts[gi] == 0 {
2525 1.0
2526 } else {
2527 1.0 / counts[gi] as f64
2528 };
2529 accum[gi] += w * inv[local];
2530 }
2531 continue;
2532 }
2533 let mut s_block =
2534 assemble_local_schur_block(sys, htt_factors, ridge_beta, backend, cols);
2535 symmetrize_upper_from_lower(&mut s_block);
2536 // Diagonal of the local inverse `(A_k⁻¹)_ii`, obtained by solving
2537 // `A_k X = I` through the same faer Cholesky used elsewhere; on a
2538 // non-PD local block, degrade to the scalar reciprocal diagonal.
2539 let local_inv_diag = match local_inverse_diagonal(&s_block) {
2540 Some(diag) => diag,
2541 None => {
2542 let inv = build_schur_scalar_inv(sys, htt_factors, ridge_beta, backend, cols)?;
2543 inv
2544 }
2545 };
2546 for (local, &gi) in cols.iter().enumerate() {
2547 let w = if counts[gi] == 0 {
2548 1.0
2549 } else {
2550 1.0 / counts[gi] as f64
2551 };
2552 accum[gi] += w * local_inv_diag[local];
2553 }
2554 }
2555 // A column never covered by any subdomain (only possible for `k` columns
2556 // with no block_offsets coverage) keeps a neutral unit scale.
2557 for (gi, &c) in counts.iter().enumerate() {
2558 if c == 0 {
2559 accum[gi] = 1.0;
2560 }
2561 }
2562 for (gi, m) in accum.iter().enumerate() {
2563 if !m.is_finite() || *m <= 0.0 {
2564 return Err(ArrowSchurError::PcgFailed {
2565 reason: format!(
2566 "diag-assembled Schwarz: non-positive assembled diagonal at index {gi}: {m}"
2567 ),
2568 });
2569 }
2570 }
2571 Ok(Self { inv_diag: accum })
2572 }
2573
2574 pub(crate) fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
2575 let mut out = Array1::<f64>::zeros(r.len());
2576 for (gi, &m) in self.inv_diag.iter().enumerate() {
2577 out[gi] = m * r[gi];
2578 }
2579 out
2580 }
2581}
2582
2583/// Diagonal of `A⁻¹` for a small dense SPD block `A`, via the same faer
2584/// Cholesky used by the cluster/Schwarz factors. Returns `None` if `A` is not
2585/// positive-definite (caller degrades to the scalar reciprocal diagonal).
2586pub(crate) fn local_inverse_diagonal(a: &Array2<f64>) -> Option<Vec<f64>> {
2587 let b = a.nrows();
2588 let llt = {
2589 use faer::Side;
2590 let view = FaerArrayView::new(a);
2591 FaerLlt::new(view.as_ref(), Side::Lower).ok()?
2592 };
2593 use faer::linalg::solvers::Solve;
2594 let mut diag = Vec::with_capacity(b);
2595 for col in 0..b {
2596 // Solve `A x = e_col`; the `col`-th entry of `x` is `(A⁻¹)_{col,col}`.
2597 let mut rhs = Array1::<f64>::zeros(b);
2598 rhs[col] = 1.0;
2599 let stride = rhs.strides()[0];
2600 let len = rhs.len();
2601 // SAFETY: `rhs` is a uniquely-borrowed contiguous `Array1<f64>` of `len`
2602 // elements with positive row stride; a single column never dereferences
2603 // the column stride, so `0` is sound.
2604 let rhs_mat = unsafe { faer::MatRef::from_raw_parts(rhs.as_ptr(), len, 1, stride, 0) };
2605 let solved = llt.solve(rhs_mat);
2606 diag.push(solved[(col, 0)]);
2607 }
2608 Some(diag)
2609}
2610
2611/// How a cluster factor's contribution is written into the output vector.
2612///
2613/// `Overwrite` assigns `out[gi] = value` (non-overlapping clusters, each global
2614/// column touched by exactly one cluster). `Accumulate` adds the partition-of-unity
2615/// weighted contribution `out[gi] += weights[gi] * value` (overlapping Schwarz
2616/// clusters, where a column may belong to several clusters).
2617pub(crate) enum ClusterApplyMode<'w> {
2618 Overwrite,
2619 Accumulate { weights: &'w [f64] },
2620}
2621
2622impl ClusterApplyMode<'_> {
2623 #[inline]
2624 pub(crate) fn write(&self, out: &mut Array1<f64>, gi: usize, value: f64) {
2625 match self {
2626 ClusterApplyMode::Overwrite => out[gi] = value,
2627 ClusterApplyMode::Accumulate { weights } => out[gi] += weights[gi] * value,
2628 }
2629 }
2630}
2631
2632/// Apply a single cluster factor to the residual `r`, writing into `out`
2633/// according to `mode` (overwrite for non-overlapping clusters, weighted
2634/// accumulate for overlapping Schwarz clusters).
2635pub(crate) fn apply_cluster(
2636 cluster: &ClusterFactor,
2637 r: &Array1<f64>,
2638 out: &mut Array1<f64>,
2639 mode: &ClusterApplyMode<'_>,
2640) {
2641 match cluster {
2642 ClusterFactor::Scalar { cols, inv } => {
2643 for (local, &gi) in cols.iter().enumerate() {
2644 mode.write(out, gi, inv[local] * r[gi]);
2645 }
2646 }
2647 ClusterFactor::Chol { cols, factor } => {
2648 let b = cols.len();
2649 let mut rhs = Array1::<f64>::zeros(b);
2650 for (local, &gi) in cols.iter().enumerate() {
2651 rhs[local] = r[gi];
2652 }
2653 use faer::linalg::solvers::Solve;
2654 let stride = rhs.strides()[0];
2655 let len = rhs.len();
2656 // SAFETY: rhs is uniquely-borrowed contiguous Array1 with positive stride.
2657 let rhs_mat = unsafe { faer::MatRef::from_raw_parts(rhs.as_ptr(), len, 1, stride, 0) };
2658 let solved = factor.solve(rhs_mat);
2659 for (local, &gi) in cols.iter().enumerate() {
2660 mode.write(out, gi, solved[(local, 0)]);
2661 }
2662 }
2663 }
2664}
2665
2666/// One connected-component factor of the block IC(0) preconditioner.
2667///
2668/// `IncompleteChol` holds a sparse lower-triangular `L̃` in column-compressed
2669/// form over the component's local indices: `col_ptr[j]..col_ptr[j+1]` indexes
2670/// into `(row_idx, val)` for column `j` (rows `>= j`, diagonal first). `cols`
2671/// maps a local index back to its global β column. `Scalar` is the non-PD /
2672/// oversized degradation, identical in meaning to [`ClusterFactor::Scalar`].
2673#[derive(Clone)]
2674pub(crate) enum Ic0Factor {
2675 IncompleteChol {
2676 cols: Vec<usize>,
2677 col_ptr: Vec<usize>,
2678 row_idx: Vec<usize>,
2679 val: Vec<f64>,
2680 },
2681 Scalar {
2682 cols: Vec<usize>,
2683 inv: Vec<f64>,
2684 },
2685}
2686
2687impl std::fmt::Debug for Ic0Factor {
2688 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2689 match self {
2690 Ic0Factor::IncompleteChol { cols, val, .. } => write!(
2691 f,
2692 "Ic0Factor::IncompleteChol {{ cols.len: {}, nnz: {} }}",
2693 cols.len(),
2694 val.len()
2695 ),
2696 Ic0Factor::Scalar { cols, .. } => {
2697 write!(f, "Ic0Factor::Scalar {{ cols.len: {} }}", cols.len())
2698 }
2699 }
2700 }
2701}
2702
2703/// Level-0 incomplete-Cholesky Schur preconditioner (#299).
2704///
2705/// One sparse incomplete-Cholesky factor per connected component of the
2706/// β-coupling graph. Within a component the dense `S[C,C]` is assembled, its
2707/// structural-nonzero pattern `P = { (i,j) : |S_ij| > drop·sqrt(S_ii S_jj) }`
2708/// is taken as the level-0 fill set, and the no-fill incomplete Cholesky
2709/// `S ≈ L̃ L̃ᵀ` is formed keeping only `P` (drop any update landing outside it).
2710/// See [`SchurPreconditionerKind::BlockIncompleteCholesky`].
2711#[derive(Debug, Clone)]
2712pub struct BlockIncompleteCholeskyPreconditioner {
2713 pub(crate) components: Vec<Ic0Factor>,
2714}
2715
2716impl BlockIncompleteCholeskyPreconditioner {
2717 pub fn from_arrow_schur<B: BatchedBlockSolver + Sync>(
2718 sys: &ArrowSchurSystem,
2719 htt_factors: &ArrowFactorSlab,
2720 ridge_beta: f64,
2721 backend: &B,
2722 ) -> Result<Self, ArrowSchurError> {
2723 // Column grouping mirrors ClusterJacobi: one group per connected
2724 // component of the β-coupling graph (whole-K single group when no
2725 // block_offsets are registered), so IC(0) preconditions exactly the
2726 // coupling ClusterJacobi keeps, but with a sparse (no-fill) factor.
2727 let col_groups: Vec<Vec<usize>> = if sys.block_offsets.is_empty() {
2728 vec![(0..sys.k).collect()]
2729 } else {
2730 let graph = BetaCouplingGraph::build(
2731 &sys.block_offsets,
2732 &sys.rows
2733 .iter()
2734 .map(|r| r.htbeta.clone())
2735 .collect::<Vec<_>>(),
2736 );
2737 graph
2738 .component_partition()
2739 .iter()
2740 .map(|comp| {
2741 let mut cols: Vec<usize> = comp
2742 .iter()
2743 .flat_map(|&blk| sys.block_offsets[blk].clone())
2744 .collect();
2745 cols.sort_unstable();
2746 cols.dedup();
2747 cols
2748 })
2749 .collect()
2750 };
2751
2752 let mut components = Vec::with_capacity(col_groups.len());
2753 for cols in &col_groups {
2754 let b = cols.len();
2755 if b == 0 {
2756 continue;
2757 }
2758 if b > IC0_MAX_COMPONENT {
2759 let inv = build_schur_scalar_inv(sys, htt_factors, ridge_beta, backend, cols)?;
2760 components.push(Ic0Factor::Scalar {
2761 cols: cols.clone(),
2762 inv,
2763 });
2764 continue;
2765 }
2766 let mut s_block =
2767 assemble_local_schur_block(sys, htt_factors, ridge_beta, backend, cols);
2768 symmetrize_upper_from_lower(&mut s_block);
2769 match incomplete_cholesky_level0(&s_block) {
2770 Some((col_ptr, row_idx, val)) => components.push(Ic0Factor::IncompleteChol {
2771 cols: cols.clone(),
2772 col_ptr,
2773 row_idx,
2774 val,
2775 }),
2776 None => {
2777 // Non-PD incomplete pivot: degrade this component to the
2778 // scalar reciprocal diagonal (mirrors the ClusterJacobi
2779 // non-PD fallback), which is always applicable for a
2780 // PD-floored Schur diagonal.
2781 let inv = build_schur_scalar_inv(sys, htt_factors, ridge_beta, backend, cols)?;
2782 components.push(Ic0Factor::Scalar {
2783 cols: cols.clone(),
2784 inv,
2785 });
2786 }
2787 }
2788 }
2789 Ok(Self { components })
2790 }
2791
2792 pub(crate) fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
2793 let mut out = Array1::<f64>::zeros(r.len());
2794 for comp in &self.components {
2795 match comp {
2796 Ic0Factor::Scalar { cols, inv } => {
2797 for (local, &gi) in cols.iter().enumerate() {
2798 out[gi] = inv[local] * r[gi];
2799 }
2800 }
2801 Ic0Factor::IncompleteChol {
2802 cols,
2803 col_ptr,
2804 row_idx,
2805 val,
2806 } => {
2807 let b = cols.len();
2808 // Gather the local residual, solve `L̃ L̃ᵀ z = r_local` by a
2809 // sparse forward solve (`L̃ y = r`) then a sparse back solve
2810 // (`L̃ᵀ z = y`), then scatter `z` back to global columns.
2811 let mut z = vec![0.0f64; b];
2812 for (local, &gi) in cols.iter().enumerate() {
2813 z[local] = r[gi];
2814 }
2815 // Forward solve `L̃ y = r` (overwrite z with y). Column-major
2816 // CSC: row_idx[col_ptr[j]] == j (diagonal stored first).
2817 for j in 0..b {
2818 let dstart = col_ptr[j];
2819 let diag = val[dstart];
2820 z[j] /= diag;
2821 let yj = z[j];
2822 for k in (dstart + 1)..col_ptr[j + 1] {
2823 z[row_idx[k]] -= val[k] * yj;
2824 }
2825 }
2826 // Back solve `L̃ᵀ z = y` (overwrite z). Walk columns in
2827 // reverse; the below-diagonal entries of column j are the
2828 // off-diagonal entries of row j of L̃ᵀ.
2829 for j in (0..b).rev() {
2830 let dstart = col_ptr[j];
2831 let mut acc = z[j];
2832 for k in (dstart + 1)..col_ptr[j + 1] {
2833 acc -= val[k] * z[row_idx[k]];
2834 }
2835 z[j] = acc / val[dstart];
2836 }
2837 for (local, &gi) in cols.iter().enumerate() {
2838 out[gi] = z[local];
2839 }
2840 }
2841 }
2842 }
2843 out
2844 }
2845}
2846
2847/// Level-0 incomplete Cholesky of a dense SPD-ish block `a` (`b×b`, symmetric).
2848///
2849/// Returns the lower factor `L̃` in column-compressed (CSC) form
2850/// `(col_ptr, row_idx, val)` where each column lists its diagonal entry FIRST
2851/// followed by the strictly-below-diagonal entries, in increasing row order.
2852/// The kept pattern is the level-0 set `P` = structural nonzeros of `a` (a
2853/// relative drop threshold prunes round-off). IC(0) computes the standard
2854/// Cholesky recurrence but DROPS any value at a position outside `P`, so the
2855/// factor has exactly `nnz(tril(P))` entries — no fill. Returns `None` on a
2856/// non-positive pivot (caller degrades to scalar diagonal).
2857///
2858/// Reference: Y. Saad, *Iterative Methods for Sparse Linear Systems*, 2nd ed.,
2859/// §10.3.2 (IC(0)). This is the left-looking, pattern-restricted variant.
2860pub(crate) fn incomplete_cholesky_level0(
2861 a: &Array2<f64>,
2862) -> Option<(Vec<usize>, Vec<usize>, Vec<f64>)> {
2863 let b = a.nrows();
2864 assert_eq!(a.ncols(), b, "incomplete Cholesky needs a square block");
2865
2866 // ---- derive the level-0 lower-triangular pattern from `a` --------------
2867 // Per column j, the kept below-or-on-diagonal rows i>=j with a structurally
2868 // nonzero a[i,j]. The diagonal is always kept.
2869 let mut col_ptr = vec![0usize; b + 1];
2870 let mut row_idx: Vec<usize> = Vec::new();
2871 // value buffer, parallel to row_idx, initialised from tril(a) on the pattern
2872 let mut val: Vec<f64> = Vec::new();
2873 // For O(1) "is (i,j) in pattern + where" lookups during the recurrence, keep
2874 // a per-column map from global row -> position in that column's value slice.
2875 let mut col_pos: Vec<std::collections::HashMap<usize, usize>> = Vec::with_capacity(b);
2876 for j in 0..b {
2877 let ajj = a[[j, j]];
2878 let scale_j = ajj.abs().max(0.0).sqrt();
2879 let mut map = std::collections::HashMap::new();
2880 // diagonal first
2881 map.insert(j, val.len());
2882 row_idx.push(j);
2883 val.push(ajj);
2884 for i in (j + 1)..b {
2885 let aij = a[[i, j]];
2886 let scale_i = a[[i, i]].abs().sqrt();
2887 let thresh = IC0_PATTERN_REL_DROP * scale_i * scale_j;
2888 if aij.abs() > thresh {
2889 map.insert(i, val.len());
2890 row_idx.push(i);
2891 val.push(aij);
2892 }
2893 }
2894 col_pos.push(map);
2895 col_ptr[j + 1] = val.len();
2896 }
2897
2898 // ---- IC(0) recurrence, left-looking over columns -----------------------
2899 // For column j: subtract the contributions of all prior columns k<j that
2900 // have BOTH a nonzero at row j (so they touch the diagonal/the column) — the
2901 // multiplier L[j,k] — and a nonzero at the rows i of column j's pattern.
2902 // Any update whose target (i,j) is OUTSIDE the kept pattern is dropped.
2903 for j in 0..b {
2904 // Diagonal: a[j,j] - Σ_{k<j} L[j,k]². Each prior column k<j contributes
2905 // its row-j entry L[j,k] (looked up by row, so the column index is not
2906 // needed); columns without a row-j entry contribute nothing.
2907 let dpos = col_ptr[j];
2908 let mut diag = val[dpos];
2909 for mapk in &col_pos[..j] {
2910 if let Some(&pjk) = mapk.get(&j) {
2911 let ljk = val[pjk];
2912 diag -= ljk * ljk;
2913 }
2914 }
2915 if !diag.is_finite() || diag <= JACOBI_DIAGONAL_PD_FLOOR {
2916 return None;
2917 }
2918 let ljj = diag.sqrt();
2919 val[dpos] = ljj;
2920 // Below-diagonal of column j: L[i,j] = (a[i,j] - Σ_{k<j} L[i,k] L[j,k]) / L[j,j]
2921 for p in (dpos + 1)..col_ptr[j + 1] {
2922 let i = row_idx[p];
2923 let mut s = val[p];
2924 for mapk in &col_pos[..j] {
2925 if let (Some(&pik), Some(&pjk)) = (mapk.get(&i), mapk.get(&j)) {
2926 s -= val[pik] * val[pjk];
2927 }
2928 }
2929 val[p] = s / ljj;
2930 }
2931 }
2932 Some((col_ptr, row_idx, val))
2933}
2934
2935/// One row of the #299 preconditioner-ladder iteration study: the converged
2936/// PCG iteration count and stop reason for a single preconditioner tier.
2937#[derive(Debug, Clone, Copy)]
2938pub struct PrecondLadderRow {
2939 /// PCG iterations to convergence (or to the `MaxIter` cutoff).
2940 pub iterations: usize,
2941 /// Whether the PCG converged (vs hit `MaxIter` / negative curvature).
2942 pub converged: bool,
2943 /// Final relative residual reported by the PCG.
2944 pub final_relative_residual: f64,
2945}
2946
2947/// Full #299 ladder iteration study on one reduced-Schur system: run the SAME
2948/// preconditioned CG (same `rhs`, tolerances, trust radius) once per ladder tier
2949/// and report the iteration count of each. This is the public seam the
2950/// `tests/owed_299.rs` iteration-reduction gate drives — it keeps the internal
2951/// `run_pcg_with_preconditioner` / preconditioner constructors `pub(crate)`
2952/// while exposing exactly the per-tier measurement the issue asks for.
2953///
2954/// Tiers (in escalation order): scalar `Diagonal`, `BetaBlockJacobi`,
2955/// `ClusterJacobi`, `AdditiveSchwarz{overlap:1}`, `DiagAssembledSchwarz{1}`, and
2956/// `BlockIncompleteCholesky`. A tier whose build fails (e.g. non-PD reduced
2957/// Schur with no curvature floor) reports `None` for that entry; every healthy
2958/// SPD reduced system populates all six.
2959pub fn arrow_precond_ladder_iteration_study(
2960 sys: &ArrowSchurSystem,
2961 ridge_beta: f64,
2962 rhs: &Array1<f64>,
2963 pcg: &ArrowPcgOptions,
2964 trust: &ArrowTrustRegionOptions,
2965) -> Result<Vec<(SchurPreconditionerKind, Option<PrecondLadderRow>)>, ArrowSchurError> {
2966 let backend = CpuBatchedBlockSolver;
2967 let htt_factors = backend.factor_blocks(&sys.rows, 0.0, sys.d, false)?;
2968
2969 let run = |apply: &dyn Fn(&Array1<f64>) -> Array1<f64>| -> Option<PrecondLadderRow> {
2970 let (_sol, diag) = run_pcg_with_preconditioner(
2971 sys,
2972 &htt_factors,
2973 ridge_beta,
2974 rhs,
2975 |r| apply(r),
2976 pcg,
2977 trust,
2978 &backend,
2979 None,
2980 None,
2981 None,
2982 )
2983 .ok()?;
2984 Some(PrecondLadderRow {
2985 iterations: diag.iterations,
2986 converged: matches!(diag.stopping_reason, PcgStopReason::Converged),
2987 final_relative_residual: diag.final_relative_residual,
2988 })
2989 };
2990
2991 let mut out: Vec<(SchurPreconditionerKind, Option<PrecondLadderRow>)> = Vec::with_capacity(6);
2992
2993 // Scalar Diagonal Jacobi: force the scalar path by clearing block_offsets on
2994 // a clone so the build does not pick up the per-block dense Schur blocks.
2995 let diag_row = {
2996 let mut bare = sys.clone();
2997 bare.set_block_offsets(std::sync::Arc::from([] as [Range<usize>; 0]));
2998 let bare_factors = backend.factor_blocks(&bare.rows, 0.0, bare.d, false)?;
2999 JacobiPreconditioner::from_arrow_schur(&bare, &bare_factors, ridge_beta, &backend, None)
3000 .ok()
3001 .and_then(|p| {
3002 run_pcg_with_preconditioner(
3003 &bare,
3004 &bare_factors,
3005 ridge_beta,
3006 rhs,
3007 |r| p.apply(r),
3008 pcg,
3009 trust,
3010 &backend,
3011 None,
3012 None,
3013 None,
3014 )
3015 .ok()
3016 .map(|(_s, diag)| PrecondLadderRow {
3017 iterations: diag.iterations,
3018 converged: matches!(diag.stopping_reason, PcgStopReason::Converged),
3019 final_relative_residual: diag.final_relative_residual,
3020 })
3021 })
3022 };
3023 out.push((SchurPreconditionerKind::Diagonal, diag_row));
3024
3025 let block_row =
3026 JacobiPreconditioner::from_arrow_schur(sys, &htt_factors, ridge_beta, &backend, None)
3027 .ok()
3028 .and_then(|p| run(&|r| p.apply(r)));
3029 out.push((SchurPreconditionerKind::BetaBlockJacobi, block_row));
3030
3031 let cluster_row =
3032 ClusterJacobiPreconditioner::from_arrow_schur(sys, &htt_factors, ridge_beta, &backend)
3033 .ok()
3034 .and_then(|p| run(&|r| p.apply(r)));
3035 out.push((SchurPreconditionerKind::ClusterJacobi, cluster_row));
3036
3037 let schwarz_row =
3038 AdditiveSchwarzPreconditioner::from_arrow_schur(sys, &htt_factors, ridge_beta, &backend, 1)
3039 .ok()
3040 .and_then(|p| run(&|r| p.apply(r)));
3041 out.push((
3042 SchurPreconditionerKind::AdditiveSchwarz { overlap: 1 },
3043 schwarz_row,
3044 ));
3045
3046 let diag_schwarz_row = DiagAssembledSchwarzPreconditioner::from_arrow_schur(
3047 sys,
3048 &htt_factors,
3049 ridge_beta,
3050 &backend,
3051 1,
3052 )
3053 .ok()
3054 .and_then(|p| run(&|r| p.apply(r)));
3055 out.push((
3056 SchurPreconditionerKind::DiagAssembledSchwarz { overlap: 1 },
3057 diag_schwarz_row,
3058 ));
3059
3060 let ic0_row = BlockIncompleteCholeskyPreconditioner::from_arrow_schur(
3061 sys,
3062 &htt_factors,
3063 ridge_beta,
3064 &backend,
3065 )
3066 .ok()
3067 .and_then(|p| run(&|r| p.apply(r)));
3068 out.push((SchurPreconditionerKind::BlockIncompleteCholesky, ic0_row));
3069
3070 Ok(out)
3071}
3072
3073/// Build scalar diagonal inverses for a set of global column indices.
3074///
3075/// Used when a cluster is non-PD or exceeds `CLUSTER_JACOBI_MAX_CLUSTER`.
3076pub(crate) fn build_schur_scalar_inv<B: BatchedBlockSolver>(
3077 sys: &ArrowSchurSystem,
3078 htt_factors: &ArrowFactorSlab,
3079 ridge_beta: f64,
3080 backend: &B,
3081 cols: &[usize],
3082) -> Result<Vec<f64>, ArrowSchurError> {
3083 let d = sys.d;
3084 let mut result = Vec::with_capacity(cols.len());
3085 let mut col_vec = Array1::<f64>::zeros(d);
3086 // Extract the penalty diagonal for all K columns once, then index per-column.
3087 let mut full_diag = Array1::<f64>::zeros(sys.k);
3088 {
3089 let diag_slice = full_diag.as_slice_mut().expect("full_diag contiguous");
3090 sys.penalty_diagonal_add(diag_slice);
3091 }
3092 for &gi in cols {
3093 let mut s = full_diag[gi] + ridge_beta;
3094 for (row_idx, row) in sys.rows.iter().enumerate() {
3095 for c in 0..d {
3096 col_vec[c] = row.htbeta[[c, gi]];
3097 }
3098 let solved = backend.solve_block_vector(htt_factors.factor(row_idx), col_vec.view());
3099 let mut acc = 0.0;
3100 for c in 0..d {
3101 acc += col_vec[c] * solved[c];
3102 }
3103 s -= acc;
3104 }
3105 if !s.is_finite() || s <= JACOBI_DIAGONAL_PD_FLOOR {
3106 return Err(ArrowSchurError::PcgFailed {
3107 reason: format!(
3108 "cluster Schur scalar fallback: non-PD diagonal at index {gi}: {s}"
3109 ),
3110 });
3111 }
3112 result.push(1.0 / s);
3113 }
3114 Ok(result)
3115}
3116
3117/// Inexact PCG with automatic preconditioner-ladder escalation.
3118///
3119/// Starts with `JacobiPreconditioner` (Diagonal or BetaBlockJacobi).
3120/// If PCG hits `MaxIter` and `k > PRECOND_ESCALATE_K_THRESHOLD`,
3121/// escalates to `ClusterJacobi`; if still `MaxIter`, escalates to
3122/// `AdditiveSchwarz { overlap: 1 }`.
3123pub(crate) fn steihaug_pcg_auto<B: BatchedBlockSolver + Sync>(
3124 sys: &ArrowSchurSystem,
3125 htt_factors: &ArrowFactorSlab,
3126 ridge_beta: f64,
3127 rhs: &Array1<f64>,
3128 pcg: &ArrowPcgOptions,
3129 trust: &ArrowTrustRegionOptions,
3130 backend: &B,
3131 gpu_matvec: Option<&GpuSchurMatvec>,
3132 metric_weights: Option<&MetricWeights>,
3133 curvature_floor: Option<f64>,
3134) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurError> {
3135 // #1017 CPU residency: stage the per-row reduced-Schur factors `(L_i, Y_i)`
3136 // (NOT the dense `p×p` block — `di ≪ p`, so the factored form is `O(n·di·p)`
3137 // memory and `2·support_i·p + 2·di·p` flops/row including the sparse
3138 // gather/scatter over the active support) once, up
3139 // front, when the SAE structure is installed and the matvec runs on host
3140 // (CPU). The GPU matvec carries its own residency, so skip when it is engaged.
3141 // The same staged operator is reused across the whole preconditioner ladder
3142 // (Jacobi → ClusterJacobi → AdditiveSchwarz) — built once, not per tier.
3143 let resident = if gpu_matvec.is_none() {
3144 SaeResidentReducedSchur::build(sys, htt_factors, backend)
3145 } else {
3146 None
3147 };
3148 // #1026 — curvature-floor retry on the Jacobi tier. The unbounded SAE inner
3149 // PCG (trust radius = ∞) fails on `pᵀSp ≤ 0` when the reduced Schur is
3150 // indefinite (K≥4 co-collapse: a near-singular per-row `H_tt` over-subtracts
3151 // `S`). Instead of letting that failure propagate to the outer LM loop —
3152 // which inflates `ridge_β` over EVERY β direction and makes the inner Newton
3153 // crawl — floor the OPERATOR by the minimal ridge `δ = |pᵀSp|/‖p‖² · (1+ε)`
3154 // that restores positive curvature along the offending direction, rebuild the
3155 // Jacobi preconditioner at the lifted ridge, and retry. This is the
3156 // matrix-free analogue of the dense `spectral_pd_floored_schur`: the healthy
3157 // β subspace (where curvature is already positive) is essentially untouched
3158 // by a tiny `δ`, while the collapsed direction gets exactly the stiffness it
3159 // needs to make a real descent step. A PD reduced Schur never hits `pᵀSp ≤ 0`,
3160 // so this loop is a strict no-op there (bit-for-bit unchanged). Bounded by a
3161 // small attempt cap and a relative ridge ceiling; on exhaustion the original
3162 // recoverable failure still reaches the outer LM loop.
3163 let mut effective_ridge = ridge_beta;
3164 let mut x0_diag0: Option<(Array1<f64>, PcgDiagnostics)> = None;
3165 let mut last_curvature_err: Option<ArrowSchurError> = None;
3166 let rhs_scale = metric_norm(rhs.view(), metric_weights).max(1.0);
3167 let ridge_ceiling = ridge_beta.max(SCHUR_CURVATURE_FLOOR_REL_CEILING * rhs_scale);
3168 for _attempt in 0..=SCHUR_CURVATURE_FLOOR_MAX_ATTEMPTS {
3169 // The Jacobi preconditioner build itself refuses a non-PD Schur diagonal
3170 // (`PcgFailed: invalid Schur Jacobi diagonal`) — the SAME co-collapse
3171 // signature reached BEFORE the CG loop, since `S_ii = H_ββ,ii − Σ …` goes
3172 // negative. Treat that build failure as a curvature deficit too: when the
3173 // floor is enabled, lift the ridge and retry; otherwise propagate.
3174 let jacobi = match JacobiPreconditioner::from_arrow_schur(
3175 sys,
3176 htt_factors,
3177 effective_ridge,
3178 backend,
3179 resident.as_ref(),
3180 ) {
3181 Ok(jacobi) => jacobi,
3182 Err(err @ ArrowSchurError::PcgFailed { .. }) => {
3183 if curvature_floor.is_none() {
3184 return Err(err);
3185 }
3186 // A diagonal refusal carries no `(curvature, ‖p‖²)` deficit, and
3187 // the over-subtraction magnitude `Σ H_tβᵀ(H_tt)⁻¹H_tβ` is
3188 // unbounded relative to `rhs_scale`, so a small additive bump
3189 // would crawl. Escalate the ridge MULTIPLICATIVELY (×10, matching
3190 // the per-row `factor_one_row_result` RIDGE_GROWTH_FACTOR), seeded
3191 // at `rhs_scale`, so even a large deficit (the collapsed
3192 // `(H_tβ)²/H_tt` over-subtraction) is reached in a handful of
3193 // attempts. The ceiling + attempt cap still bound it; on
3194 // exhaustion the recoverable failure reaches the outer LM loop.
3195 let next = if effective_ridge > 0.0 {
3196 effective_ridge * SCHUR_CURVATURE_FLOOR_DIAG_GROWTH
3197 } else {
3198 rhs_scale
3199 };
3200 last_curvature_err = Some(err);
3201 if !next.is_finite() || next > ridge_ceiling {
3202 break;
3203 }
3204 effective_ridge = next;
3205 continue;
3206 }
3207 Err(other) => return Err(other),
3208 };
3209 match run_pcg_with_preconditioner(
3210 sys,
3211 htt_factors,
3212 effective_ridge,
3213 rhs,
3214 |r| jacobi.apply(r),
3215 pcg,
3216 trust,
3217 backend,
3218 gpu_matvec,
3219 metric_weights,
3220 resident.as_ref(),
3221 ) {
3222 Ok(result) => {
3223 x0_diag0 = Some(result);
3224 break;
3225 }
3226 Err(ArrowSchurError::UnboundedNegativeCurvature {
3227 curvature,
3228 direction_norm_sq,
3229 }) => {
3230 // Only floor when the caller opted in (SAE solve path); otherwise
3231 // propagate the raw negative-curvature signal so BA / non-SAE
3232 // unbounded solves keep their existing failure contract.
3233 let Some(relative_floor) = curvature_floor else {
3234 return Err(ArrowSchurError::UnboundedNegativeCurvature {
3235 curvature,
3236 direction_norm_sq,
3237 });
3238 };
3239 // Minimal ridge to make `pᵀ(S+δI)p = |curvature| + δ·‖p‖² > 0`,
3240 // with a margin so the next CG iterate has strictly positive
3241 // curvature rather than sitting on the `0` knife-edge.
3242 let deficit = if direction_norm_sq > 0.0 {
3243 curvature.abs() / direction_norm_sq
3244 } else {
3245 0.0
3246 };
3247 let bump = (deficit * (1.0 + SCHUR_CURVATURE_FLOOR_MARGIN))
3248 .max(relative_floor.max(SCHUR_CURVATURE_FLOOR_REL_FLOOR) * rhs_scale);
3249 let next = (effective_ridge + bump).max(effective_ridge * 2.0);
3250 last_curvature_err = Some(ArrowSchurError::UnboundedNegativeCurvature {
3251 curvature,
3252 direction_norm_sq,
3253 });
3254 if !next.is_finite() || next > ridge_ceiling {
3255 break;
3256 }
3257 effective_ridge = next;
3258 }
3259 Err(other) => return Err(other),
3260 }
3261 }
3262 let (x0, diag0) = match x0_diag0 {
3263 Some(result) => result,
3264 None => {
3265 // The curvature floor could not condition the operator within the
3266 // ceiling; hand the recoverable failure to the outer LM loop, which
3267 // re-forms the system at a heavier ridge.
3268 return Err(last_curvature_err.unwrap_or(ArrowSchurError::PcgFailed {
3269 reason: "unbounded Schur PCG negative curvature unresolved by curvature floor"
3270 .to_string(),
3271 }));
3272 }
3273 };
3274 if sys.k <= PRECOND_ESCALATE_K_THRESHOLD || diag0.stopping_reason != PcgStopReason::MaxIter {
3275 return Ok((x0, diag0));
3276 }
3277 // Escalation tiers reuse the curvature-floored `effective_ridge` so the
3278 // operator they precondition is the SAME (PD-floored) one the Jacobi tier
3279 // settled on; a still-negative-curvature signal here is handed to the outer
3280 // LM loop (it only arises if the floored Jacobi tier merely ran out of
3281 // iterations yet a coarser preconditioner still finds an indefinite
3282 // direction — rare; the LM loop re-forms at a heavier ridge).
3283 let cluster =
3284 ClusterJacobiPreconditioner::from_arrow_schur(sys, htt_factors, effective_ridge, backend)?;
3285 let (x1, diag1) = run_pcg_with_preconditioner(
3286 sys,
3287 htt_factors,
3288 effective_ridge,
3289 rhs,
3290 |r| cluster.apply(r),
3291 pcg,
3292 trust,
3293 backend,
3294 gpu_matvec,
3295 metric_weights,
3296 resident.as_ref(),
3297 )?;
3298 if diag1.stopping_reason != PcgStopReason::MaxIter {
3299 return Ok((x1, diag1));
3300 }
3301 let schwarz = AdditiveSchwarzPreconditioner::from_arrow_schur(
3302 sys,
3303 htt_factors,
3304 effective_ridge,
3305 backend,
3306 1,
3307 )?;
3308 let (x2, diag2) = run_pcg_with_preconditioner(
3309 sys,
3310 htt_factors,
3311 effective_ridge,
3312 rhs,
3313 |r| schwarz.apply(r),
3314 pcg,
3315 trust,
3316 backend,
3317 gpu_matvec,
3318 metric_weights,
3319 resident.as_ref(),
3320 )?;
3321 if diag2.stopping_reason != PcgStopReason::MaxIter {
3322 return Ok((x2, diag2));
3323 }
3324 // Final tier — diagonal-assembled additive Schwarz (#299), the cheap-apply
3325 // Schwarz variant. When the dense-block AdditiveSchwarz still ran out of
3326 // iterations its O(Σ b_k²) apply may have throttled the iteration budget on
3327 // a wide subdomain; the diag-assembled variant keeps Schwarz's overlapping
3328 // local-inverse conditioning but applies in O(K), so it can take more CG
3329 // iterations within the same wall budget. Same overlap (1) and same
3330 // curvature-floored ridge as the dense-block tier.
3331 let diag_schwarz = DiagAssembledSchwarzPreconditioner::from_arrow_schur(
3332 sys,
3333 htt_factors,
3334 effective_ridge,
3335 backend,
3336 1,
3337 )?;
3338 let (x3, diag3) = run_pcg_with_preconditioner(
3339 sys,
3340 htt_factors,
3341 effective_ridge,
3342 rhs,
3343 |r| diag_schwarz.apply(r),
3344 pcg,
3345 trust,
3346 backend,
3347 gpu_matvec,
3348 metric_weights,
3349 resident.as_ref(),
3350 )?;
3351 if diag3.stopping_reason != PcgStopReason::MaxIter {
3352 return Ok((x3, diag3));
3353 }
3354 // Richest tier — level-0 incomplete Cholesky (#299). ClusterJacobi keeps the
3355 // full DENSE Cholesky of each component (so on a single large connected
3356 // component it fills the whole `b×b` factor and its `O(b²)` apply throttles
3357 // the CG iteration budget), while the diagonal/Schwarz tiers drop most
3358 // inter-block coupling. IC(0) keeps the component's full structural coupling
3359 // but only the level-0 (no-fill) pattern, so its sparse triangular apply is
3360 // `O(nnz(S[C,C]))` — it can take more CG iterations within the same wall
3361 // budget AND conditions the off-diagonal coupling the cheap tiers discard.
3362 // Last in the ladder so it is only paid when every cheaper tier stalled.
3363 let ic0 = BlockIncompleteCholeskyPreconditioner::from_arrow_schur(
3364 sys,
3365 htt_factors,
3366 effective_ridge,
3367 backend,
3368 )?;
3369 let (x4, diag4) = run_pcg_with_preconditioner(
3370 sys,
3371 htt_factors,
3372 effective_ridge,
3373 rhs,
3374 |r| ic0.apply(r),
3375 pcg,
3376 trust,
3377 backend,
3378 gpu_matvec,
3379 metric_weights,
3380 resident.as_ref(),
3381 )?;
3382 // All five preconditioner tiers (Jacobi -> ClusterJacobi -> AdditiveSchwarz
3383 // -> DiagAssembledSchwarz -> BlockIncompleteCholesky) exhausted their
3384 // iteration budget without driving the residual below tolerance. Returning a
3385 // truncated iterate as `Ok` would feed an arbitrarily-large-residual step
3386 // into the Newton driver, where the PCG diagnostics are discarded. Surface a
3387 // recoverable failure instead so `solve_with_lm_escalation_inner` escalates
3388 // the proximal ridge: better conditioning is precisely what a stalled PCG on
3389 // an ill-conditioned reduced system needs.
3390 if diag4.stopping_reason == PcgStopReason::MaxIter {
3391 return Err(ArrowSchurError::PcgFailed {
3392 reason: format!(
3393 "Schur PCG exhausted all preconditioner tiers (Jacobi, ClusterJacobi, \
3394 AdditiveSchwarz, DiagAssembledSchwarz, BlockIncompleteCholesky) at MaxIter; \
3395 final relative residual = {:e}",
3396 diag4.final_relative_residual
3397 ),
3398 });
3399 }
3400 Ok((x4, diag4))
3401}
3402
3403/// Run Steihaug-CG with a generic preconditioner closure.
3404/// Routes matvec through GPU when `gpu_matvec` is set.
3405pub(crate) fn run_pcg_with_preconditioner<ApplyPrec, B: BatchedBlockSolver + Sync>(
3406 sys: &ArrowSchurSystem,
3407 htt_factors: &ArrowFactorSlab,
3408 ridge_beta: f64,
3409 rhs: &Array1<f64>,
3410 apply_prec: ApplyPrec,
3411 pcg: &ArrowPcgOptions,
3412 trust: &ArrowTrustRegionOptions,
3413 backend: &B,
3414 gpu_matvec: Option<&GpuSchurMatvec>,
3415 metric_weights: Option<&MetricWeights>,
3416 resident: Option<&SaeResidentReducedSchur>,
3417) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurError>
3418where
3419 ApplyPrec: FnMut(&Array1<f64>) -> Array1<f64>,
3420{
3421 let max_iters = pcg.max_iterations.min(trust.max_iterations);
3422 let tol = pcg
3423 .relative_tolerance
3424 .max(trust.steihaug_relative_tolerance);
3425 if let Some(gpu_mv) = gpu_matvec {
3426 let gpu_mv = Arc::clone(gpu_mv);
3427 steihaug_cg(
3428 rhs,
3429 move |p, out| gpu_mv(p, out),
3430 apply_prec,
3431 max_iters,
3432 tol,
3433 trust.radius,
3434 metric_weights,
3435 )
3436 } else {
3437 steihaug_cg(
3438 rhs,
3439 |p, out| schur_matvec(sys, htt_factors, ridge_beta, p, out, backend, resident),
3440 apply_prec,
3441 max_iters,
3442 tol,
3443 trust.radius,
3444 metric_weights,
3445 )
3446 }
3447}
3448
3449#[derive(Debug, Clone, Copy)]
3450pub(crate) struct IdentityPreconditioner;
3451
3452impl IdentityPreconditioner {
3453 pub(crate) fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
3454 r.clone()
3455 }
3456}
3457
3458pub(crate) fn steihaug_dense_system(
3459 schur: &Array2<f64>,
3460 rhs: &Array1<f64>,
3461 preconditioner: &IdentityPreconditioner,
3462 pcg: &ArrowPcgOptions,
3463 trust: &ArrowTrustRegionOptions,
3464 metric_weights: Option<&MetricWeights>,
3465) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurError> {
3466 steihaug_cg(
3467 rhs,
3468 |p, out| dense_matvec(schur, p, out),
3469 |r| preconditioner.apply(r),
3470 pcg.max_iterations,
3471 pcg.relative_tolerance,
3472 trust.radius,
3473 metric_weights,
3474 )
3475}
3476
3477pub(crate) fn steihaug_cg<MatVec, ApplyPrec>(
3478 rhs: &Array1<f64>,
3479 mut matvec: MatVec,
3480 mut apply_preconditioner: ApplyPrec,
3481 max_iterations: usize,
3482 relative_tolerance: f64,
3483 trust_radius: f64,
3484 metric_weights: Option<&MetricWeights>,
3485) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurError>
3486where
3487 MatVec: FnMut(&Array1<f64>, &mut Array1<f64>),
3488 ApplyPrec: FnMut(&Array1<f64>) -> Array1<f64>,
3489{
3490 let n = rhs.len();
3491 if let Some(weights) = metric_weights {
3492 assert_eq!(
3493 weights.len(),
3494 n,
3495 "Steihaug-CG metric weight length must match solve dimension"
3496 );
3497 }
3498 let radius = if trust_radius.is_finite() && trust_radius > 0.0 {
3499 trust_radius
3500 } else {
3501 f64::INFINITY
3502 };
3503 let rhs_norm = metric_norm(rhs.view(), metric_weights);
3504 if rhs_norm == 0.0 {
3505 return Ok((Array1::<f64>::zeros(n), PcgDiagnostics::default()));
3506 }
3507 let tol = (relative_tolerance.max(0.0) * rhs_norm).max(PCG_ABSOLUTE_TOLERANCE_FLOOR);
3508 let mut x = Array1::<f64>::zeros(n);
3509 let mut r = rhs.clone();
3510 let mut z = apply_preconditioner(&r);
3511 let mut diag = PcgDiagnostics {
3512 precond_apply_calls: 1,
3513 ..PcgDiagnostics::default()
3514 };
3515 let mut p = z.clone();
3516 let mut rz = metric_dot(&r, &z, metric_weights);
3517 if rz <= 0.0 || !rz.is_finite() {
3518 if radius.is_finite() {
3519 diag.final_relative_residual = metric_norm(r.view(), metric_weights) / rhs_norm;
3520 diag.stopping_reason = PcgStopReason::TrustRegion;
3521 return Ok((step_to_trust_boundary(&x, &r, radius, metric_weights), diag));
3522 }
3523 // Unbounded (radius = ∞) non-positive preconditioned residual: the
3524 // reduced Schur is indefinite at the very first direction. Surface the
3525 // typed curvature-floor signal so `steihaug_pcg_auto` floors the
3526 // operator minimally and retries, instead of failing into a global
3527 // `ridge_β` ramp. `rz = rᵀM⁻¹r` is a preconditioner-metric curvature;
3528 // report it with the residual norm² as the direction scale.
3529 return Err(ArrowSchurError::UnboundedNegativeCurvature {
3530 curvature: rz,
3531 direction_norm_sq: metric_dot(&r, &r, metric_weights),
3532 });
3533 }
3534 if metric_norm(r.view(), metric_weights) <= tol {
3535 diag.final_relative_residual = 0.0;
3536 diag.stopping_reason = PcgStopReason::Converged;
3537 return Ok((x, diag));
3538 }
3539 let mut ap = Array1::<f64>::zeros(n);
3540 // Reused candidate scratch — avoid per-iteration clone of x.
3541 let mut candidate = Array1::<f64>::zeros(n);
3542 for _ in 0..max_iterations {
3543 matvec(&p, &mut ap);
3544 diag.matvec_calls += 1;
3545 diag.iterations += 1;
3546 let pap = metric_dot(&p, &ap, metric_weights);
3547 if pap <= 0.0 || !pap.is_finite() {
3548 if radius.is_finite() {
3549 diag.final_relative_residual = metric_norm(r.view(), metric_weights) / rhs_norm;
3550 diag.stopping_reason = PcgStopReason::TrustRegion;
3551 return Ok((step_to_trust_boundary(&x, &p, radius, metric_weights), diag));
3552 }
3553 // Unbounded negative curvature `pᵀSp ≤ 0`: the reduced Schur is
3554 // indefinite along `p` (the #1026 co-collapse direction). Surface
3555 // the typed signal carrying `pᵀSp` and `‖p‖²` so the caller floors
3556 // the operator by the minimal ridge `δ = |pᵀSp|/‖p‖²` (which makes
3557 // `pᵀ(S+δI)p = 0⁺`) plus a margin, and retries.
3558 return Err(ArrowSchurError::UnboundedNegativeCurvature {
3559 curvature: pap,
3560 direction_norm_sq: metric_dot(&p, &p, metric_weights),
3561 });
3562 }
3563 let alpha = rz / pap;
3564 for i in 0..n {
3565 candidate[i] = x[i] + alpha * p[i];
3566 }
3567 if radius.is_finite() && metric_norm(candidate.view(), metric_weights) >= radius {
3568 diag.final_relative_residual = metric_norm(r.view(), metric_weights) / rhs_norm;
3569 diag.stopping_reason = PcgStopReason::TrustRegion;
3570 return Ok((step_to_trust_boundary(&x, &p, radius, metric_weights), diag));
3571 }
3572 x.assign(&candidate);
3573 for i in 0..n {
3574 r[i] -= alpha * ap[i];
3575 }
3576 if metric_norm(r.view(), metric_weights) <= tol {
3577 diag.final_relative_residual = metric_norm(r.view(), metric_weights) / rhs_norm;
3578 diag.stopping_reason = PcgStopReason::Converged;
3579 return Ok((x, diag));
3580 }
3581 z = apply_preconditioner(&r);
3582 diag.precond_apply_calls += 1;
3583 let rz_next = metric_dot(&r, &z, metric_weights);
3584 if rz_next <= 0.0 || !rz_next.is_finite() {
3585 return Err(ArrowSchurError::PcgFailed {
3586 reason: "non-positive or non-finite PCG residual".to_string(),
3587 });
3588 }
3589 let beta = rz_next / rz;
3590 for i in 0..n {
3591 p[i] = z[i] + beta * p[i];
3592 }
3593 rz = rz_next;
3594 }
3595 diag.final_relative_residual = metric_norm(r.view(), metric_weights) / rhs_norm;
3596 diag.stopping_reason = PcgStopReason::MaxIter;
3597 Ok((x, diag))
3598}
3599
3600pub(crate) fn step_to_trust_boundary(
3601 x: &Array1<f64>,
3602 p: &Array1<f64>,
3603 radius: f64,
3604 metric_weights: Option<&MetricWeights>,
3605) -> Array1<f64> {
3606 let pp = metric_dot(p, p, metric_weights);
3607 if pp == 0.0 {
3608 return x.clone();
3609 }
3610 let xp = metric_dot(x, p, metric_weights);
3611 let xx = metric_dot(x, x, metric_weights);
3612 let disc = (xp * xp + pp * (radius * radius - xx)).max(0.0);
3613 let tau = (-xp + disc.sqrt()) / pp;
3614 let mut out = x.clone();
3615 for i in 0..out.len() {
3616 out[i] += tau * p[i];
3617 }
3618 out
3619}
3620
3621pub(crate) fn dense_matvec(a: &Array2<f64>, x: &Array1<f64>, out: &mut Array1<f64>) {
3622 let n = a.nrows();
3623 for i in 0..n {
3624 let mut acc = 0.0;
3625 for j in 0..n {
3626 acc += a[[i, j]] * x[j];
3627 }
3628 out[i] = acc;
3629 }
3630}
3631
3632pub(crate) fn dot(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
3633 let mut acc = 0.0;
3634 for i in 0..a.len() {
3635 acc += a[i] * b[i];
3636 }
3637 acc
3638}
3639
3640pub(crate) fn metric_dot(
3641 a: &Array1<f64>,
3642 b: &Array1<f64>,
3643 metric_weights: Option<&MetricWeights>,
3644) -> f64 {
3645 assert_eq!(a.len(), b.len());
3646 match metric_weights {
3647 Some(weights) => {
3648 assert_eq!(weights.len(), a.len());
3649 let mut acc = 0.0;
3650 for i in 0..a.len() {
3651 acc += weights[i] * a[i] * b[i];
3652 }
3653 acc
3654 }
3655 None => dot(a, b),
3656 }
3657}
3658
3659pub(crate) fn metric_norm(v: ArrayView1<'_, f64>, metric_weights: Option<&MetricWeights>) -> f64 {
3660 let mut acc = 0.0;
3661 match metric_weights {
3662 Some(weights) => {
3663 assert_eq!(weights.len(), v.len());
3664 for i in 0..v.len() {
3665 acc += weights[i] * v[i] * v[i];
3666 }
3667 }
3668 None => {
3669 for x in v.iter() {
3670 acc += x * x;
3671 }
3672 }
3673 }
3674 acc.sqrt()
3675}
3676
3677pub(crate) fn symmetrize_upper_from_lower(a: &mut Array2<f64>) {
3678 let n = a.nrows().min(a.ncols());
3679 for i in 0..n {
3680 for j in 0..i {
3681 let v = 0.5 * (a[[i, j]] + a[[j, i]]);
3682 a[[i, j]] = v;
3683 a[[j, i]] = v;
3684 }
3685 }
3686}
3687
3688/// Errors raised by [`ArrowSchurSystem::solve`].
3689#[derive(Debug, Clone)]
3690pub enum ArrowSchurError {
3691 /// A per-row `H_tt^(i)` block was not positive-definite at the
3692 /// supplied ridge. Indicates an under-regularized latent block —
3693 /// typically a gauge-free fit without an identifiability penalty.
3694 PerRowFactorFailed { row: usize, reason: String },
3695 /// A per-row `H_tt^(i)` block factored, but the Cholesky factor failed
3696 /// the safe-inversion guard for the Schur reduction. This can be either
3697 /// an excessive diagonal-ratio condition-number estimate or a numerically
3698 /// tiny pivot relative to the row block scale. Cholesky technically
3699 /// succeeded, but the inverse used in
3700 /// `S = H_ββ − Σ_i H_tβ^(i)ᵀ (H_tt^(i))⁻¹ H_tβ^(i)` is contaminated
3701 /// by spectral terms on the order of `κ_i`; functionally
3702 /// equivalent to a PSD-fail for Schur stability. The LM outer
3703 /// wrapper escalates `ridge_t` identically to `PerRowFactorFailed`.
3704 PerRowFactorIllConditioned { row: usize, kappa_estimate: f64 },
3705 /// The Schur complement was not positive-definite. Indicates a
3706 /// near-collinear decoder or a degenerate weighting; the LM outer
3707 /// wrapper should escalate `ridge_beta` and retry.
3708 SchurFactorFailed { reason: String },
3709 /// The BA inexact-step PCG solve failed before producing a usable
3710 /// Steihaug trust-region step.
3711 PcgFailed { reason: String },
3712 /// The UNBOUNDED (trust-radius = ∞) Schur PCG encountered negative
3713 /// curvature `pᵀSp ≤ 0` (or a non-positive preconditioned residual): the
3714 /// reduced Schur is indefinite, the #1026 K≥4 co-collapse signature where
3715 /// a near-singular per-row `H_tt` over-subtracts `S`. With no trust radius
3716 /// there is no boundary to step to, so CG cannot proceed. `curvature` is
3717 /// the offending `pᵀSp` and `direction_norm_sq` the `‖p‖²` of the
3718 /// negative-curvature direction; the caller floors the operator with the
3719 /// minimal ridge `δ = (|curvature|/‖p‖² )·(1+ε)` that restores positive
3720 /// curvature along `p` and retries (matrix-free analogue of the dense
3721 /// `spectral_pd_floored_schur`), rather than blindly inflating `ridge_β`.
3722 UnboundedNegativeCurvature {
3723 curvature: f64,
3724 direction_norm_sq: f64,
3725 },
3726 /// Adaptive proximal damping could not produce an Armijo-accepted
3727 /// nonlinear step.
3728 AdaptiveCorrectionFailed { reason: String },
3729}
3730
3731impl std::fmt::Display for ArrowSchurError {
3732 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
3733 match self {
3734 ArrowSchurError::PerRowFactorFailed { row, reason } => write!(
3735 f,
3736 "arrow-Schur: per-row H_tt^({row}) Cholesky failed: {reason}"
3737 ),
3738 ArrowSchurError::PerRowFactorIllConditioned {
3739 row,
3740 kappa_estimate,
3741 } => write!(
3742 f,
3743 "arrow-Schur: per-row H_tt^({row}) Cholesky succeeded but failed \
3744 the safe-inversion guard (kappa_estimate={kappa_estimate:e}); \
3745 Schur reduction would be numerically contaminated"
3746 ),
3747 ArrowSchurError::SchurFactorFailed { reason } => {
3748 write!(f, "arrow-Schur: Schur complement Cholesky failed: {reason}")
3749 }
3750 ArrowSchurError::PcgFailed { reason } => {
3751 write!(f, "arrow-Schur: Schur PCG failed: {reason}")
3752 }
3753 ArrowSchurError::UnboundedNegativeCurvature {
3754 curvature,
3755 direction_norm_sq,
3756 } => write!(
3757 f,
3758 "arrow-Schur: unbounded Schur PCG hit negative curvature pᵀSp={curvature:e} \
3759 (‖p‖²={direction_norm_sq:e}); reduced Schur is indefinite (co-collapse), \
3760 retry with a curvature-floor ridge"
3761 ),
3762 ArrowSchurError::AdaptiveCorrectionFailed { reason } => {
3763 write!(
3764 f,
3765 "arrow-Schur: adaptive proximal correction failed: {reason}"
3766 )
3767 }
3768 }
3769 }
3770}
3771
3772impl std::error::Error for ArrowSchurError {}
3773
3774// ---------------------------------------------------------------------------
3775// Cholesky helpers (kept local to avoid a new public-API dependency on the
3776// linalg crate. The systems here are tiny per-row (d × d, d ∈ {1..16}) and
3777// modest at the Schur level (K × K, K ∈ {basis size}). For production SAE
3778// scales the Schur factor should switch to faer; this module's `cholesky_lower`
3779// is the obvious replacement site.)
3780// ---------------------------------------------------------------------------
3781
3782pub(crate) fn cholesky_lower(a: &Array2<f64>) -> Result<Array2<f64>, String> {
3783 let n = a.nrows();
3784 if a.ncols() != n {
3785 return Err(format!("cholesky_lower: non-square {}×{}", n, a.ncols()));
3786 }
3787 if let Some((idx, _)) = a.iter().enumerate().find(|(_, v)| !v.is_finite()) {
3788 return Err(format!(
3789 "cholesky_lower: non-finite entry at linear index {idx}"
3790 ));
3791 }
3792
3793 let mut maybe_device = a.clone();
3794 if gam_gpu::try_cholesky_lower_inplace(&mut maybe_device).is_some() {
3795 return Ok(maybe_device);
3796 }
3797
3798 let mut l = Array2::<f64>::zeros((n, n));
3799 for i in 0..n {
3800 for j in 0..=i {
3801 let mut sum = a[[i, j]];
3802 for kk in 0..j {
3803 sum -= l[[i, kk]] * l[[j, kk]];
3804 }
3805 if i == j {
3806 if !sum.is_finite() || sum <= 0.0 {
3807 return Err(format!(
3808 "non-PD pivot {sum} at index {i} (matrix is not positive definite)"
3809 ));
3810 }
3811 l[[i, j]] = sum.sqrt();
3812 } else {
3813 l[[i, j]] = sum / l[[j, j]];
3814 }
3815 }
3816 }
3817 Ok(l)
3818}