Skip to main content

gam_solve/arrow_schur/
system.rs

1//! The bordered arrow-Schur system itself: [`ArrowRowBlock`], the
2//! [`ArrowSchurSystem`] container and its assembly impl, cross-row latent
3//! penalties, the streaming builder, and the per-row factor caches.
4
5use super::*;
6
7/// Per-row block data for the arrow-Schur system.
8///
9/// `htt` holds the `d × d` Gauss–Newton block for row `i` (including any
10/// analytic-penalty contributions on that row); `htbeta` holds the
11/// `d × K` cross-block `H_tβ^(i)`; `gt` is the `d`-length latent
12/// gradient for row `i`.
13#[derive(Debug, Clone)]
14pub struct ArrowRowBlock {
15    /// `H_tt^(i)`, shape `(d, d)`.
16    pub htt: Array2<f64>,
17    /// `H_tβ^(i)`, shape `(d, K)`.
18    pub htbeta: Array2<f64>,
19    /// `g_t^(i)`, shape `(d,)`.
20    pub gt: Array1<f64>,
21}
22
23impl ArrowRowBlock {
24    /// Allocate one BA point-block row: local latent Hessian, point-camera
25    /// cross block, and point gradient.
26    pub fn new(d: usize, k: usize) -> Self {
27        Self::new_with_htbeta_cols(d, k)
28    }
29
30    /// Allocate one BA row whose dense cross-block slab has `htbeta_cols`
31    /// columns. This is used by matrix-free assemblers that keep the shared
32    /// beta tier at one width while dense row supplements live in another
33    /// coordinate system.
34    pub fn new_with_htbeta_cols(d: usize, htbeta_cols: usize) -> Self {
35        Self {
36            htt: Array2::<f64>::zeros((d, d)),
37            htbeta: Array2::<f64>::zeros((d, htbeta_cols)),
38            gt: Array1::<f64>::zeros(d),
39        }
40    }
41}
42
43/// Bordered (t, β) Newton system with arrow structure.
44///
45/// The β-block is held as a dense `K × K` Hessian `H_ββ` plus a `K`-length
46/// gradient `g_β` for direct BA modes. Large-scale inexact BA callers may
47/// additionally install a matrix-free `H_ββ x` operator and diagonal via
48/// [`ArrowSchurSystem::set_shared_beta_operator`]; the InexactPCG mode then
49/// avoids dense Schur formation/factorization.
50/// The t-block is a `Vec<ArrowRowBlock>` of length `N`.
51///
52/// Construction is the driver's responsibility: the driver
53///
54///   1. evaluates Φ(t) and the radial jet `∂Φ/∂t` (the latter via
55///      [`gam_terms::latent::LatentCoordValues::design_gradient_wrt_t`]);
56///   2. forms the working-weighted Gauss–Newton blocks
57///      `H_tt^(i) += (g_i β)(g_i β)^T`, `H_tβ^(i) += (g_i β) ⊗ Φ_i`,
58///      `H_ββ += Φ^T W Φ + Σ_k λ_k S_k`;
59///   3. calls [`ArrowSchurSystem::add_analytic_penalty_contributions`] to
60///      fold row-block Psi-tier analytic penalties (`ARDPenalty`,
61///      `SparsityPenalty`) into `H_tt^(i)` and Beta-tier penalties into `H_ββ`;
62///   4. calls [`ArrowSchurSystem::solve`] to obtain `(Δt, Δβ)`.
63pub struct ArrowSchurSystem {
64    /// Per-row latent block (length `N`, each row `d × d` / `d × K` / `d`).
65    pub rows: Vec<ArrowRowBlock>,
66    /// `H_ββ`, shape `(K, K)` for direct BA modes; empty when constructed
67    /// by [`ArrowSchurSystem::new_matrix_free_shared`] for PCG-only use.
68    pub hbb: Array2<f64>,
69    /// Optional matrix-free `H_ββ x` operator for large BA Schur PCG.
70    ///
71    /// Direct and Square-Root BA modes still require `hbb`; InexactPCG uses
72    /// this operator when present, avoiding dense shared-block storage for
73    /// SAE-manifold scale `K`.
74    pub hbb_matvec: Option<SharedBetaMatvec>,
75    /// Optional row-local matrix-free multiply for `H_tβ^(i) x`.
76    ///
77    /// When present, all inner-Schur paths route through this operator instead
78    /// of indexing the per-row `htbeta` dense slabs: `reduced_rhs_beta`,
79    /// `schur_matvec` (PCG hot loop), back-substitution,
80    /// `JacobiPreconditioner` construction, `build_dense_schur_direct`, and
81    /// `build_dense_schur_sqrt_ba` all call `sys_htbeta_apply_row` or
82    /// `sys_htbeta_materialize_row`.  Factor caches retain the operator for
83    /// IFT/evidence consumers as before.
84    pub htbeta_matvec: Option<RowHtbetaMatvec>,
85    /// Optional row-local matrix-free transpose multiply `out += H_βt^(i) · v`.
86    ///
87    /// The sparse adjoint of [`Self::htbeta_matvec`]. When present, the
88    /// reduced-Schur matvec applies `H_βt^(i)` directly (sparse `scatter`)
89    /// instead of probing the forward operator against `K` basis vectors. This
90    /// is the per-row sparse apply that lifts the `O(K)` column-probe in the
91    /// GPU PCG and streaming Schur paths to `O(m_i · p)` per row. Installed in
92    /// lock-step with `htbeta_matvec` by [`Self::set_row_htbeta_operator`].
93    pub htbeta_transpose_matvec: Option<RowHtbetaTransposeMatvec>,
94    /// Whether `rows[*].htbeta` contains a dense contribution that must be added
95    /// on top of the matrix-free row operator.
96    pub htbeta_dense_supplement: bool,
97    /// Optional diagonal of the matrix-free shared block, used by the
98    /// Schur-Jacobi preconditioner in the Agarwal-style PCG path.
99    pub hbb_diag: Option<Array1<f64>>,
100    /// `g_β`, shape `(K,)`.
101    pub gb: Array1<f64>,
102    /// Maximum per-row latent dimensionality across all rows.
103    ///
104    /// For homogeneous systems (all rows have the same dim) this equals the
105    /// common per-row `d`.  For heterogeneous systems (e.g. sparse SAE rows
106    /// where JumpReLU / TopK / sparsemax active sets vary per observation)
107    /// this is `max_i row_dims[i]`.  Per-row code should use
108    /// `row.htt.nrows()` or `row_dims[i]`; `d` is an upper bound for
109    /// scratch-buffer sizing.
110    pub d: usize,
111    /// Per-row latent dimensionality: `row_dims[i] == rows[i].htt.nrows()`.
112    ///
113    /// For homogeneous systems `row_dims[i] == d` for all `i`.
114    pub row_dims: Arc<[usize]>,
115    /// Flat-buffer row offsets for the `delta_t` vector produced by
116    /// [`Self::solve`] / [`solve_arrow_newton_step_core`].
117    ///
118    /// `row_offsets[i]` is the start index for row `i`'s slice in `delta_t`;
119    /// `row_offsets[n]` is the total `delta_t` length.  For homogeneous
120    /// systems `row_offsets[i] == i * d`.
121    pub row_offsets: Arc<[usize]>,
122    /// β dimensionality `K`.
123    pub k: usize,
124    /// Geometry tag for the row-local latent blocks after optional
125    /// Riemannian projection. Euclidean/no-op geometry uses the sentinel.
126    pub manifold_mode_fingerprint: u64,
127    /// Structural/value tag for row-local Hessian factors and their Schur
128    /// inputs. Stale caches must be rejected when row-dependent Hessian
129    /// penalties or cross-blocks change.
130    pub row_hessian_fingerprint: u64,
131    /// Registry-side tag for row-dependent analytic-penalty Hessian inputs.
132    /// Combined with the materialized row blocks in
133    /// [`Self::current_row_hessian_fingerprint`].
134    pub analytic_row_hessian_fingerprint: u64,
135    /// Term-block column ranges for the block-Jacobi Schur preconditioner.
136    ///
137    /// Each entry `r` means that indices `r.start..r.end` belong to one
138    /// coefficient block (a GAM term or a custom parameter family from
139    /// `ParameterBlockSpec`). When populated via
140    /// [`Self::set_block_offsets`], the Jacobi preconditioner inverts the
141    /// full `b × b` Schur block for each term instead of only its diagonal.
142    ///
143    /// The default (empty slice) causes `JacobiPreconditioner` to fall back
144    /// to pure scalar diagonal inversion, preserving the pre-#283 behaviour.
145    pub block_offsets: Arc<[Range<usize>]>,
146    /// Optional matrix-free penalty-side `H_ββ` operator (#296).
147    ///
148    /// When set, all hot paths (`schur_matvec`, `build_dense_schur_*`,
149    /// `JacobiPreconditioner`, quadratic-form reduction) route through this
150    /// operator instead of the dense `hbb` accumulator, enabling
151    /// `BlockPenaltyOp` / `KroneckerPenaltyOp` to skip the `O(K²)` dense
152    /// materialisation for structured smoothness penalties.
153    ///
154    /// When `None`, those paths fall back to wrapping `hbb` in a transient
155    /// `DensePenaltyOp` — identical observable behaviour, no new allocation
156    /// hot-path cost for callers that have not opted in.
157    pub penalty_op: Option<Arc<dyn BetaPenaltyOp>>,
158    /// Device-uploadable SAE Kronecker data for CUDA-resident reduced PCG.
159    ///
160    /// The generic matrix-free closures remain the authoritative CPU path. This
161    /// descriptor is installed only when SAE assembly has a matching CUDA sparse
162    /// representation for both `H_tβ` and `H_ββ`.
163    pub device_sae_pcg: Option<Arc<DeviceSaePcgData>>,
164    /// Registered Psi-tier analytic penalties whose Hessian couples *distinct*
165    /// latent rows (non-row-block-diagonal), captured by
166    /// [`Self::add_analytic_penalty_contributions`].
167    ///
168    /// These penalties (`TotalVariationPenalty`, `SheafConsistencyPenalty`,
169    /// block-orthogonality, …) produce off-row Hessian blocks `∂²P/∂t_i∂t_j`
170    /// (`i ≠ j`) that the arrow elimination — which assumes each `H_tt^(i)` is
171    /// independent of every other row — cannot represent. Their *gradient* is
172    /// still folded into `g_t` exactly like every other Psi penalty; only their
173    /// curvature is held here, applied during the solve as a full-latent
174    /// Hessian-vector product `P_cross · Δt` against the penalty's
175    /// `psd_majorizer_hvp`. When this vector is non-empty,
176    /// [`solve_arrow_newton_step_artifacts`] auto-selects the matrix-free
177    /// full-system PCG path (arrow block-diagonal inverse as preconditioner)
178    /// instead of the exact one-shot Schur elimination. When empty, the system
179    /// is purely row-block-diagonal and the exact Schur path is unchanged.
180    pub cross_row_penalties: Vec<CrossRowLatentPenalty>,
181    /// Optional row-local gauge directions for evidence-only Faddeev-Popov
182    /// deflation of an otherwise non-PD `H_tt` row block.
183    ///
184    /// These vectors live in each row's actual chart block, so compact SAE rows
185    /// and dense rows share the same factorization path. Ordinary Newton solves
186    /// ignore them; only undamped evidence factors with
187    /// `tolerate_ill_conditioning` set may stiffen a gauge-explained row
188    /// direction.
189    pub row_gauge_deflation: Option<ArrowRowGaugeDeflation>,
190    /// Optional exact cross-row IBP low-rank source (#1038). When set, the
191    /// factorization downdates the per-row logit-slot self term and layers the
192    /// exact rank-`R` Woodbury correction onto the evidence cache (value,
193    /// log-determinant, and θ/ρ-adjoint together). `None` for all non-IBP
194    /// systems — the row-block-diagonal arrow path is then unchanged.
195    pub ibp_cross_row: Option<IbpCrossRowSource>,
196}
197
198impl Clone for ArrowSchurSystem {
199    fn clone(&self) -> Self {
200        Self {
201            rows: self.rows.clone(),
202            hbb: self.hbb.clone(),
203            hbb_matvec: self.hbb_matvec.clone(),
204            htbeta_matvec: self.htbeta_matvec.clone(),
205            htbeta_transpose_matvec: self.htbeta_transpose_matvec.clone(),
206            htbeta_dense_supplement: self.htbeta_dense_supplement,
207            hbb_diag: self.hbb_diag.clone(),
208            gb: self.gb.clone(),
209            d: self.d,
210            row_dims: Arc::clone(&self.row_dims),
211            row_offsets: Arc::clone(&self.row_offsets),
212            k: self.k,
213            manifold_mode_fingerprint: self.manifold_mode_fingerprint,
214            row_hessian_fingerprint: self.row_hessian_fingerprint,
215            analytic_row_hessian_fingerprint: self.analytic_row_hessian_fingerprint,
216            block_offsets: Arc::clone(&self.block_offsets),
217            penalty_op: self.penalty_op.clone(),
218            device_sae_pcg: self.device_sae_pcg.clone(),
219            cross_row_penalties: self.cross_row_penalties.clone(),
220            row_gauge_deflation: self.row_gauge_deflation.clone(),
221            ibp_cross_row: self.ibp_cross_row.clone(),
222        }
223    }
224}
225
226/// A captured cross-row Psi-tier analytic penalty: the penalty kind plus the
227/// global-ρ slice (`rho_local`) it was registered with.
228///
229/// Holds an owned copy of the local ρ-axes so the penalty's
230/// [`AnalyticPenaltyKind::psd_majorizer_hvp`] can be evaluated during the
231/// matrix-free full-system solve without re-deriving the ρ layout. The penalty
232/// itself is an `Arc`-backed clone (cheap), so capturing it does not copy the
233/// penalty payload.
234#[derive(Clone)]
235pub struct CrossRowLatentPenalty {
236    /// The non-row-block-diagonal Psi penalty (e.g. `TotalVariationPenalty`).
237    pub penalty: AnalyticPenaltyKind,
238    /// The penalty's local ρ-axes (its slice of the global ρ vector).
239    pub rho_local: Array1<f64>,
240    /// The flat latent vector (`N·d`, row-major) the penalty's curvature was
241    /// linearized at — i.e. the `target_t` passed to
242    /// [`ArrowSchurSystem::add_analytic_penalty_contributions`]. The Hessian of
243    /// a nonlinear penalty (the smoothed-TV curvature weights `φ''(D t)`,
244    /// etc.) depends on this point, so `psd_majorizer_hvp` must be evaluated
245    /// against it for the Newton operator to be the true Hessian at the
246    /// current iterate.
247    pub target_t: Array1<f64>,
248}
249
250/// Exact cross-row low-rank IBP source (#1038): the per-column rank-one Hessian
251/// terms `H_(i,k),(j,k) = d_k·z'_ik·z'_jk` (for ALL `i,j`, including the `i=j`
252/// self term) that couple DISTINCT latent rows through a shared atom column `k`.
253///
254/// Stacking over rows, this is `H_full = H₀' + U D Uᵀ`, where:
255/// * `U` is `delta_t_len × R` with `U[g, k] = z'_ik` at the global latent index
256///   `g` of row `i`'s logit slot for atom `k` (zero elsewhere) — i.e. column `k`
257///   is supported on the atom-`k` logit slot of every row;
258/// * `D = diag(d_k)`, `d_k = w·s'_k` ([`gam_terms::analytic_penalties::IbpHessianDiagThirdChannels::cross_row_d`]);
259/// * `H₀'` is the assembled latent block-diagonal `H₀` with the per-row self
260///   term `d_k·z'_ik²` REMOVED from each logit-slot diagonal (the assembled
261///   `H₀` already carries it, so the FULL rank-one outer product `U D Uᵀ` —
262///   which re-adds the `i=j` diagonal — would double-count without this
263///   downdate). The determinant lemma `log det(I_R + D UᵀH₀'⁻¹U)` is only the
264///   exact rank-`R` correction against this no-self base.
265///
266/// The arrow elimination assumes each row's `H_tt^(i)` is independent of every
267/// other row, so it structurally cannot hold this coupling block-locally. The
268/// factorization owner (`solver::arrow_schur`) consumes this source to (a)
269/// downdate the per-row logit diagonal before factoring, (b) build `U`/`D` onto
270/// the resulting [`ArrowFactorCache`] as a [`CrossRowWoodbury`], and (c) apply
271/// the exact Woodbury correction to the value/curvature solve, the evidence
272/// log-determinant, and the θ/ρ-adjoint TOGETHER (they all describe the SAME
273/// `H_full`).
274#[derive(Clone, Debug, Default)]
275pub struct IbpCrossRowSource {
276    /// Number of atom columns `R` (the rank of the cross-row update).
277    pub r: usize,
278    /// `d_k = w·s'_k`, the scalar `D`-coefficient of column `k`. Length `R`.
279    pub d: Array1<f64>,
280    /// Per-row column entries `(global_t_index, atom_k, z'_ik)`: each tuple
281    /// places `z'_ik` at `U[global_t_index, atom_k]`. The `global_t_index` is
282    /// `row_offsets[i] + local_slot` for the row's logit slot of atom `k`. Only
283    /// nonzero entries are listed (one per active (row, atom) pair).
284    pub entries: Vec<(usize, usize, f64)>,
285}
286
287impl IbpCrossRowSource {
288    /// Build the dense `delta_t_len × R` factor `U` (each column supported on
289    /// its atom's per-row logit slots) from the sparse entry list.
290    pub(crate) fn dense_u(&self, delta_t_len: usize) -> Array2<f64> {
291        let mut u = Array2::<f64>::zeros((delta_t_len, self.r));
292        for &(g, k, z) in &self.entries {
293            u[[g, k]] += z;
294        }
295        u
296    }
297
298    /// Per-row-slot self-term downdate: returns, for each global latent index,
299    /// the scalar `Σ_k d_k·z'_ik²` to subtract from that logit slot's diagonal
300    /// so the factored base is `H₀'` (self term removed). Indexed by global
301    /// `delta_t` position.
302    pub(crate) fn self_term_downdate(&self, delta_t_len: usize) -> Array1<f64> {
303        let mut down = Array1::<f64>::zeros(delta_t_len);
304        for &(g, k, z) in &self.entries {
305            down[g] += self.d[k] * z * z;
306        }
307        down
308    }
309}
310
311impl ArrowSchurSystem {
312    /// Allocate an empty BA reduced-camera-system instance sized
313    /// `(N point/latent rows × d, K shared decoder parameters)`.
314    pub fn new(n: usize, d: usize, k: usize) -> Self {
315        Self::new_with_hbb(n, d, k, Array2::<f64>::zeros((k, k)))
316    }
317
318    /// Allocate an arrow system with no dense shared `H_ββ` block and with
319    /// per-row dense `H_tβ` slabs allocated at `htbeta_cols` columns.
320    pub fn new_with_empty_hbb_and_htbeta_cols(
321        n: usize,
322        d: usize,
323        k: usize,
324        htbeta_cols: usize,
325    ) -> Self {
326        let rows = (0..n)
327            .map(|_| ArrowRowBlock::new_with_htbeta_cols(d, htbeta_cols))
328            .collect();
329        let row_dims: Arc<[usize]> = (0..n).map(|_| d).collect::<Vec<_>>().into();
330        let row_offsets: Arc<[usize]> = (0..=n).map(|i| i * d).collect::<Vec<_>>().into();
331        Self {
332            rows,
333            hbb: Array2::<f64>::zeros((0, 0)),
334            hbb_matvec: None,
335            htbeta_matvec: None,
336            htbeta_transpose_matvec: None,
337            htbeta_dense_supplement: false,
338            hbb_diag: None,
339            gb: Array1::<f64>::zeros(k),
340            d,
341            row_dims,
342            row_offsets,
343            k,
344            manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
345            row_hessian_fingerprint: 0,
346            analytic_row_hessian_fingerprint: 0,
347            block_offsets: Arc::from([] as [Range<usize>; 0]),
348            penalty_op: None,
349            device_sae_pcg: None,
350            cross_row_penalties: Vec::new(),
351            row_gauge_deflation: None,
352            ibp_cross_row: None,
353        }
354    }
355
356    /// Allocate an arrow system using a caller-owned dense shared-block buffer.
357    /// The buffer must already have shape `(k, k)` and is zeroed in place before
358    /// use so callers can recycle it across assemblies without changing
359    /// numerics.
360    pub fn new_with_hbb(n: usize, d: usize, k: usize, hbb: Array2<f64>) -> Self {
361        Self::new_with_hbb_and_htbeta_cols(n, d, k, hbb, k)
362    }
363
364    /// Allocate an arrow system with a caller-owned dense shared-block buffer and
365    /// per-row dense `H_tβ` slabs allocated at `htbeta_cols` columns.
366    pub fn new_with_hbb_and_htbeta_cols(
367        n: usize,
368        d: usize,
369        k: usize,
370        mut hbb: Array2<f64>,
371        htbeta_cols: usize,
372    ) -> Self {
373        assert_eq!(hbb.dim(), (k, k));
374        hbb.fill(0.0);
375        let rows = (0..n)
376            .map(|_| ArrowRowBlock::new_with_htbeta_cols(d, htbeta_cols))
377            .collect();
378        let row_dims: Arc<[usize]> = (0..n).map(|_| d).collect::<Vec<_>>().into();
379        let row_offsets: Arc<[usize]> = (0..=n).map(|i| i * d).collect::<Vec<_>>().into();
380        Self {
381            rows,
382            hbb,
383            hbb_matvec: None,
384            htbeta_matvec: None,
385            htbeta_transpose_matvec: None,
386            htbeta_dense_supplement: false,
387            hbb_diag: None,
388            gb: Array1::<f64>::zeros(k),
389            d,
390            row_dims,
391            row_offsets,
392            k,
393            manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
394            row_hessian_fingerprint: 0,
395            analytic_row_hessian_fingerprint: 0,
396            block_offsets: Arc::from([] as [Range<usize>; 0]),
397            penalty_op: None,
398            device_sae_pcg: None,
399            cross_row_penalties: Vec::new(),
400            row_gauge_deflation: None,
401            ibp_cross_row: None,
402        }
403    }
404
405    /// Allocate an arrow system whose shared `H_ββ` block is supplied only as
406    /// a matrix-free operator for large BA InexactPCG.
407    ///
408    /// Direct and Square-Root BA modes require dense `hbb` and must not be
409    /// used with this constructor. The row-local `H_tβ` slabs remain explicit;
410    /// a future MegBA backend can replace those slab operations behind
411    /// [`BatchedBlockSolver`].
412    pub fn new_matrix_free_shared<F>(
413        n: usize,
414        d: usize,
415        k: usize,
416        matvec: F,
417        diag: Array1<f64>,
418    ) -> Self
419    where
420        F: for<'a> Fn(ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
421    {
422        assert_eq!(diag.len(), k);
423        let rows = (0..n).map(|_| ArrowRowBlock::new(d, k)).collect();
424        let row_dims: Arc<[usize]> = (0..n).map(|_| d).collect::<Vec<_>>().into();
425        let row_offsets: Arc<[usize]> = (0..=n).map(|i| i * d).collect::<Vec<_>>().into();
426        let matvec_arc: SharedBetaMatvec = Arc::new(matvec);
427        // Mirror the closure into a BetaPenaltyOp so all hot paths (#296)
428        // route through the trait while preserving hbb_matvec + hbb_diag for
429        // code that inspects them directly.
430        let penalty_op: Option<Arc<dyn BetaPenaltyOp>> = Some(Arc::new(MatvecDiagPenaltyOp::new(
431            k,
432            Arc::clone(&matvec_arc),
433            diag.clone(),
434        )));
435        Self {
436            rows,
437            hbb: Array2::<f64>::zeros((0, 0)),
438            hbb_matvec: Some(matvec_arc),
439            htbeta_matvec: None,
440            htbeta_transpose_matvec: None,
441            htbeta_dense_supplement: false,
442            hbb_diag: Some(diag),
443            gb: Array1::<f64>::zeros(k),
444            d,
445            row_dims,
446            row_offsets,
447            k,
448            manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
449            row_hessian_fingerprint: 0,
450            analytic_row_hessian_fingerprint: 0,
451            block_offsets: Arc::from([] as [Range<usize>; 0]),
452            penalty_op,
453            device_sae_pcg: None,
454            cross_row_penalties: Vec::new(),
455            row_gauge_deflation: None,
456            ibp_cross_row: None,
457        }
458    }
459
460
461
462    /// Allocate a heterogeneous-row arrow system with no dense shared `H_ββ`
463    /// block and with row `H_tβ` slabs allocated at `htbeta_cols` columns.
464    pub fn new_with_per_row_dims_empty_hbb_and_htbeta_cols(
465        per_row_dims: Vec<usize>,
466        k: usize,
467        htbeta_cols: usize,
468    ) -> Self {
469        let n = per_row_dims.len();
470        let d = per_row_dims.iter().copied().max().unwrap_or(0);
471        let mut offsets = Vec::with_capacity(n + 1);
472        let mut cursor = 0usize;
473        offsets.push(cursor);
474        for &dim in &per_row_dims {
475            cursor += dim;
476            offsets.push(cursor);
477        }
478        let rows = per_row_dims
479            .iter()
480            .map(|&dim| ArrowRowBlock::new_with_htbeta_cols(dim, htbeta_cols))
481            .collect();
482        Self {
483            rows,
484            hbb: Array2::<f64>::zeros((0, 0)),
485            hbb_matvec: None,
486            htbeta_matvec: None,
487            htbeta_transpose_matvec: None,
488            htbeta_dense_supplement: false,
489            hbb_diag: None,
490            gb: Array1::<f64>::zeros(k),
491            d,
492            row_dims: Arc::from(per_row_dims.into_boxed_slice()),
493            row_offsets: Arc::from(offsets.into_boxed_slice()),
494            k,
495            manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
496            row_hessian_fingerprint: 0,
497            analytic_row_hessian_fingerprint: 0,
498            block_offsets: Arc::from([] as [Range<usize>; 0]),
499            penalty_op: None,
500            device_sae_pcg: None,
501            cross_row_penalties: Vec::new(),
502            row_gauge_deflation: None,
503            ibp_cross_row: None,
504        }
505    }
506
507    /// Allocate a heterogeneous-row system using a caller-owned dense shared
508    /// block and row `H_tβ` slabs allocated at `htbeta_cols` columns.
509    pub fn new_with_per_row_dims_and_hbb_and_htbeta_cols(
510        per_row_dims: Vec<usize>,
511        k: usize,
512        mut hbb: Array2<f64>,
513        htbeta_cols: usize,
514    ) -> Self {
515        assert_eq!(hbb.dim(), (k, k));
516        hbb.fill(0.0);
517        let n = per_row_dims.len();
518        let max_d = per_row_dims.iter().copied().max().unwrap_or(0);
519        let row_dims: Arc<[usize]> = per_row_dims.iter().copied().collect::<Vec<_>>().into();
520        let mut off_vec = Vec::with_capacity(n + 1);
521        let mut cursor = 0usize;
522        for &di in &per_row_dims {
523            off_vec.push(cursor);
524            cursor += di;
525        }
526        off_vec.push(cursor);
527        let row_offsets: Arc<[usize]> = off_vec.into();
528        let rows = per_row_dims
529            .iter()
530            .map(|&di| ArrowRowBlock::new_with_htbeta_cols(di, htbeta_cols))
531            .collect();
532        Self {
533            rows,
534            hbb,
535            hbb_matvec: None,
536            htbeta_matvec: None,
537            htbeta_transpose_matvec: None,
538            htbeta_dense_supplement: false,
539            hbb_diag: None,
540            gb: Array1::<f64>::zeros(k),
541            d: max_d,
542            row_dims,
543            row_offsets,
544            k,
545            manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
546            row_hessian_fingerprint: 0,
547            analytic_row_hessian_fingerprint: 0,
548            block_offsets: Arc::from([] as [Range<usize>; 0]),
549            penalty_op: None,
550            device_sae_pcg: None,
551            cross_row_penalties: Vec::new(),
552            row_gauge_deflation: None,
553            ibp_cross_row: None,
554        }
555    }
556
557    pub fn set_row_gauge_deflation(&mut self, deflation: ArrowRowGaugeDeflation) {
558        self.row_gauge_deflation = Some(deflation);
559    }
560
561    /// Register the exact cross-row IBP low-rank source (#1038). The assembly
562    /// passes the per-column `D`-coefficients (`cross_row_d`) and the `(global
563    /// latent index, atom, z'_ik)` entries built from `z_jac`; the factorization
564    /// then carries the exact rank-`R` Woodbury (value + log-determinant +
565    /// θ/ρ-adjoint) on the evidence cache. An empty source (`r == 0` or no
566    /// entries) is treated as absent so the row-block-diagonal path is unchanged.
567    pub fn set_ibp_cross_row_source(&mut self, source: IbpCrossRowSource) {
568        if source.r == 0 || source.entries.is_empty() {
569            self.ibp_cross_row = None;
570        } else {
571            self.ibp_cross_row = Some(source);
572        }
573    }
574
575    /// Number of BA point/latent rows `N`.
576    pub fn n(&self) -> usize {
577        self.rows.len()
578    }
579
580    /// Recompute the row-system fingerprint from the currently materialized
581    /// row blocks, cross-blocks, and shared-block diagonal.
582    pub fn compute_row_hessian_fingerprint(&self) -> u64 {
583        row_hessian_fingerprint_for_system(self)
584    }
585
586    /// Current effective row-system fingerprint, including the materialized
587    /// row blocks and any registry metadata captured while folding analytic
588    /// penalties into the system.
589    pub fn current_row_hessian_fingerprint(&self) -> u64 {
590        combine_row_and_registry_fingerprints(
591            self.compute_row_hessian_fingerprint(),
592            self.analytic_row_hessian_fingerprint,
593        )
594    }
595
596    /// Store the current row-system fingerprint on the system.
597    ///
598    /// This is intentionally explicit and expensive. Cache and evidence callers
599    /// use [`Self::current_row_hessian_fingerprint`] at the point they need the
600    /// value, after assembly has populated the system, instead of hashing each
601    /// intermediate construction/mutation step.
602    pub fn refresh_row_hessian_fingerprint(&mut self) {
603        self.row_hessian_fingerprint = self.current_row_hessian_fingerprint();
604    }
605
606    /// Install a matrix-free shared-block operator for Agarwal-style
607    /// inexact Schur PCG.
608    ///
609    /// `diag` must be the diagonal of the same `H_ββ` operator and is used
610    /// for the Schur-Jacobi preconditioner. This is the BA "large camera
611    /// system" path mapped to large decoder coefficient blocks.
612    pub fn set_shared_beta_operator<F>(&mut self, matvec: F, diag: Array1<f64>)
613    where
614        F: for<'a> Fn(ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
615    {
616        assert_eq!(diag.len(), self.k);
617        let matvec_arc: SharedBetaMatvec = Arc::new(matvec);
618        // Mirror the closure into a BetaPenaltyOp so all hot paths (#296)
619        // route through the trait, preserving the existing hbb_matvec +
620        // hbb_diag fields for code that inspects them directly.
621        self.penalty_op = Some(Arc::new(MatvecDiagPenaltyOp::new(
622            self.k,
623            Arc::clone(&matvec_arc),
624            diag.clone(),
625        )));
626        self.hbb_matvec = Some(matvec_arc);
627        self.hbb_diag = Some(diag);
628    }
629
630    /// Mark the dense per-row cross-block slabs as active supplements to the
631    /// installed matrix-free row operator.
632    pub fn activate_dense_htbeta_supplement(&mut self) {
633        self.htbeta_dense_supplement = true;
634    }
635
636    /// Install a matrix-free per-row cross-block operator and its sparse
637    /// adjoint.
638    ///
639    /// `forward` must write `out = H_tβ^(row) x` for `out.len() == d` and
640    /// `x.len() == K`. `transpose` must **add** `H_βt^(row) v` into `out` for
641    /// `out.len() == K` and `v.len() == d` (the sparse `scatter` adjoint).
642    ///
643    /// When installed, the forward operator is used during the Newton solve
644    /// (inside `reduced_rhs_beta`, `schur_matvec`, back-substitution, and
645    /// `JacobiPreconditioner` construction) and afterwards by IFT/evidence
646    /// predictors.  Per-row `htbeta` slabs in `ArrowRowBlock` may be left
647    /// zero-sized when this operator is installed — all inner-Schur paths route
648    /// through the matvec instead of indexing the dense block. The transpose
649    /// operator lets the reduced-Schur matvec apply `H_βt^(row)` directly
650    /// (`O(m_i · p)`) instead of probing `forward` against `K` basis vectors.
651    pub fn set_row_htbeta_operator<F, T>(&mut self, forward: F, transpose: T)
652    where
653        F: for<'a> Fn(usize, ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
654        T: for<'a> Fn(usize, ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
655    {
656        self.htbeta_matvec = Some(Arc::new(forward));
657        self.htbeta_transpose_matvec = Some(Arc::new(transpose));
658    }
659
660    /// Register term-block column ranges for the block-Jacobi Schur preconditioner.
661    ///
662    /// Each `Range<usize>` covers the columns of one GAM term (or custom
663    /// parameter family) in the shared `β` vector. The ranges must be
664    /// non-overlapping, sorted, and their union must cover `0..k`.
665    ///
666    /// Call this after building the system and before [`Self::solve`] /
667    /// [`Self::solve_with_options`] whenever the solver will use
668    /// [`ArrowSolverMode::InexactPCG`]. Absent a call, the preconditioner
669    /// falls back to scalar diagonal Jacobi (the pre-#283 behaviour).
670    ///
671    /// The same plumbing is compatible with #287 (custom `ParameterBlockSpec`
672    /// families): callers from that path simply supply ranges derived from
673    /// their own block layout.
674    pub fn set_block_offsets(&mut self, offsets: Arc<[Range<usize>]>) {
675        self.block_offsets = offsets;
676    }
677
678    /// Install a matrix-free penalty-side `H_ββ` operator (#296).
679    ///
680    /// When set, all hot paths (`schur_matvec`, `build_dense_schur_*`,
681    /// `JacobiPreconditioner`, quadratic-form reduction) route through this
682    /// operator instead of the dense `hbb` accumulator, enabling
683    /// `BlockPenaltyOp` / `KroneckerPenaltyOp` to avoid `O(K²)` allocation
684    /// for structured smoothness penalties.
685    pub fn set_penalty_op(&mut self, op: Arc<dyn BetaPenaltyOp>) {
686        self.penalty_op = Some(op);
687    }
688
689    pub fn set_device_sae_pcg_data(&mut self, data: DeviceSaePcgData) {
690        assert_eq!(data.beta_dim, self.k);
691        assert_eq!(data.a_phi.len(), self.rows.len());
692        assert_eq!(data.local_jac.len(), self.rows.len());
693        self.device_sae_pcg = Some(Arc::new(data));
694    }
695
696    /// Return the effective penalty operator: the installed `penalty_op` if
697    /// present, otherwise a `DensePenaltyOp` wrapping the current `hbb`.
698    ///
699    /// Note: when `penalty_op` is `None`, this clones `hbb` into a new
700    /// `DensePenaltyOp`. Callers in hot loops should call this once and
701    /// store the result, not call it per-iteration.
702    pub fn effective_penalty_op(&self) -> Arc<dyn BetaPenaltyOp> {
703        match self.penalty_op.as_ref() {
704            Some(op) => Arc::clone(op),
705            None => Arc::new(DensePenaltyOp(self.hbb.clone())),
706        }
707    }
708
709    /// `y += P x` without allocating a new Arc; dispatches to `penalty_op`
710    /// or falls back to `hbb` inline, avoiding the K×K clone hot-path cost.
711    #[inline]
712    pub(crate) fn penalty_matvec_add(&self, x: &[f64], y: &mut [f64]) {
713        if let Some(op) = self.penalty_op.as_ref() {
714            op.matvec(x, y);
715        } else {
716            let k = self.hbb.nrows();
717            for a in 0..k {
718                let mut acc = 0.0_f64;
719                for b in 0..k {
720                    acc += self.hbb[[a, b]] * x[b];
721                }
722                y[a] += acc;
723            }
724        }
725    }
726
727    /// Reduced-Schur matvec prologue `y = (P + ridge·I) x` written fresh into a
728    /// zeroed `y` (the caller clears `out` first; this is the first writer).
729    ///
730    /// At the SAE LLM border width (#1017) the dense `H_ββ` fallback is a `k×k`
731    /// GEMV whose `O(k²)` cost (≈4M flops at k=2048) runs once per CG iteration
732    /// and was the serial Amdahl ceiling on the per-row-parallel matvec: while
733    /// the `n`-row point-elimination term fans across all cores, this prologue
734    /// pinned one core and grows as `k²`. The dense GEMV is embarrassingly
735    /// parallel over output rows `a` — each `y[a] = Σ_b hbb[a,b]·x[b] + ridge·x[a]`
736    /// is independent and its inner sum order is identical whether one thread or
737    /// many compute it. Here parallelism is over independent output rows (NOT a
738    /// reassociated reduction), so each `y[a]` accumulates in the SAME order as
739    /// serial — the result is bit-identical to serial, not merely deterministic
740    /// run-to-run (the #1017 determinism gate). On THIS exact-order path the
741    /// criterion ranking is invariant; that no-move guarantee holds because the
742    /// order matches serial, and does NOT generalise to chunk-reassociated
743    /// reductions, where a near-tie winner can flip within the f64 margin
744    /// (#1211). The `penalty_op` path stays serial — it is an opaque operator
745    /// with its own structure (SAE uses the dense `hbb`), and small `k` stays
746    /// serial to avoid rayon overhead on a trivial GEMV.
747    ///
748    /// `parallel` is the caller's top-level / not-nested-in-rayon decision (the
749    /// same guard the row loop uses), so this never oversubscribes inside the
750    /// topology race.
751    pub(crate) fn penalty_ridge_prologue_into(
752        &self,
753        x: &[f64],
754        ridge: f64,
755        y: &mut [f64],
756        parallel: bool,
757    ) {
758        let k = self.hbb.nrows();
759        let dense_parallel = parallel
760            && self.penalty_op.is_none()
761            && self.hbb.dim() == (k, k)
762            && k >= SCHUR_PROLOGUE_PARALLEL_K_MIN;
763        if dense_parallel {
764            use rayon::prelude::*;
765            let hbb = &self.hbb;
766            y.par_iter_mut().enumerate().for_each(|(a, ya)| {
767                let mut acc = 0.0_f64;
768                for b in 0..k {
769                    acc += hbb[[a, b]] * x[b];
770                }
771                *ya = acc + ridge * x[a];
772            });
773        } else {
774            self.penalty_matvec_add(x, y);
775            for a in 0..k {
776                y[a] += ridge * x[a];
777            }
778        }
779    }
780
781    /// `diag += diag(P)` without allocating; dispatches to `penalty_op`
782    /// or falls back to `hbb` diagonal / `hbb_diag` inline.
783    #[inline]
784    pub(crate) fn penalty_diagonal_add(&self, diag: &mut [f64]) {
785        if let Some(op) = self.penalty_op.as_ref() {
786            op.diagonal(diag);
787        } else if let Some(hbb_diag) = self.hbb_diag.as_ref() {
788            let k = hbb_diag.len().min(diag.len());
789            for j in 0..k {
790                diag[j] += hbb_diag[j];
791            }
792        } else {
793            let k = self.hbb.nrows().min(diag.len());
794            for j in 0..k {
795                diag[j] += self.hbb[[j, j]];
796            }
797        }
798    }
799
800    /// Add the `b×b` penalty sub-block for `id` to `out`, routing through
801    /// `penalty_op` or falling back to `hbb` / `hbb_diag` inline.
802    #[inline]
803    pub(crate) fn penalty_block_add(
804        &self,
805        id: BetaBlockId,
806        offsets: &[Range<usize>],
807        out: &mut Array2<f64>,
808    ) {
809        if let Some(op) = self.penalty_op.as_ref() {
810            op.block(id, offsets, out);
811        } else {
812            let range = &offsets[id.0];
813            let b = range.end - range.start;
814            if self.hbb.dim() == (self.k, self.k) {
815                for bi in 0..b {
816                    for bj in 0..b {
817                        out[[bi, bj]] += self.hbb[[range.start + bi, range.start + bj]];
818                    }
819                }
820            } else if let Some(hbb_diag) = self.hbb_diag.as_ref() {
821                for bi in 0..b {
822                    out[[bi, bi]] += hbb_diag[range.start + bi];
823                }
824            }
825        }
826    }
827
828    /// Fill a `b×b` penalty sub-block for a set of arbitrary (possibly
829    /// non-contiguous) global column indices `cols`, routing through
830    /// `penalty_op` or falling back to `hbb` / `hbb_diag` inline.
831    ///
832    /// Used by the cluster-Jacobi preconditioner (#299) which groups columns
833    /// by spectral adjacency rather than contiguous block ranges.
834    #[inline]
835    pub(crate) fn penalty_subblock_add(&self, cols: &[usize], out: &mut Array2<f64>) {
836        let b = cols.len();
837        if let Some(op) = self.penalty_op.as_ref() {
838            // Probe each column basis vector and extract the sub-block entries.
839            let mut probe = Array1::<f64>::zeros(self.k);
840            let mut result = Array1::<f64>::zeros(self.k);
841            for bj in 0..b {
842                probe.fill(0.0);
843                probe[cols[bj]] = 1.0;
844                result.fill(0.0);
845                {
846                    let p_slice = probe.as_slice().expect("probe contiguous");
847                    let r_slice = result.as_slice_mut().expect("result contiguous");
848                    op.matvec(p_slice, r_slice);
849                }
850                for bi in 0..b {
851                    out[[bi, bj]] += result[cols[bi]];
852                }
853            }
854        } else if self.hbb.dim() == (self.k, self.k) {
855            for bi in 0..b {
856                for bj in 0..b {
857                    out[[bi, bj]] += self.hbb[[cols[bi], cols[bj]]];
858                }
859            }
860        } else if let Some(hbb_diag) = self.hbb_diag.as_ref() {
861            for bi in 0..b {
862                out[[bi, bi]] += hbb_diag[cols[bi]];
863            }
864        }
865    }
866
867    /// Fold analytic-penalty contributions into the appropriate blocks.
868    ///
869    /// BA source mapping: these are extra prior/regularization normal-equation
870    /// terms before point elimination, the same place Ceres/g2o attach robust
871    /// priors or gauge-fixing constraints.
872    ///
873    /// **Composition path.** Each registered [`AnalyticPenaltyKind`] is
874    /// queried for `grad_target` (added to `g_t` or `g_β`) and then for
875    /// `hessian_diag` first. Diagonal penalties (ARD and the shipped
876    /// sparsity kernels) are injected directly. The row-block-only Psi-tier
877    /// penalties are `ARDPenalty`, `SparsityPenalty`,
878    /// `SoftmaxAssignmentSparsity`, `IBPAssignment`,
879    /// `RowPrecisionPrior`, `ParametricRowPrecisionPrior`, and
880    /// `ScadMcpPenalty`. Their `d × d` per-row Hessian folds into
881    /// `rows[i].htt`, so the exact arrow Schur elimination (`N` independent
882    /// `d × d` row solves) represents them exactly. Dense Beta-tier penalties
883    /// still fall back to `hvp` probes against the canonical basis vectors for
884    /// `β`.
885    ///
886    /// **Cross-row Psi penalties.** Penalties whose Hessian couples *distinct*
887    /// latent rows — `TotalVariationPenalty`, `SheafConsistencyPenalty`,
888    /// block-orthogonality, … — produce off-row blocks `∂²P/∂t_i∂t_j`
889    /// (`i ≠ j`) that the arrow elimination cannot store, since it assumes each
890    /// `H_tt^(i)` is independent of every other row. These are handled without
891    /// any approximation: their **gradient** is folded into `g_t` exactly as
892    /// for every other Psi penalty (`grad_target → g_t`), and their full
893    /// **curvature** is captured into [`Self::cross_row_penalties`] as a
894    /// matrix-free operator. At solve time, `K = K0 + P_cross` where `K0` is
895    /// the block-diagonal arrow operator and `P_cross · Δt = Σ_p ρ_p ·
896    /// psd_majorizer_hvp_p(t, Δt)` is the cross-row penalty Hessian applied to
897    /// the full flat latent vector. The presence of any captured cross-row
898    /// penalty auto-routes [`Self::solve`] through the matrix-free full-system
899    /// PCG path (the exact arrow block-diagonal inverse `K0⁻¹` is the
900    /// preconditioner `M⁻¹`); a purely row-block-diagonal system keeps the
901    /// exact one-shot Schur path unchanged. No new flag is involved — the route
902    /// is selected from the captured penalty set alone (magic by default).
903    ///
904    /// `target_t` is the full flat latent-coordinate vector (row-major, `N·d` entries)
905    /// at the current iterate; `target_beta` is the current `β`. `rho`
906    /// is the global ρ vector restricted to each penalty's local slice
907    /// by [`AnalyticPenaltyRegistry::rho_layout`].
908    pub fn add_analytic_penalty_contributions(
909        &mut self,
910        registry: &AnalyticPenaltyRegistry,
911        target_t: ArrayView1<'_, f64>,
912        target_beta: ArrayView1<'_, f64>,
913        rho_global: ArrayView1<'_, f64>,
914    ) -> Result<(), ArrowSchurError> {
915        let layout = registry.rho_layout();
916        let mut penalty_fingerprints = Vec::new();
917        self.cross_row_penalties.clear();
918        for (penalty, (rho_slice, tier, _name)) in registry.penalties.iter().zip(layout.iter()) {
919            let rho_local = rho_global.slice(ndarray::s![rho_slice.clone()]);
920            match tier {
921                PenaltyTier::Psi => {
922                    if analytic_penalty_is_row_block_diagonal(penalty) {
923                        // Row-block-diagonal: fold gradient + per-row d×d
924                        // curvature into rows[i].htt, exactly representable by
925                        // the arrow Schur elimination.
926                        self.add_ext_coord_penalty(penalty, target_t, rho_local);
927                        if let Some(fingerprint) =
928                            analytic_penalty_row_hessian_fingerprint(penalty, target_t, rho_local)
929                        {
930                            penalty_fingerprints.push(fingerprint);
931                        }
932                    } else {
933                        // Cross-row: fold the gradient into g_t (exact, like
934                        // every Psi penalty), but DO NOT fold any curvature into
935                        // the row blocks — its off-row coupling cannot be stored
936                        // there. Capture the penalty so the solve applies its
937                        // full Hessian-vector product P_cross·Δt over the flat
938                        // latent vector. This auto-selects the matrix-free
939                        // full-system PCG path.
940                        self.add_ext_coord_penalty_gradient_only(penalty, target_t, rho_local);
941                        self.cross_row_penalties.push(CrossRowLatentPenalty {
942                            penalty: penalty.clone(),
943                            rho_local: rho_local.to_owned(),
944                            target_t: target_t.to_owned(),
945                        });
946                    }
947                }
948                PenaltyTier::Beta => {
949                    self.add_beta_penalty(penalty, target_beta, rho_local);
950                }
951                PenaltyTier::Rho => {
952                    // Rho-tier hyperpriors do not contribute to the inner
953                    // (t, β) Newton step; they enter only at the REML
954                    // outer level.
955                }
956            }
957        }
958        // Cross-row penalties contribute to the Newton Hessian operator, not
959        // the stored row blocks, so they must still invalidate the row-Hessian
960        // cache when their curvature changes. Probe each captured penalty's PSD
961        // majorizer against the current latent vector (a deterministic, generic
962        // probe) and fold the resulting signature in.
963        for cross in &self.cross_row_penalties {
964            penalty_fingerprints.push(cross_row_penalty_fingerprint(
965                &cross.penalty,
966                target_t,
967                cross.rho_local.view(),
968            ));
969        }
970        self.analytic_row_hessian_fingerprint = if penalty_fingerprints.is_empty() {
971            0
972        } else {
973            let mut hasher = Fingerprinter::new();
974            hasher.write_str("arrow-schur-row-hessian-registry-v1");
975            hasher.write_usize(penalty_fingerprints.len());
976            for fingerprint in penalty_fingerprints {
977                hasher.write_u64(fingerprint);
978            }
979            hasher.finish_u64()
980        };
981        Ok(())
982    }
983
984    /// Convert row-local Euclidean latent blocks to Riemannian tangent blocks.
985    ///
986    /// This is the only arrow-Schur algebra change needed for manifold
987    /// latents: `g_t`, `H_tt`, and each `H_tβ` column are projected to
988    /// `T_{t_i}M`, while the shared β block and Schur structure remain
989    /// untouched. Embedded constrained manifolds carry a pinned normal block
990    /// so the existing ambient Cholesky factorization still works; all RHS
991    /// terms live in the tangent space, so the solved update retracts cleanly.
992    pub fn apply_riemannian_latent_geometry(&mut self, latent: &LatentCoordValues) {
993        let manifold = latent.manifold();
994        self.manifold_mode_fingerprint = manifold_mode_fingerprint(latent);
995        if manifold.is_euclidean() {
996            return;
997        }
998        assert_eq!(latent.n_obs(), self.rows.len());
999        assert_eq!(latent.latent_dim(), self.d);
1000        for (i, row) in self.rows.iter_mut().enumerate() {
1001            let t_i = ArrayView1::from(latent.row(i));
1002            let gt_e = row.gt.clone();
1003            let htt_e = row.htt.clone();
1004            let htbeta_e = row.htbeta.clone();
1005            row.gt = manifold.project_gradient_to_tangent(t_i, gt_e.view());
1006            row.htt = manifold.riemannian_hessian_matrix(t_i, gt_e.view(), htt_e.view());
1007            row.htbeta = manifold.project_matrix_columns_to_gradient_tangent(
1008                t_i,
1009                gt_e.view(),
1010                htbeta_e.view(),
1011            );
1012        }
1013    }
1014
1015    pub(crate) fn add_ext_coord_penalty(
1016        &mut self,
1017        penalty: &AnalyticPenaltyKind,
1018        target_t: ArrayView1<'_, f64>,
1019        rho_local: ArrayView1<'_, f64>,
1020    ) {
1021        let d = self.d;
1022        let n = self.rows.len();
1023        apply_analytic_penalty(
1024            penalty,
1025            target_t,
1026            rho_local,
1027            n * d,
1028            d,
1029            self,
1030            |sys, flat, value| sys.rows[flat / d].gt[flat % d] += value,
1031            |sys, flat, value| sys.rows[flat / d].htt[[flat % d, flat % d]] += value,
1032            |a, probe| {
1033                for i in 0..n {
1034                    probe[i * d + a] = 1.0;
1035                }
1036            },
1037            |sys, a, hv| {
1038                for i in 0..n {
1039                    for b in 0..d {
1040                        sys.rows[i].htt[[b, a]] += hv[i * d + b];
1041                    }
1042                }
1043            },
1044        );
1045    }
1046
1047    /// Fold ONLY the latent gradient `grad_target → g_t` of an analytic
1048    /// penalty, leaving the row-block Hessian untouched.
1049    ///
1050    /// Used for cross-row Psi penalties: their gradient enters `g_t` exactly
1051    /// like every other Psi penalty, but their curvature must NOT be scattered
1052    /// into the per-row `H_tt^(i)` blocks (the diagonal piece would be
1053    /// double-counted and the off-row coupling cannot be stored there). The
1054    /// full curvature is instead applied as a matrix-free `P_cross · Δt`
1055    /// during the solve, via [`Self::cross_row_penalties`].
1056    pub(crate) fn add_ext_coord_penalty_gradient_only(
1057        &mut self,
1058        penalty: &AnalyticPenaltyKind,
1059        target_t: ArrayView1<'_, f64>,
1060        rho_local: ArrayView1<'_, f64>,
1061    ) {
1062        let d = self.d;
1063        let n = self.rows.len();
1064        assert_eq!(target_t.len(), n * d);
1065        let grad = penalty.grad_target(target_t, rho_local);
1066        for flat in 0..n * d {
1067            self.rows[flat / d].gt[flat % d] += grad[flat];
1068        }
1069    }
1070
1071    /// Apply the aggregate cross-row penalty Hessian `P_cross · v` over the
1072    /// full flat latent vector `v` (length `Σ_i row_dims[i]`), accumulating
1073    /// into `out`.
1074    ///
1075    /// `P_cross = Σ_p psd_majorizer_hvp_p(target_t, ·; ρ_p)` summed over every
1076    /// captured cross-row penalty. Each penalty's `psd_majorizer_hvp` is its
1077    /// exact (PSD) Hessian-vector product over the `N·d` flat latent vector —
1078    /// for `TotalVariationPenalty` this is `Dᵀ diag(φ''(D t)) D · v`, the
1079    /// graph/forward-difference Laplacian-style coupling that links distinct
1080    /// rows. The ρ scaling is already baked into each penalty's resolved
1081    /// weight, so no extra factor is applied here.
1082    ///
1083    /// This is only valid for homogeneous systems (every row of dimension
1084    /// `d`), the only shape cross-row latent penalties are defined on; the
1085    /// flat-index convention `flat = i·d + j` matches every penalty's
1086    /// `latent_dim`/row-major contract.
1087    pub(crate) fn apply_cross_row_penalty_hessian(
1088        &self,
1089        v: ArrayView1<'_, f64>,
1090        out: &mut Array1<f64>,
1091    ) {
1092        for cross in &self.cross_row_penalties {
1093            assert_eq!(cross.target_t.len(), v.len());
1094            let hv =
1095                cross
1096                    .penalty
1097                    .psd_majorizer_hvp(cross.target_t.view(), cross.rho_local.view(), v);
1098            assert_eq!(hv.len(), out.len());
1099            for i in 0..out.len() {
1100                out[i] += hv[i];
1101            }
1102        }
1103    }
1104
1105    pub(crate) fn add_beta_penalty(
1106        &mut self,
1107        penalty: &AnalyticPenaltyKind,
1108        target_beta: ArrayView1<'_, f64>,
1109        rho_local: ArrayView1<'_, f64>,
1110    ) {
1111        let k = self.k;
1112        let hvp_columns = if self.hbb.dim() == (k, k) { k } else { 0 };
1113        apply_analytic_penalty(
1114            penalty,
1115            target_beta,
1116            rho_local,
1117            k,
1118            hvp_columns,
1119            self,
1120            |sys, j, value| sys.gb[j] += value,
1121            |sys, j, value| {
1122                if sys.hbb.dim() == (k, k) {
1123                    sys.hbb[[j, j]] += value;
1124                }
1125                if let Some(hbb_diag) = sys.hbb_diag.as_mut() {
1126                    hbb_diag[j] += value;
1127                }
1128            },
1129            |j, probe| probe[j] = 1.0,
1130            |sys, j, hv| {
1131                for i in 0..k {
1132                    sys.hbb[[i, j]] += hv[i];
1133                }
1134                // Keep `hbb_diag` consistent with the dense `hbb` Hessian when
1135                // both are populated (the dense-allocated path + a later
1136                // `set_shared_beta_operator` install). The HVP probe for
1137                // column `j` returns the full Hessian column, whose `j`-th
1138                // entry is the diagonal contribution of this penalty. Without
1139                // this mirror, the Jacobi Schur preconditioner — which prefers
1140                // `hbb_diag` over `hbb`'s diagonal — would silently use a
1141                // stale diagonal for any Beta-tier analytic penalty that
1142                // exposes only an HVP (no `hessian_diag`).
1143                if let Some(hbb_diag) = sys.hbb_diag.as_mut() {
1144                    hbb_diag[j] += hv[j];
1145                }
1146            },
1147        );
1148    }
1149
1150    /// Schur-eliminate the per-row latent block and solve for `(Δt, Δβ, diag)`.
1151    ///
1152    /// This uses [`ArrowSolveOptions::automatic`]: BA dense RCS for
1153    /// `K <= 2000`, and Agarwal-style inexact Schur PCG above that size.
1154    /// Call [`ArrowSchurSystem::solve_with_options`] to force Square-Root BA
1155    /// or a specific inexact solve policy.
1156    ///
1157    /// Returns `(delta_t, delta_beta, PcgDiagnostics)` with `delta_t` flat
1158    /// row-major of length `N · d` and `delta_beta` of length `K`. The sign
1159    /// convention matches `solve_newton_direction_dense`: the returned
1160    /// increments satisfy the bordered system with RHS `[-g_t; -g_β]`, i.e.
1161    /// they are the *negated* solutions of the standard Newton-direction
1162    /// formulation. `PcgDiagnostics` is zero-valued for the Direct path and
1163    /// carries live counters (PCG iters, ridge escalations, residual) for
1164    /// InexactPCG.
1165    ///
1166    /// `ridge_t` and `ridge_beta` are nonnegative diagonal regularizers
1167    /// added to the latent and β blocks respectively before factorization
1168    /// — used by the LM damping outer wrapper to recover from near-singular
1169    /// inner steps. Pass `0.0` for both to obtain the unregularized
1170    /// Newton direction.
1171    pub fn solve(
1172        &self,
1173        ridge_t: f64,
1174        ridge_beta: f64,
1175    ) -> Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError> {
1176        let options = ArrowSolveOptions::automatic(self.k);
1177        solve_arrow_newton_step_core(self, ridge_t, ridge_beta, &options)
1178    }
1179
1180    /// Solve with the standard LM-style ridge escalation: if a per-row
1181    /// `H_tt + ridge_t·I` Cholesky pivot is non-PD, or the reduced Schur
1182    /// factor fails, geometrically grow both ridges and retry. This is the
1183    /// same Ceres-style proximal correction the Newton driver in
1184    /// `run_joint_fit_arrow_schur` performs around `solve`, lifted into the
1185    /// system itself so every entry point (predict OOS reconstruction,
1186    /// single-shot Newton refinement, …) is self-healing against the
1187    /// pathological per-row blocks produced by PCA-seeded latent
1188    /// coordinates on subset / new data — see #163 and #175.
1189    ///
1190    /// `ridge_t` / `ridge_beta` are the caller-nominal Tikhonov ridges; the
1191    /// escalation only adds extra damping on top of them when the factor
1192    /// fails. PCG / AdaptiveCorrection failures are left untouched because
1193    /// they are not factorization-recoverable.
1194    pub fn solve_with_lm_escalation(
1195        &self,
1196        ridge_t: f64,
1197        ridge_beta: f64,
1198    ) -> Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError> {
1199        let options = ArrowSolveOptions::automatic(self.k);
1200        solve_with_lm_escalation_inner(self, ridge_t, ridge_beta, &options)
1201    }
1202
1203    /// Solve with an explicit BA Schur mode, returning `(Δt, Δβ, PcgDiagnostics)`.
1204    ///
1205    /// [`ArrowSolverMode::Direct`] is the classic dense reduced-camera-system
1206    /// Cholesky path; [`ArrowSolverMode::SqrtBA`] forms the same dense system
1207    /// through Square-Root BA factors; [`ArrowSolverMode::InexactPCG`] runs
1208    /// inexact-step LM on the reduced system with Jacobi-preconditioned
1209    /// Steihaug-CG. `PcgDiagnostics` is zero-valued for Direct/SqrtBA and
1210    /// carries live counters for InexactPCG (iterations, matvec calls,
1211    /// preconditioner escalations, final relative residual, stopping reason).
1212    pub fn solve_with_options(
1213        &self,
1214        ridge_t: f64,
1215        ridge_beta: f64,
1216        options: &ArrowSolveOptions,
1217    ) -> Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError> {
1218        solve_arrow_newton_step_core(self, ridge_t, ridge_beta, options)
1219    }
1220}
1221
1222/// Chunked Schur assembler that never retains all row cross-blocks.
1223pub struct StreamingArrowSchur {
1224    pub n_rows: usize,
1225    /// Maximum per-row latent dim (upper bound for scratch buffers).
1226    pub d: usize,
1227    /// Per-row latent dims `row_dims[i] == rows[i].htt.nrows()`.
1228    pub row_dims: Arc<[usize]>,
1229    /// Flat-buffer row offsets: `row_offsets[i]` is the start of row `i` in
1230    /// `delta_t`; `row_offsets[n_rows]` is the total `delta_t` length.
1231    pub row_offsets: Arc<[usize]>,
1232    pub k: usize,
1233    pub chunk_size: usize,
1234    pub s_acc: Array2<f64>,
1235    pub(crate) rhs_acc: Array1<f64>,
1236    pub(crate) hbb: Array2<f64>,
1237    pub(crate) gb: Array1<f64>,
1238    pub(crate) row_builder: StreamingArrowRowBuilder,
1239    /// Procedural cross-block operator `H_tβ^(i) x`. When present, the dense
1240    /// per-row `H_tβ` slabs are never materialized: `accumulate_chunk` and
1241    /// `back_substitute` probe this operator column-by-column to apply the
1242    /// cross-block, matching the Kronecker / matrix-free assembly path. When
1243    /// `None` (legacy dense BA callers), the per-row `row.htbeta` slab is used.
1244    pub(crate) htbeta_matvec: Option<RowHtbetaMatvec>,
1245    /// Sparse adjoint of `htbeta_matvec`. When present, `row_htbeta` rebuilds
1246    /// the dense `(d_i × K)` cross-block by probing the transpose with `d_i`
1247    /// basis vectors — `O(d_i · m_i · p)` total, vs the `O(K · m_i · p)` cost
1248    /// of probing the forward operator with `K` basis vectors. Since
1249    /// `d_i ≪ K`, this is the per-row sparse apply that replaces the `O(K)`
1250    /// column-probe in the streaming reduced-Schur accumulation.
1251    pub(crate) htbeta_transpose_matvec: Option<RowHtbetaTransposeMatvec>,
1252    /// Lift the per-row κ rejection for evidence/log-det-only solves; see
1253    /// [`ArrowSolveOptions::tolerate_ill_conditioning`]. Set by [`Self::solve`]
1254    /// from the options; defaults to `false` so direct callers of
1255    /// [`Self::accumulate_chunk`] keep the full guard.
1256    pub(crate) tolerate_ill_conditioning: bool,
1257    /// Set when the source system carried an exact cross-row IBP source
1258    /// ([`IbpCrossRowSource`], #1038). The streaming chunked accumulator cannot
1259    /// hold the rank-`R` Woodbury correction chunk-locally — `U`'s columns span
1260    /// ALL rows, so the capacitance `I_R + D Uᵀ H₀'⁻¹ U` needs the per-row
1261    /// factors retained for a global `H₀'⁻¹U` back-solve, which is exactly the
1262    /// `(N·K)`-scale residency the streaming path exists to avoid. Rather than
1263    /// silently DROP the cross-row term (an inexact logdet that would desync
1264    /// from the dense-resident gradient), the streaming log-determinant errors
1265    /// loudly when this is set, forcing IBP-active fits onto the dense resident
1266    /// [`ArrowFactorCache::arrow_log_det`] path (which carries the exact
1267    /// Woodbury). See the #1038 streaming note.
1268    pub(crate) ibp_cross_row_active: bool,
1269    /// SAE manifold evidence-path per-row gauge deflation, copied from the
1270    /// source [`ArrowSchurSystem::row_gauge_deflation`] (#1273/#1377). When
1271    /// present, the streaming per-row factor MUST apply the SAME spectral
1272    /// discovery-and-deflation of an intrinsic-dimension-flat `H_tt^(i)`
1273    /// direction (eigenvalue → +1, ρ-independent `log 1 = 0` evidence) that the
1274    /// dense [`factor_blocks_for_system`] path applies, or the two routes report
1275    /// different log-determinants for the SAME system — the cross-route
1276    /// invariant `streaming_logdet == full_logdet` would break (the #1377
1277    /// regression: #1273 wired the deflation into the dense path only). `None`
1278    /// for every non-evidence caller, which keeps the strict non-PD refusal.
1279    pub(crate) row_gauge_deflation: Option<ArrowRowGaugeDeflation>,
1280    /// The exact cross-row IBP source ([`IbpCrossRowSource`], #1038), cloned from
1281    /// the assembled [`ArrowSchurSystem`]. The bare streaming paths
1282    /// ([`Self::solve`] / [`Self::reduced_schur_and_log_det_tt`]) still REFUSE
1283    /// when this is present (they cannot carry the rank-`R` Woodbury), but the
1284    /// evidence-only [`Self::reduced_schur_log_det_tt_woodbury`] consumes it to
1285    /// accumulate the chunk-additive Woodbury building blocks `M0 = Uᵀ A⁻¹ U`
1286    /// and `W = Bᵀ A⁻¹ U` against the NO-SELF base `H₀'` (self term downdated),
1287    /// which the caller closes into the exact `log det(I_R + D Uᵀ H₀'⁻¹ U)` via
1288    /// [`streaming_cross_row_woodbury_log_det`].
1289    pub(crate) ibp_cross_row: Option<IbpCrossRowSource>,
1290}
1291
1292/// One chunk's contribution to the streaming cross-row IBP Woodbury (#1038).
1293///
1294/// The capacitance `C = I_R + D·M` with `M = Uᵀ H₀'⁻¹ U` is chunk-additive in
1295/// its two ingredients (`A = H₀'` is block-diagonal over rows, `U` is supported
1296/// per-row): `M = M0 + Wᵀ S⁻¹ W` where `M0 = Σ_i Uᵢᵀ Aᵢ⁻¹ Uᵢ` (`R×R`) and
1297/// `W = Σ_i Bᵢᵀ Aᵢ⁻¹ Uᵢ` (`k×R`) accumulate row-by-row, and `S` is the final
1298/// GLOBAL reduced Schur. This carries one chunk's `M0`/`W` plus the per-atom
1299/// `D`; the caller sums them and closes the capacitance after the chunk loop.
1300#[derive(Debug, Clone)]
1301pub struct StreamingWoodburyChunk {
1302    /// `Σ_{i∈chunk} Uᵢᵀ Aᵢ⁻¹ Uᵢ`, `R×R`.
1303    pub m0: Array2<f64>,
1304    /// `Σ_{i∈chunk} Bᵢᵀ Aᵢ⁻¹ Uᵢ`, `k×R` (`k` = reduced β border).
1305    pub w: Array2<f64>,
1306    /// Per-atom `D`-coefficients `d_k = w·s'_k`, length `R`.
1307    pub d: Array1<f64>,
1308}
1309
1310impl std::fmt::Debug for StreamingArrowSchur {
1311    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1312        f.debug_struct("StreamingArrowSchur")
1313            .field("n_rows", &self.n_rows)
1314            .field("d", &self.d)
1315            .field("k", &self.k)
1316            .field("chunk_size", &self.chunk_size)
1317            .finish_non_exhaustive()
1318    }
1319}
1320
1321impl StreamingArrowSchur {
1322    #[must_use]
1323    pub fn new(
1324        n_rows: usize,
1325        d: usize,
1326        row_dims: Arc<[usize]>,
1327        row_offsets: Arc<[usize]>,
1328        k: usize,
1329        hbb: Array2<f64>,
1330        gb: Array1<f64>,
1331        row_builder: StreamingArrowRowBuilder,
1332        chunk_size: usize,
1333    ) -> Self {
1334        assert_eq!(hbb.dim(), (k, k));
1335        assert_eq!(gb.len(), k);
1336        Self {
1337            n_rows,
1338            d,
1339            row_dims,
1340            row_offsets,
1341            k,
1342            chunk_size: chunk_size.max(1),
1343            s_acc: Array2::<f64>::zeros((k, k)),
1344            rhs_acc: Array1::<f64>::zeros(k),
1345            hbb,
1346            gb,
1347            row_builder,
1348            htbeta_matvec: None,
1349            htbeta_transpose_matvec: None,
1350            tolerate_ill_conditioning: false,
1351            ibp_cross_row_active: false,
1352            row_gauge_deflation: None,
1353            ibp_cross_row: None,
1354        }
1355    }
1356
1357    #[must_use]
1358    pub fn from_system(sys: &ArrowSchurSystem, chunk_size: usize) -> Self {
1359        // When a Kronecker / matrix-free htbeta_matvec is installed, the dense
1360        // row.htbeta slabs may be zero-sized.  Rather than materialize every
1361        // `(d × K)` slab (the very `(N·K)`-scale buffer the streaming path
1362        // exists to avoid), retain the procedural operator and probe it per row
1363        // inside `accumulate_chunk` / `back_substitute`.  The row builder then
1364        // only carries the small `H_tt` / `g_t` blocks.
1365        let htbeta_matvec = sys.htbeta_matvec.clone();
1366        let rows: Vec<ArrowRowBlock> = if htbeta_matvec.is_some() {
1367            sys.rows
1368                .iter()
1369                .map(|row| ArrowRowBlock {
1370                    htt: row.htt.clone(),
1371                    htbeta: Array2::<f64>::zeros((0, 0)),
1372                    gt: row.gt.clone(),
1373                })
1374                .collect()
1375        } else {
1376            sys.rows.clone()
1377        };
1378        let rows = Arc::new(rows);
1379        let row_builder: StreamingArrowRowBuilder = Arc::new(move |row| {
1380            rows.get(row)
1381                .cloned()
1382                .ok_or_else(|| ArrowSchurError::SchurFactorFailed {
1383                    reason: format!("streaming row {row} out of bounds"),
1384                })
1385        });
1386        // Materialize the dense β-block from the effective penalty operator so
1387        // the streaming accumulator stays correct when contributions live in a
1388        // structured `BetaPenaltyOp` (e.g. the SAE data-fit Gauss-Newton block,
1389        // represented as `G ⊗ I_p`) rather than the dense `hbb` accumulator.
1390        // When no `penalty_op` is installed this reduces to `hbb.clone()`.
1391        let hbb_dense = sys.effective_penalty_op().to_dense();
1392        let mut streaming = Self::new(
1393            sys.rows.len(),
1394            sys.d,
1395            Arc::clone(&sys.row_dims),
1396            Arc::clone(&sys.row_offsets),
1397            sys.k,
1398            hbb_dense,
1399            sys.gb.clone(),
1400            row_builder,
1401            chunk_size,
1402        );
1403        streaming.htbeta_matvec = htbeta_matvec;
1404        streaming.htbeta_transpose_matvec = sys.htbeta_transpose_matvec.clone();
1405        streaming.ibp_cross_row_active = sys.ibp_cross_row.is_some();
1406        streaming.ibp_cross_row = sys.ibp_cross_row.clone();
1407        // Carry the SAE evidence-path per-row gauge deflation so the streaming
1408        // per-row factor matches the dense `factor_blocks_for_system` exactly
1409        // (#1377): without it, a row with an intrinsic-dimension-flat `H_tt`
1410        // would deflate on the dense path but be refused / log-det-divergent on
1411        // the streaming path, breaking `streaming_logdet == full_logdet`.
1412        streaming.row_gauge_deflation = sys.row_gauge_deflation.clone();
1413        streaming
1414    }
1415
1416    /// Factor one streaming row's `H_tt^(i)`, applying the SAME per-row gauge /
1417    /// spectral deflation the dense [`factor_blocks_for_system`] path applies
1418    /// when this is the SAE manifold evidence path (an installed
1419    /// `row_gauge_deflation`). For every non-evidence caller this is exactly the
1420    /// generic [`factor_one_row`] (strict non-PD refusal), so PD blocks are
1421    /// bit-for-bit unchanged. Routing both the dense and streaming per-row
1422    /// factors through the identical recovery is what keeps their
1423    /// log-determinants identical (#1273/#1377).
1424    fn factor_row(
1425        &self,
1426        row: &ArrowRowBlock,
1427        ridge_t: f64,
1428        di: usize,
1429        row_idx: usize,
1430    ) -> Result<Array2<f64>, ArrowSchurError> {
1431        match self.row_gauge_deflation.as_ref() {
1432            Some(deflation) => factor_one_row_result(
1433                row,
1434                ridge_t,
1435                di,
1436                row_idx,
1437                self.tolerate_ill_conditioning,
1438                deflation.row(row_idx),
1439                // Evidence path: opt into spectral discovery of an
1440                // intrinsic-dimension-flat direction even when this row's
1441                // supplied gauge list is empty/non-spanning — matching the
1442                // `allow_spectral_deflation = true` the dense path passes.
1443                true,
1444            )
1445            .map(|result| result.factor),
1446            None => factor_one_row(row, ridge_t, di, row_idx, self.tolerate_ill_conditioning),
1447        }
1448    }
1449
1450    /// Build the `(di × k)` cross-block for `row_idx` on demand.
1451    ///
1452    /// When the sparse transpose adjoint is installed, probes it with `di`
1453    /// standard basis vectors — each yields a full `K`-row of `H_βt^(i)`
1454    /// (i.e. a row of the `(di × k)` block) via the sparse scatter, for
1455    /// `O(di · m_i · p)` total, far below the `O(K · m_i · p)` cost of probing
1456    /// the forward operator with `K` basis vectors when `di ≪ K`.
1457    ///
1458    /// When only the forward operator is installed (no adjoint), falls back to
1459    /// the `k`-column forward probe. Otherwise clones the dense `row.htbeta`
1460    /// slab.
1461    pub(crate) fn row_htbeta(&self, row_idx: usize, row: &ArrowRowBlock, di: usize) -> Array2<f64> {
1462        if let Some(op_t) = self.htbeta_transpose_matvec.as_ref() {
1463            // Probe the adjoint: for each latent index c, scatter e_c to obtain
1464            // row c of the (di × k) block.
1465            let mut mat = Array2::<f64>::zeros((di, self.k));
1466            let mut e_c = Array1::<f64>::zeros(di);
1467            let mut beta_row = Array1::<f64>::zeros(self.k);
1468            for c in 0..di {
1469                e_c.fill(0.0);
1470                e_c[c] = 1.0;
1471                beta_row.fill(0.0);
1472                op_t(row_idx, e_c.view(), &mut beta_row);
1473                for a in 0..self.k {
1474                    mat[[c, a]] = beta_row[a];
1475                }
1476            }
1477            return mat;
1478        }
1479        match self.htbeta_matvec.as_ref() {
1480            Some(op) => {
1481                let mut mat = Array2::<f64>::zeros((di, self.k));
1482                let mut e_a = Array1::<f64>::zeros(self.k);
1483                let mut col = Array1::<f64>::zeros(di);
1484                for a in 0..self.k {
1485                    e_a.fill(0.0);
1486                    e_a[a] = 1.0;
1487                    col.fill(0.0);
1488                    op(row_idx, e_a.view(), &mut col);
1489                    for c in 0..di {
1490                        mat[[c, a]] = col[c];
1491                    }
1492                }
1493                mat
1494            }
1495            None => row.htbeta.clone(),
1496        }
1497    }
1498
1499    /// Move out the accumulated reduced Schur block `s_acc` and reduced RHS
1500    /// `rhs_acc`, leaving fresh zero buffers in their place.
1501    ///
1502    /// The reduced contribution is `s_acc = hbb − Σ_i H_βt^(i)(H_tt^(i))⁻¹H_tβ^(i)`
1503    /// (the β-block `hbb` seeded by `reset_accumulator`, minus the per-row
1504    /// reduction summed by `accumulate_chunk`) and
1505    /// `rhs_acc = +Σ_i H_βt^(i)(H_tt^(i))⁻¹g_t^(i)`. Used by external online
1506    /// drivers (e.g. the SAE streaming joint fit) that accumulate the reduced
1507    /// system across re-materialized chunk systems.
1508    #[must_use]
1509    pub fn take_accumulators(&mut self) -> (Array2<f64>, Array1<f64>) {
1510        let s = std::mem::replace(&mut self.s_acc, Array2::<f64>::zeros((self.k, self.k)));
1511        let rhs = std::mem::replace(&mut self.rhs_acc, Array1::<f64>::zeros(self.k));
1512        (s, rhs)
1513    }
1514
1515    /// Reset the dense shared accumulator to `H_ββ + ridge_beta I`.
1516    pub fn reset_accumulator(&mut self, ridge_beta: f64) -> Result<(), ArrowSchurError> {
1517        if self.hbb.dim() != (self.k, self.k) {
1518            return Err(ArrowSchurError::SchurFactorFailed {
1519                reason: "streaming Arrow-Schur requires a dense beta block accumulator".to_string(),
1520            });
1521        }
1522        self.s_acc.assign(&self.hbb);
1523        for j in 0..self.k {
1524            self.s_acc[[j, j]] += ridge_beta;
1525            self.rhs_acc[j] = 0.0;
1526        }
1527        Ok(())
1528    }
1529
1530    /// Accumulate rows `[start, end)` into the reduced RHS and Schur block.
1531    pub fn accumulate_chunk(
1532        &mut self,
1533        start: usize,
1534        end: usize,
1535        ridge_t: f64,
1536        mode: ArrowSolverMode,
1537    ) -> Result<(), ArrowSchurError> {
1538        if start > end || end > self.n_rows {
1539            return Err(ArrowSchurError::SchurFactorFailed {
1540                reason: format!(
1541                    "streaming Arrow-Schur chunk [{start}, {end}) outside 0..{}",
1542                    self.n_rows
1543                ),
1544            });
1545        }
1546        let backend = CpuBatchedBlockSolver;
1547        let k = self.k;
1548        // Per-row factor + two block solves + a `k×k` GEMM subtract is the whole
1549        // assembly cost at the SAE LLM shape (#1017); the rows are independent so
1550        // the chunk fans across cores. Stay sequential for the handful-of-rows
1551        // non-SAE callers, or when already inside a rayon worker (the topology
1552        // race fans candidates with `run_topology_race_parallel`) to avoid
1553        // nested-rayon oversubscription — the same gate `schur_matvec` uses.
1554        let parallel = (end - start) >= SCHUR_MATVEC_PARALLEL_ROW_MIN
1555            && rayon::current_thread_index().is_none();
1556        if parallel {
1557            use rayon::prelude::*;
1558            const CHUNK: usize = 64;
1559            // Bind `&self` so the per-row body borrows only the immutable
1560            // streaming state. Each row contributes `+H_βt^(i)(H_tt^(i))⁻¹ g_t^(i)`
1561            // (length `k`) to the reduced RHS and `−H_βt^(i)(H_tt^(i))⁻¹ H_tβ^(i)`
1562            // (`k×k`) to the reduced Schur complement; both are written INTO a
1563            // worker-private `(rhs_part, s_part)` pair so the chunk partials fold
1564            // back in chunk order — bit-identical run-to-run regardless of thread
1565            // scheduling (the #1017 verification gate). The chunk fold reassociates
1566            // the row sum relative to serial, so the criterion ranking is stable
1567            // only up to that f64 margin; a near-tie winner inside the margin can
1568            // flip — not an exact no-move guarantee (#1211).
1569            let this: &Self = self;
1570            let row_into = |row_idx: usize,
1571                            rhs_part: &mut Array1<f64>,
1572                            s_part: &mut Array2<f64>|
1573             -> Result<(), ArrowSchurError> {
1574                let row = (this.row_builder)(row_idx)?;
1575                let di = row.htt.nrows();
1576                this.validate_row(row_idx, &row)?;
1577                let htbeta = this.row_htbeta(row_idx, &row, di);
1578                let factor = this.factor_row(&row, ridge_t, di, row_idx)?;
1579                let v = backend.solve_block_vector(factor.view(), row.gt.view());
1580                for c in 0..di {
1581                    let vc = v[c];
1582                    if vc == 0.0 {
1583                        continue;
1584                    }
1585                    for a in 0..k {
1586                        rhs_part[a] += htbeta[[c, a]] * vc;
1587                    }
1588                }
1589                match mode {
1590                    // InexactPCG differs from Direct only in how the *reduced*
1591                    // system is solved, not in how it is assembled, so it shares
1592                    // the dense Schur subtraction here (see the serial branch).
1593                    ArrowSolverMode::Direct | ArrowSolverMode::InexactPCG => {
1594                        let solved = backend.solve_block_matrix(factor.view(), htbeta.view());
1595                        backend.block_gemm_subtract(s_part, &htbeta, &solved);
1596                    }
1597                    ArrowSolverMode::SqrtBA => {
1598                        let whitened =
1599                            backend.sqrt_solve_block_matrix(factor.view(), htbeta.view());
1600                        backend.block_gemm_subtract(s_part, &whitened, &whitened);
1601                    }
1602                }
1603                Ok(())
1604            };
1605            let partials: Vec<(Array1<f64>, Array2<f64>)> = (start..end)
1606                .into_par_iter()
1607                .chunks(CHUNK)
1608                .map(|idxs| {
1609                    let mut rhs_part = Array1::<f64>::zeros(k);
1610                    let mut s_part = Array2::<f64>::zeros((k, k));
1611                    for i in idxs {
1612                        row_into(i, &mut rhs_part, &mut s_part)?;
1613                    }
1614                    Ok::<_, ArrowSchurError>((rhs_part, s_part))
1615                })
1616                .collect::<Result<Vec<_>, _>>()?;
1617            // Deterministic ordered reduction: fold chunk partials left-to-right.
1618            // `block_gemm_subtract` already subtracted into each `s_part`, so the
1619            // partials carry the negative Schur contribution; add them in.
1620            for (rhs_part, s_part) in &partials {
1621                for a in 0..k {
1622                    self.rhs_acc[a] += rhs_part[a];
1623                }
1624                self.s_acc += s_part;
1625            }
1626        } else {
1627            // Serial path accumulates DIRECTLY into the running `self.{rhs,s}_acc`
1628            // (which carry the `reset_accumulator` seed `H_ββ + ridge·I`), exactly
1629            // as before — bit-for-bit unchanged for the handful-of-rows callers.
1630            for row_idx in start..end {
1631                let row = (self.row_builder)(row_idx)?;
1632                let di = row.htt.nrows();
1633                self.validate_row(row_idx, &row)?;
1634                let htbeta = self.row_htbeta(row_idx, &row, di);
1635                let factor = self.factor_row(&row, ridge_t, di, row_idx)?;
1636                let v = backend.solve_block_vector(factor.view(), row.gt.view());
1637                for c in 0..di {
1638                    let vc = v[c];
1639                    if vc == 0.0 {
1640                        continue;
1641                    }
1642                    for a in 0..k {
1643                        self.rhs_acc[a] += htbeta[[c, a]] * vc;
1644                    }
1645                }
1646                match mode {
1647                    ArrowSolverMode::Direct | ArrowSolverMode::InexactPCG => {
1648                        let solved = backend.solve_block_matrix(factor.view(), htbeta.view());
1649                        backend.block_gemm_subtract(&mut self.s_acc, &htbeta, &solved);
1650                    }
1651                    ArrowSolverMode::SqrtBA => {
1652                        let whitened =
1653                            backend.sqrt_solve_block_matrix(factor.view(), htbeta.view());
1654                        backend.block_gemm_subtract(&mut self.s_acc, &whitened, &whitened);
1655                    }
1656                }
1657            }
1658        }
1659        Ok(())
1660    }
1661
1662    /// Compute the exact arrow Hessian log-determinant by accumulating the
1663    /// reduced Schur complement in row chunks, without retaining the full set
1664    /// of per-row Cholesky factors.
1665    ///
1666    /// This is the streaming analogue of [`ArrowFactorCache::arrow_log_det`]:
1667    ///
1668    /// ```text
1669    /// log|H| = Σ_i log|H_tt^(i)| + log|H_ββ - Σ_i H_βt^(i) H_tt^(i)⁻¹ H_tβ^(i)|.
1670    /// ```
1671    ///
1672    /// The same row builder and procedural `H_tβ` callbacks used by the
1673    /// streaming Newton solve are consumed here, so callers can score REML
1674    /// evidence without materialising the full `(N × q × K)` cross block or
1675    /// the full list of row factors.
1676    pub fn reduced_schur_and_log_det_tt(
1677        &mut self,
1678        ridge_t: f64,
1679        ridge_beta: f64,
1680        options: &ArrowSolveOptions,
1681    ) -> Result<(f64, Array2<f64>), ArrowSchurError> {
1682        if self.ibp_cross_row_active {
1683            return Err(ArrowSchurError::SchurFactorFailed {
1684                reason: "streaming arrow log-det cannot carry the exact cross-row IBP \
1685                         Woodbury correction (#1038): U's columns span all rows, so the \
1686                         rank-R capacitance needs the per-row factors retained — the very \
1687                         (N·K) residency the streaming path avoids. Route IBP-active fits \
1688                         through the dense resident ArrowFactorCache::arrow_log_det instead."
1689                    .to_string(),
1690            });
1691        }
1692        self.tolerate_ill_conditioning = options.tolerate_ill_conditioning;
1693        self.reset_accumulator(ridge_beta)?;
1694        let backend = CpuBatchedBlockSolver;
1695        let mut log_det_tt = 0.0_f64;
1696        for start in (0..self.n_rows).step_by(self.chunk_size) {
1697            let end = (start + self.chunk_size).min(self.n_rows);
1698            for row_idx in start..end {
1699                let row = (self.row_builder)(row_idx)?;
1700                let di = row.htt.nrows();
1701                self.validate_row(row_idx, &row)?;
1702                let htbeta = self.row_htbeta(row_idx, &row, di);
1703                let factor = self.factor_row(&row, ridge_t, di, row_idx)?;
1704                for axis in 0..di {
1705                    log_det_tt += 2.0 * factor[[axis, axis]].ln();
1706                }
1707                match options.mode {
1708                    ArrowSolverMode::Direct | ArrowSolverMode::InexactPCG => {
1709                        let solved = backend.solve_block_matrix(factor.view(), htbeta.view());
1710                        backend.block_gemm_subtract(&mut self.s_acc, &htbeta, &solved);
1711                    }
1712                    ArrowSolverMode::SqrtBA => {
1713                        let whitened =
1714                            backend.sqrt_solve_block_matrix(factor.view(), htbeta.view());
1715                        backend.block_gemm_subtract(&mut self.s_acc, &whitened, &whitened);
1716                    }
1717                }
1718            }
1719        }
1720        symmetrize_upper_from_lower(&mut self.s_acc);
1721        let schur = std::mem::replace(&mut self.s_acc, Array2::<f64>::zeros((self.k, self.k)));
1722        Ok((log_det_tt, schur))
1723    }
1724
1725    /// As [`Self::reduced_schur_and_log_det_tt`], but ALSO accumulates the exact
1726    /// cross-row IBP Woodbury building blocks (#1038) when this streaming system
1727    /// carried an [`IbpCrossRowSource`].
1728    ///
1729    /// Mirrors the dense `factor_blocks_for_system` + [`CrossRowWoodbury::build`]
1730    /// path exactly:
1731    ///
1732    /// * the per-row logit self term `Σ_k d_k·z'_ik²`
1733    ///   ([`IbpCrossRowSource::self_term_downdate`]) is subtracted from each
1734    ///   `H_tt^(i)` BEFORE factoring, so the factored base — and therefore the
1735    ///   returned `log_det_tt`/Schur — is the NO-SELF `H₀'` (the dense path
1736    ///   factors `H₀'` too, then layers the rank-`R` Woodbury);
1737    /// * for each row it forms `Aᵢ⁻¹ Uᵢ` (one block solve against the row's
1738    ///   `H₀'` factor, `Uᵢ` the J-weighted atom-column indicator) and adds the
1739    ///   row's contributions to `M0 = Σ Uᵢᵀ Aᵢ⁻¹ Uᵢ` (`R×R`) and
1740    ///   `W = Σ Bᵢᵀ Aᵢ⁻¹ Uᵢ` (`k×R`, `Bᵢ = H_tβ^(i)`).
1741    ///
1742    /// The caller sums `M0`/`W`/`log_det_tt`/Schur over chunks and closes the
1743    /// capacitance `M = M0 + Wᵀ S⁻¹ W`, `log det(I_R + D·M)` against the GLOBAL
1744    /// reduced Schur `S` via [`streaming_cross_row_woodbury_log_det`]. This is
1745    /// the exact `log det H_full = log det H₀' + log det(I_R + D Uᵀ H₀'⁻¹ U)`
1746    /// the dense [`ArrowFactorCache::arrow_log_det`] returns — the streaming and
1747    /// dense evidence then optimize the SAME REML objective (#1225).
1748    ///
1749    /// When no IBP source is present this delegates to
1750    /// [`Self::reduced_schur_and_log_det_tt`] and returns `None` woodbury, so
1751    /// every non-IBP (softmax / JumpReLU) caller is bit-for-bit unchanged.
1752    pub fn reduced_schur_log_det_tt_woodbury(
1753        &mut self,
1754        ridge_t: f64,
1755        ridge_beta: f64,
1756        options: &ArrowSolveOptions,
1757    ) -> Result<(f64, Array2<f64>, Option<StreamingWoodburyChunk>), ArrowSchurError> {
1758        let Some(source) = self.ibp_cross_row.clone() else {
1759            // Temporarily clear the refusal flag so the shared bare path runs;
1760            // it is `false` whenever the source is absent, so this is a no-op.
1761            let (log_det_tt, schur) =
1762                self.reduced_schur_and_log_det_tt(ridge_t, ridge_beta, options)?;
1763            return Ok((log_det_tt, schur, None));
1764        };
1765        let r = source.r;
1766        let total_len = self.row_offsets[self.n_rows];
1767        let down = source.self_term_downdate(total_len);
1768        // Group the sparse `U` entries `(global_t_index, atom, z'_ik)` by row as
1769        // `(local_slot, atom, z)`. `row_offsets` is strictly ascending, so the
1770        // owning row of a global index is located by binary search.
1771        let mut row_entries: Vec<Vec<(usize, usize, f64)>> = vec![Vec::new(); self.n_rows];
1772        for &(g, atom, z) in &source.entries {
1773            let i = match self.row_offsets.binary_search(&g) {
1774                Ok(idx) => idx,
1775                Err(idx) => idx - 1,
1776            };
1777            let slot = g - self.row_offsets[i];
1778            row_entries[i].push((slot, atom, z));
1779        }
1780        self.tolerate_ill_conditioning = options.tolerate_ill_conditioning;
1781        self.reset_accumulator(ridge_beta)?;
1782        let backend = CpuBatchedBlockSolver;
1783        let mut log_det_tt = 0.0_f64;
1784        let mut m0 = Array2::<f64>::zeros((r, r));
1785        let mut w = Array2::<f64>::zeros((self.k, r));
1786        for start in (0..self.n_rows).step_by(self.chunk_size) {
1787            let end = (start + self.chunk_size).min(self.n_rows);
1788            for row_idx in start..end {
1789                let mut row = (self.row_builder)(row_idx)?;
1790                let di = row.htt.nrows();
1791                self.validate_row(row_idx, &row)?;
1792                // Downdate the per-row logit self term so the factored base is
1793                // `H₀'` (the dense path downdates the SAME `down[base + j]`).
1794                let base = self.row_offsets[row_idx];
1795                for j in 0..di {
1796                    row.htt[[j, j]] -= down[base + j];
1797                }
1798                let htbeta = self.row_htbeta(row_idx, &row, di);
1799                let factor = self.factor_row(&row, ridge_t, di, row_idx)?;
1800                for axis in 0..di {
1801                    log_det_tt += 2.0 * factor[[axis, axis]].ln();
1802                }
1803                match options.mode {
1804                    ArrowSolverMode::Direct | ArrowSolverMode::InexactPCG => {
1805                        let solved = backend.solve_block_matrix(factor.view(), htbeta.view());
1806                        backend.block_gemm_subtract(&mut self.s_acc, &htbeta, &solved);
1807                    }
1808                    ArrowSolverMode::SqrtBA => {
1809                        let whitened =
1810                            backend.sqrt_solve_block_matrix(factor.view(), htbeta.view());
1811                        backend.block_gemm_subtract(&mut self.s_acc, &whitened, &whitened);
1812                    }
1813                }
1814                let entries = &row_entries[row_idx];
1815                if !entries.is_empty() {
1816                    // `Uᵢ`: `di × R`, column `atom` carries `z'_ik` at its logit
1817                    // slot. `Aᵢ⁻¹ Uᵢ` via the row's `H₀'` Cholesky.
1818                    let mut u_local = Array2::<f64>::zeros((di, r));
1819                    for &(slot, atom, z) in entries {
1820                        u_local[[slot, atom]] += z;
1821                    }
1822                    let ainv_u = backend.solve_block_matrix(factor.view(), u_local.view());
1823                    // `M0 += Uᵢᵀ Aᵢ⁻¹ Uᵢ` (`R×R`); `W += Bᵢᵀ Aᵢ⁻¹ Uᵢ` (`k×R`).
1824                    m0 += &u_local.t().dot(&ainv_u);
1825                    w += &htbeta.t().dot(&ainv_u);
1826                }
1827            }
1828        }
1829        symmetrize_upper_from_lower(&mut self.s_acc);
1830        let schur = std::mem::replace(&mut self.s_acc, Array2::<f64>::zeros((self.k, self.k)));
1831        Ok((
1832            log_det_tt,
1833            schur,
1834            Some(StreamingWoodburyChunk {
1835                m0,
1836                w,
1837                d: source.d.clone(),
1838            }),
1839        ))
1840    }
1841
1842    pub fn reduced_schur_log_det(
1843        schur: &Array2<f64>,
1844        options: &ArrowSolveOptions,
1845    ) -> Result<f64, ArrowSchurError> {
1846        let rhs = Array1::<f64>::zeros(schur.nrows());
1847        let trust_metric_weights = None;
1848        let (delta, schur_factor, diag) =
1849            solve_dense_reduced_system(schur, &rhs, options, trust_metric_weights)?;
1850        if delta.len() != schur.nrows() || diag.iterations != 0 {
1851            return Err(ArrowSchurError::SchurFactorFailed {
1852                reason: "streaming log-det reduced solve returned incoherent diagnostics"
1853                    .to_string(),
1854            });
1855        }
1856        let schur_factor = schur_factor.ok_or_else(|| ArrowSchurError::SchurFactorFailed {
1857            reason: "streaming log-det requires a dense reduced Schur factor".to_string(),
1858        })?;
1859        let mut log_det_schur = 0.0_f64;
1860        for axis in 0..schur_factor.nrows() {
1861            log_det_schur += 2.0 * schur_factor[[axis, axis]].ln();
1862        }
1863        Ok(log_det_schur)
1864    }
1865
1866    pub fn exact_arrow_log_det(
1867        &mut self,
1868        ridge_t: f64,
1869        ridge_beta: f64,
1870        options: &ArrowSolveOptions,
1871    ) -> Result<f64, ArrowSchurError> {
1872        let (log_det_tt, schur) =
1873            self.reduced_schur_and_log_det_tt(ridge_t, ridge_beta, options)?;
1874        Ok(log_det_tt + Self::reduced_schur_log_det(&schur, options)?)
1875    }
1876
1877    pub fn solve(
1878        &mut self,
1879        ridge_t: f64,
1880        ridge_beta: f64,
1881        options: &ArrowSolveOptions,
1882    ) -> Result<(Array1<f64>, Array1<f64>, Option<Array2<f64>>), ArrowSchurError> {
1883        if self.ibp_cross_row_active {
1884            return Err(ArrowSchurError::SchurFactorFailed {
1885                reason: "streaming arrow solve cannot carry the exact cross-row IBP \
1886                         Woodbury correction (#1038); route IBP-active fits through the \
1887                         dense resident solve_arrow_newton_step_with_options instead."
1888                    .to_string(),
1889            });
1890        }
1891        // Propagate the evidence/log-det ill-conditioning tolerance to the
1892        // per-row factor calls inside `accumulate_chunk` / `back_substitute`,
1893        // which take their stable public signatures. Direct callers of
1894        // `accumulate_chunk` keep the conservative default (`false`, full guard).
1895        self.tolerate_ill_conditioning = options.tolerate_ill_conditioning;
1896        self.reset_accumulator(ridge_beta)?;
1897        for start in (0..self.n_rows).step_by(self.chunk_size) {
1898            let end = (start + self.chunk_size).min(self.n_rows);
1899            self.accumulate_chunk(start, end, ridge_t, options.mode)?;
1900        }
1901        for j in 0..self.k {
1902            self.rhs_acc[j] -= self.gb[j];
1903        }
1904        symmetrize_upper_from_lower(&mut self.s_acc);
1905        let trust_metric_weights = None;
1906        let (delta_beta, schur_factor, _diag) =
1907            solve_dense_reduced_system(&self.s_acc, &self.rhs_acc, options, trust_metric_weights)?;
1908        let delta_t = self.back_substitute(ridge_t, delta_beta.view())?;
1909        Ok((delta_t, delta_beta, schur_factor))
1910    }
1911
1912    pub(crate) fn back_substitute(
1913        &self,
1914        ridge_t: f64,
1915        delta_beta: ArrayView1<'_, f64>,
1916    ) -> Result<Array1<f64>, ArrowSchurError> {
1917        let backend = CpuBatchedBlockSolver;
1918        // Total delta_t length = row_offsets[n_rows].
1919        let total_len = self.row_offsets[self.n_rows];
1920        let mut delta_t = Array1::<f64>::zeros(total_len);
1921        // Each row's back-solve `Δt_i = -(H_tt^(i))⁻¹(g_t^(i) + H_tβ^(i)Δβ)`
1922        // writes a DISJOINT segment `delta_t[row_base .. row_base+di]` — no
1923        // cross-row reduction, so this is embarrassingly parallel and the scatter
1924        // is bit-identical regardless of which thread produced each segment (the
1925        // #1017 verification gate). At the SAE LLM shape (`n` in the thousands)
1926        // the per-row factor + solve is the whole cost; below the threshold, or
1927        // when already inside a rayon worker (the topology race fans candidates
1928        // with `run_topology_race_parallel`), stay sequential to avoid
1929        // nested-rayon oversubscription — the same guard `schur_matvec` uses.
1930        let parallel =
1931            self.n_rows >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
1932        if parallel {
1933            use rayon::prelude::*;
1934            const CHUNK: usize = 64;
1935            // Per-row body: factor, form the RHS, solve, return `-(dt_i)`.
1936            let row_solve = |row_idx: usize| -> Result<(usize, Array1<f64>), ArrowSchurError> {
1937                let row = (self.row_builder)(row_idx)?;
1938                let di = row.htt.nrows();
1939                self.validate_row(row_idx, &row)?;
1940                let factor = self.factor_row(&row, ridge_t, di, row_idx)?;
1941                let mut htbeta_delta = Array1::<f64>::zeros(di);
1942                if let Some(op) = self.htbeta_matvec.as_ref() {
1943                    op(row_idx, delta_beta, &mut htbeta_delta);
1944                } else {
1945                    for c in 0..di {
1946                        let mut acc = 0.0_f64;
1947                        for a in 0..self.k {
1948                            acc += row.htbeta[[c, a]] * delta_beta[a];
1949                        }
1950                        htbeta_delta[c] = acc;
1951                    }
1952                }
1953                let mut rhs = Array1::<f64>::zeros(di);
1954                for c in 0..di {
1955                    rhs[c] = row.gt[c] + htbeta_delta[c];
1956                }
1957                let dt_i = backend.solve_block_vector(factor.view(), rhs.view());
1958                let mut neg = Array1::<f64>::zeros(di);
1959                for c in 0..di {
1960                    neg[c] = -dt_i[c];
1961                }
1962                Ok((self.row_offsets[row_idx], neg))
1963            };
1964            // Collect per-row segments under rayon, then scatter into the disjoint
1965            // slices. Errors are surfaced via `collect::<Result<…>>`.
1966            let segments: Vec<(usize, Array1<f64>)> = (0..self.n_rows)
1967                .into_par_iter()
1968                .chunks(CHUNK)
1969                .map(|idxs| {
1970                    idxs.into_iter()
1971                        .map(&row_solve)
1972                        .collect::<Result<Vec<_>, _>>()
1973                })
1974                .collect::<Result<Vec<_>, _>>()?
1975                .into_iter()
1976                .flatten()
1977                .collect();
1978            for (base, seg) in &segments {
1979                for (c, &v) in seg.iter().enumerate() {
1980                    delta_t[base + c] = v;
1981                }
1982            }
1983        } else {
1984            let mut rhs = Array1::<f64>::zeros(self.d);
1985            for start in (0..self.n_rows).step_by(self.chunk_size) {
1986                let end = (start + self.chunk_size).min(self.n_rows);
1987                for row_idx in start..end {
1988                    let row = (self.row_builder)(row_idx)?;
1989                    let di = row.htt.nrows();
1990                    self.validate_row(row_idx, &row)?;
1991                    let factor = self.factor_row(&row, ridge_t, di, row_idx)?;
1992                    // `H_tβ^(i) Δβ`: route through the procedural operator when
1993                    // present (no dense slab), else through the dense slab.
1994                    let mut htbeta_delta = Array1::<f64>::zeros(di);
1995                    if let Some(op) = self.htbeta_matvec.as_ref() {
1996                        op(row_idx, delta_beta, &mut htbeta_delta);
1997                    } else {
1998                        for c in 0..di {
1999                            let mut acc = 0.0_f64;
2000                            for a in 0..self.k {
2001                                acc += row.htbeta[[c, a]] * delta_beta[a];
2002                            }
2003                            htbeta_delta[c] = acc;
2004                        }
2005                    }
2006                    for c in 0..di {
2007                        rhs[c] = row.gt[c] + htbeta_delta[c];
2008                    }
2009                    let dt_i = backend.solve_block_vector(factor.view(), rhs.view());
2010                    let row_base = self.row_offsets[row_idx];
2011                    for c in 0..di {
2012                        delta_t[row_base + c] = -dt_i[c];
2013                    }
2014                }
2015            }
2016        }
2017        Ok(delta_t)
2018    }
2019
2020    pub(crate) fn validate_row(
2021        &self,
2022        row_idx: usize,
2023        row: &ArrowRowBlock,
2024    ) -> Result<(), ArrowSchurError> {
2025        let expected_di = if row_idx < self.row_dims.len() {
2026            self.row_dims[row_idx]
2027        } else {
2028            self.d
2029        };
2030        let actual_di = row.htt.nrows();
2031        if actual_di != expected_di || row.htt.ncols() != expected_di {
2032            return Err(ArrowSchurError::PerRowFactorFailed {
2033                row: row_idx,
2034                reason: format!(
2035                    "streaming row H_tt shape {:?} != ({expected_di}, {expected_di})",
2036                    row.htt.dim(),
2037                ),
2038            });
2039        }
2040        // The dense `H_tβ` slab is only validated when no procedural operator is
2041        // installed; with `htbeta_matvec` the slab is intentionally zero-sized
2042        // and the cross-block is probed in `row_htbeta`.
2043        if self.htbeta_matvec.is_none() && row.htbeta.dim() != (expected_di, self.k) {
2044            return Err(ArrowSchurError::SchurFactorFailed {
2045                reason: format!(
2046                    "streaming row H_tβ shape {:?} != ({expected_di}, {})",
2047                    row.htbeta.dim(),
2048                    self.k
2049                ),
2050            });
2051        }
2052        if row.gt.len() != expected_di {
2053            return Err(ArrowSchurError::PerRowFactorFailed {
2054                row: row_idx,
2055                reason: format!("streaming row g_t length {} != {expected_di}", row.gt.len()),
2056            });
2057        }
2058        Ok::<(), _>(())
2059    }
2060}
2061
2062pub(crate) fn apply_analytic_penalty<S, G, D, P, H>(
2063    penalty: &AnalyticPenaltyKind,
2064    target: ArrayView1<'_, f64>,
2065    rho_local: ArrayView1<'_, f64>,
2066    expected_target_len: usize,
2067    hvp_columns: usize,
2068    scatter_target: &mut S,
2069    mut grad_scatter: G,
2070    mut diag_scatter: D,
2071    seed_hvp_probe: P,
2072    mut hvp_column_scatter: H,
2073) where
2074    G: FnMut(&mut S, usize, f64),
2075    D: FnMut(&mut S, usize, f64),
2076    P: Fn(usize, &mut Array1<f64>),
2077    H: for<'a> FnMut(&mut S, usize, ArrayView1<'a, f64>),
2078{
2079    assert_eq!(target.len(), expected_target_len);
2080
2081    let grad = penalty.grad_target(target, rho_local);
2082    for index in 0..expected_target_len {
2083        grad_scatter(scatter_target, index, grad[index]);
2084    }
2085
2086    // The scattered curvature lands in the arrow-Schur `H_tt` / `H_ββ` blocks,
2087    // which are Cholesky-factored (with LM ridge escalation) as the Newton /
2088    // PIRLS curvature operator and must therefore stay PSD. Nonconvex
2089    // sparsifiers (log sparsity, JumpReLU) have an *indefinite* exact Hessian
2090    // that would destroy that positive-definiteness, so we scatter the PSD
2091    // majorizer here — never the exact `hessian_diag` / `hvp`. For convex
2092    // penalties the majorizer equals the exact Hessian (the trait default
2093    // delegates), so this is exact for them. Exact-derivative consumers (the
2094    // outer objective Hessian) use `hessian_diag` / `hvp` directly elsewhere.
2095    if let Some(diag) = penalty.psd_majorizer_diag(target, rho_local) {
2096        assert_eq!(diag.len(), expected_target_len);
2097        for index in 0..expected_target_len {
2098            diag_scatter(scatter_target, index, diag[index]);
2099        }
2100        return;
2101    }
2102
2103    let mut probe = Array1::<f64>::zeros(expected_target_len);
2104    for column in 0..hvp_columns {
2105        probe.fill(0.0);
2106        seed_hvp_probe(column, &mut probe);
2107        let hv = penalty.psd_majorizer_hvp(target, rho_local, probe.view());
2108        hvp_column_scatter(scatter_target, column, hv.view());
2109    }
2110}
2111
2112pub(crate) fn analytic_penalty_is_row_block_diagonal(penalty: &AnalyticPenaltyKind) -> bool {
2113    penalty.is_row_block_diagonal()
2114}
2115
2116/// Per-row + Schur Cholesky factor cache produced by
2117/// [`solve_arrow_newton_step_with_options`]. Consumed downstream by the IFT warm-start
2118/// predictor in `crate::persistent_warm_start`: when the outer
2119/// loop perturbs `(β, ρ)` by a small amount, the new Newton step can be
2120/// predicted by re-using these factors against a refreshed RHS, saving
2121/// the dominant `O(N d³ + K³)` factorization cost.
2122#[derive(Clone)]
2123pub struct ArrowFactorSlab {
2124    pub(crate) data: Arc<[f64]>,
2125    pub(crate) offsets: Arc<[usize]>,
2126    pub(crate) dims: Arc<[usize]>,
2127}
2128
2129impl ArrowFactorSlab {
2130    pub fn from_blocks(blocks: Vec<Array2<f64>>) -> Self {
2131        let mut data = Vec::new();
2132        let mut offsets = Vec::with_capacity(blocks.len() + 1);
2133        let mut dims = Vec::with_capacity(blocks.len());
2134        offsets.push(0);
2135        for block in blocks {
2136            let (rows, cols) = block.dim();
2137            assert_eq!(rows, cols, "ArrowFactorSlab stores square row factors");
2138            dims.push(rows);
2139            data.extend(block.iter().copied());
2140            offsets.push(data.len());
2141        }
2142        Self {
2143            data: data.into(),
2144            offsets: offsets.into(),
2145            dims: dims.into(),
2146        }
2147    }
2148
2149    pub fn len(&self) -> usize {
2150        self.dims.len()
2151    }
2152
2153    pub fn is_empty(&self) -> bool {
2154        self.dims.is_empty()
2155    }
2156
2157    pub fn factor(&self, row: usize) -> ArrayView2<'_, f64> {
2158        let dim = self.dims[row];
2159        let range = self.offsets[row]..self.offsets[row + 1];
2160        ArrayView2::from_shape((dim, dim), &self.data[range])
2161            .expect("ArrowFactorSlab row offset/dim invariant violated")
2162    }
2163
2164    pub fn iter(&self) -> impl Iterator<Item = ArrayView2<'_, f64>> + '_ {
2165        (0..self.len()).map(|row| self.factor(row))
2166    }
2167}
2168
2169impl std::fmt::Debug for ArrowFactorSlab {
2170    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2171        f.debug_struct("ArrowFactorSlab")
2172            .field("rows", &self.len())
2173            .field("values", &self.data.len())
2174            .finish()
2175    }
2176}
2177
2178#[derive(Clone)]
2179pub enum ArrowUndampedFactors {
2180    SameAsDamped,
2181    Owned(ArrowFactorSlab),
2182}
2183
2184impl std::fmt::Debug for ArrowUndampedFactors {
2185    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2186        match self {
2187            Self::SameAsDamped => f.write_str("SameAsDamped"),
2188            Self::Owned(factors) => f.debug_tuple("Owned").field(&factors.len()).finish(),
2189        }
2190    }
2191}
2192
2193/// Apply `H_tβ^(row) · x` for one row, writing into `out` (length `d`).
2194///
2195/// Sums the installed matrix-free operator, when present, and any correctly
2196/// shaped dense `row.htbeta` slab. This lets structured data-fit rows coexist
2197/// with dense analytic-penalty cross blocks on the same row.
2198pub(crate) fn sys_htbeta_apply_row(
2199    sys: &ArrowSchurSystem,
2200    row_idx: usize,
2201    row: &ArrowRowBlock,
2202    x: ArrayView1<'_, f64>,
2203    out: &mut Array1<f64>,
2204) {
2205    out.fill(0.0);
2206    if let Some(op) = sys.htbeta_matvec.as_ref() {
2207        op(row_idx, x, out);
2208    }
2209    if (sys.htbeta_dense_supplement || sys.htbeta_matvec.is_none())
2210        && row.htbeta.dim() == (out.len(), sys.k)
2211    {
2212        let di = row.htbeta.nrows();
2213        for c in 0..di {
2214            let mut acc = 0.0_f64;
2215            for a in 0..sys.k {
2216                acc += row.htbeta[[c, a]] * x[a];
2217            }
2218            out[c] += acc;
2219        }
2220    }
2221}
2222
2223/// Accumulate `H_βt^(row) · v` into `out` (length `k`).
2224///
2225/// `out[a] += Σ_c H_tβ^(row)[c, a] · v[c]`
2226///
2227/// Sums the installed matrix-free operator, when present, and any correctly
2228/// shaped dense `row.htbeta` slab.
2229pub(crate) fn sys_htbeta_accumulate_transpose(
2230    sys: &ArrowSchurSystem,
2231    row_idx: usize,
2232    row: &ArrowRowBlock,
2233    v: ArrayView1<'_, f64>,
2234    out: &mut Array1<f64>,
2235) {
2236    if let Some(op) = sys.htbeta_matvec.as_ref() {
2237        htbeta_probe_transpose(row_idx, op, v, out, v.len(), sys.k);
2238    }
2239    if (sys.htbeta_dense_supplement || sys.htbeta_matvec.is_none())
2240        && row.htbeta.dim() == (v.len(), sys.k)
2241    {
2242        let di = row.htbeta.nrows();
2243        for c in 0..di {
2244            let vc = v[c];
2245            if vc == 0.0 {
2246                continue;
2247            }
2248            for a in 0..sys.k {
2249                out[a] += row.htbeta[[c, a]] * vc;
2250            }
2251        }
2252    }
2253}
2254
2255/// Materialize the dense `(di, k)` cross-block for one row.
2256///
2257/// Materializes the sum of the installed matrix-free operator and any correctly
2258/// shaped dense slab on the row.
2259pub(crate) fn sys_htbeta_materialize_row(
2260    sys: &ArrowSchurSystem,
2261    row_idx: usize,
2262    row: &ArrowRowBlock,
2263) -> Result<Array2<f64>, ArrowSchurError> {
2264    let di = sys.row_dims[row_idx];
2265    let k = sys.k;
2266    let use_dense = sys.htbeta_dense_supplement || sys.htbeta_matvec.is_none();
2267    let mut mat = if use_dense && row.htbeta.dim() == (di, k) {
2268        row.htbeta.clone()
2269    } else {
2270        Array2::<f64>::zeros((di, k))
2271    };
2272    if let Some(op) = sys.htbeta_matvec.as_ref() {
2273        let mut e_a = Array1::<f64>::zeros(k);
2274        let mut col = Array1::<f64>::zeros(di);
2275        for a in 0..k {
2276            e_a.fill(0.0);
2277            e_a[a] = 1.0;
2278            col.fill(0.0);
2279            op(row_idx, e_a.view(), &mut col);
2280            for c in 0..di {
2281                mat[[c, a]] += col[c];
2282            }
2283        }
2284    } else if use_dense && row.htbeta.dim() != (di, k) {
2285        return Err(ArrowSchurError::SchurFactorFailed {
2286            reason: format!(
2287                "row {row_idx}: htbeta shape {:?} != ({di}, {k}) and no htbeta_matvec installed",
2288                row.htbeta.dim()
2289            ),
2290        });
2291    }
2292    Ok(mat)
2293}
2294
2295/// Probe each column of `H_tβ^(row)` by applying the operator to `e_a` and
2296/// dotting the result with `v`.  Accumulates into `out[a]` for all `a in 0..k`.
2297///
2298/// `out[a] += (H_tβ^(row) e_a) · v = H_βt^(row)[a, :] · v`
2299pub(crate) fn htbeta_probe_transpose(
2300    row: usize,
2301    op: &RowHtbetaMatvec,
2302    v: ArrayView1<'_, f64>,
2303    out: &mut Array1<f64>,
2304    d: usize,
2305    k: usize,
2306) {
2307    let mut e_a = Array1::<f64>::zeros(k);
2308    let mut col_a = Array1::<f64>::zeros(d);
2309    for a in 0..k {
2310        e_a.fill(0.0);
2311        e_a[a] = 1.0;
2312        col_a.fill(0.0);
2313        op(row, e_a.view(), &mut col_a);
2314        let mut acc = 0.0_f64;
2315        for c in 0..d {
2316            acc += col_a[c] * v[c];
2317        }
2318        out[a] += acc;
2319    }
2320}
2321
2322#[derive(Clone)]
2323pub enum ArrowHtbetaCache {
2324    Dense {
2325        blocks: Arc<[Array2<f64>]>,
2326        estimated_bytes: usize,
2327    },
2328    Matvec {
2329        op: RowHtbetaMatvec,
2330        estimated_bytes: usize,
2331    },
2332    Disabled {
2333        estimated_bytes: usize,
2334    },
2335}
2336
2337impl std::fmt::Debug for ArrowHtbetaCache {
2338    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2339        match self {
2340            Self::Dense {
2341                blocks,
2342                estimated_bytes,
2343            } => f
2344                .debug_struct("Dense")
2345                .field("blocks", &blocks.len())
2346                .field("estimated_bytes", estimated_bytes)
2347                .finish(),
2348            Self::Matvec {
2349                estimated_bytes, ..
2350            } => f
2351                .debug_struct("Matvec")
2352                .field("estimated_bytes", estimated_bytes)
2353                .finish(),
2354            Self::Disabled { estimated_bytes } => f
2355                .debug_struct("Disabled")
2356                .field("estimated_bytes", estimated_bytes)
2357                .finish(),
2358        }
2359    }
2360}
2361
2362impl ArrowHtbetaCache {
2363    pub(crate) fn is_available(&self) -> bool {
2364        !matches!(self, Self::Disabled { .. })
2365    }
2366
2367    pub(crate) fn apply_row(
2368        &self,
2369        row: usize,
2370        delta_beta: ArrayView1<'_, f64>,
2371        out: &mut Array1<f64>,
2372    ) -> bool {
2373        match self {
2374            Self::Dense { blocks, .. } => {
2375                let Some(block) = blocks.get(row) else {
2376                    return false;
2377                };
2378                if block.ncols() != delta_beta.len() || block.nrows() != out.len() {
2379                    return false;
2380                }
2381                for c in 0..block.nrows() {
2382                    let mut acc = 0.0_f64;
2383                    for a in 0..block.ncols() {
2384                        acc += block[[c, a]] * delta_beta[a];
2385                    }
2386                    out[c] = acc;
2387                }
2388                true
2389            }
2390            Self::Matvec { op, .. } => {
2391                op(row, delta_beta, out);
2392                true
2393            }
2394            Self::Disabled { .. } => false,
2395        }
2396    }
2397
2398    /// Apply the transpose: `out[a] += H_βt^(row)[a, c] · v[c]` for all `a`.
2399    ///
2400    /// `v` has length `d`; `out` has length `k`. Accumulates (does NOT zero
2401    /// `out` first) so callers can sum contributions across rows into a shared
2402    /// accumulator.  Returns `false` when the cache is `Disabled` and no
2403    /// `fallback_op` is provided.
2404    pub(crate) fn apply_row_transpose_accumulate(
2405        &self,
2406        row: usize,
2407        v: ArrayView1<'_, f64>,
2408        out: &mut Array1<f64>,
2409        d: usize,
2410        k: usize,
2411        fallback_op: Option<&RowHtbetaMatvec>,
2412    ) -> bool {
2413        match self {
2414            Self::Dense { blocks, .. } => {
2415                let Some(block) = blocks.get(row) else {
2416                    return false;
2417                };
2418                if block.nrows() != v.len() || block.ncols() != out.len() {
2419                    return false;
2420                }
2421                // H_βt^(i) · v: outer-loop c hoists v[c], inner-loop a is
2422                // contiguous in row-major (d, k) layout.
2423                for c in 0..block.nrows() {
2424                    let vc = v[c];
2425                    if vc == 0.0 {
2426                        continue;
2427                    }
2428                    for a in 0..block.ncols() {
2429                        out[a] += block[[c, a]] * vc;
2430                    }
2431                }
2432                true
2433            }
2434            Self::Matvec { op, .. } => {
2435                // Probe column-by-column: H_tβ^(row) e_a is column a.  dot(col_a, v)
2436                // is entry a of H_βt^(row) v.
2437                htbeta_probe_transpose(row, op, v, out, d, k);
2438                true
2439            }
2440            Self::Disabled { .. } => {
2441                // No cached block.  Use the caller-supplied fallback op if present.
2442                if let Some(op) = fallback_op {
2443                    htbeta_probe_transpose(row, op, v, out, d, k);
2444                    true
2445                } else {
2446                    false
2447                }
2448            }
2449        }
2450    }
2451}
2452
2453/// RAW per-row spectral data of a spectrally-deflated undamped evidence `H_tt`
2454/// block (see [`ArrowFactorCache::deflation_row_spectra`]).
2455///
2456/// `evecs` columns are the RAW symmetric eigenvectors `uₘ` of `H_tt`
2457/// (orthonormal; the deflated directions `vᵢ` are the subset whose eigenvalue
2458/// was pinned). `raw_evals[m]` is the RAW eigenvalue `λₘ` BEFORE the unit-pin /
2459/// floor-clamp. `cond_evals[m]` is the conditioned eigenvalue `λ̃ₘ` the factor
2460/// actually uses (`λ̃ = λ` for an unclamped kept direction, the positive `floor`
2461/// for a clamped kept direction, `1` for a deflated direction). Together they
2462/// give the Daleckii–Krein divided differences the outer-gradient deflation
2463/// correction needs.
2464#[derive(Debug, Clone)]
2465pub struct RowDeflationSpectrum {
2466    pub evecs: Array2<f64>,
2467    pub raw_evals: Array1<f64>,
2468    pub cond_evals: Array1<f64>,
2469}
2470
2471#[derive(Debug, Clone)]
2472pub struct ArrowFactorCache {
2473    /// Per-row lower-triangular Cholesky factors of `H_tt^(i) + ridge_t·I`.
2474    ///
2475    /// These are the *damped* factors used inside the Newton solve. The IFT
2476    /// predictor must NOT use them — see [`Self::htt_factors_undamped`].
2477    pub htt_factors: ArrowFactorSlab,
2478    /// Per-row lower-triangular Cholesky factors of the UNDAMPED
2479    /// `H_tt^(i)` (no `ridge_t` added).
2480    ///
2481    /// The IFT predictor formula
2482    /// `Δt_i = -(H_tt^(i))⁻¹ · (H_tβ^(i) Δβ + δg_t^(i))` is derived from
2483    /// `∂g_t/∂t = H_tt` at the stationary point, with no LM damping term.
2484    /// Reusing the damped factors would bias the predicted shift toward zero
2485    /// in proportion to `ridge_t`. We pay one extra `O(N d³)` Cholesky per
2486    /// Newton solve — the same complexity class as the Newton solve itself —
2487    /// to make the IFT exact.
2488    pub htt_factors_undamped: ArrowUndampedFactors,
2489    /// Lower-triangular Cholesky factor of the Schur complement when the
2490    /// selected BA mode formed/factored dense RCS. `None` for
2491    /// [`ArrowSolverMode::InexactPCG`], where Agarwal-style inexact LM avoids
2492    /// the dense `K × K` factor.
2493    pub schur_factor: Option<Array2<f64>>,
2494    /// Exact undamped joint-Hessian log-determinant produced by the dense
2495    /// factorization path. REML evidence consumes this directly so the Laplace
2496    /// normalizer cannot miss the log-det even when later cache consumers only
2497    /// need solves/traces.
2498    pub joint_hessian_log_det: Option<f64>,
2499    /// BA mode used to create this cache.
2500    pub solver_mode: ArrowSolverMode,
2501    /// Ridge values used to build the cached factors (recorded so the
2502    /// warm-start predictor knows whether the cache is still valid for a
2503    /// requested ridge level).
2504    pub ridge_t: f64,
2505    pub ridge_beta: f64,
2506    /// Per-row cross-block access for `H_tβ^(i) x`.
2507    ///
2508    /// Large caches retain a row matvec callback or disable β-coupled IFT
2509    /// prediction instead of cloning every dense `d × K` slab.
2510    pub htbeta: ArrowHtbetaCache,
2511    /// Maximum per-row latent dim (upper bound; matches `sys.d` at creation).
2512    pub d: usize,
2513    /// Per-row latent dims: `row_dims[i]` is the active dim for row `i`.
2514    pub row_dims: Arc<[usize]>,
2515    /// Flat-buffer row offsets for `delta_t` / IFT output vectors.
2516    /// `row_offsets[i]` is the start of row `i`; `row_offsets[n]` is the
2517    /// total length.
2518    pub row_offsets: Arc<[usize]>,
2519    /// β dimensionality `K`.
2520    pub k: usize,
2521    /// Geometry tag for the row-local factors and cross-blocks.
2522    pub manifold_mode_fingerprint: u64,
2523    /// Row-system tag for the cached per-row factors, cross-blocks, and
2524    /// shared-block diagonal used to build the Schur factor.
2525    pub row_hessian_fingerprint: u64,
2526    /// PCG instrumentation from the solve that produced this cache.
2527    ///
2528    /// Zero-valued (default) when the selected mode did not use PCG
2529    /// (i.e. `Direct` or `SqrtBA`).
2530    pub pcg_diagnostics: PcgDiagnostics,
2531    /// Number of row-local gauge directions stiffened in an undamped evidence
2532    /// factorization.
2533    ///
2534    /// Each direction is stiffened at UNIT stiffness `kappa = 1.0`, so it
2535    /// contributes `log(1) = 0` to the row-block logdet through the returned
2536    /// Cholesky factor: the gauge orbit is a criterion null direction and adds
2537    /// nothing to the Laplace normalizer (the quotient pseudo-determinant
2538    /// convention, cf. `PenaltyPseudologdet`). Zero theta/rho dependence.
2539    pub gauge_deflated_directions: usize,
2540    /// Per-row unit-norm directions `vᵢ` (in each row's `d`-dim latent block
2541    /// coordinates) that an undamped evidence factorization stiffened to UNIT
2542    /// stiffness `λ̃ = 1` (gauge or spectral deflation). Indexed by row; empty
2543    /// for every PD row factored without deflation, and empty overall on the
2544    /// non-deflating solver paths (streaming / cross-row-penalty CG / device).
2545    ///
2546    /// A deflated direction contributes `log(1) = 0` to the row-block log-det
2547    /// and is ρ/θ-INDEPENDENT, so its true contribution to `∂log|H|/∂ρ` is `0`.
2548    /// The analytic outer-gradient traces (`assignment_log_strength_hessian_trace`,
2549    /// `learnable_ibp_data_logdet_alpha_trace`, `logdet_theta_adjoint`) contract
2550    /// `∂H_raw/∂ρ` (the RAW, pre-deflation block derivative) against the DEFLATED
2551    /// inverse, which assigns `1/λ̃ = 1` to each `vᵢ` and therefore spuriously
2552    /// adds `½ vᵢᵀ (∂H_raw/∂ρ) vᵢ`. Those traces subtract this per-row term
2553    /// (kept-subspace restriction) using these directions; without them the
2554    /// REML outer ρ-gradient is biased by `+Σ_deflated ½ vᵢᵀ ∂H_raw/∂ρ vᵢ`.
2555    pub deflated_row_directions: Arc<[Vec<Array1<f64>>]>,
2556    /// Per-row RAW spectral decomposition of an undamped evidence `H_tt` block
2557    /// that underwent SPECTRAL deflation, surfaced so the outer ρ/θ-gradient
2558    /// traces can apply the EXACT deflation-map (Daleckii–Krein) derivative
2559    /// correction, not just the within-row kept-subspace term.
2560    ///
2561    /// The criterion VALUE re-deflates `H_tt` at every ρ, so its gradient is
2562    /// `tr(H_deflated⁻¹ DΦ[∂H_raw/∂ρ])`, where `Φ` is the spectral pin-to-unit
2563    /// map. By Daleckii–Krein `DΦ[Ȧ] = U (F ∘ UᵀȦU) Uᵀ` with the divided-
2564    /// difference matrix `F_{ml} = (λ̃ₘ − λ̃ₗ)/(λₘ − λₗ)` (raw `λ` in the
2565    /// denominator, conditioned `λ̃` in the numerator). The kept×kept block of
2566    /// `F` is `1` (the kept subspace contracts the raw derivative unchanged), the
2567    /// deflated×deflated block is `0`, and the kept(m)×deflated(i) block is
2568    /// `(λₘ − 1)/(λₘ − λᵢ)` — this last, ROTATION, term is what the per-row
2569    /// kept-subspace correction alone misses; it couples to the β-block through
2570    /// the Schur back-substitution carried in `(H⁻¹)_tt`.
2571    ///
2572    /// `Some(spectrum)` only for spectrally-deflated rows; `None` for PD rows,
2573    /// gauge-only deflation (ρ-independent structural null — within-row term
2574    /// suffices), and every non-SAE-evidence solver path (streaming / device /
2575    /// cross-row CG). Empty overall when no row deflated spectrally.
2576    pub deflation_row_spectra: Arc<[Option<RowDeflationSpectrum>]>,
2577    /// Exact cross-row IBP rank-`R` Woodbury correction (#1038), present iff the
2578    /// source system carried an [`IbpCrossRowSource`]. When set, the per-row
2579    /// factors above are of the NO-SELF base `H₀'` (self term `d_k·z'_ik²`
2580    /// downdated from each logit diagonal), and this carrier supplies the exact
2581    /// rank-`R` correction so the value/curvature solve
2582    /// ([`Self::full_inverse_apply`]), the evidence log-determinant
2583    /// ([`Self::arrow_log_det`]), and the θ/ρ-adjoint all describe the same
2584    /// `H_full = H₀' + U D Uᵀ`.
2585    pub cross_row_woodbury: Option<CrossRowWoodbury>,
2586}
2587
2588/// Materialized exact cross-row IBP Woodbury correction (#1038), built against
2589/// an [`ArrowFactorCache`] whose per-row factors are the NO-SELF base `H₀'`.
2590///
2591/// Holds `U` (the `delta_t_len × R` arrow-`t` factor, β-part implicitly zero),
2592/// `D = diag(d_k)`, the projected `M = UᵀH₀'⁻¹U`, the columns `H₀'⁻¹U`, and the
2593/// **LU factorization of the (generally non-symmetric, possibly indefinite)
2594/// capacitance** `C = I_R + D·M`. `d_k = w·s'_k` is not sign-definite, so the
2595/// capacitance is factored by a partial-pivot LU (exact for any sign); the same
2596/// factorization serves the log-determinant `log det C`, the inverse correction
2597/// `H_full⁻¹w = H₀'⁻¹w − H₀'⁻¹U·C⁻¹·(D Uᵀ H₀'⁻¹w)`, and the adjoint's
2598/// selected-inverse (`C⁻¹` and `M`). The full inverse, value/curvature solve,
2599/// log-determinant, and adjoint therefore all describe the SAME
2600/// `H_full = H₀' + U D Uᵀ`.
2601#[derive(Debug, Clone)]
2602pub struct CrossRowWoodbury {
2603    /// `U`: `delta_t_len × R`, column `k` supported on atom-`k` logit slots.
2604    pub u: Array2<f64>,
2605    /// `d_k`, length `R`.
2606    pub d: Array1<f64>,
2607    /// `H₀'⁻¹ U` (the `t`-block), `delta_t_len × R`.
2608    pub h0inv_u: Array2<f64>,
2609    /// `(H₀'⁻¹ U)` β-block, `K × R`. `U` has no β support, but the bordered
2610    /// solve couples the latent columns to `β` through the Schur complement, so
2611    /// this block is generally nonzero and the inverse correction must apply it
2612    /// to the `β` output too.
2613    pub h0inv_u_beta: Array2<f64>,
2614    /// `M = Uᵀ H₀'⁻¹ U`, `R × R` (symmetric). Retained for the θ/ρ-adjoint.
2615    pub m: Array2<f64>,
2616    /// Partial-pivot LU of the capacitance `C = I_R + D·M` (`lu` packs `L`/`U`,
2617    /// `piv` the row swaps), built by [`small_lu_factor`].
2618    pub capacitance_lu: SmallLu,
2619    /// The sparse `U` entries `(global_t_index, atom_k, z'_ik)` — retained so
2620    /// `Uᵀ·v` can be formed over the atom slots without re-deriving them.
2621    pub entries: Vec<(usize, usize, f64)>,
2622}
2623
2624/// Dense partial-pivot LU of a small square matrix. Used for the cross-row IBP
2625/// capacitance `C = I_R + D·M`, which is generally non-symmetric and possibly
2626/// indefinite (`d_k = w·s'_k` is not sign-definite), so a Cholesky/LDLᵀ is
2627/// unavailable. `R` is the atom count, so this is a cheap dense factorization.
2628#[derive(Debug, Clone)]
2629pub struct SmallLu {
2630    /// Packed `L` (unit lower, below diagonal) and `U` (upper, on/above
2631    /// diagonal), `R × R`, in the row-permuted order encoded by `piv`.
2632    pub(crate) lu: Array2<f64>,
2633    /// Row permutation: `piv[i]` is the original row now in position `i`.
2634    pub(crate) piv: Vec<usize>,
2635    /// Sign of the permutation (`±1`), folded into the determinant.
2636    pub(crate) perm_sign: f64,
2637}
2638
2639/// Partial-pivot LU factorization of a small dense square matrix `a` (`R × R`).
2640/// Returns `None` only when a pivot is exactly zero (singular `C`).
2641pub(crate) fn small_lu_factor(a: &Array2<f64>) -> Option<SmallLu> {
2642    let r = a.nrows();
2643    assert_eq!(a.ncols(), r, "small_lu_factor: non-square input");
2644    let mut lu = a.clone();
2645    let mut piv: Vec<usize> = (0..r).collect();
2646    let mut perm_sign = 1.0_f64;
2647    for col in 0..r {
2648        // Partial pivot: pick the largest-magnitude entry on/below the diagonal.
2649        let mut pivot_row = col;
2650        let mut pivot_mag = lu[[col, col]].abs();
2651        for row in (col + 1)..r {
2652            let mag = lu[[row, col]].abs();
2653            if mag > pivot_mag {
2654                pivot_mag = mag;
2655                pivot_row = row;
2656            }
2657        }
2658        // Reject not just an exactly-zero pivot, but any non-finite or
2659        // subnormal magnitude: dividing by a subnormal in the elimination /
2660        // back-solve produces `Inf`/`NaN` that would otherwise flow silently
2661        // into the Woodbury inverse and the evidence log-det (#1038). A
2662        // capacitance this degenerate is a desync the caller must surface
2663        // (→ `Ok(None)` cross-row-absent / `SchurFactorFailed`), not consume.
2664        if !pivot_mag.is_finite() || pivot_mag < f64::MIN_POSITIVE {
2665            return None;
2666        }
2667        if pivot_row != col {
2668            for c in 0..r {
2669                lu.swap((col, c), (pivot_row, c));
2670            }
2671            piv.swap(col, pivot_row);
2672            perm_sign = -perm_sign;
2673        }
2674        let pivot = lu[[col, col]];
2675        for row in (col + 1)..r {
2676            let factor = lu[[row, col]] / pivot;
2677            lu[[row, col]] = factor;
2678            for c in (col + 1)..r {
2679                let v = lu[[col, c]];
2680                lu[[row, c]] -= factor * v;
2681            }
2682        }
2683    }
2684    // Post-elimination invariant: every U diagonal is finite and not subnormal.
2685    // The per-column pivot guard above validates each diagonal as it is chosen,
2686    // but assert it explicitly so `SmallLu::solve` can divide by `lu[[i, i]]`
2687    // without a per-entry guard and so a `SmallLu` value can never carry a
2688    // factor that would silently emit `Inf`/`NaN` into the capacitance solve.
2689    for i in 0..r {
2690        let u = lu[[i, i]];
2691        if !u.is_finite() || u.abs() < f64::MIN_POSITIVE {
2692            return None;
2693        }
2694    }
2695    Some(SmallLu { lu, piv, perm_sign })
2696}
2697
2698impl SmallLu {
2699    pub(crate) fn dim(&self) -> usize {
2700        self.lu.nrows()
2701    }
2702
2703    /// `log|det|` and the determinant sign (`±1`).
2704    pub(crate) fn log_abs_det_and_sign(&self) -> (f64, f64) {
2705        let mut log_abs = 0.0_f64;
2706        let mut sign = self.perm_sign;
2707        for i in 0..self.dim() {
2708            let u = self.lu[[i, i]];
2709            log_abs += u.abs().ln();
2710            if u < 0.0 {
2711                sign = -sign;
2712            }
2713        }
2714        (log_abs, sign)
2715    }
2716
2717    /// Solve `C x = b` reusing the factorization (in place into a fresh vector).
2718    ///
2719    /// Returns `None` when the solve cannot produce a finite result — either a
2720    /// `U` diagonal is non-finite/subnormal (defensive: `small_lu_factor`
2721    /// already rejects such factors, but a future construction path might not)
2722    /// or the back-substitution overflows to `Inf`/`NaN` for an extreme RHS on
2723    /// an ill-conditioned (yet validly factored) capacitance. Surfacing `None`
2724    /// lets the Woodbury / evidence consumers fail loudly (#1038) instead of
2725    /// flowing a silent `NaN` into the log-det and outer gradient.
2726    pub(crate) fn solve(&self, b: &Array1<f64>) -> Option<Array1<f64>> {
2727        let r = self.dim();
2728        // Apply the row permutation: y = P b.
2729        let mut y = Array1::<f64>::zeros(r);
2730        for i in 0..r {
2731            y[i] = b[self.piv[i]];
2732        }
2733        // Forward solve L y' = P b (L unit-lower).
2734        for i in 0..r {
2735            let mut sum = y[i];
2736            for j in 0..i {
2737                sum -= self.lu[[i, j]] * y[j];
2738            }
2739            y[i] = sum;
2740        }
2741        // Back solve U x = y' (U upper, explicit diagonal).
2742        let mut x = Array1::<f64>::zeros(r);
2743        for i in (0..r).rev() {
2744            let mut sum = y[i];
2745            for j in (i + 1)..r {
2746                sum -= self.lu[[i, j]] * x[j];
2747            }
2748            let pivot = self.lu[[i, i]];
2749            if !pivot.is_finite() || pivot.abs() < f64::MIN_POSITIVE {
2750                return None;
2751            }
2752            x[i] = sum / pivot;
2753        }
2754        if x.iter().all(|v| v.is_finite()) {
2755            Some(x)
2756        } else {
2757            None
2758        }
2759    }
2760}
2761
2762/// Close the streaming cross-row IBP Woodbury (#1038) into its exact evidence
2763/// log-determinant correction.
2764///
2765/// Given the chunk-summed `M0 = Uᵀ A⁻¹ U` (`R×R`), `W = Bᵀ A⁻¹ U` (`k×R`), the
2766/// GLOBAL reduced Schur `schur` (`S`, `k×k`, PD), and the per-atom `D`, this
2767/// forms the EXACT projected `M = Uᵀ H₀'⁻¹ U = M0 + Wᵀ S⁻¹ W` (the bordered
2768/// inverse `t`-block `(H₀'⁻¹)_tt = A⁻¹ + A⁻¹ B S⁻¹ Bᵀ A⁻¹` contracted by `U`),
2769/// then the capacitance `C = I_R + D·M` and returns `log det C`.
2770///
2771/// This is byte-for-byte the same quantity, and the same sign convention, that
2772/// the dense [`CrossRowWoodbury::log_det`] returns (symmetrize `M`, scale rows
2773/// by `D`, partial-pivot LU, reject a non-positive determinant). Returns
2774/// `Ok(None)` when the capacitance is non-PD / singular — the recoverable
2775/// "cross-row IBP joint Hessian is non-PD at this ρ" infeasible-probe refusal,
2776/// NOT a silent wrong value.
2777pub fn streaming_cross_row_woodbury_log_det(
2778    schur: &Array2<f64>,
2779    m0: &Array2<f64>,
2780    w: &Array2<f64>,
2781    d: &Array1<f64>,
2782) -> Result<Option<f64>, ArrowSchurError> {
2783    let r = d.len();
2784    let factor =
2785        cholesky_lower(schur).map_err(|reason| ArrowSchurError::SchurFactorFailed { reason })?;
2786    // M = M0 + Wᵀ S⁻¹ W. With S⁻¹ symmetric, `(Wᵀ S⁻¹ W)[a, b] = (S⁻¹ w_a)·w_b`.
2787    let mut m = m0.clone();
2788    for a in 0..r {
2789        let w_a = w.column(a).to_owned();
2790        let sinv_w_a = cholesky_solve_vector(&factor, &w_a);
2791        for b in 0..r {
2792            m[[a, b]] += sinv_w_a.dot(&w.column(b));
2793        }
2794    }
2795    // Symmetrize to clear back-substitution rounding asymmetry (matches
2796    // `CrossRowWoodbury::build`).
2797    for a in 0..r {
2798        for b in (a + 1)..r {
2799            let avg = 0.5 * (m[[a, b]] + m[[b, a]]);
2800            m[[a, b]] = avg;
2801            m[[b, a]] = avg;
2802        }
2803    }
2804    // Capacitance C = I_R + D·M (row k scaled by d_k), factored by partial-pivot
2805    // LU (`d_k = w·s'_k` is not sign-definite, so C is generally non-symmetric /
2806    // indefinite — exactly the dense carrier's factorization).
2807    let mut c = Array2::<f64>::zeros((r, r));
2808    for a in 0..r {
2809        for b in 0..r {
2810            c[[a, b]] = d[a] * m[[a, b]];
2811        }
2812        c[[a, a]] += 1.0;
2813    }
2814    match small_lu_factor(&c) {
2815        Some(lu) => {
2816            let (log_abs, sign) = lu.log_abs_det_and_sign();
2817            Ok((sign > 0.0).then_some(log_abs))
2818        }
2819        // Exactly-singular capacitance: `H_full` det is 0, the Laplace log-det is
2820        // undefined — surface as the recoverable non-PD probe refusal.
2821        None => Ok(None),
2822    }
2823}
2824
2825impl CrossRowWoodbury {
2826    /// Build the exact rank-`R` cross-row Woodbury carrier from the IBP source
2827    /// and a cache whose per-row factors are the NO-SELF base `H₀'`.
2828    ///
2829    /// Computes `H₀'⁻¹U` (one [`ArrowFactorCache::full_inverse_apply`] back-solve
2830    /// per column, β-RHS zero — the `t`-block of the result is `H₀'⁻¹U`'s
2831    /// column), `M = UᵀH₀'⁻¹U`, and the LU of `C = I_R + D·M`. Returns `None`
2832    /// when the capacitance is exactly singular (the only un-representable case;
2833    /// the caller then proceeds with the bare `H₀'` cache and the cross-row term
2834    /// is absent — never silently inconsistent, since logdet/inverse/adjoint all
2835    /// key off the presence of this carrier).
2836    pub(crate) fn build(
2837        cache: &ArrowFactorCache,
2838        source: &IbpCrossRowSource,
2839    ) -> Result<Option<Self>, ArrowSchurError> {
2840        let r = source.r;
2841        let total_len = cache.delta_t_len();
2842        let u = source.dense_u(total_len);
2843        let d = source.d.clone();
2844        let zero_beta = Array1::<f64>::zeros(cache.k);
2845        // h0inv_u[:, k] = (H₀'⁻¹ U)_t for column k; h0inv_u_beta[:, k] its β-block.
2846        let mut h0inv_u = Array2::<f64>::zeros((total_len, r));
2847        let mut h0inv_u_beta = Array2::<f64>::zeros((cache.k, r));
2848        for k in 0..r {
2849            let col = u.column(k).to_owned();
2850            let (sol_t, sol_beta) = cache.full_inverse_apply(col.view(), zero_beta.view())?;
2851            for g in 0..total_len {
2852                h0inv_u[[g, k]] = sol_t[g];
2853            }
2854            for c in 0..cache.k {
2855                h0inv_u_beta[[c, k]] = sol_beta[c];
2856            }
2857        }
2858        // M = Uᵀ (H₀'⁻¹ U), symmetric R×R. U is sparse (atom-slot supported), so
2859        // contract over the listed entries.
2860        let mut m = Array2::<f64>::zeros((r, r));
2861        for a in 0..r {
2862            for b in 0..r {
2863                let mut acc = 0.0_f64;
2864                for &(g, k, z) in &source.entries {
2865                    if k == a {
2866                        acc += z * h0inv_u[[g, b]];
2867                    }
2868                }
2869                m[[a, b]] = acc;
2870            }
2871        }
2872        // Symmetrize M to clear back-substitution rounding asymmetry.
2873        for a in 0..r {
2874            for b in (a + 1)..r {
2875                let avg = 0.5 * (m[[a, b]] + m[[b, a]]);
2876                m[[a, b]] = avg;
2877                m[[b, a]] = avg;
2878            }
2879        }
2880        // Capacitance C = I_R + D·M (row k scaled by d_k).
2881        let mut c = Array2::<f64>::zeros((r, r));
2882        for a in 0..r {
2883            for b in 0..r {
2884                c[[a, b]] = d[a] * m[[a, b]];
2885            }
2886            c[[a, a]] += 1.0;
2887        }
2888        let Some(capacitance_lu) = small_lu_factor(&c) else {
2889            return Ok(None);
2890        };
2891        Ok(Some(Self {
2892            u,
2893            d,
2894            h0inv_u,
2895            h0inv_u_beta,
2896            m,
2897            capacitance_lu,
2898            entries: source.entries.clone(),
2899        }))
2900    }
2901
2902    /// The sparse `U` entry list `(global_t_index, atom_k, z'_ik)`.
2903    pub(crate) fn source_entries(&self) -> &[(usize, usize, f64)] {
2904        &self.entries
2905    }
2906
2907    /// `C⁻¹ D` as a dense `R × R` matrix (`R` capacitance solves; column `l` is
2908    /// `d_l · C⁻¹ e_l`). Shared by the inverse-diagonal correction and any
2909    /// adjoint trace that needs the selected inverse of the capacitance.
2910    ///
2911    /// Returns `None` when any capacitance solve fails to produce a finite
2912    /// result (#1038); the consumer must surface this as a loud failure rather
2913    /// than propagate a `NaN` into the evidence/gradient.
2914    pub fn capacitance_inv_times_d(&self) -> Option<Array2<f64>> {
2915        let r = self.d.len();
2916        let mut out = Array2::<f64>::zeros((r, r));
2917        let mut e_l = Array1::<f64>::zeros(r);
2918        for l in 0..r {
2919            e_l.fill(0.0);
2920            e_l[l] = 1.0;
2921            let col = self.capacitance_lu.solve(&e_l)?;
2922            for k in 0..r {
2923                out[[k, l]] = col[k] * self.d[l];
2924            }
2925        }
2926        Some(out)
2927    }
2928
2929    /// Subtract the rank-`R` Woodbury term from the latent inverse diagonal:
2930    /// `diag ← diag − diag(H₀'⁻¹U C⁻¹ D Uᵀ H₀'⁻¹)`. With `G = h0inv_u` and
2931    /// `(C⁻¹D) = capacitance_inv_times_d()`, the entry at global index `g` is
2932    /// `Σ_{k,l} G[g,k] (C⁻¹D)[k,l] G[g,l]`.
2933    pub(crate) fn subtract_inverse_diagonal(
2934        &self,
2935        diag: &mut Array1<f64>,
2936    ) -> Result<(), ArrowSchurError> {
2937        let r = self.d.len();
2938        let cinv_d =
2939            self.capacitance_inv_times_d()
2940                .ok_or_else(|| ArrowSchurError::SchurFactorFailed {
2941                    reason: "cross-row Woodbury capacitance solve produced a non-finite \
2942                         C⁻¹D for the inverse-diagonal correction (#1038): \
2943                         singular/ill-conditioned cross-row capacitance"
2944                        .to_string(),
2945                })?;
2946        let total_len = self.h0inv_u.nrows();
2947        for g in 0..total_len {
2948            let mut acc = 0.0_f64;
2949            for k in 0..r {
2950                let gk = self.h0inv_u[[g, k]];
2951                if gk == 0.0 {
2952                    continue;
2953                }
2954                for l in 0..r {
2955                    acc += gk * cinv_d[[k, l]] * self.h0inv_u[[g, l]];
2956                }
2957            }
2958            diag[g] -= acc;
2959        }
2960        Ok(())
2961    }
2962
2963    /// `log det(I_R + D·M)` (the matrix-determinant-lemma correction). Returns
2964    /// `None` when the capacitance LU has a negative determinant — i.e. the
2965    /// implied `H_full` is non-PD, which is a desync the evidence must reject
2966    /// loudly rather than return a complex/`NaN` log-det.
2967    pub fn log_det(&self) -> Option<f64> {
2968        let (log_abs, sign) = self.log_det_correction();
2969        if sign > 0.0 { Some(log_abs) } else { None }
2970    }
2971
2972    /// `log det(I_R + D·M)`: the exact additive correction
2973    /// `log det H_full − log det H₀'` (matrix-determinant lemma). For a genuine
2974    /// PD `H_full` this is real; the LU sign is returned for the caller to
2975    /// surface a non-PD capacitance as an error rather than a silent `NaN`.
2976    pub(crate) fn log_det_correction(&self) -> (f64, f64) {
2977        self.capacitance_lu.log_abs_det_and_sign()
2978    }
2979
2980    /// Apply the rank-`R` inverse correction in place on BOTH arrow blocks:
2981    /// `u ← u − (H₀'⁻¹U) · C⁻¹ · (D Uᵀ (H₀'⁻¹ rhs)_t)`, where `h0inv_rhs_t` is
2982    /// the `t`-block of `H₀'⁻¹ rhs` already computed by the base
2983    /// [`ArrowFactorCache::full_inverse_apply`]. Implements the Woodbury
2984    /// identity `H_full⁻¹ = H₀'⁻¹ − H₀'⁻¹U C⁻¹ D Uᵀ H₀'⁻¹`. `U` has no `β`
2985    /// support so `Uᵀ·v` reads only the `t`-block, but `H₀'⁻¹U` couples to `β`
2986    /// through the Schur complement, so the correction touches `u_beta` too.
2987    ///
2988    /// `entries` lets `Uᵀ·v` be formed over the sparse atom slots.
2989    pub(crate) fn apply_inverse_correction(
2990        &self,
2991        h0inv_rhs_t: ArrayView1<'_, f64>,
2992        entries: &[(usize, usize, f64)],
2993        u_t: &mut Array1<f64>,
2994        u_beta: &mut Array1<f64>,
2995    ) -> Result<(), ArrowSchurError> {
2996        let r = self.d.len();
2997        // p = D Uᵀ (H₀'⁻¹ rhs)_t.
2998        let mut p = Array1::<f64>::zeros(r);
2999        for &(g, k, z) in entries {
3000            p[k] += z * h0inv_rhs_t[g];
3001        }
3002        for k in 0..r {
3003            p[k] *= self.d[k];
3004        }
3005        // q = C⁻¹ p. A non-finite solve is a singular/ill-conditioned cross-row
3006        // capacitance (#1038): fail loudly rather than write `NaN` into the
3007        // Newton step / adjoint solve.
3008        let q =
3009            self.capacitance_lu
3010                .solve(&p)
3011                .ok_or_else(|| ArrowSchurError::SchurFactorFailed {
3012                    reason: "cross-row Woodbury capacitance solve produced a non-finite \
3013                         C⁻¹p for the inverse correction (#1038): \
3014                         singular/ill-conditioned cross-row capacitance"
3015                        .to_string(),
3016                })?;
3017        // u_t -= (H₀'⁻¹U)_t · q.
3018        for g in 0..u_t.len() {
3019            let mut acc = 0.0_f64;
3020            for k in 0..r {
3021                acc += self.h0inv_u[[g, k]] * q[k];
3022            }
3023            u_t[g] -= acc;
3024        }
3025        // u_beta -= (H₀'⁻¹U)_β · q.
3026        for c in 0..u_beta.len() {
3027            let mut acc = 0.0_f64;
3028            for k in 0..r {
3029                acc += self.h0inv_u_beta[[c, k]] * q[k];
3030            }
3031            u_beta[c] -= acc;
3032        }
3033        Ok(())
3034    }
3035
3036    /// Forward apply of the rank-`R` cross-row curvature on the latent (`t`)
3037    /// block: `out_t += U D Uᵀ v_t`. This is the EXACT forward of the same
3038    /// `H_full = H₀' + U D Uᵀ` whose inverse [`Self::apply_inverse_correction`]
3039    /// applies and whose log-determinant [`Self::log_det_correction`] reports, so
3040    /// a forward Hessian apply that adds this term stays consistent (operator and
3041    /// preconditioner describe the SAME operator). `U` has no `β` support, so the
3042    /// forward term touches only the `t` block; the `β` coupling is purely an
3043    /// inverse-side Schur artifact (`h0inv_u_beta`) and must NOT appear here.
3044    ///
3045    /// Formed over the sparse `entries` `(global_t_index, atom_k, z'_ik)` so it
3046    /// matches `Uᵀ·v` in `apply_inverse_correction` bit-for-bit:
3047    /// `p_k = d_k · Σ_{(g,k,z)} z·v_t[g]`, then `out_t[g] += Σ_{(g,k,z)} z·p_k`.
3048    pub fn apply_forward_t(&self, v_t: ArrayView1<'_, f64>, out_t: &mut Array1<f64>) {
3049        let r = self.d.len();
3050        // p = D Uᵀ v_t.
3051        let mut p = Array1::<f64>::zeros(r);
3052        for &(g, k, z) in &self.entries {
3053            p[k] += z * v_t[g];
3054        }
3055        for k in 0..r {
3056            p[k] *= self.d[k];
3057        }
3058        // out_t += U p.
3059        for &(g, k, z) in &self.entries {
3060            out_t[g] += z * p[k];
3061        }
3062    }
3063}
3064
3065#[derive(Debug, Clone, Copy, PartialEq)]
3066pub struct ArrowFactorMinPivot {
3067    pub min_row_pivot: Option<f64>,
3068    pub min_schur_pivot: Option<f64>,
3069    pub min_pivot: Option<f64>,
3070}
3071
3072impl ArrowFactorMinPivot {
3073    pub(crate) fn combine(row: Option<f64>, schur: Option<f64>) -> Self {
3074        let min_pivot = match (row, schur) {
3075            (Some(a), Some(b)) => Some(a.min(b)),
3076            (Some(a), None) => Some(a),
3077            (None, Some(b)) => Some(b),
3078            (None, None) => None,
3079        };
3080        Self {
3081            min_row_pivot: row,
3082            min_schur_pivot: schur,
3083            min_pivot,
3084        }
3085    }
3086}
3087
3088pub(crate) fn lower_cholesky_min_pivot(factor: ArrayView2<'_, f64>) -> Option<f64> {
3089    let width = factor.nrows().min(factor.ncols());
3090    let mut out = None;
3091    for idx in 0..width {
3092        let pivot = factor[[idx, idx]] * factor[[idx, idx]];
3093        out = Some(match out {
3094            Some(current) => f64::min(current, pivot),
3095            None => pivot,
3096        });
3097    }
3098    out
3099}
3100
3101pub(crate) fn lower_cholesky_max_pivot(factor: ArrayView2<'_, f64>) -> Option<f64> {
3102    let width = factor.nrows().min(factor.ncols());
3103    let mut out = None;
3104    for idx in 0..width {
3105        let pivot = factor[[idx, idx]] * factor[[idx, idx]];
3106        out = Some(match out {
3107            Some(current) => f64::max(current, pivot),
3108            None => pivot,
3109        });
3110    }
3111    out
3112}
3113
3114/// Smallest cached Cholesky pivot for row blocks and the dense Schur factor.
3115///
3116/// Pivots are returned as squared lower-factor diagonals, matching the Hessian
3117/// scale rather than the Cholesky-factor scale. In inexact PCG mode the dense
3118/// Schur factor is absent, so `min_schur_pivot` is `None`.
3119pub fn arrow_factor_min_pivot(cache: &ArrowFactorCache) -> ArrowFactorMinPivot {
3120    let mut min_row_pivot = None;
3121    for factor in cache.htt_factors.iter() {
3122        if let Some(pivot) = lower_cholesky_min_pivot(factor) {
3123            min_row_pivot = Some(match min_row_pivot {
3124                Some(current) => f64::min(current, pivot),
3125                None => pivot,
3126            });
3127        }
3128    }
3129    let min_schur_pivot = cache
3130        .schur_factor
3131        .as_ref()
3132        .and_then(|factor| lower_cholesky_min_pivot(factor.view()));
3133    ArrowFactorMinPivot::combine(min_row_pivot, min_schur_pivot)
3134}
3135
3136/// Largest cached Cholesky pivot across the row blocks and the dense Schur
3137/// factor (Hessian scale, i.e. squared lower-factor diagonal). This is the
3138/// diagonal magnitude scale a safe-SPD pivot floor is measured against: the
3139/// curvature-homotopy tracker (#1007) compares the min pivot against
3140/// `√eps · max(this, 1)`, the same floor the inner solver's
3141/// [`safe_spd_pivot_min`] uses. `None` only for an empty cache.
3142pub fn arrow_factor_max_pivot(cache: &ArrowFactorCache) -> Option<f64> {
3143    let mut max_pivot: Option<f64> = None;
3144    for factor in cache.htt_factors.iter() {
3145        if let Some(pivot) = lower_cholesky_max_pivot(factor) {
3146            max_pivot = Some(match max_pivot {
3147                Some(current) => f64::max(current, pivot),
3148                None => pivot,
3149            });
3150        }
3151    }
3152    if let Some(factor) = cache.schur_factor.as_ref()
3153        && let Some(pivot) = lower_cholesky_max_pivot(factor.view())
3154    {
3155        max_pivot = Some(match max_pivot {
3156            Some(current) => f64::max(current, pivot),
3157            None => pivot,
3158        });
3159    }
3160    max_pivot
3161}
3162
3163impl ArrowFactorCache {
3164    pub fn n_rows(&self) -> usize {
3165        self.htt_factors.len()
3166    }
3167
3168    pub fn htbeta_available(&self) -> bool {
3169        self.htbeta.is_available()
3170    }
3171
3172    /// Whether the Newton solve that produced this cache actually executed on
3173    /// the device: the device-resident Direct dense solve or the device-resident
3174    /// matrix-free SAE PCG (whose matvec runs in CUDA kernels). This does NOT
3175    /// include the injected host-procedural reduced-Schur matvec, whose
3176    /// arithmetic runs on the CPU even when a CUDA context was opened to build
3177    /// per-row factors (#1209) — that path sets
3178    /// `PcgDiagnostics::injected_host_procedural_matvec` instead. Read-only
3179    /// routing provenance: lets a fit result record device-vs-CPU as ground
3180    /// truth instead of inferring it from the runtime probe. Mirrors
3181    /// `PcgDiagnostics::used_device_arrow`.
3182    #[must_use]
3183    pub fn used_device(&self) -> bool {
3184        self.pcg_diagnostics.used_device_arrow
3185    }
3186
3187    pub fn undamped_factor(&self, row: usize) -> ArrayView2<'_, f64> {
3188        match &self.htt_factors_undamped {
3189            ArrowUndampedFactors::SameAsDamped => self.htt_factors.factor(row),
3190            ArrowUndampedFactors::Owned(factors) => factors.factor(row),
3191        }
3192    }
3193
3194    pub fn undamped_factor_count(&self) -> usize {
3195        match &self.htt_factors_undamped {
3196            ArrowUndampedFactors::SameAsDamped => self.htt_factors.len(),
3197            ArrowUndampedFactors::Owned(factors) => factors.len(),
3198        }
3199    }
3200
3201    pub fn undamped_factors_iter(&self) -> impl Iterator<Item = ArrayView2<'_, f64>> + '_ {
3202        (0..self.undamped_factor_count()).map(|row| self.undamped_factor(row))
3203    }
3204
3205    pub fn compute_undamped_arrow_log_det(&self) -> Option<f64> {
3206        if self.ridge_t != 0.0 || self.ridge_beta != 0.0 {
3207            return None;
3208        }
3209        // When the shared β block is empty (`k == 0`) the joint Hessian is
3210        // exactly the block diagonal of the per-row latent blocks: there is no
3211        // reduced Schur complement to form, so the dense Direct path leaves
3212        // `schur_factor = None` legitimately (not the InexactPCG "never formed
3213        // the dense K×K factor" case, which has `k > 0`). The log-det is then
3214        // the per-row sum with a zero (empty `0×0`) Schur contribution. Without
3215        // this the `schur_factor.as_ref()?` below would return `None` for a
3216        // β-profiled atom (#1132 euclidean K=4) and starve the REML Laplace
3217        // normaliser of the joint Hessian log-det it requires.
3218        let schur = match self.schur_factor.as_ref() {
3219            Some(schur) => Some(schur),
3220            None if self.k == 0 => None,
3221            None => return None,
3222        };
3223
3224        let mut acc = 0.0_f64;
3225        for l in self.undamped_factors_iter() {
3226            for i in 0..l.nrows() {
3227                let d = l[[i, i]];
3228                if d <= 0.0 || !d.is_finite() {
3229                    return None;
3230                }
3231                acc += 2.0 * d.ln();
3232            }
3233        }
3234        if let Some(schur) = schur {
3235            for i in 0..schur.nrows() {
3236                let d = schur[[i, i]];
3237                if d <= 0.0 || !d.is_finite() {
3238                    return None;
3239                }
3240                acc += 2.0 * d.ln();
3241            }
3242        }
3243        let woodbury_correction = self.cross_row_woodbury_log_det();
3244        if !woodbury_correction.is_finite() {
3245            return None;
3246        }
3247        Some(acc + woodbury_correction)
3248    }
3249
3250    /// The total length of `delta_t` / IFT output vectors for this cache.
3251    pub fn delta_t_len(&self) -> usize {
3252        self.row_offsets[self.n_rows()]
3253    }
3254
3255    pub fn apply_htbeta_row(
3256        &self,
3257        row: usize,
3258        delta_beta: ArrayView1<'_, f64>,
3259        out: &mut Array1<f64>,
3260    ) -> bool {
3261        let di = if row < self.row_dims.len() {
3262            self.row_dims[row]
3263        } else {
3264            self.d
3265        };
3266        if out.len() != di || delta_beta.len() != self.k {
3267            return false;
3268        }
3269        self.htbeta.apply_row(row, delta_beta, out)
3270    }
3271
3272    /// Accumulate `out[a] += H_βt^(row)[a, :] · v` for all `a in 0..k`.
3273    ///
3274    /// `v` has length `row_dims[row]`; `out` has length `k`. The caller must
3275    /// zero `out` before the first call if it needs a fresh result.  Returns
3276    /// `false` when the cache is `Disabled` and no `fallback_op` is provided;
3277    /// callers must treat the accumulator as invalid in that case.
3278    pub fn apply_htbeta_row_transpose(
3279        &self,
3280        row: usize,
3281        v: ArrayView1<'_, f64>,
3282        out: &mut Array1<f64>,
3283        fallback_op: Option<&RowHtbetaMatvec>,
3284    ) -> bool {
3285        let di = if row < self.row_dims.len() {
3286            self.row_dims[row]
3287        } else {
3288            self.d
3289        };
3290        if v.len() != di || out.len() != self.k {
3291            return false;
3292        }
3293        self.htbeta
3294            .apply_row_transpose_accumulate(row, v, out, di, self.k, fallback_op)
3295    }
3296
3297    /// Arrow log-determinant
3298    /// `log|H| = Σ_i log|H_{t_i t_i}| + log|Schur_β|`
3299    /// using the cached (damped) factors.
3300    ///
3301    /// Returns `(log_det_tt_sum, log_det_schur)` so the caller can decide
3302    /// what to do with the Schur piece (e.g. REML evidence wants both;
3303    /// some diagnostics want only the per-row sum). `None` for the Schur
3304    /// piece signals that the cache was produced by an InexactPCG solve
3305    /// and never formed/factored the dense `K × K` reduced system.
3306    ///
3307    /// The log-determinant of a Cholesky factor `L` of `M` is
3308    /// `2 Σ log L_ii`.
3309    pub fn arrow_log_det(&self) -> (f64, Option<f64>) {
3310        let mut log_det_tt = 0.0_f64;
3311        for l in self.htt_factors.iter() {
3312            for i in 0..l.nrows() {
3313                log_det_tt += l[[i, i]].ln();
3314            }
3315        }
3316        log_det_tt *= 2.0;
3317        let log_det_schur = self.schur_factor.as_ref().map(|l| {
3318            let mut s = 0.0_f64;
3319            for i in 0..l.nrows() {
3320                s += l[[i, i]].ln();
3321            }
3322            2.0 * s + self.cross_row_woodbury_log_det()
3323        });
3324        (log_det_tt, log_det_schur)
3325    }
3326
3327    /// The exact cross-row IBP correction `log det(I_R + D·M)` to add to the
3328    /// base `log det H₀'` (#1038). Zero when no [`CrossRowWoodbury`] is present,
3329    /// so non-IBP caches are unaffected. The determinant lemma gives
3330    /// `log det H_full = log det H₀' + log det(I_R + D Uᵀ H₀'⁻¹ U)`; this is the
3331    /// second term, the only piece beyond the bare arrow log-determinant.
3332    ///
3333    /// Panics-free: a negative capacitance determinant (non-PD `H_full`) yields
3334    /// `NaN` here so the evidence surfaces the desync rather than silently
3335    /// dropping the imaginary part. Callers that must reject it should check
3336    /// [`CrossRowWoodbury::log_det`] directly.
3337    pub fn cross_row_woodbury_log_det(&self) -> f64 {
3338        match self.cross_row_woodbury.as_ref() {
3339            Some(w) => w.log_det().unwrap_or(f64::NAN),
3340            None => 0.0,
3341        }
3342    }
3343
3344    /// Diagonal of the latent (`t`-block) of the *full* bordered-arrow
3345    /// inverse `(H⁻¹)_tt`, in `delta_t` layout (length [`Self::delta_t_len`]).
3346    ///
3347    /// For the bordered arrow Hessian
3348    /// `H = [[A, B], [Bᵀ, H_ββ]]` with `A = H_tt` (block-diagonal per row,
3349    /// `A_i = H_tt^(i)`) and `B = H_tβ`, the standard block-inverse identity
3350    /// gives the `t`-block
3351    /// `(H⁻¹)_tt = A⁻¹ + A⁻¹ B S⁻¹ Bᵀ A⁻¹`, where
3352    /// `S = H_ββ − Bᵀ A⁻¹ B` is the Schur complement on `β`. Because `A` is
3353    /// block-diagonal, the `(i, j)` diagonal entry of `(H⁻¹)_tt` is computed
3354    /// purely from row `i`'s factor and cross-block:
3355    ///
3356    /// ```text
3357    /// a    = A_i⁻¹ e_j                       (chol_solve on the per-row factor)
3358    /// [A_i⁻¹]_{jj} = a[j]
3359    /// w    = B_iᵀ a = H_βt^(i) a             (a K-vector)
3360    /// z    = S⁻¹ w                           (chol_solve on the Schur factor)
3361    /// diag = a[j] + w · z
3362    /// ```
3363    ///
3364    /// The UNDAMPED per-row factors ([`Self::undamped_factor`]) are used so
3365    /// the result is the inverse of the *true* `H_tt`, not the LM-damped
3366    /// `H_tt + ridge_t·I` — same rationale the IFT predictor docstring gives
3367    /// at the top of this struct.
3368    ///
3369    /// # Consuming the diagonal as a per-(atom, axis) trace
3370    ///
3371    /// `(H⁻¹)_tt` is the latent covariance block. The selected-inverse trace
3372    /// for a contiguous group of latent coordinates (e.g. one atom's rows, or
3373    /// one axis across rows) is simply the sum of the returned diagonal entries
3374    /// over those `row_offsets[i] + j` indices — no off-diagonal terms are
3375    /// needed for the trace `tr[(H⁻¹)_tt · D]` against a diagonal selector `D`.
3376    ///
3377    /// # Errors
3378    ///
3379    /// Returns [`ArrowSchurError::SchurFactorFailed`] when this cache has no
3380    /// dense Schur factor or no usable `H_βt` coupling — i.e. it was produced
3381    /// by an [`ArrowSolverMode::InexactPCG`] solve (no dense `K × K` factor) or
3382    /// by a `Disabled` `htbeta` cache. The selected-inverse block-trace is not
3383    /// yet supported for the matrix-free PCG mode; that branch needs a separate
3384    /// Lanczos/Hutchinson estimator.
3385    pub fn latent_block_inverse_diagonal(&self) -> Result<Array1<f64>, ArrowSchurError> {
3386        let Some(schur_factor) = self.schur_factor.as_ref() else {
3387            return Err(ArrowSchurError::SchurFactorFailed {
3388                reason: "latent_block_inverse_diagonal requires a dense Schur factor; \
3389                         the InexactPCG mode does not form one"
3390                    .to_string(),
3391            });
3392        };
3393        if !self.htbeta_available() {
3394            return Err(ArrowSchurError::SchurFactorFailed {
3395                reason: "latent_block_inverse_diagonal requires the H_tβ coupling, \
3396                         but this cache's htbeta is Disabled"
3397                    .to_string(),
3398            });
3399        }
3400        let n = self.undamped_factor_count();
3401        let total_len = self.delta_t_len();
3402        let mut out = Array1::<f64>::zeros(total_len);
3403        // Per-row scratch, sized to the max latent dim / K.
3404        let mut e_j = Array1::<f64>::zeros(self.d);
3405        let mut w = Array1::<f64>::zeros(self.k);
3406        for i in 0..n {
3407            let di = self.row_dims[i];
3408            let row_base = self.row_offsets[i];
3409            let factor = self.undamped_factor(i);
3410            for j in 0..di {
3411                // a = A_i⁻¹ e_j.
3412                for c in 0..di {
3413                    e_j[c] = 0.0;
3414                }
3415                e_j[j] = 1.0;
3416                let e_j_slice = e_j.slice(ndarray::s![..di]).to_owned();
3417                let a = cholesky_solve_vector(factor, &e_j_slice);
3418                // w = H_βt^(i) a (a K-vector); accumulator must start zeroed.
3419                w.fill(0.0);
3420                if !self.apply_htbeta_row_transpose(i, a.view(), &mut w, None) {
3421                    return Err(ArrowSchurError::SchurFactorFailed {
3422                        reason: format!(
3423                            "latent_block_inverse_diagonal: H_βt^({i}) apply failed \
3424                             (htbeta cache could not supply row {i})"
3425                        ),
3426                    });
3427                }
3428                // z = S⁻¹ w; correction = w · z.
3429                let z = cholesky_solve_vector(schur_factor, &w);
3430                let mut corr = 0.0_f64;
3431                for c in 0..self.k {
3432                    corr += w[c] * z[c];
3433                }
3434                out[row_base + j] = a[j] + corr;
3435            }
3436        }
3437        if let Some(woodbury) = self.cross_row_woodbury.as_ref() {
3438            // #1038: the factors above are `H₀'`, so `out` is diag((H₀'⁻¹)_tt).
3439            // The full inverse diagonal subtracts the rank-`R` Woodbury term
3440            // diag(H₀'⁻¹U C⁻¹ D Uᵀ H₀'⁻¹). With `G = h0inv_u = (H₀'⁻¹U)_t` and
3441            // (by symmetry of `H₀'⁻¹`) `(Uᵀ H₀'⁻¹)_t = Gᵀ`, the diagonal entry at
3442            // global index `g` is `Σ_{k,l} G[g,k] (C⁻¹D)[k,l] G[g,l]`. Form the
3443            // `R×R` matrix `C⁻¹D` once (R solves), then contract per row index.
3444            woodbury.subtract_inverse_diagonal(&mut out)?;
3445        }
3446        Ok(out)
3447    }
3448
3449    /// Solve the full bordered-arrow system `H·u = w` on the cached factor
3450    /// (#1006): `w` arrives in arrow layout — `w_t` flat per
3451    /// [`Self::delta_t_len`] / `row_offsets`, `w_beta` of length `K` — and the
3452    /// solution comes back in the same layout. Standard block elimination on
3453    /// the SAME factors whose log-determinant the evidence reports:
3454    ///
3455    /// ```text
3456    ///   y_i      = H_tt^(i)⁻¹ · w_t^(i)
3457    ///   r_β      = w_β − Σ_i H_βt^(i) · y_i
3458    ///   u_β      = Schur⁻¹ · r_β
3459    ///   u_t^(i)  = y_i − H_tt^(i)⁻¹ · (H_tβ^(i) · u_β)
3460    /// ```
3461    ///
3462    /// This is the IFT / adjoint back-solve the analytic outer ρ-gradient
3463    /// consumes: `u_j = H⁻¹ (∂g/∂ρ_j)` per outer coordinate and the
3464    /// `H⁻¹`-side of the third-order correction `−½·Γᵀ·H⁻¹·(∂g/∂ρ_j)`.
3465    /// Contract: the cache must be the ridge-0 Direct evidence factor
3466    /// (undamped per-row factors + dense Schur), so the solve is against the
3467    /// criterion's own `H` — never a damped surrogate (that would desync the
3468    /// gradient from the reported evidence).
3469    ///
3470    /// When the cache carries an exact cross-row IBP
3471    /// [`CrossRowWoodbury`] (#1038), the per-row factors are the NO-SELF base
3472    /// `H₀'` and this method layers the rank-`R` Woodbury correction so the
3473    /// returned solve is against the FULL `H_full = H₀' + U D Uᵀ` — the same
3474    /// operator whose log-determinant [`Self::arrow_log_det`] reports. The
3475    /// θ/ρ-adjoint that consumes this therefore sees the cross-row curvature.
3476    pub fn full_inverse_apply(
3477        &self,
3478        w_t: ArrayView1<'_, f64>,
3479        w_beta: ArrayView1<'_, f64>,
3480    ) -> Result<(Array1<f64>, Array1<f64>), ArrowSchurError> {
3481        let (mut u_t, mut u_beta) = self.full_inverse_apply_base(w_t, w_beta)?;
3482        if let Some(woodbury) = self.cross_row_woodbury.as_ref() {
3483            // u ← u − H₀'⁻¹U C⁻¹ D Uᵀ u. `u_t` is the `t`-block of `H₀'⁻¹ w`.
3484            let h0inv_w_t = u_t.clone();
3485            woodbury.apply_inverse_correction(
3486                h0inv_w_t.view(),
3487                woodbury.source_entries(),
3488                &mut u_t,
3489                &mut u_beta,
3490            )?;
3491        }
3492        Ok((u_t, u_beta))
3493    }
3494
3495    /// Bare bordered-arrow inverse solve against the cached per-row factors and
3496    /// Schur factor (the NO-SELF base `H₀'` when a cross-row Woodbury is
3497    /// present). [`Self::full_inverse_apply`] wraps this with the rank-`R`
3498    /// correction; [`CrossRowWoodbury::build`] calls this directly (before the
3499    /// carrier exists) to form `H₀'⁻¹U`.
3500    pub(crate) fn full_inverse_apply_base(
3501        &self,
3502        w_t: ArrayView1<'_, f64>,
3503        w_beta: ArrayView1<'_, f64>,
3504    ) -> Result<(Array1<f64>, Array1<f64>), ArrowSchurError> {
3505        let total_len = self.delta_t_len();
3506        if w_t.len() != total_len || w_beta.len() != self.k {
3507            return Err(ArrowSchurError::SchurFactorFailed {
3508                reason: format!(
3509                    "full_inverse_apply: rhs shapes (w_t={}, w_beta={}) != (delta_t_len={}, K={})",
3510                    w_t.len(),
3511                    w_beta.len(),
3512                    total_len,
3513                    self.k
3514                ),
3515            });
3516        }
3517        let n = self.undamped_factor_count();
3518        // Forward pass: y_i = H_tt^(i)⁻¹ w_t^(i), accumulating the border RHS.
3519        let mut y = Array1::<f64>::zeros(total_len);
3520        let mut r_beta = w_beta.to_owned();
3521        for i in 0..n {
3522            let di = self.row_dims[i];
3523            let base = self.row_offsets[i];
3524            let factor = self.undamped_factor(i);
3525            let w_row = w_t.slice(ndarray::s![base..base + di]).to_owned();
3526            let y_row = cholesky_solve_vector(factor, &w_row);
3527            if self.k > 0 {
3528                // r_β −= H_βt^(i) y_i: accumulate into a scratch then subtract,
3529                // because the helper ACCUMULATES (+=) into its output.
3530                let mut acc = Array1::<f64>::zeros(self.k);
3531                if !self.apply_htbeta_row_transpose(i, y_row.view(), &mut acc, None) {
3532                    return Err(ArrowSchurError::SchurFactorFailed {
3533                        reason: format!(
3534                            "full_inverse_apply: H_βt^({i}) apply failed (htbeta cache \
3535                             could not supply row {i}; htbeta={:?}, di={}, k={})",
3536                            self.htbeta,
3537                            self.row_dims.get(i).copied().unwrap_or(self.d),
3538                            self.k
3539                        ),
3540                    });
3541                }
3542                for c in 0..self.k {
3543                    r_beta[c] -= acc[c];
3544                }
3545            }
3546            for j in 0..di {
3547                y[base + j] = y_row[j];
3548            }
3549        }
3550        // Border solve + back-substitution.
3551        let u_beta = if self.k > 0 {
3552            self.schur_inverse_apply(r_beta.view())?
3553        } else {
3554            Array1::<f64>::zeros(0)
3555        };
3556        let mut u_t = y;
3557        if self.k > 0 {
3558            let mut cross = Array1::<f64>::zeros(self.d);
3559            for i in 0..n {
3560                let di = self.row_dims[i];
3561                let base = self.row_offsets[i];
3562                let mut cross_row = cross.slice_mut(ndarray::s![..di]);
3563                cross_row.fill(0.0);
3564                let mut cross_owned = cross_row.to_owned();
3565                if !self.apply_htbeta_row(i, u_beta.view(), &mut cross_owned) {
3566                    return Err(ArrowSchurError::SchurFactorFailed {
3567                        reason: format!(
3568                            "full_inverse_apply: H_tβ^({i}) apply failed (htbeta cache \
3569                             could not supply row {i})"
3570                        ),
3571                    });
3572                }
3573                let factor = self.undamped_factor(i);
3574                let corr = cholesky_solve_vector(factor, &cross_owned);
3575                for j in 0..di {
3576                    u_t[base + j] -= corr[j];
3577                }
3578            }
3579        }
3580        Ok((u_t, u_beta))
3581    }
3582
3583    /// Apply the β-block of the full inverse, `(H⁻¹)_ββ · rhs = S_β⁻¹ · rhs`,
3584    /// where `S_β` is the Schur complement on β whose Cholesky factor this
3585    /// cache holds in [`Self::schur_factor`].
3586    ///
3587    /// For the bordered arrow Hessian `H = [[A, B], [Bᵀ, H_ββ]]`, the
3588    /// β-block of `H⁻¹` is exactly the inverse of the Schur complement
3589    /// `S_β = H_ββ − Bᵀ A⁻¹ B`. One Cholesky back-substitution per call,
3590    /// reusing the cached factor; `rhs` and the returned vector both have
3591    /// length `K`.
3592    ///
3593    /// This is the general single-solve primitive for the β border. Callers
3594    /// that need a Schur-inverse trace `tr(S_β⁻¹ M)` against a structured
3595    /// penalty `M` (e.g. the SAE λ_smooth Fellner-Schall step, where
3596    /// `M = blockdiag_k(λ_k S_k ⊗ I_p)`) build it as
3597    /// `Σ_col e_colᵀ S_β⁻¹ M e_col` — apply this to each column of `M`
3598    /// (exploiting whatever sparsity `M` has) and read off `result[col]`.
3599    /// Keeping `M`'s layout on the caller side avoids coupling this solver
3600    /// to penalty-op types.
3601    ///
3602    /// # Errors
3603    ///
3604    /// Returns [`ArrowSchurError::SchurFactorFailed`] when this cache has no
3605    /// dense Schur factor (an [`ArrowSolverMode::InexactPCG`] solve) — the
3606    /// same not-yet-supported branch as [`Self::latent_block_inverse_diagonal`]
3607    /// — or when `rhs.len() != k`.
3608    ///
3609    /// Cross-row IBP (#1038) note: this is the β-block primitive of the
3610    /// factored base `S_β` (`H₀'` when a [`CrossRowWoodbury`] is present), used
3611    /// internally by [`Self::full_inverse_apply_base`]; it is deliberately NOT
3612    /// Woodbury-corrected so the base solve stays bare. The cross-row term has
3613    /// no `β` support, so `(H_full⁻¹)_ββ = S_β⁻¹` exactly on the directions any
3614    /// IBP ρ-trace contracts. A consumer needing the full `(H_full⁻¹)_ββ` for a
3615    /// β-supported direction should call [`Self::full_inverse_apply`] with a
3616    /// unit `β`-RHS (which applies the rank-`R` correction).
3617    pub fn schur_inverse_apply(
3618        &self,
3619        rhs: ArrayView1<'_, f64>,
3620    ) -> Result<Array1<f64>, ArrowSchurError> {
3621        let Some(schur_factor) = self.schur_factor.as_ref() else {
3622            return Err(ArrowSchurError::SchurFactorFailed {
3623                reason: "schur_inverse_apply requires a dense Schur factor; \
3624                         the InexactPCG mode does not form one"
3625                    .to_string(),
3626            });
3627        };
3628        if rhs.len() != self.k {
3629            return Err(ArrowSchurError::SchurFactorFailed {
3630                reason: format!(
3631                    "schur_inverse_apply: rhs length {} != K {}",
3632                    rhs.len(),
3633                    self.k
3634                ),
3635            });
3636        }
3637        let rhs_owned = rhs.to_owned();
3638        Ok(cholesky_solve_vector(schur_factor, &rhs_owned))
3639    }
3640
3641    /// Dense principal sub-block of the β-block of the full inverse,
3642    /// `(H⁻¹)_ββ[block, block] = S_β⁻¹[block, block]`, shape `(W, W)` with
3643    /// `W = block.len()`.
3644    ///
3645    /// For the bordered arrow Hessian `H = [[A, B], [Bᵀ, H_ββ]]`, the β-block
3646    /// of `H⁻¹` is exactly `S_β⁻¹` (the inverse of the Schur complement whose
3647    /// Cholesky factor this cache holds). This returns the contiguous
3648    /// `block × block` sub-block — e.g. one SAE atom's decoder coefficients via
3649    /// [`gam_terms::sae::manifold::SaeManifoldTerm::beta_block_offsets`] — by
3650    /// solving `S_β x = e_j` for each `j ∈ block` (reusing the cached factor)
3651    /// and gathering the `block` rows of each solution column. `W`
3652    /// back-substitutions of size `K`; the result is symmetrized to clear
3653    /// back-substitution rounding asymmetry. Up to a dispersion scale `φ`, this
3654    /// block is the joint posterior covariance `Cov(β_block)` of those
3655    /// coefficients with the latent coordinates already marginalized out (that
3656    /// is precisely what Schur-eliminating the per-row `t`-blocks does).
3657    ///
3658    /// Same dense-Schur requirement / error contract as
3659    /// [`Self::schur_inverse_apply`]; additionally errors when `block` runs past
3660    /// `K`.
3661    pub fn schur_inverse_block(
3662        &self,
3663        block: std::ops::Range<usize>,
3664    ) -> Result<Array2<f64>, ArrowSchurError> {
3665        let Some(schur_factor) = self.schur_factor.as_ref() else {
3666            return Err(ArrowSchurError::SchurFactorFailed {
3667                reason: "schur_inverse_block requires a dense Schur factor; \
3668                         the InexactPCG mode does not form one"
3669                    .to_string(),
3670            });
3671        };
3672        if block.end > self.k {
3673            return Err(ArrowSchurError::SchurFactorFailed {
3674                reason: format!(
3675                    "schur_inverse_block: block end {} exceeds K {}",
3676                    block.end, self.k
3677                ),
3678            });
3679        }
3680        let w = block.len();
3681        let mut out = Array2::<f64>::zeros((w, w));
3682        let mut e_j = Array1::<f64>::zeros(self.k);
3683        for (jc, j) in block.clone().enumerate() {
3684            e_j.fill(0.0);
3685            e_j[j] = 1.0;
3686            let col = cholesky_solve_vector(schur_factor, &e_j);
3687            for (ic, i) in block.clone().enumerate() {
3688                out[[ic, jc]] = col[i];
3689            }
3690        }
3691        // S_β⁻¹ is symmetric; symmetrize to clear back-substitution rounding.
3692        for ic in 0..w {
3693            for jc in (ic + 1)..w {
3694                let avg = 0.5 * (out[[ic, jc]] + out[[jc, ic]]);
3695                out[[ic, jc]] = avg;
3696                out[[jc, ic]] = avg;
3697            }
3698        }
3699        Ok(out)
3700    }
3701}