Skip to main content

gam_solve/arrow_schur/
system.rs

1//! The bordered arrow-Schur system itself: [`ArrowRowBlock`], the
2//! [`ArrowSchurSystem`] container and its assembly impl, cross-row latent
3//! penalties, the streaming builder, and the per-row factor caches.
4
5use super::*;
6
7/// Per-row block data for the arrow-Schur system.
8///
9/// `htt` holds the `d × d` Gauss–Newton block for row `i` (including any
10/// analytic-penalty contributions on that row); `htbeta` holds the
11/// `d × K` cross-block `H_tβ^(i)`; `gt` is the `d`-length latent
12/// gradient for row `i`.
13#[derive(Debug, Clone)]
14pub struct ArrowRowBlock {
15    /// `H_tt^(i)`, shape `(d, d)`.
16    pub htt: Array2<f64>,
17    /// `H_tβ^(i)`, shape `(d, K)`.
18    pub htbeta: Array2<f64>,
19    /// `g_t^(i)`, shape `(d,)`.
20    pub gt: Array1<f64>,
21}
22
23impl ArrowRowBlock {
24    /// Allocate one BA point-block row: local latent Hessian, point-camera
25    /// cross block, and point gradient.
26    pub fn new(d: usize, k: usize) -> Self {
27        Self::new_with_htbeta_cols(d, k)
28    }
29
30    /// Allocate one BA row whose dense cross-block slab has `htbeta_cols`
31    /// columns. This is used by matrix-free assemblers that keep the shared
32    /// beta tier at one width while dense row supplements live in another
33    /// coordinate system.
34    pub fn new_with_htbeta_cols(d: usize, htbeta_cols: usize) -> Self {
35        Self {
36            htt: Array2::<f64>::zeros((d, d)),
37            htbeta: Array2::<f64>::zeros((d, htbeta_cols)),
38            gt: Array1::<f64>::zeros(d),
39        }
40    }
41}
42
43/// Bordered (t, β) Newton system with arrow structure.
44///
45/// The β-block is held as a dense `K × K` Hessian `H_ββ` plus a `K`-length
46/// gradient `g_β` for direct BA modes. Large-scale inexact BA callers may
47/// additionally install a matrix-free `H_ββ x` operator and diagonal via
48/// [`ArrowSchurSystem::set_shared_beta_operator`]; the InexactPCG mode then
49/// avoids dense Schur formation/factorization.
50/// The t-block is a `Vec<ArrowRowBlock>` of length `N`.
51///
52/// Construction is the driver's responsibility: the driver
53///
54///   1. evaluates Φ(t) and the radial jet `∂Φ/∂t` (the latter via
55///      [`gam_terms::latent::LatentCoordValues::design_gradient_wrt_t`]);
56///   2. forms the working-weighted Gauss–Newton blocks
57///      `H_tt^(i) += (g_i β)(g_i β)^T`, `H_tβ^(i) += (g_i β) ⊗ Φ_i`,
58///      `H_ββ += Φ^T W Φ + Σ_k λ_k S_k`;
59///   3. calls [`ArrowSchurSystem::add_analytic_penalty_contributions`] to
60///      fold row-block Psi-tier analytic penalties (`ARDPenalty`,
61///      `SparsityPenalty`) into `H_tt^(i)` and Beta-tier penalties into `H_ββ`;
62///   4. calls [`ArrowSchurSystem::solve`] to obtain `(Δt, Δβ)`.
63pub struct ArrowSchurSystem {
64    /// Per-row latent block (length `N`, each row `d × d` / `d × K` / `d`).
65    pub rows: Vec<ArrowRowBlock>,
66    /// `H_ββ`, shape `(K, K)` for direct BA modes; empty when constructed
67    /// by [`ArrowSchurSystem::new_matrix_free_shared`] for PCG-only use.
68    pub hbb: Array2<f64>,
69    /// Optional matrix-free `H_ββ x` operator for large BA Schur PCG.
70    ///
71    /// Direct and Square-Root BA modes still require `hbb`; InexactPCG uses
72    /// this operator when present, avoiding dense shared-block storage for
73    /// SAE-manifold scale `K`.
74    pub hbb_matvec: Option<SharedBetaMatvec>,
75    /// Optional row-local matrix-free multiply for `H_tβ^(i) x`.
76    ///
77    /// When present, all inner-Schur paths route through this operator instead
78    /// of indexing the per-row `htbeta` dense slabs: `reduced_rhs_beta`,
79    /// `schur_matvec` (PCG hot loop), back-substitution,
80    /// `JacobiPreconditioner` construction, `build_dense_schur_direct`, and
81    /// `build_dense_schur_sqrt_ba` all call `sys_htbeta_apply_row` or
82    /// `sys_htbeta_materialize_row`.  Factor caches retain the operator for
83    /// IFT/evidence consumers as before.
84    pub htbeta_matvec: Option<RowHtbetaMatvec>,
85    /// Optional row-local matrix-free transpose multiply `out += H_βt^(i) · v`.
86    ///
87    /// The sparse adjoint of [`Self::htbeta_matvec`]. When present, the
88    /// reduced-Schur matvec applies `H_βt^(i)` directly (sparse `scatter`)
89    /// instead of probing the forward operator against `K` basis vectors. This
90    /// is the per-row sparse apply that lifts the `O(K)` column-probe in the
91    /// GPU PCG and streaming Schur paths to `O(m_i · p)` per row. Installed in
92    /// lock-step with `htbeta_matvec` by [`Self::set_row_htbeta_operator`].
93    pub htbeta_transpose_matvec: Option<RowHtbetaTransposeMatvec>,
94    /// Whether `rows[*].htbeta` contains a dense contribution that must be added
95    /// on top of the matrix-free row operator.
96    pub htbeta_dense_supplement: bool,
97    /// Optional diagonal of the matrix-free shared block, used by the
98    /// Schur-Jacobi preconditioner in the Agarwal-style PCG path.
99    pub hbb_diag: Option<Array1<f64>>,
100    /// `g_β`, shape `(K,)`.
101    pub gb: Array1<f64>,
102    /// Maximum per-row latent dimensionality across all rows.
103    ///
104    /// For homogeneous systems (all rows have the same dim) this equals the
105    /// common per-row `d`.  For heterogeneous systems (e.g. sparse SAE rows
106    /// where JumpReLU / TopK / sparsemax active sets vary per observation)
107    /// this is `max_i row_dims[i]`.  Per-row code should use
108    /// `row.htt.nrows()` or `row_dims[i]`; `d` is an upper bound for
109    /// scratch-buffer sizing.
110    pub d: usize,
111    /// Per-row latent dimensionality: `row_dims[i] == rows[i].htt.nrows()`.
112    ///
113    /// For homogeneous systems `row_dims[i] == d` for all `i`.
114    pub row_dims: Arc<[usize]>,
115    /// Flat-buffer row offsets for the `delta_t` vector produced by
116    /// [`Self::solve`] / [`solve_arrow_newton_step_core`].
117    ///
118    /// `row_offsets[i]` is the start index for row `i`'s slice in `delta_t`;
119    /// `row_offsets[n]` is the total `delta_t` length.  For homogeneous
120    /// systems `row_offsets[i] == i * d`.
121    pub row_offsets: Arc<[usize]>,
122    /// β dimensionality `K`.
123    pub k: usize,
124    /// Geometry tag for the row-local latent blocks after optional
125    /// Riemannian projection. Euclidean/no-op geometry uses the sentinel.
126    pub manifold_mode_fingerprint: u64,
127    /// Structural/value tag for row-local Hessian factors and their Schur
128    /// inputs. Stale caches must be rejected when row-dependent Hessian
129    /// penalties or cross-blocks change.
130    pub row_hessian_fingerprint: u64,
131    /// Registry-side tag for row-dependent analytic-penalty Hessian inputs.
132    /// Combined with the materialized row blocks in
133    /// [`Self::current_row_hessian_fingerprint`].
134    pub analytic_row_hessian_fingerprint: u64,
135    /// Term-block column ranges for the block-Jacobi Schur preconditioner.
136    ///
137    /// Each entry `r` means that indices `r.start..r.end` belong to one
138    /// coefficient block (a GAM term or a custom parameter family from
139    /// `ParameterBlockSpec`). When populated via
140    /// [`Self::set_block_offsets`], the Jacobi preconditioner inverts the
141    /// full `b × b` Schur block for each term instead of only its diagonal.
142    ///
143    /// The default (empty slice) causes `JacobiPreconditioner` to fall back
144    /// to pure scalar diagonal inversion, preserving the pre-#283 behaviour.
145    pub block_offsets: Arc<[Range<usize>]>,
146    /// Optional matrix-free penalty-side `H_ββ` operator (#296).
147    ///
148    /// When set, all hot paths (`schur_matvec`, `build_dense_schur_*`,
149    /// `JacobiPreconditioner`, quadratic-form reduction) route through this
150    /// operator instead of the dense `hbb` accumulator, enabling
151    /// `BlockPenaltyOp` / `KroneckerPenaltyOp` to skip the `O(K²)` dense
152    /// materialisation for structured smoothness penalties.
153    ///
154    /// When `None`, those paths fall back to wrapping `hbb` in a transient
155    /// `DensePenaltyOp` — identical observable behaviour, no new allocation
156    /// hot-path cost for callers that have not opted in.
157    pub penalty_op: Option<Arc<dyn BetaPenaltyOp>>,
158    /// Device-uploadable SAE Kronecker data for CUDA-resident reduced PCG.
159    ///
160    /// The generic matrix-free closures remain the authoritative CPU path. This
161    /// descriptor is installed only when SAE assembly has a matching CUDA sparse
162    /// representation for both `H_tβ` and `H_ββ`.
163    pub device_sae_pcg: Option<Arc<DeviceSaePcgData>>,
164    /// Registered Psi-tier analytic penalties whose Hessian couples *distinct*
165    /// latent rows (non-row-block-diagonal), captured by
166    /// [`Self::add_analytic_penalty_contributions`].
167    ///
168    /// These penalties (`TotalVariationPenalty`, `SheafConsistencyPenalty`,
169    /// block-orthogonality, …) produce off-row Hessian blocks `∂²P/∂t_i∂t_j`
170    /// (`i ≠ j`) that the arrow elimination — which assumes each `H_tt^(i)` is
171    /// independent of every other row — cannot represent. Their *gradient* is
172    /// still folded into `g_t` exactly like every other Psi penalty; only their
173    /// curvature is held here, applied during the solve as a full-latent
174    /// Hessian-vector product `P_cross · Δt` against the penalty's
175    /// `psd_majorizer_hvp`. When this vector is non-empty,
176    /// [`solve_arrow_newton_step_artifacts`] auto-selects the matrix-free
177    /// full-system PCG path (arrow block-diagonal inverse as preconditioner)
178    /// instead of the exact one-shot Schur elimination. When empty, the system
179    /// is purely row-block-diagonal and the exact Schur path is unchanged.
180    pub cross_row_penalties: Vec<CrossRowLatentPenalty>,
181    /// Optional row-local gauge directions for evidence-only Faddeev-Popov
182    /// deflation of an otherwise non-PD `H_tt` row block.
183    ///
184    /// These vectors live in each row's actual chart block, so compact SAE rows
185    /// and dense rows share the same factorization path. Ordinary Newton solves
186    /// ignore them; only undamped evidence factors with
187    /// `tolerate_ill_conditioning` set may stiffen a gauge-explained row
188    /// direction.
189    pub row_gauge_deflation: Option<ArrowRowGaugeDeflation>,
190    /// Optional exact cross-row IBP low-rank source (#1038). When set, the
191    /// factorization downdates the per-row logit-slot self term and layers the
192    /// exact rank-`R` Woodbury correction onto the evidence cache (value,
193    /// log-determinant, and θ/ρ-adjoint together). `None` for all non-IBP
194    /// systems — the row-block-diagonal arrow path is then unchanged.
195    pub ibp_cross_row: Option<IbpCrossRowSource>,
196}
197
198impl Clone for ArrowSchurSystem {
199    fn clone(&self) -> Self {
200        Self {
201            rows: self.rows.clone(),
202            hbb: self.hbb.clone(),
203            hbb_matvec: self.hbb_matvec.clone(),
204            htbeta_matvec: self.htbeta_matvec.clone(),
205            htbeta_transpose_matvec: self.htbeta_transpose_matvec.clone(),
206            htbeta_dense_supplement: self.htbeta_dense_supplement,
207            hbb_diag: self.hbb_diag.clone(),
208            gb: self.gb.clone(),
209            d: self.d,
210            row_dims: Arc::clone(&self.row_dims),
211            row_offsets: Arc::clone(&self.row_offsets),
212            k: self.k,
213            manifold_mode_fingerprint: self.manifold_mode_fingerprint,
214            row_hessian_fingerprint: self.row_hessian_fingerprint,
215            analytic_row_hessian_fingerprint: self.analytic_row_hessian_fingerprint,
216            block_offsets: Arc::clone(&self.block_offsets),
217            penalty_op: self.penalty_op.clone(),
218            device_sae_pcg: self.device_sae_pcg.clone(),
219            cross_row_penalties: self.cross_row_penalties.clone(),
220            row_gauge_deflation: self.row_gauge_deflation.clone(),
221            ibp_cross_row: self.ibp_cross_row.clone(),
222        }
223    }
224}
225
226/// A captured cross-row Psi-tier analytic penalty: the penalty kind plus the
227/// global-ρ slice (`rho_local`) it was registered with.
228///
229/// Holds an owned copy of the local ρ-axes so the penalty's
230/// [`AnalyticPenaltyKind::psd_majorizer_hvp`] can be evaluated during the
231/// matrix-free full-system solve without re-deriving the ρ layout. The penalty
232/// itself is an `Arc`-backed clone (cheap), so capturing it does not copy the
233/// penalty payload.
234#[derive(Clone)]
235pub struct CrossRowLatentPenalty {
236    /// The non-row-block-diagonal Psi penalty (e.g. `TotalVariationPenalty`).
237    pub penalty: AnalyticPenaltyKind,
238    /// The penalty's local ρ-axes (its slice of the global ρ vector).
239    pub rho_local: Array1<f64>,
240    /// The flat latent vector (`N·d`, row-major) the penalty's curvature was
241    /// linearized at — i.e. the `target_t` passed to
242    /// [`ArrowSchurSystem::add_analytic_penalty_contributions`]. The Hessian of
243    /// a nonlinear penalty (the smoothed-TV curvature weights `φ''(D t)`,
244    /// etc.) depends on this point, so `psd_majorizer_hvp` must be evaluated
245    /// against it for the Newton operator to be the true Hessian at the
246    /// current iterate.
247    pub target_t: Array1<f64>,
248}
249
250/// Exact cross-row low-rank IBP source (#1038): the per-column rank-one Hessian
251/// terms `H_(i,k),(j,k) = d_k·z'_ik·z'_jk` (for ALL `i,j`, including the `i=j`
252/// self term) that couple DISTINCT latent rows through a shared atom column `k`.
253///
254/// Stacking over rows, this is `H_full = H₀' + U D Uᵀ`, where:
255/// * `U` is `delta_t_len × R` with `U[g, k] = z'_ik` at the global latent index
256///   `g` of row `i`'s logit slot for atom `k` (zero elsewhere) — i.e. column `k`
257///   is supported on the atom-`k` logit slot of every row;
258/// * `D = diag(d_k)`, `d_k = w·s'_k` ([`gam_terms::analytic_penalties::IbpHessianDiagThirdChannels::cross_row_d`]);
259/// * `H₀'` is the assembled latent block-diagonal `H₀` with the per-row self
260///   term `d_k·z'_ik²` REMOVED from each logit-slot diagonal (the assembled
261///   `H₀` already carries it, so the FULL rank-one outer product `U D Uᵀ` —
262///   which re-adds the `i=j` diagonal — would double-count without this
263///   downdate). The determinant lemma `log det(I_R + D UᵀH₀'⁻¹U)` is only the
264///   exact rank-`R` correction against this no-self base.
265///
266/// The arrow elimination assumes each row's `H_tt^(i)` is independent of every
267/// other row, so it structurally cannot hold this coupling block-locally. The
268/// factorization owner (`solver::arrow_schur`) consumes this source to (a)
269/// downdate the per-row logit diagonal before factoring, (b) build `U`/`D` onto
270/// the resulting [`ArrowFactorCache`] as a [`CrossRowWoodbury`], and (c) apply
271/// the exact Woodbury correction to the value/curvature solve, the evidence
272/// log-determinant, and the θ/ρ-adjoint TOGETHER (they all describe the SAME
273/// `H_full`).
274#[derive(Clone, Debug, Default)]
275pub struct IbpCrossRowSource {
276    /// Number of atom columns `R` (the rank of the cross-row update).
277    pub r: usize,
278    /// `d_k = w·s'_k`, the scalar `D`-coefficient of column `k`. Length `R`.
279    pub d: Array1<f64>,
280    /// Per-row column entries `(global_t_index, atom_k, z'_ik)`: each tuple
281    /// places `z'_ik` at `U[global_t_index, atom_k]`. The `global_t_index` is
282    /// `row_offsets[i] + local_slot` for the row's logit slot of atom `k`. Only
283    /// nonzero entries are listed (one per active (row, atom) pair).
284    pub entries: Vec<(usize, usize, f64)>,
285}
286
287impl IbpCrossRowSource {
288    /// Build the dense `delta_t_len × R` factor `U` (each column supported on
289    /// its atom's per-row logit slots) from the sparse entry list.
290    pub(crate) fn dense_u(&self, delta_t_len: usize) -> Array2<f64> {
291        let mut u = Array2::<f64>::zeros((delta_t_len, self.r));
292        for &(g, k, z) in &self.entries {
293            u[[g, k]] += z;
294        }
295        u
296    }
297
298    /// Per-row-slot self-term downdate: returns, for each global latent index,
299    /// the scalar `Σ_k d_k·z'_ik²` to subtract from that logit slot's diagonal
300    /// so the factored base is `H₀'` (self term removed). Indexed by global
301    /// `delta_t` position.
302    pub(crate) fn self_term_downdate(&self, delta_t_len: usize) -> Array1<f64> {
303        let mut down = Array1::<f64>::zeros(delta_t_len);
304        for &(g, k, z) in &self.entries {
305            down[g] += self.d[k] * z * z;
306        }
307        down
308    }
309}
310
311impl ArrowSchurSystem {
312    /// Allocate an empty BA reduced-camera-system instance sized
313    /// `(N point/latent rows × d, K shared decoder parameters)`.
314    pub fn new(n: usize, d: usize, k: usize) -> Self {
315        Self::new_with_hbb(n, d, k, Array2::<f64>::zeros((k, k)))
316    }
317
318    /// Allocate an arrow system with no dense shared `H_ββ` block and with
319    /// per-row dense `H_tβ` slabs allocated at `htbeta_cols` columns.
320    pub fn new_with_empty_hbb_and_htbeta_cols(
321        n: usize,
322        d: usize,
323        k: usize,
324        htbeta_cols: usize,
325    ) -> Self {
326        let rows = (0..n)
327            .map(|_| ArrowRowBlock::new_with_htbeta_cols(d, htbeta_cols))
328            .collect();
329        let row_dims: Arc<[usize]> = (0..n).map(|_| d).collect::<Vec<_>>().into();
330        let row_offsets: Arc<[usize]> = (0..=n).map(|i| i * d).collect::<Vec<_>>().into();
331        Self {
332            rows,
333            hbb: Array2::<f64>::zeros((0, 0)),
334            hbb_matvec: None,
335            htbeta_matvec: None,
336            htbeta_transpose_matvec: None,
337            htbeta_dense_supplement: false,
338            hbb_diag: None,
339            gb: Array1::<f64>::zeros(k),
340            d,
341            row_dims,
342            row_offsets,
343            k,
344            manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
345            row_hessian_fingerprint: 0,
346            analytic_row_hessian_fingerprint: 0,
347            block_offsets: Arc::from([] as [Range<usize>; 0]),
348            penalty_op: None,
349            device_sae_pcg: None,
350            cross_row_penalties: Vec::new(),
351            row_gauge_deflation: None,
352            ibp_cross_row: None,
353        }
354    }
355
356    /// Allocate an arrow system using a caller-owned dense shared-block buffer.
357    /// The buffer must already have shape `(k, k)` and is zeroed in place before
358    /// use so callers can recycle it across assemblies without changing
359    /// numerics.
360    pub fn new_with_hbb(n: usize, d: usize, k: usize, hbb: Array2<f64>) -> Self {
361        Self::new_with_hbb_and_htbeta_cols(n, d, k, hbb, k)
362    }
363
364    /// Allocate an arrow system with a caller-owned dense shared-block buffer and
365    /// per-row dense `H_tβ` slabs allocated at `htbeta_cols` columns.
366    pub fn new_with_hbb_and_htbeta_cols(
367        n: usize,
368        d: usize,
369        k: usize,
370        mut hbb: Array2<f64>,
371        htbeta_cols: usize,
372    ) -> Self {
373        assert_eq!(hbb.dim(), (k, k));
374        hbb.fill(0.0);
375        let rows = (0..n)
376            .map(|_| ArrowRowBlock::new_with_htbeta_cols(d, htbeta_cols))
377            .collect();
378        let row_dims: Arc<[usize]> = (0..n).map(|_| d).collect::<Vec<_>>().into();
379        let row_offsets: Arc<[usize]> = (0..=n).map(|i| i * d).collect::<Vec<_>>().into();
380        Self {
381            rows,
382            hbb,
383            hbb_matvec: None,
384            htbeta_matvec: None,
385            htbeta_transpose_matvec: None,
386            htbeta_dense_supplement: false,
387            hbb_diag: None,
388            gb: Array1::<f64>::zeros(k),
389            d,
390            row_dims,
391            row_offsets,
392            k,
393            manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
394            row_hessian_fingerprint: 0,
395            analytic_row_hessian_fingerprint: 0,
396            block_offsets: Arc::from([] as [Range<usize>; 0]),
397            penalty_op: None,
398            device_sae_pcg: None,
399            cross_row_penalties: Vec::new(),
400            row_gauge_deflation: None,
401            ibp_cross_row: None,
402        }
403    }
404
405    /// Allocate an arrow system whose shared `H_ββ` block is supplied only as
406    /// a matrix-free operator for large BA InexactPCG.
407    ///
408    /// Direct and Square-Root BA modes require dense `hbb` and must not be
409    /// used with this constructor. The row-local `H_tβ` slabs remain explicit;
410    /// a future MegBA backend can replace those slab operations behind
411    /// [`BatchedBlockSolver`].
412    pub fn new_matrix_free_shared<F>(
413        n: usize,
414        d: usize,
415        k: usize,
416        matvec: F,
417        diag: Array1<f64>,
418    ) -> Self
419    where
420        F: for<'a> Fn(ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
421    {
422        assert_eq!(diag.len(), k);
423        let rows = (0..n).map(|_| ArrowRowBlock::new(d, k)).collect();
424        let row_dims: Arc<[usize]> = (0..n).map(|_| d).collect::<Vec<_>>().into();
425        let row_offsets: Arc<[usize]> = (0..=n).map(|i| i * d).collect::<Vec<_>>().into();
426        let matvec_arc: SharedBetaMatvec = Arc::new(matvec);
427        // Mirror the closure into a BetaPenaltyOp so all hot paths (#296)
428        // route through the trait while preserving hbb_matvec + hbb_diag for
429        // code that inspects them directly.
430        let penalty_op: Option<Arc<dyn BetaPenaltyOp>> = Some(Arc::new(MatvecDiagPenaltyOp::new(
431            k,
432            Arc::clone(&matvec_arc),
433            diag.clone(),
434        )));
435        Self {
436            rows,
437            hbb: Array2::<f64>::zeros((0, 0)),
438            hbb_matvec: Some(matvec_arc),
439            htbeta_matvec: None,
440            htbeta_transpose_matvec: None,
441            htbeta_dense_supplement: false,
442            hbb_diag: Some(diag),
443            gb: Array1::<f64>::zeros(k),
444            d,
445            row_dims,
446            row_offsets,
447            k,
448            manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
449            row_hessian_fingerprint: 0,
450            analytic_row_hessian_fingerprint: 0,
451            block_offsets: Arc::from([] as [Range<usize>; 0]),
452            penalty_op,
453            device_sae_pcg: None,
454            cross_row_penalties: Vec::new(),
455            row_gauge_deflation: None,
456            ibp_cross_row: None,
457        }
458    }
459
460
461
462    /// Allocate a heterogeneous-row arrow system with no dense shared `H_ββ`
463    /// block and with row `H_tβ` slabs allocated at `htbeta_cols` columns.
464    pub fn new_with_per_row_dims_empty_hbb_and_htbeta_cols(
465        per_row_dims: Vec<usize>,
466        k: usize,
467        htbeta_cols: usize,
468    ) -> Self {
469        let n = per_row_dims.len();
470        let d = per_row_dims.iter().copied().max().unwrap_or(0);
471        let mut offsets = Vec::with_capacity(n + 1);
472        let mut cursor = 0usize;
473        offsets.push(cursor);
474        for &dim in &per_row_dims {
475            cursor += dim;
476            offsets.push(cursor);
477        }
478        let rows = per_row_dims
479            .iter()
480            .map(|&dim| ArrowRowBlock::new_with_htbeta_cols(dim, htbeta_cols))
481            .collect();
482        Self {
483            rows,
484            hbb: Array2::<f64>::zeros((0, 0)),
485            hbb_matvec: None,
486            htbeta_matvec: None,
487            htbeta_transpose_matvec: None,
488            htbeta_dense_supplement: false,
489            hbb_diag: None,
490            gb: Array1::<f64>::zeros(k),
491            d,
492            row_dims: Arc::from(per_row_dims.into_boxed_slice()),
493            row_offsets: Arc::from(offsets.into_boxed_slice()),
494            k,
495            manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
496            row_hessian_fingerprint: 0,
497            analytic_row_hessian_fingerprint: 0,
498            block_offsets: Arc::from([] as [Range<usize>; 0]),
499            penalty_op: None,
500            device_sae_pcg: None,
501            cross_row_penalties: Vec::new(),
502            row_gauge_deflation: None,
503            ibp_cross_row: None,
504        }
505    }
506
507    /// Allocate a heterogeneous-row system using a caller-owned dense shared
508    /// block and row `H_tβ` slabs allocated at `htbeta_cols` columns.
509    pub fn new_with_per_row_dims_and_hbb_and_htbeta_cols(
510        per_row_dims: Vec<usize>,
511        k: usize,
512        mut hbb: Array2<f64>,
513        htbeta_cols: usize,
514    ) -> Self {
515        assert_eq!(hbb.dim(), (k, k));
516        hbb.fill(0.0);
517        let n = per_row_dims.len();
518        let max_d = per_row_dims.iter().copied().max().unwrap_or(0);
519        let row_dims: Arc<[usize]> = per_row_dims.iter().copied().collect::<Vec<_>>().into();
520        let mut off_vec = Vec::with_capacity(n + 1);
521        let mut cursor = 0usize;
522        for &di in &per_row_dims {
523            off_vec.push(cursor);
524            cursor += di;
525        }
526        off_vec.push(cursor);
527        let row_offsets: Arc<[usize]> = off_vec.into();
528        let rows = per_row_dims
529            .iter()
530            .map(|&di| ArrowRowBlock::new_with_htbeta_cols(di, htbeta_cols))
531            .collect();
532        Self {
533            rows,
534            hbb,
535            hbb_matvec: None,
536            htbeta_matvec: None,
537            htbeta_transpose_matvec: None,
538            htbeta_dense_supplement: false,
539            hbb_diag: None,
540            gb: Array1::<f64>::zeros(k),
541            d: max_d,
542            row_dims,
543            row_offsets,
544            k,
545            manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
546            row_hessian_fingerprint: 0,
547            analytic_row_hessian_fingerprint: 0,
548            block_offsets: Arc::from([] as [Range<usize>; 0]),
549            penalty_op: None,
550            device_sae_pcg: None,
551            cross_row_penalties: Vec::new(),
552            row_gauge_deflation: None,
553            ibp_cross_row: None,
554        }
555    }
556
557    pub fn set_row_gauge_deflation(&mut self, deflation: ArrowRowGaugeDeflation) {
558        self.row_gauge_deflation = Some(deflation);
559    }
560
561    /// Register the exact cross-row IBP low-rank source (#1038). The assembly
562    /// passes the per-column `D`-coefficients (`cross_row_d`) and the `(global
563    /// latent index, atom, z'_ik)` entries built from `z_jac`; the factorization
564    /// then carries the exact rank-`R` Woodbury (value + log-determinant +
565    /// θ/ρ-adjoint) on the evidence cache. An empty source (`r == 0` or no
566    /// entries) is treated as absent so the row-block-diagonal path is unchanged.
567    pub fn set_ibp_cross_row_source(&mut self, source: IbpCrossRowSource) {
568        if source.r == 0 || source.entries.is_empty() {
569            self.ibp_cross_row = None;
570        } else {
571            self.ibp_cross_row = Some(source);
572        }
573    }
574
575    /// Number of BA point/latent rows `N`.
576    pub fn n(&self) -> usize {
577        self.rows.len()
578    }
579
580    /// Recompute the row-system fingerprint from the currently materialized
581    /// row blocks, cross-blocks, and shared-block diagonal.
582    pub fn compute_row_hessian_fingerprint(&self) -> u64 {
583        row_hessian_fingerprint_for_system(self)
584    }
585
586    /// Current effective row-system fingerprint, including the materialized
587    /// row blocks and any registry metadata captured while folding analytic
588    /// penalties into the system.
589    pub fn current_row_hessian_fingerprint(&self) -> u64 {
590        combine_row_and_registry_fingerprints(
591            self.compute_row_hessian_fingerprint(),
592            self.analytic_row_hessian_fingerprint,
593        )
594    }
595
596    /// Store the current row-system fingerprint on the system.
597    ///
598    /// This is intentionally explicit and expensive. Cache and evidence callers
599    /// use [`Self::current_row_hessian_fingerprint`] at the point they need the
600    /// value, after assembly has populated the system, instead of hashing each
601    /// intermediate construction/mutation step.
602    pub fn refresh_row_hessian_fingerprint(&mut self) {
603        self.row_hessian_fingerprint = self.current_row_hessian_fingerprint();
604    }
605
606    /// Install a matrix-free shared-block operator for Agarwal-style
607    /// inexact Schur PCG.
608    ///
609    /// `diag` must be the diagonal of the same `H_ββ` operator and is used
610    /// for the Schur-Jacobi preconditioner. This is the BA "large camera
611    /// system" path mapped to large decoder coefficient blocks.
612    pub fn set_shared_beta_operator<F>(&mut self, matvec: F, diag: Array1<f64>)
613    where
614        F: for<'a> Fn(ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
615    {
616        assert_eq!(diag.len(), self.k);
617        let matvec_arc: SharedBetaMatvec = Arc::new(matvec);
618        // Mirror the closure into a BetaPenaltyOp so all hot paths (#296)
619        // route through the trait, preserving the existing hbb_matvec +
620        // hbb_diag fields for code that inspects them directly.
621        self.penalty_op = Some(Arc::new(MatvecDiagPenaltyOp::new(
622            self.k,
623            Arc::clone(&matvec_arc),
624            diag.clone(),
625        )));
626        self.hbb_matvec = Some(matvec_arc);
627        self.hbb_diag = Some(diag);
628    }
629
630    /// Mark the dense per-row cross-block slabs as active supplements to the
631    /// installed matrix-free row operator.
632    pub fn activate_dense_htbeta_supplement(&mut self) {
633        self.htbeta_dense_supplement = true;
634    }
635
636    /// Install a matrix-free per-row cross-block operator and its sparse
637    /// adjoint.
638    ///
639    /// `forward` must write `out = H_tβ^(row) x` for `out.len() == d` and
640    /// `x.len() == K`. `transpose` must **add** `H_βt^(row) v` into `out` for
641    /// `out.len() == K` and `v.len() == d` (the sparse `scatter` adjoint).
642    ///
643    /// When installed, the forward operator is used during the Newton solve
644    /// (inside `reduced_rhs_beta`, `schur_matvec`, back-substitution, and
645    /// `JacobiPreconditioner` construction) and afterwards by IFT/evidence
646    /// predictors.  Per-row `htbeta` slabs in `ArrowRowBlock` may be left
647    /// zero-sized when this operator is installed — all inner-Schur paths route
648    /// through the matvec instead of indexing the dense block. The transpose
649    /// operator lets the reduced-Schur matvec apply `H_βt^(row)` directly
650    /// (`O(m_i · p)`) instead of probing `forward` against `K` basis vectors.
651    pub fn set_row_htbeta_operator<F, T>(&mut self, forward: F, transpose: T)
652    where
653        F: for<'a> Fn(usize, ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
654        T: for<'a> Fn(usize, ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
655    {
656        self.htbeta_matvec = Some(Arc::new(forward));
657        self.htbeta_transpose_matvec = Some(Arc::new(transpose));
658    }
659
660    /// Register term-block column ranges for the block-Jacobi Schur preconditioner.
661    ///
662    /// Each `Range<usize>` covers the columns of one GAM term (or custom
663    /// parameter family) in the shared `β` vector. The ranges must be
664    /// non-overlapping, sorted, and their union must cover `0..k`.
665    ///
666    /// Call this after building the system and before [`Self::solve`] /
667    /// [`Self::solve_with_options`] whenever the solver will use
668    /// [`ArrowSolverMode::InexactPCG`]. Absent a call, the preconditioner
669    /// falls back to scalar diagonal Jacobi (the pre-#283 behaviour).
670    ///
671    /// The same plumbing is compatible with #287 (custom `ParameterBlockSpec`
672    /// families): callers from that path simply supply ranges derived from
673    /// their own block layout.
674    pub fn set_block_offsets(&mut self, offsets: Arc<[Range<usize>]>) {
675        self.block_offsets = offsets;
676    }
677
678    /// Install a matrix-free penalty-side `H_ββ` operator (#296).
679    ///
680    /// When set, all hot paths (`schur_matvec`, `build_dense_schur_*`,
681    /// `JacobiPreconditioner`, quadratic-form reduction) route through this
682    /// operator instead of the dense `hbb` accumulator, enabling
683    /// `BlockPenaltyOp` / `KroneckerPenaltyOp` to avoid `O(K²)` allocation
684    /// for structured smoothness penalties.
685    pub fn set_penalty_op(&mut self, op: Arc<dyn BetaPenaltyOp>) {
686        self.penalty_op = Some(op);
687    }
688
689    pub fn set_device_sae_pcg_data(&mut self, data: DeviceSaePcgData) {
690        assert_eq!(data.beta_dim, self.k);
691        // The frames-engaged builder (`build_framed_device_sae_data`) carries the
692        // per-row cross block through `frame.frame_blocks` and intentionally leaves
693        // the full-`B` `a_phi`/`local_jac` slabs EMPTY (#1033). Only the non-framed
694        // full-`B` path populates those per-row slabs, so the length contract
695        // applies only when there is no frame.
696        if data.frame.is_none() {
697            assert_eq!(data.a_phi.len(), self.rows.len());
698            assert_eq!(data.local_jac.len(), self.rows.len());
699        }
700        self.device_sae_pcg = Some(Arc::new(data));
701    }
702
703    /// Return the effective penalty operator: the installed `penalty_op` if
704    /// present, otherwise a `DensePenaltyOp` wrapping the current `hbb`.
705    ///
706    /// Note: when `penalty_op` is `None`, this clones `hbb` into a new
707    /// `DensePenaltyOp`. Callers in hot loops should call this once and
708    /// store the result, not call it per-iteration.
709    pub fn effective_penalty_op(&self) -> Arc<dyn BetaPenaltyOp> {
710        match self.penalty_op.as_ref() {
711            Some(op) => Arc::clone(op),
712            None => Arc::new(DensePenaltyOp(self.hbb.clone())),
713        }
714    }
715
716    /// `y += P x` without allocating a new Arc; dispatches to `penalty_op`
717    /// or falls back to `hbb` inline, avoiding the K×K clone hot-path cost.
718    #[inline]
719    pub(crate) fn penalty_matvec_add(&self, x: &[f64], y: &mut [f64]) {
720        if let Some(op) = self.penalty_op.as_ref() {
721            op.matvec(x, y);
722        } else {
723            let k = self.hbb.nrows();
724            for a in 0..k {
725                let mut acc = 0.0_f64;
726                for b in 0..k {
727                    acc += self.hbb[[a, b]] * x[b];
728                }
729                y[a] += acc;
730            }
731        }
732    }
733
734    /// Reduced-Schur matvec prologue `y = (P + ridge·I) x` written fresh into a
735    /// zeroed `y` (the caller clears `out` first; this is the first writer).
736    ///
737    /// At the SAE LLM border width (#1017) the dense `H_ββ` fallback is a `k×k`
738    /// GEMV whose `O(k²)` cost (≈4M flops at k=2048) runs once per CG iteration
739    /// and was the serial Amdahl ceiling on the per-row-parallel matvec: while
740    /// the `n`-row point-elimination term fans across all cores, this prologue
741    /// pinned one core and grows as `k²`. The dense GEMV is embarrassingly
742    /// parallel over output rows `a` — each `y[a] = Σ_b hbb[a,b]·x[b] + ridge·x[a]`
743    /// is independent and its inner sum order is identical whether one thread or
744    /// many compute it. Here parallelism is over independent output rows (NOT a
745    /// reassociated reduction), so each `y[a]` accumulates in the SAME order as
746    /// serial — the result is bit-identical to serial, not merely deterministic
747    /// run-to-run (the #1017 determinism gate). On THIS exact-order path the
748    /// criterion ranking is invariant; that no-move guarantee holds because the
749    /// order matches serial, and does NOT generalise to chunk-reassociated
750    /// reductions, where a near-tie winner can flip within the f64 margin
751    /// (#1211). The `penalty_op` path stays serial — it is an opaque operator
752    /// with its own structure (SAE uses the dense `hbb`), and small `k` stays
753    /// serial to avoid rayon overhead on a trivial GEMV.
754    ///
755    /// `parallel` is the caller's top-level / not-nested-in-rayon decision (the
756    /// same guard the row loop uses), so this never oversubscribes inside the
757    /// topology race.
758    pub(crate) fn penalty_ridge_prologue_into(
759        &self,
760        x: &[f64],
761        ridge: f64,
762        y: &mut [f64],
763        parallel: bool,
764    ) {
765        let k = self.hbb.nrows();
766        let dense_parallel = parallel
767            && self.penalty_op.is_none()
768            && self.hbb.dim() == (k, k)
769            && k >= SCHUR_PROLOGUE_PARALLEL_K_MIN;
770        if dense_parallel {
771            use rayon::prelude::*;
772            let hbb = &self.hbb;
773            y.par_iter_mut().enumerate().for_each(|(a, ya)| {
774                let mut acc = 0.0_f64;
775                for b in 0..k {
776                    acc += hbb[[a, b]] * x[b];
777                }
778                *ya = acc + ridge * x[a];
779            });
780        } else {
781            self.penalty_matvec_add(x, y);
782            for a in 0..k {
783                y[a] += ridge * x[a];
784            }
785        }
786    }
787
788    /// `diag += diag(P)` without allocating; dispatches to `penalty_op`
789    /// or falls back to `hbb` diagonal / `hbb_diag` inline.
790    #[inline]
791    pub(crate) fn penalty_diagonal_add(&self, diag: &mut [f64]) {
792        if let Some(op) = self.penalty_op.as_ref() {
793            op.diagonal(diag);
794        } else if let Some(hbb_diag) = self.hbb_diag.as_ref() {
795            let k = hbb_diag.len().min(diag.len());
796            for j in 0..k {
797                diag[j] += hbb_diag[j];
798            }
799        } else {
800            let k = self.hbb.nrows().min(diag.len());
801            for j in 0..k {
802                diag[j] += self.hbb[[j, j]];
803            }
804        }
805    }
806
807    /// Add the `b×b` penalty sub-block for `id` to `out`, routing through
808    /// `penalty_op` or falling back to `hbb` / `hbb_diag` inline.
809    #[inline]
810    pub(crate) fn penalty_block_add(
811        &self,
812        id: BetaBlockId,
813        offsets: &[Range<usize>],
814        out: &mut Array2<f64>,
815    ) {
816        if let Some(op) = self.penalty_op.as_ref() {
817            op.block(id, offsets, out);
818        } else {
819            let range = &offsets[id.0];
820            let b = range.end - range.start;
821            if self.hbb.dim() == (self.k, self.k) {
822                for bi in 0..b {
823                    for bj in 0..b {
824                        out[[bi, bj]] += self.hbb[[range.start + bi, range.start + bj]];
825                    }
826                }
827            } else if let Some(hbb_diag) = self.hbb_diag.as_ref() {
828                for bi in 0..b {
829                    out[[bi, bi]] += hbb_diag[range.start + bi];
830                }
831            }
832        }
833    }
834
835    /// Fill a `b×b` penalty sub-block for a set of arbitrary (possibly
836    /// non-contiguous) global column indices `cols`, routing through
837    /// `penalty_op` or falling back to `hbb` / `hbb_diag` inline.
838    ///
839    /// Used by the cluster-Jacobi preconditioner (#299) which groups columns
840    /// by spectral adjacency rather than contiguous block ranges.
841    #[inline]
842    pub(crate) fn penalty_subblock_add(&self, cols: &[usize], out: &mut Array2<f64>) {
843        let b = cols.len();
844        if let Some(op) = self.penalty_op.as_ref() {
845            // Probe each column basis vector and extract the sub-block entries.
846            let mut probe = Array1::<f64>::zeros(self.k);
847            let mut result = Array1::<f64>::zeros(self.k);
848            for bj in 0..b {
849                probe.fill(0.0);
850                probe[cols[bj]] = 1.0;
851                result.fill(0.0);
852                {
853                    let p_slice = probe.as_slice().expect("probe contiguous");
854                    let r_slice = result.as_slice_mut().expect("result contiguous");
855                    op.matvec(p_slice, r_slice);
856                }
857                for bi in 0..b {
858                    out[[bi, bj]] += result[cols[bi]];
859                }
860            }
861        } else if self.hbb.dim() == (self.k, self.k) {
862            for bi in 0..b {
863                for bj in 0..b {
864                    out[[bi, bj]] += self.hbb[[cols[bi], cols[bj]]];
865                }
866            }
867        } else if let Some(hbb_diag) = self.hbb_diag.as_ref() {
868            for bi in 0..b {
869                out[[bi, bi]] += hbb_diag[cols[bi]];
870            }
871        }
872    }
873
874    /// Fold analytic-penalty contributions into the appropriate blocks.
875    ///
876    /// BA source mapping: these are extra prior/regularization normal-equation
877    /// terms before point elimination, the same place Ceres/g2o attach robust
878    /// priors or gauge-fixing constraints.
879    ///
880    /// **Composition path.** Each registered [`AnalyticPenaltyKind`] is
881    /// queried for `grad_target` (added to `g_t` or `g_β`) and then for
882    /// `hessian_diag` first. Diagonal penalties (ARD and the shipped
883    /// sparsity kernels) are injected directly. The row-block-only Psi-tier
884    /// penalties are `ARDPenalty`, `SparsityPenalty`,
885    /// `SoftmaxAssignmentSparsity`, `IBPAssignment`,
886    /// `RowPrecisionPrior`, `ParametricRowPrecisionPrior`, and
887    /// `ScadMcpPenalty`. Their `d × d` per-row Hessian folds into
888    /// `rows[i].htt`, so the exact arrow Schur elimination (`N` independent
889    /// `d × d` row solves) represents them exactly. Dense Beta-tier penalties
890    /// still fall back to `hvp` probes against the canonical basis vectors for
891    /// `β`.
892    ///
893    /// **Cross-row Psi penalties.** Penalties whose Hessian couples *distinct*
894    /// latent rows — `TotalVariationPenalty`, `SheafConsistencyPenalty`,
895    /// block-orthogonality, … — produce off-row blocks `∂²P/∂t_i∂t_j`
896    /// (`i ≠ j`) that the arrow elimination cannot store, since it assumes each
897    /// `H_tt^(i)` is independent of every other row. These are handled without
898    /// any approximation: their **gradient** is folded into `g_t` exactly as
899    /// for every other Psi penalty (`grad_target → g_t`), and their full
900    /// **curvature** is captured into [`Self::cross_row_penalties`] as a
901    /// matrix-free operator. At solve time, `K = K0 + P_cross` where `K0` is
902    /// the block-diagonal arrow operator and `P_cross · Δt = Σ_p ρ_p ·
903    /// psd_majorizer_hvp_p(t, Δt)` is the cross-row penalty Hessian applied to
904    /// the full flat latent vector. The presence of any captured cross-row
905    /// penalty auto-routes [`Self::solve`] through the matrix-free full-system
906    /// PCG path (the exact arrow block-diagonal inverse `K0⁻¹` is the
907    /// preconditioner `M⁻¹`); a purely row-block-diagonal system keeps the
908    /// exact one-shot Schur path unchanged. No new flag is involved — the route
909    /// is selected from the captured penalty set alone (magic by default).
910    ///
911    /// `target_t` is the full flat latent-coordinate vector (row-major, `N·d` entries)
912    /// at the current iterate; `target_beta` is the current `β`. `rho`
913    /// is the global ρ vector restricted to each penalty's local slice
914    /// by [`AnalyticPenaltyRegistry::rho_layout`].
915    pub fn add_analytic_penalty_contributions(
916        &mut self,
917        registry: &AnalyticPenaltyRegistry,
918        target_t: ArrayView1<'_, f64>,
919        target_beta: ArrayView1<'_, f64>,
920        rho_global: ArrayView1<'_, f64>,
921    ) -> Result<(), ArrowSchurError> {
922        let layout = registry.rho_layout();
923        let mut penalty_fingerprints = Vec::new();
924        self.cross_row_penalties.clear();
925        for (penalty, (rho_slice, tier, _name)) in registry.penalties.iter().zip(layout.iter()) {
926            let rho_local = rho_global.slice(ndarray::s![rho_slice.clone()]);
927            match tier {
928                PenaltyTier::Psi => {
929                    if analytic_penalty_is_row_block_diagonal(penalty) {
930                        // Row-block-diagonal: fold gradient + per-row d×d
931                        // curvature into rows[i].htt, exactly representable by
932                        // the arrow Schur elimination.
933                        self.add_ext_coord_penalty(penalty, target_t, rho_local);
934                        if let Some(fingerprint) =
935                            analytic_penalty_row_hessian_fingerprint(penalty, target_t, rho_local)
936                        {
937                            penalty_fingerprints.push(fingerprint);
938                        }
939                    } else {
940                        // Cross-row: fold the gradient into g_t (exact, like
941                        // every Psi penalty), but DO NOT fold any curvature into
942                        // the row blocks — its off-row coupling cannot be stored
943                        // there. Capture the penalty so the solve applies its
944                        // full Hessian-vector product P_cross·Δt over the flat
945                        // latent vector. This auto-selects the matrix-free
946                        // full-system PCG path.
947                        self.add_ext_coord_penalty_gradient_only(penalty, target_t, rho_local);
948                        self.cross_row_penalties.push(CrossRowLatentPenalty {
949                            penalty: penalty.clone(),
950                            rho_local: rho_local.to_owned(),
951                            target_t: target_t.to_owned(),
952                        });
953                    }
954                }
955                PenaltyTier::Beta => {
956                    self.add_beta_penalty(penalty, target_beta, rho_local);
957                }
958                PenaltyTier::Rho => {
959                    // Rho-tier hyperpriors do not contribute to the inner
960                    // (t, β) Newton step; they enter only at the REML
961                    // outer level.
962                }
963            }
964        }
965        // Cross-row penalties contribute to the Newton Hessian operator, not
966        // the stored row blocks, so they must still invalidate the row-Hessian
967        // cache when their curvature changes. Probe each captured penalty's PSD
968        // majorizer against the current latent vector (a deterministic, generic
969        // probe) and fold the resulting signature in.
970        for cross in &self.cross_row_penalties {
971            penalty_fingerprints.push(cross_row_penalty_fingerprint(
972                &cross.penalty,
973                target_t,
974                cross.rho_local.view(),
975            ));
976        }
977        self.analytic_row_hessian_fingerprint = if penalty_fingerprints.is_empty() {
978            0
979        } else {
980            let mut hasher = Fingerprinter::new();
981            hasher.write_str("arrow-schur-row-hessian-registry-v1");
982            hasher.write_usize(penalty_fingerprints.len());
983            for fingerprint in penalty_fingerprints {
984                hasher.write_u64(fingerprint);
985            }
986            hasher.finish_u64()
987        };
988        Ok(())
989    }
990
991    /// Convert row-local Euclidean latent blocks to Riemannian tangent blocks.
992    ///
993    /// This is the only arrow-Schur algebra change needed for manifold
994    /// latents: `g_t`, `H_tt`, and each `H_tβ` column are projected to
995    /// `T_{t_i}M`, while the shared β block and Schur structure remain
996    /// untouched. Embedded constrained manifolds carry a pinned normal block
997    /// so the existing ambient Cholesky factorization still works; all RHS
998    /// terms live in the tangent space, so the solved update retracts cleanly.
999    pub fn apply_riemannian_latent_geometry(&mut self, latent: &LatentCoordValues) {
1000        let manifold = latent.manifold();
1001        self.manifold_mode_fingerprint = manifold_mode_fingerprint(latent);
1002        if manifold.is_euclidean() {
1003            return;
1004        }
1005        assert_eq!(latent.n_obs(), self.rows.len());
1006        assert_eq!(latent.latent_dim(), self.d);
1007        for (i, row) in self.rows.iter_mut().enumerate() {
1008            let t_i = ArrayView1::from(latent.row(i));
1009            let gt_e = row.gt.clone();
1010            let htt_e = row.htt.clone();
1011            let htbeta_e = row.htbeta.clone();
1012            row.gt = manifold.project_gradient_to_tangent(t_i, gt_e.view());
1013            row.htt = manifold.riemannian_hessian_matrix(t_i, gt_e.view(), htt_e.view());
1014            row.htbeta = manifold.project_matrix_columns_to_gradient_tangent(
1015                t_i,
1016                gt_e.view(),
1017                htbeta_e.view(),
1018            );
1019        }
1020    }
1021
1022    pub(crate) fn add_ext_coord_penalty(
1023        &mut self,
1024        penalty: &AnalyticPenaltyKind,
1025        target_t: ArrayView1<'_, f64>,
1026        rho_local: ArrayView1<'_, f64>,
1027    ) {
1028        let d = self.d;
1029        let n = self.rows.len();
1030        apply_analytic_penalty(
1031            penalty,
1032            target_t,
1033            rho_local,
1034            n * d,
1035            d,
1036            self,
1037            |sys, flat, value| sys.rows[flat / d].gt[flat % d] += value,
1038            |sys, flat, value| sys.rows[flat / d].htt[[flat % d, flat % d]] += value,
1039            |a, probe| {
1040                for i in 0..n {
1041                    probe[i * d + a] = 1.0;
1042                }
1043            },
1044            |sys, a, hv| {
1045                for i in 0..n {
1046                    for b in 0..d {
1047                        sys.rows[i].htt[[b, a]] += hv[i * d + b];
1048                    }
1049                }
1050            },
1051        );
1052    }
1053
1054    /// Fold ONLY the latent gradient `grad_target → g_t` of an analytic
1055    /// penalty, leaving the row-block Hessian untouched.
1056    ///
1057    /// Used for cross-row Psi penalties: their gradient enters `g_t` exactly
1058    /// like every other Psi penalty, but their curvature must NOT be scattered
1059    /// into the per-row `H_tt^(i)` blocks (the diagonal piece would be
1060    /// double-counted and the off-row coupling cannot be stored there). The
1061    /// full curvature is instead applied as a matrix-free `P_cross · Δt`
1062    /// during the solve, via [`Self::cross_row_penalties`].
1063    pub(crate) fn add_ext_coord_penalty_gradient_only(
1064        &mut self,
1065        penalty: &AnalyticPenaltyKind,
1066        target_t: ArrayView1<'_, f64>,
1067        rho_local: ArrayView1<'_, f64>,
1068    ) {
1069        let d = self.d;
1070        let n = self.rows.len();
1071        assert_eq!(target_t.len(), n * d);
1072        let grad = penalty.grad_target(target_t, rho_local);
1073        for flat in 0..n * d {
1074            self.rows[flat / d].gt[flat % d] += grad[flat];
1075        }
1076    }
1077
1078    /// Apply the aggregate cross-row penalty Hessian `P_cross · v` over the
1079    /// full flat latent vector `v` (length `Σ_i row_dims[i]`), accumulating
1080    /// into `out`.
1081    ///
1082    /// `P_cross = Σ_p psd_majorizer_hvp_p(target_t, ·; ρ_p)` summed over every
1083    /// captured cross-row penalty. Each penalty's `psd_majorizer_hvp` is its
1084    /// exact (PSD) Hessian-vector product over the `N·d` flat latent vector —
1085    /// for `TotalVariationPenalty` this is `Dᵀ diag(φ''(D t)) D · v`, the
1086    /// graph/forward-difference Laplacian-style coupling that links distinct
1087    /// rows. The ρ scaling is already baked into each penalty's resolved
1088    /// weight, so no extra factor is applied here.
1089    ///
1090    /// This is only valid for homogeneous systems (every row of dimension
1091    /// `d`), the only shape cross-row latent penalties are defined on; the
1092    /// flat-index convention `flat = i·d + j` matches every penalty's
1093    /// `latent_dim`/row-major contract.
1094    pub(crate) fn apply_cross_row_penalty_hessian(
1095        &self,
1096        v: ArrayView1<'_, f64>,
1097        out: &mut Array1<f64>,
1098    ) {
1099        for cross in &self.cross_row_penalties {
1100            assert_eq!(cross.target_t.len(), v.len());
1101            let hv =
1102                cross
1103                    .penalty
1104                    .psd_majorizer_hvp(cross.target_t.view(), cross.rho_local.view(), v);
1105            assert_eq!(hv.len(), out.len());
1106            for i in 0..out.len() {
1107                out[i] += hv[i];
1108            }
1109        }
1110    }
1111
1112    pub(crate) fn add_beta_penalty(
1113        &mut self,
1114        penalty: &AnalyticPenaltyKind,
1115        target_beta: ArrayView1<'_, f64>,
1116        rho_local: ArrayView1<'_, f64>,
1117    ) {
1118        let k = self.k;
1119        let hvp_columns = if self.hbb.dim() == (k, k) { k } else { 0 };
1120        apply_analytic_penalty(
1121            penalty,
1122            target_beta,
1123            rho_local,
1124            k,
1125            hvp_columns,
1126            self,
1127            |sys, j, value| sys.gb[j] += value,
1128            |sys, j, value| {
1129                if sys.hbb.dim() == (k, k) {
1130                    sys.hbb[[j, j]] += value;
1131                }
1132                if let Some(hbb_diag) = sys.hbb_diag.as_mut() {
1133                    hbb_diag[j] += value;
1134                }
1135            },
1136            |j, probe| probe[j] = 1.0,
1137            |sys, j, hv| {
1138                for i in 0..k {
1139                    sys.hbb[[i, j]] += hv[i];
1140                }
1141                // Keep `hbb_diag` consistent with the dense `hbb` Hessian when
1142                // both are populated (the dense-allocated path + a later
1143                // `set_shared_beta_operator` install). The HVP probe for
1144                // column `j` returns the full Hessian column, whose `j`-th
1145                // entry is the diagonal contribution of this penalty. Without
1146                // this mirror, the Jacobi Schur preconditioner — which prefers
1147                // `hbb_diag` over `hbb`'s diagonal — would silently use a
1148                // stale diagonal for any Beta-tier analytic penalty that
1149                // exposes only an HVP (no `hessian_diag`).
1150                if let Some(hbb_diag) = sys.hbb_diag.as_mut() {
1151                    hbb_diag[j] += hv[j];
1152                }
1153            },
1154        );
1155    }
1156
1157    /// Schur-eliminate the per-row latent block and solve for `(Δt, Δβ, diag)`.
1158    ///
1159    /// This uses [`ArrowSolveOptions::automatic`]: BA dense RCS for
1160    /// `K <= 2000`, and Agarwal-style inexact Schur PCG above that size.
1161    /// Call [`ArrowSchurSystem::solve_with_options`] to force Square-Root BA
1162    /// or a specific inexact solve policy.
1163    ///
1164    /// Returns `(delta_t, delta_beta, PcgDiagnostics)` with `delta_t` flat
1165    /// row-major of length `N · d` and `delta_beta` of length `K`. The sign
1166    /// convention matches `solve_newton_direction_dense`: the returned
1167    /// increments satisfy the bordered system with RHS `[-g_t; -g_β]`, i.e.
1168    /// they are the *negated* solutions of the standard Newton-direction
1169    /// formulation. `PcgDiagnostics` is zero-valued for the Direct path and
1170    /// carries live counters (PCG iters, ridge escalations, residual) for
1171    /// InexactPCG.
1172    ///
1173    /// `ridge_t` and `ridge_beta` are nonnegative diagonal regularizers
1174    /// added to the latent and β blocks respectively before factorization
1175    /// — used by the LM damping outer wrapper to recover from near-singular
1176    /// inner steps. Pass `0.0` for both to obtain the unregularized
1177    /// Newton direction.
1178    pub fn solve(
1179        &self,
1180        ridge_t: f64,
1181        ridge_beta: f64,
1182    ) -> Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError> {
1183        let options = ArrowSolveOptions::automatic(self.k);
1184        solve_arrow_newton_step_core(self, ridge_t, ridge_beta, &options)
1185    }
1186
1187    /// Solve with the standard LM-style ridge escalation: if a per-row
1188    /// `H_tt + ridge_t·I` Cholesky pivot is non-PD, or the reduced Schur
1189    /// factor fails, geometrically grow both ridges and retry. This is the
1190    /// same Ceres-style proximal correction the Newton driver in
1191    /// `run_joint_fit_arrow_schur` performs around `solve`, lifted into the
1192    /// system itself so every entry point (predict OOS reconstruction,
1193    /// single-shot Newton refinement, …) is self-healing against the
1194    /// pathological per-row blocks produced by PCA-seeded latent
1195    /// coordinates on subset / new data — see #163 and #175.
1196    ///
1197    /// `ridge_t` / `ridge_beta` are the caller-nominal Tikhonov ridges; the
1198    /// escalation only adds extra damping on top of them when the factor
1199    /// fails. PCG / AdaptiveCorrection failures are left untouched because
1200    /// they are not factorization-recoverable.
1201    pub fn solve_with_lm_escalation(
1202        &self,
1203        ridge_t: f64,
1204        ridge_beta: f64,
1205    ) -> Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError> {
1206        let options = ArrowSolveOptions::automatic(self.k);
1207        solve_with_lm_escalation_inner(self, ridge_t, ridge_beta, &options)
1208    }
1209
1210    /// Solve with an explicit BA Schur mode, returning `(Δt, Δβ, PcgDiagnostics)`.
1211    ///
1212    /// [`ArrowSolverMode::Direct`] is the classic dense reduced-camera-system
1213    /// Cholesky path; [`ArrowSolverMode::SqrtBA`] forms the same dense system
1214    /// through Square-Root BA factors; [`ArrowSolverMode::InexactPCG`] runs
1215    /// inexact-step LM on the reduced system with Jacobi-preconditioned
1216    /// Steihaug-CG. `PcgDiagnostics` is zero-valued for Direct/SqrtBA and
1217    /// carries live counters for InexactPCG (iterations, matvec calls,
1218    /// preconditioner escalations, final relative residual, stopping reason).
1219    pub fn solve_with_options(
1220        &self,
1221        ridge_t: f64,
1222        ridge_beta: f64,
1223        options: &ArrowSolveOptions,
1224    ) -> Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError> {
1225        solve_arrow_newton_step_core(self, ridge_t, ridge_beta, options)
1226    }
1227}
1228
1229/// Chunked Schur assembler that never retains all row cross-blocks.
1230pub struct StreamingArrowSchur {
1231    pub n_rows: usize,
1232    /// Maximum per-row latent dim (upper bound for scratch buffers).
1233    pub d: usize,
1234    /// Per-row latent dims `row_dims[i] == rows[i].htt.nrows()`.
1235    pub row_dims: Arc<[usize]>,
1236    /// Flat-buffer row offsets: `row_offsets[i]` is the start of row `i` in
1237    /// `delta_t`; `row_offsets[n_rows]` is the total `delta_t` length.
1238    pub row_offsets: Arc<[usize]>,
1239    pub k: usize,
1240    pub chunk_size: usize,
1241    pub s_acc: Array2<f64>,
1242    pub(crate) rhs_acc: Array1<f64>,
1243    pub(crate) hbb: Array2<f64>,
1244    pub(crate) gb: Array1<f64>,
1245    pub(crate) row_builder: StreamingArrowRowBuilder,
1246    /// Procedural cross-block operator `H_tβ^(i) x`. When present, the dense
1247    /// per-row `H_tβ` slabs are never materialized: `accumulate_chunk` and
1248    /// `back_substitute` probe this operator column-by-column to apply the
1249    /// cross-block, matching the Kronecker / matrix-free assembly path. When
1250    /// `None` (legacy dense BA callers), the per-row `row.htbeta` slab is used.
1251    pub(crate) htbeta_matvec: Option<RowHtbetaMatvec>,
1252    /// Sparse adjoint of `htbeta_matvec`. When present, `row_htbeta` rebuilds
1253    /// the dense `(d_i × K)` cross-block by probing the transpose with `d_i`
1254    /// basis vectors — `O(d_i · m_i · p)` total, vs the `O(K · m_i · p)` cost
1255    /// of probing the forward operator with `K` basis vectors. Since
1256    /// `d_i ≪ K`, this is the per-row sparse apply that replaces the `O(K)`
1257    /// column-probe in the streaming reduced-Schur accumulation.
1258    pub(crate) htbeta_transpose_matvec: Option<RowHtbetaTransposeMatvec>,
1259    /// Lift the per-row κ rejection for evidence/log-det-only solves; see
1260    /// [`ArrowSolveOptions::tolerate_ill_conditioning`]. Set by [`Self::solve`]
1261    /// from the options; defaults to `false` so direct callers of
1262    /// [`Self::accumulate_chunk`] keep the full guard.
1263    pub(crate) tolerate_ill_conditioning: bool,
1264    /// Set when the source system carried an exact cross-row IBP source
1265    /// ([`IbpCrossRowSource`], #1038). The streaming chunked accumulator cannot
1266    /// hold the rank-`R` Woodbury correction chunk-locally — `U`'s columns span
1267    /// ALL rows, so the capacitance `I_R + D Uᵀ H₀'⁻¹ U` needs the per-row
1268    /// factors retained for a global `H₀'⁻¹U` back-solve, which is exactly the
1269    /// `(N·K)`-scale residency the streaming path exists to avoid. Rather than
1270    /// silently DROP the cross-row term (an inexact logdet that would desync
1271    /// from the dense-resident gradient), the streaming log-determinant errors
1272    /// loudly when this is set, forcing IBP-active fits onto the dense resident
1273    /// [`ArrowFactorCache::arrow_log_det`] path (which carries the exact
1274    /// Woodbury). See the #1038 streaming note.
1275    pub(crate) ibp_cross_row_active: bool,
1276    /// SAE manifold evidence-path per-row gauge deflation, copied from the
1277    /// source [`ArrowSchurSystem::row_gauge_deflation`] (#1273/#1377). When
1278    /// present, the streaming per-row factor MUST apply the SAME spectral
1279    /// discovery-and-deflation of an intrinsic-dimension-flat `H_tt^(i)`
1280    /// direction (eigenvalue → +1, ρ-independent `log 1 = 0` evidence) that the
1281    /// dense [`factor_blocks_for_system`] path applies, or the two routes report
1282    /// different log-determinants for the SAME system — the cross-route
1283    /// invariant `streaming_logdet == full_logdet` would break (the #1377
1284    /// regression: #1273 wired the deflation into the dense path only). `None`
1285    /// for every non-evidence caller, which keeps the strict non-PD refusal.
1286    pub(crate) row_gauge_deflation: Option<ArrowRowGaugeDeflation>,
1287    /// The exact cross-row IBP source ([`IbpCrossRowSource`], #1038), cloned from
1288    /// the assembled [`ArrowSchurSystem`]. The bare streaming paths
1289    /// ([`Self::solve`] / [`Self::reduced_schur_and_log_det_tt`]) still REFUSE
1290    /// when this is present (they cannot carry the rank-`R` Woodbury), but the
1291    /// evidence-only [`Self::reduced_schur_log_det_tt_woodbury`] consumes it to
1292    /// accumulate the chunk-additive Woodbury building blocks `M0 = Uᵀ A⁻¹ U`
1293    /// and `W = Bᵀ A⁻¹ U` against the NO-SELF base `H₀'` (self term downdated),
1294    /// which the caller closes into the exact `log det(I_R + D Uᵀ H₀'⁻¹ U)` via
1295    /// [`streaming_cross_row_woodbury_log_det`].
1296    pub(crate) ibp_cross_row: Option<IbpCrossRowSource>,
1297}
1298
1299/// One chunk's contribution to the streaming cross-row IBP Woodbury (#1038).
1300///
1301/// The capacitance `C = I_R + D·M` with `M = Uᵀ H₀'⁻¹ U` is chunk-additive in
1302/// its two ingredients (`A = H₀'` is block-diagonal over rows, `U` is supported
1303/// per-row): `M = M0 + Wᵀ S⁻¹ W` where `M0 = Σ_i Uᵢᵀ Aᵢ⁻¹ Uᵢ` (`R×R`) and
1304/// `W = Σ_i Bᵢᵀ Aᵢ⁻¹ Uᵢ` (`k×R`) accumulate row-by-row, and `S` is the final
1305/// GLOBAL reduced Schur. This carries one chunk's `M0`/`W` plus the per-atom
1306/// `D`; the caller sums them and closes the capacitance after the chunk loop.
1307#[derive(Debug, Clone)]
1308pub struct StreamingWoodburyChunk {
1309    /// `Σ_{i∈chunk} Uᵢᵀ Aᵢ⁻¹ Uᵢ`, `R×R`.
1310    pub m0: Array2<f64>,
1311    /// `Σ_{i∈chunk} Bᵢᵀ Aᵢ⁻¹ Uᵢ`, `k×R` (`k` = reduced β border).
1312    pub w: Array2<f64>,
1313    /// Per-atom `D`-coefficients `d_k = w·s'_k`, length `R`.
1314    pub d: Array1<f64>,
1315}
1316
1317impl std::fmt::Debug for StreamingArrowSchur {
1318    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1319        f.debug_struct("StreamingArrowSchur")
1320            .field("n_rows", &self.n_rows)
1321            .field("d", &self.d)
1322            .field("k", &self.k)
1323            .field("chunk_size", &self.chunk_size)
1324            .finish_non_exhaustive()
1325    }
1326}
1327
1328impl StreamingArrowSchur {
1329    #[must_use]
1330    pub fn new(
1331        n_rows: usize,
1332        d: usize,
1333        row_dims: Arc<[usize]>,
1334        row_offsets: Arc<[usize]>,
1335        k: usize,
1336        hbb: Array2<f64>,
1337        gb: Array1<f64>,
1338        row_builder: StreamingArrowRowBuilder,
1339        chunk_size: usize,
1340    ) -> Self {
1341        assert_eq!(hbb.dim(), (k, k));
1342        assert_eq!(gb.len(), k);
1343        Self {
1344            n_rows,
1345            d,
1346            row_dims,
1347            row_offsets,
1348            k,
1349            chunk_size: chunk_size.max(1),
1350            s_acc: Array2::<f64>::zeros((k, k)),
1351            rhs_acc: Array1::<f64>::zeros(k),
1352            hbb,
1353            gb,
1354            row_builder,
1355            htbeta_matvec: None,
1356            htbeta_transpose_matvec: None,
1357            tolerate_ill_conditioning: false,
1358            ibp_cross_row_active: false,
1359            row_gauge_deflation: None,
1360            ibp_cross_row: None,
1361        }
1362    }
1363
1364    #[must_use]
1365    pub fn from_system(sys: &ArrowSchurSystem, chunk_size: usize) -> Self {
1366        // When a Kronecker / matrix-free htbeta_matvec is installed, the dense
1367        // row.htbeta slabs may be zero-sized.  Rather than materialize every
1368        // `(d × K)` slab (the very `(N·K)`-scale buffer the streaming path
1369        // exists to avoid), retain the procedural operator and probe it per row
1370        // inside `accumulate_chunk` / `back_substitute`.  The row builder then
1371        // only carries the small `H_tt` / `g_t` blocks.
1372        let htbeta_matvec = sys.htbeta_matvec.clone();
1373        let rows: Vec<ArrowRowBlock> = if htbeta_matvec.is_some() {
1374            sys.rows
1375                .iter()
1376                .map(|row| ArrowRowBlock {
1377                    htt: row.htt.clone(),
1378                    htbeta: Array2::<f64>::zeros((0, 0)),
1379                    gt: row.gt.clone(),
1380                })
1381                .collect()
1382        } else {
1383            sys.rows.clone()
1384        };
1385        let rows = Arc::new(rows);
1386        let row_builder: StreamingArrowRowBuilder = Arc::new(move |row| {
1387            rows.get(row)
1388                .cloned()
1389                .ok_or_else(|| ArrowSchurError::SchurFactorFailed {
1390                    reason: format!("streaming row {row} out of bounds"),
1391                })
1392        });
1393        // Materialize the dense β-block from the effective penalty operator so
1394        // the streaming accumulator stays correct when contributions live in a
1395        // structured `BetaPenaltyOp` (e.g. the SAE data-fit Gauss-Newton block,
1396        // represented as `G ⊗ I_p`) rather than the dense `hbb` accumulator.
1397        // When no `penalty_op` is installed this reduces to `hbb.clone()`.
1398        let hbb_dense = sys.effective_penalty_op().to_dense();
1399        let mut streaming = Self::new(
1400            sys.rows.len(),
1401            sys.d,
1402            Arc::clone(&sys.row_dims),
1403            Arc::clone(&sys.row_offsets),
1404            sys.k,
1405            hbb_dense,
1406            sys.gb.clone(),
1407            row_builder,
1408            chunk_size,
1409        );
1410        streaming.htbeta_matvec = htbeta_matvec;
1411        streaming.htbeta_transpose_matvec = sys.htbeta_transpose_matvec.clone();
1412        streaming.ibp_cross_row_active = sys.ibp_cross_row.is_some();
1413        streaming.ibp_cross_row = sys.ibp_cross_row.clone();
1414        // Carry the SAE evidence-path per-row gauge deflation so the streaming
1415        // per-row factor matches the dense `factor_blocks_for_system` exactly
1416        // (#1377): without it, a row with an intrinsic-dimension-flat `H_tt`
1417        // would deflate on the dense path but be refused / log-det-divergent on
1418        // the streaming path, breaking `streaming_logdet == full_logdet`.
1419        streaming.row_gauge_deflation = sys.row_gauge_deflation.clone();
1420        streaming
1421    }
1422
1423    /// Factor one streaming row's `H_tt^(i)`, applying the SAME per-row gauge /
1424    /// spectral deflation the dense [`factor_blocks_for_system`] path applies
1425    /// when this is the SAE manifold evidence path (an installed
1426    /// `row_gauge_deflation`). For every non-evidence caller this is exactly the
1427    /// generic [`factor_one_row`] (strict non-PD refusal), so PD blocks are
1428    /// bit-for-bit unchanged. Routing both the dense and streaming per-row
1429    /// factors through the identical recovery is what keeps their
1430    /// log-determinants identical (#1273/#1377).
1431    fn factor_row(
1432        &self,
1433        row: &ArrowRowBlock,
1434        ridge_t: f64,
1435        di: usize,
1436        row_idx: usize,
1437    ) -> Result<Array2<f64>, ArrowSchurError> {
1438        match self.row_gauge_deflation.as_ref() {
1439            Some(deflation) => factor_one_row_result(
1440                row,
1441                ridge_t,
1442                di,
1443                row_idx,
1444                self.tolerate_ill_conditioning,
1445                deflation.row(row_idx),
1446                // Evidence path: opt into spectral discovery of an
1447                // intrinsic-dimension-flat direction even when this row's
1448                // supplied gauge list is empty/non-spanning — matching the
1449                // `allow_spectral_deflation = true` the dense path passes.
1450                true,
1451            )
1452            .map(|result| result.factor),
1453            None => factor_one_row(row, ridge_t, di, row_idx, self.tolerate_ill_conditioning),
1454        }
1455    }
1456
1457    /// Build the `(di × k)` cross-block for `row_idx` on demand.
1458    ///
1459    /// When the sparse transpose adjoint is installed, probes it with `di`
1460    /// standard basis vectors — each yields a full `K`-row of `H_βt^(i)`
1461    /// (i.e. a row of the `(di × k)` block) via the sparse scatter, for
1462    /// `O(di · m_i · p)` total, far below the `O(K · m_i · p)` cost of probing
1463    /// the forward operator with `K` basis vectors when `di ≪ K`.
1464    ///
1465    /// When only the forward operator is installed (no adjoint), falls back to
1466    /// the `k`-column forward probe. Otherwise clones the dense `row.htbeta`
1467    /// slab.
1468    pub(crate) fn row_htbeta(&self, row_idx: usize, row: &ArrowRowBlock, di: usize) -> Array2<f64> {
1469        if let Some(op_t) = self.htbeta_transpose_matvec.as_ref() {
1470            // Probe the adjoint: for each latent index c, scatter e_c to obtain
1471            // row c of the (di × k) block.
1472            let mut mat = Array2::<f64>::zeros((di, self.k));
1473            let mut e_c = Array1::<f64>::zeros(di);
1474            let mut beta_row = Array1::<f64>::zeros(self.k);
1475            for c in 0..di {
1476                e_c.fill(0.0);
1477                e_c[c] = 1.0;
1478                beta_row.fill(0.0);
1479                op_t(row_idx, e_c.view(), &mut beta_row);
1480                for a in 0..self.k {
1481                    mat[[c, a]] = beta_row[a];
1482                }
1483            }
1484            return mat;
1485        }
1486        match self.htbeta_matvec.as_ref() {
1487            Some(op) => {
1488                let mut mat = Array2::<f64>::zeros((di, self.k));
1489                let mut e_a = Array1::<f64>::zeros(self.k);
1490                let mut col = Array1::<f64>::zeros(di);
1491                for a in 0..self.k {
1492                    e_a.fill(0.0);
1493                    e_a[a] = 1.0;
1494                    col.fill(0.0);
1495                    op(row_idx, e_a.view(), &mut col);
1496                    for c in 0..di {
1497                        mat[[c, a]] = col[c];
1498                    }
1499                }
1500                mat
1501            }
1502            None => row.htbeta.clone(),
1503        }
1504    }
1505
1506    /// Move out the accumulated reduced Schur block `s_acc` and reduced RHS
1507    /// `rhs_acc`, leaving fresh zero buffers in their place.
1508    ///
1509    /// The reduced contribution is `s_acc = hbb − Σ_i H_βt^(i)(H_tt^(i))⁻¹H_tβ^(i)`
1510    /// (the β-block `hbb` seeded by `reset_accumulator`, minus the per-row
1511    /// reduction summed by `accumulate_chunk`) and
1512    /// `rhs_acc = +Σ_i H_βt^(i)(H_tt^(i))⁻¹g_t^(i)`. Used by external online
1513    /// drivers (e.g. the SAE streaming joint fit) that accumulate the reduced
1514    /// system across re-materialized chunk systems.
1515    #[must_use]
1516    pub fn take_accumulators(&mut self) -> (Array2<f64>, Array1<f64>) {
1517        let s = std::mem::replace(&mut self.s_acc, Array2::<f64>::zeros((self.k, self.k)));
1518        let rhs = std::mem::replace(&mut self.rhs_acc, Array1::<f64>::zeros(self.k));
1519        (s, rhs)
1520    }
1521
1522    /// Reset the dense shared accumulator to `H_ββ + ridge_beta I`.
1523    pub fn reset_accumulator(&mut self, ridge_beta: f64) -> Result<(), ArrowSchurError> {
1524        if self.hbb.dim() != (self.k, self.k) {
1525            return Err(ArrowSchurError::SchurFactorFailed {
1526                reason: "streaming Arrow-Schur requires a dense beta block accumulator".to_string(),
1527            });
1528        }
1529        self.s_acc.assign(&self.hbb);
1530        for j in 0..self.k {
1531            self.s_acc[[j, j]] += ridge_beta;
1532            self.rhs_acc[j] = 0.0;
1533        }
1534        Ok(())
1535    }
1536
1537    /// Accumulate rows `[start, end)` into the reduced RHS and Schur block.
1538    pub fn accumulate_chunk(
1539        &mut self,
1540        start: usize,
1541        end: usize,
1542        ridge_t: f64,
1543        mode: ArrowSolverMode,
1544    ) -> Result<(), ArrowSchurError> {
1545        if start > end || end > self.n_rows {
1546            return Err(ArrowSchurError::SchurFactorFailed {
1547                reason: format!(
1548                    "streaming Arrow-Schur chunk [{start}, {end}) outside 0..{}",
1549                    self.n_rows
1550                ),
1551            });
1552        }
1553        let backend = CpuBatchedBlockSolver;
1554        let k = self.k;
1555        // Per-row factor + two block solves + a `k×k` GEMM subtract is the whole
1556        // assembly cost at the SAE LLM shape (#1017); the rows are independent so
1557        // the chunk fans across cores. Stay sequential for the handful-of-rows
1558        // non-SAE callers, or when already inside a rayon worker (the topology
1559        // race fans candidates with `run_topology_race_parallel`) to avoid
1560        // nested-rayon oversubscription — the same gate `schur_matvec` uses.
1561        let parallel = (end - start) >= SCHUR_MATVEC_PARALLEL_ROW_MIN
1562            && rayon::current_thread_index().is_none();
1563        if parallel {
1564            use rayon::prelude::*;
1565            const CHUNK: usize = 64;
1566            // Bind `&self` so the per-row body borrows only the immutable
1567            // streaming state. Each row contributes `+H_βt^(i)(H_tt^(i))⁻¹ g_t^(i)`
1568            // (length `k`) to the reduced RHS and `−H_βt^(i)(H_tt^(i))⁻¹ H_tβ^(i)`
1569            // (`k×k`) to the reduced Schur complement; both are written INTO a
1570            // worker-private `(rhs_part, s_part)` pair so the chunk partials fold
1571            // back in chunk order — bit-identical run-to-run regardless of thread
1572            // scheduling (the #1017 verification gate). The chunk fold reassociates
1573            // the row sum relative to serial, so the criterion ranking is stable
1574            // only up to that f64 margin; a near-tie winner inside the margin can
1575            // flip — not an exact no-move guarantee (#1211).
1576            let this: &Self = self;
1577            let row_into = |row_idx: usize,
1578                            rhs_part: &mut Array1<f64>,
1579                            s_part: &mut Array2<f64>|
1580             -> Result<(), ArrowSchurError> {
1581                let row = (this.row_builder)(row_idx)?;
1582                let di = row.htt.nrows();
1583                this.validate_row(row_idx, &row)?;
1584                let htbeta = this.row_htbeta(row_idx, &row, di);
1585                let factor = this.factor_row(&row, ridge_t, di, row_idx)?;
1586                let v = backend.solve_block_vector(factor.view(), row.gt.view());
1587                for c in 0..di {
1588                    let vc = v[c];
1589                    if vc == 0.0 {
1590                        continue;
1591                    }
1592                    for a in 0..k {
1593                        rhs_part[a] += htbeta[[c, a]] * vc;
1594                    }
1595                }
1596                match mode {
1597                    // InexactPCG differs from Direct only in how the *reduced*
1598                    // system is solved, not in how it is assembled, so it shares
1599                    // the dense Schur subtraction here (see the serial branch).
1600                    ArrowSolverMode::Direct | ArrowSolverMode::InexactPCG => {
1601                        let solved = backend.solve_block_matrix(factor.view(), htbeta.view());
1602                        backend.block_gemm_subtract(s_part, &htbeta, &solved);
1603                    }
1604                    ArrowSolverMode::SqrtBA => {
1605                        let whitened =
1606                            backend.sqrt_solve_block_matrix(factor.view(), htbeta.view());
1607                        backend.block_gemm_subtract(s_part, &whitened, &whitened);
1608                    }
1609                }
1610                Ok(())
1611            };
1612            let partials: Vec<(Array1<f64>, Array2<f64>)> = (start..end)
1613                .into_par_iter()
1614                .chunks(CHUNK)
1615                .map(|idxs| {
1616                    let mut rhs_part = Array1::<f64>::zeros(k);
1617                    let mut s_part = Array2::<f64>::zeros((k, k));
1618                    for i in idxs {
1619                        row_into(i, &mut rhs_part, &mut s_part)?;
1620                    }
1621                    Ok::<_, ArrowSchurError>((rhs_part, s_part))
1622                })
1623                .collect::<Result<Vec<_>, _>>()?;
1624            // Deterministic ordered reduction: fold chunk partials left-to-right.
1625            // `block_gemm_subtract` already subtracted into each `s_part`, so the
1626            // partials carry the negative Schur contribution; add them in.
1627            for (rhs_part, s_part) in &partials {
1628                for a in 0..k {
1629                    self.rhs_acc[a] += rhs_part[a];
1630                }
1631                self.s_acc += s_part;
1632            }
1633        } else {
1634            // Serial path accumulates DIRECTLY into the running `self.{rhs,s}_acc`
1635            // (which carry the `reset_accumulator` seed `H_ββ + ridge·I`), exactly
1636            // as before — bit-for-bit unchanged for the handful-of-rows callers.
1637            for row_idx in start..end {
1638                let row = (self.row_builder)(row_idx)?;
1639                let di = row.htt.nrows();
1640                self.validate_row(row_idx, &row)?;
1641                let htbeta = self.row_htbeta(row_idx, &row, di);
1642                let factor = self.factor_row(&row, ridge_t, di, row_idx)?;
1643                let v = backend.solve_block_vector(factor.view(), row.gt.view());
1644                for c in 0..di {
1645                    let vc = v[c];
1646                    if vc == 0.0 {
1647                        continue;
1648                    }
1649                    for a in 0..k {
1650                        self.rhs_acc[a] += htbeta[[c, a]] * vc;
1651                    }
1652                }
1653                match mode {
1654                    ArrowSolverMode::Direct | ArrowSolverMode::InexactPCG => {
1655                        let solved = backend.solve_block_matrix(factor.view(), htbeta.view());
1656                        backend.block_gemm_subtract(&mut self.s_acc, &htbeta, &solved);
1657                    }
1658                    ArrowSolverMode::SqrtBA => {
1659                        let whitened =
1660                            backend.sqrt_solve_block_matrix(factor.view(), htbeta.view());
1661                        backend.block_gemm_subtract(&mut self.s_acc, &whitened, &whitened);
1662                    }
1663                }
1664            }
1665        }
1666        Ok(())
1667    }
1668
1669    /// Compute the exact arrow Hessian log-determinant by accumulating the
1670    /// reduced Schur complement in row chunks, without retaining the full set
1671    /// of per-row Cholesky factors.
1672    ///
1673    /// This is the streaming analogue of [`ArrowFactorCache::arrow_log_det`]:
1674    ///
1675    /// ```text
1676    /// log|H| = Σ_i log|H_tt^(i)| + log|H_ββ - Σ_i H_βt^(i) H_tt^(i)⁻¹ H_tβ^(i)|.
1677    /// ```
1678    ///
1679    /// The same row builder and procedural `H_tβ` callbacks used by the
1680    /// streaming Newton solve are consumed here, so callers can score REML
1681    /// evidence without materialising the full `(N × q × K)` cross block or
1682    /// the full list of row factors.
1683    pub fn reduced_schur_and_log_det_tt(
1684        &mut self,
1685        ridge_t: f64,
1686        ridge_beta: f64,
1687        options: &ArrowSolveOptions,
1688    ) -> Result<(f64, Array2<f64>), ArrowSchurError> {
1689        if self.ibp_cross_row_active {
1690            return Err(ArrowSchurError::SchurFactorFailed {
1691                reason: "streaming arrow log-det cannot carry the exact cross-row IBP \
1692                         Woodbury correction (#1038): U's columns span all rows, so the \
1693                         rank-R capacitance needs the per-row factors retained — the very \
1694                         (N·K) residency the streaming path avoids. Route IBP-active fits \
1695                         through the dense resident ArrowFactorCache::arrow_log_det instead."
1696                    .to_string(),
1697            });
1698        }
1699        self.tolerate_ill_conditioning = options.tolerate_ill_conditioning;
1700        self.reset_accumulator(ridge_beta)?;
1701        let backend = CpuBatchedBlockSolver;
1702        let mut log_det_tt = 0.0_f64;
1703        for start in (0..self.n_rows).step_by(self.chunk_size) {
1704            let end = (start + self.chunk_size).min(self.n_rows);
1705            for row_idx in start..end {
1706                let row = (self.row_builder)(row_idx)?;
1707                let di = row.htt.nrows();
1708                self.validate_row(row_idx, &row)?;
1709                let htbeta = self.row_htbeta(row_idx, &row, di);
1710                let factor = self.factor_row(&row, ridge_t, di, row_idx)?;
1711                for axis in 0..di {
1712                    log_det_tt += 2.0 * factor[[axis, axis]].ln();
1713                }
1714                match options.mode {
1715                    ArrowSolverMode::Direct | ArrowSolverMode::InexactPCG => {
1716                        let solved = backend.solve_block_matrix(factor.view(), htbeta.view());
1717                        backend.block_gemm_subtract(&mut self.s_acc, &htbeta, &solved);
1718                    }
1719                    ArrowSolverMode::SqrtBA => {
1720                        let whitened =
1721                            backend.sqrt_solve_block_matrix(factor.view(), htbeta.view());
1722                        backend.block_gemm_subtract(&mut self.s_acc, &whitened, &whitened);
1723                    }
1724                }
1725            }
1726        }
1727        symmetrize_upper_from_lower(&mut self.s_acc);
1728        let schur = std::mem::replace(&mut self.s_acc, Array2::<f64>::zeros((self.k, self.k)));
1729        Ok((log_det_tt, schur))
1730    }
1731
1732    /// As [`Self::reduced_schur_and_log_det_tt`], but ALSO accumulates the exact
1733    /// cross-row IBP Woodbury building blocks (#1038) when this streaming system
1734    /// carried an [`IbpCrossRowSource`].
1735    ///
1736    /// Mirrors the dense `factor_blocks_for_system` + [`CrossRowWoodbury::build`]
1737    /// path exactly:
1738    ///
1739    /// * the per-row logit self term `Σ_k d_k·z'_ik²`
1740    ///   ([`IbpCrossRowSource::self_term_downdate`]) is subtracted from each
1741    ///   `H_tt^(i)` BEFORE factoring, so the factored base — and therefore the
1742    ///   returned `log_det_tt`/Schur — is the NO-SELF `H₀'` (the dense path
1743    ///   factors `H₀'` too, then layers the rank-`R` Woodbury);
1744    /// * for each row it forms `Aᵢ⁻¹ Uᵢ` (one block solve against the row's
1745    ///   `H₀'` factor, `Uᵢ` the J-weighted atom-column indicator) and adds the
1746    ///   row's contributions to `M0 = Σ Uᵢᵀ Aᵢ⁻¹ Uᵢ` (`R×R`) and
1747    ///   `W = Σ Bᵢᵀ Aᵢ⁻¹ Uᵢ` (`k×R`, `Bᵢ = H_tβ^(i)`).
1748    ///
1749    /// The caller sums `M0`/`W`/`log_det_tt`/Schur over chunks and closes the
1750    /// capacitance `M = M0 + Wᵀ S⁻¹ W`, `log det(I_R + D·M)` against the GLOBAL
1751    /// reduced Schur `S` via [`streaming_cross_row_woodbury_log_det`]. This is
1752    /// the exact `log det H_full = log det H₀' + log det(I_R + D Uᵀ H₀'⁻¹ U)`
1753    /// the dense [`ArrowFactorCache::arrow_log_det`] returns — the streaming and
1754    /// dense evidence then optimize the SAME REML objective (#1225).
1755    ///
1756    /// When no IBP source is present this delegates to
1757    /// [`Self::reduced_schur_and_log_det_tt`] and returns `None` woodbury, so
1758    /// every non-IBP (softmax / JumpReLU) caller is bit-for-bit unchanged.
1759    pub fn reduced_schur_log_det_tt_woodbury(
1760        &mut self,
1761        ridge_t: f64,
1762        ridge_beta: f64,
1763        options: &ArrowSolveOptions,
1764    ) -> Result<(f64, Array2<f64>, Option<StreamingWoodburyChunk>), ArrowSchurError> {
1765        let Some(source) = self.ibp_cross_row.clone() else {
1766            // Temporarily clear the refusal flag so the shared bare path runs;
1767            // it is `false` whenever the source is absent, so this is a no-op.
1768            let (log_det_tt, schur) =
1769                self.reduced_schur_and_log_det_tt(ridge_t, ridge_beta, options)?;
1770            return Ok((log_det_tt, schur, None));
1771        };
1772        let r = source.r;
1773        let total_len = self.row_offsets[self.n_rows];
1774        let down = source.self_term_downdate(total_len);
1775        // Group the sparse `U` entries `(global_t_index, atom, z'_ik)` by row as
1776        // `(local_slot, atom, z)`. `row_offsets` is strictly ascending, so the
1777        // owning row of a global index is located by binary search.
1778        let mut row_entries: Vec<Vec<(usize, usize, f64)>> = vec![Vec::new(); self.n_rows];
1779        for &(g, atom, z) in &source.entries {
1780            let i = match self.row_offsets.binary_search(&g) {
1781                Ok(idx) => idx,
1782                Err(idx) => idx - 1,
1783            };
1784            let slot = g - self.row_offsets[i];
1785            row_entries[i].push((slot, atom, z));
1786        }
1787        self.tolerate_ill_conditioning = options.tolerate_ill_conditioning;
1788        self.reset_accumulator(ridge_beta)?;
1789        let backend = CpuBatchedBlockSolver;
1790        let mut log_det_tt = 0.0_f64;
1791        let mut m0 = Array2::<f64>::zeros((r, r));
1792        let mut w = Array2::<f64>::zeros((self.k, r));
1793        for start in (0..self.n_rows).step_by(self.chunk_size) {
1794            let end = (start + self.chunk_size).min(self.n_rows);
1795            for row_idx in start..end {
1796                let mut row = (self.row_builder)(row_idx)?;
1797                let di = row.htt.nrows();
1798                self.validate_row(row_idx, &row)?;
1799                // Downdate the per-row logit self term so the factored base is
1800                // `H₀'` (the dense path downdates the SAME `down[base + j]`).
1801                let base = self.row_offsets[row_idx];
1802                for j in 0..di {
1803                    row.htt[[j, j]] -= down[base + j];
1804                }
1805                let htbeta = self.row_htbeta(row_idx, &row, di);
1806                let factor = self.factor_row(&row, ridge_t, di, row_idx)?;
1807                for axis in 0..di {
1808                    log_det_tt += 2.0 * factor[[axis, axis]].ln();
1809                }
1810                match options.mode {
1811                    ArrowSolverMode::Direct | ArrowSolverMode::InexactPCG => {
1812                        let solved = backend.solve_block_matrix(factor.view(), htbeta.view());
1813                        backend.block_gemm_subtract(&mut self.s_acc, &htbeta, &solved);
1814                    }
1815                    ArrowSolverMode::SqrtBA => {
1816                        let whitened =
1817                            backend.sqrt_solve_block_matrix(factor.view(), htbeta.view());
1818                        backend.block_gemm_subtract(&mut self.s_acc, &whitened, &whitened);
1819                    }
1820                }
1821                let entries = &row_entries[row_idx];
1822                if !entries.is_empty() {
1823                    // `Uᵢ`: `di × R`, column `atom` carries `z'_ik` at its logit
1824                    // slot. `Aᵢ⁻¹ Uᵢ` via the row's `H₀'` Cholesky.
1825                    let mut u_local = Array2::<f64>::zeros((di, r));
1826                    for &(slot, atom, z) in entries {
1827                        u_local[[slot, atom]] += z;
1828                    }
1829                    let ainv_u = backend.solve_block_matrix(factor.view(), u_local.view());
1830                    // `M0 += Uᵢᵀ Aᵢ⁻¹ Uᵢ` (`R×R`); `W += Bᵢᵀ Aᵢ⁻¹ Uᵢ` (`k×R`).
1831                    m0 += &u_local.t().dot(&ainv_u);
1832                    w += &htbeta.t().dot(&ainv_u);
1833                }
1834            }
1835        }
1836        symmetrize_upper_from_lower(&mut self.s_acc);
1837        let schur = std::mem::replace(&mut self.s_acc, Array2::<f64>::zeros((self.k, self.k)));
1838        Ok((
1839            log_det_tt,
1840            schur,
1841            Some(StreamingWoodburyChunk {
1842                m0,
1843                w,
1844                d: source.d.clone(),
1845            }),
1846        ))
1847    }
1848
1849    pub fn reduced_schur_log_det(
1850        schur: &Array2<f64>,
1851        options: &ArrowSolveOptions,
1852    ) -> Result<f64, ArrowSchurError> {
1853        let rhs = Array1::<f64>::zeros(schur.nrows());
1854        let trust_metric_weights = None;
1855        let (delta, schur_factor, diag) =
1856            solve_dense_reduced_system(schur, &rhs, options, trust_metric_weights)?;
1857        if delta.len() != schur.nrows() || diag.iterations != 0 {
1858            return Err(ArrowSchurError::SchurFactorFailed {
1859                reason: "streaming log-det reduced solve returned incoherent diagnostics"
1860                    .to_string(),
1861            });
1862        }
1863        let schur_factor = schur_factor.ok_or_else(|| ArrowSchurError::SchurFactorFailed {
1864            reason: "streaming log-det requires a dense reduced Schur factor".to_string(),
1865        })?;
1866        let mut log_det_schur = 0.0_f64;
1867        for axis in 0..schur_factor.nrows() {
1868            log_det_schur += 2.0 * schur_factor[[axis, axis]].ln();
1869        }
1870        Ok(log_det_schur)
1871    }
1872
1873    pub fn exact_arrow_log_det(
1874        &mut self,
1875        ridge_t: f64,
1876        ridge_beta: f64,
1877        options: &ArrowSolveOptions,
1878    ) -> Result<f64, ArrowSchurError> {
1879        let (log_det_tt, schur) =
1880            self.reduced_schur_and_log_det_tt(ridge_t, ridge_beta, options)?;
1881        Ok(log_det_tt + Self::reduced_schur_log_det(&schur, options)?)
1882    }
1883
1884    pub fn solve(
1885        &mut self,
1886        ridge_t: f64,
1887        ridge_beta: f64,
1888        options: &ArrowSolveOptions,
1889    ) -> Result<(Array1<f64>, Array1<f64>, Option<Array2<f64>>), ArrowSchurError> {
1890        if self.ibp_cross_row_active {
1891            return Err(ArrowSchurError::SchurFactorFailed {
1892                reason: "streaming arrow solve cannot carry the exact cross-row IBP \
1893                         Woodbury correction (#1038); route IBP-active fits through the \
1894                         dense resident solve_arrow_newton_step_with_options instead."
1895                    .to_string(),
1896            });
1897        }
1898        // Propagate the evidence/log-det ill-conditioning tolerance to the
1899        // per-row factor calls inside `accumulate_chunk` / `back_substitute`,
1900        // which take their stable public signatures. Direct callers of
1901        // `accumulate_chunk` keep the conservative default (`false`, full guard).
1902        self.tolerate_ill_conditioning = options.tolerate_ill_conditioning;
1903        self.reset_accumulator(ridge_beta)?;
1904        for start in (0..self.n_rows).step_by(self.chunk_size) {
1905            let end = (start + self.chunk_size).min(self.n_rows);
1906            self.accumulate_chunk(start, end, ridge_t, options.mode)?;
1907        }
1908        for j in 0..self.k {
1909            self.rhs_acc[j] -= self.gb[j];
1910        }
1911        symmetrize_upper_from_lower(&mut self.s_acc);
1912        let trust_metric_weights = None;
1913        let (delta_beta, schur_factor, _diag) =
1914            solve_dense_reduced_system(&self.s_acc, &self.rhs_acc, options, trust_metric_weights)?;
1915        let delta_t = self.back_substitute(ridge_t, delta_beta.view())?;
1916        Ok((delta_t, delta_beta, schur_factor))
1917    }
1918
1919    pub(crate) fn back_substitute(
1920        &self,
1921        ridge_t: f64,
1922        delta_beta: ArrayView1<'_, f64>,
1923    ) -> Result<Array1<f64>, ArrowSchurError> {
1924        let backend = CpuBatchedBlockSolver;
1925        // Total delta_t length = row_offsets[n_rows].
1926        let total_len = self.row_offsets[self.n_rows];
1927        let mut delta_t = Array1::<f64>::zeros(total_len);
1928        // Each row's back-solve `Δt_i = -(H_tt^(i))⁻¹(g_t^(i) + H_tβ^(i)Δβ)`
1929        // writes a DISJOINT segment `delta_t[row_base .. row_base+di]` — no
1930        // cross-row reduction, so this is embarrassingly parallel and the scatter
1931        // is bit-identical regardless of which thread produced each segment (the
1932        // #1017 verification gate). At the SAE LLM shape (`n` in the thousands)
1933        // the per-row factor + solve is the whole cost; below the threshold, or
1934        // when already inside a rayon worker (the topology race fans candidates
1935        // with `run_topology_race_parallel`), stay sequential to avoid
1936        // nested-rayon oversubscription — the same guard `schur_matvec` uses.
1937        let parallel =
1938            self.n_rows >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
1939        if parallel {
1940            use rayon::prelude::*;
1941            const CHUNK: usize = 64;
1942            // Per-row body: factor, form the RHS, solve, return `-(dt_i)`.
1943            let row_solve = |row_idx: usize| -> Result<(usize, Array1<f64>), ArrowSchurError> {
1944                let row = (self.row_builder)(row_idx)?;
1945                let di = row.htt.nrows();
1946                self.validate_row(row_idx, &row)?;
1947                let factor = self.factor_row(&row, ridge_t, di, row_idx)?;
1948                let mut htbeta_delta = Array1::<f64>::zeros(di);
1949                if let Some(op) = self.htbeta_matvec.as_ref() {
1950                    op(row_idx, delta_beta, &mut htbeta_delta);
1951                } else {
1952                    for c in 0..di {
1953                        let mut acc = 0.0_f64;
1954                        for a in 0..self.k {
1955                            acc += row.htbeta[[c, a]] * delta_beta[a];
1956                        }
1957                        htbeta_delta[c] = acc;
1958                    }
1959                }
1960                let mut rhs = Array1::<f64>::zeros(di);
1961                for c in 0..di {
1962                    rhs[c] = row.gt[c] + htbeta_delta[c];
1963                }
1964                let dt_i = backend.solve_block_vector(factor.view(), rhs.view());
1965                let mut neg = Array1::<f64>::zeros(di);
1966                for c in 0..di {
1967                    neg[c] = -dt_i[c];
1968                }
1969                Ok((self.row_offsets[row_idx], neg))
1970            };
1971            // Collect per-row segments under rayon, then scatter into the disjoint
1972            // slices. Errors are surfaced via `collect::<Result<…>>`.
1973            let segments: Vec<(usize, Array1<f64>)> = (0..self.n_rows)
1974                .into_par_iter()
1975                .chunks(CHUNK)
1976                .map(|idxs| {
1977                    idxs.into_iter()
1978                        .map(&row_solve)
1979                        .collect::<Result<Vec<_>, _>>()
1980                })
1981                .collect::<Result<Vec<_>, _>>()?
1982                .into_iter()
1983                .flatten()
1984                .collect();
1985            for (base, seg) in &segments {
1986                for (c, &v) in seg.iter().enumerate() {
1987                    delta_t[base + c] = v;
1988                }
1989            }
1990        } else {
1991            let mut rhs = Array1::<f64>::zeros(self.d);
1992            for start in (0..self.n_rows).step_by(self.chunk_size) {
1993                let end = (start + self.chunk_size).min(self.n_rows);
1994                for row_idx in start..end {
1995                    let row = (self.row_builder)(row_idx)?;
1996                    let di = row.htt.nrows();
1997                    self.validate_row(row_idx, &row)?;
1998                    let factor = self.factor_row(&row, ridge_t, di, row_idx)?;
1999                    // `H_tβ^(i) Δβ`: route through the procedural operator when
2000                    // present (no dense slab), else through the dense slab.
2001                    let mut htbeta_delta = Array1::<f64>::zeros(di);
2002                    if let Some(op) = self.htbeta_matvec.as_ref() {
2003                        op(row_idx, delta_beta, &mut htbeta_delta);
2004                    } else {
2005                        for c in 0..di {
2006                            let mut acc = 0.0_f64;
2007                            for a in 0..self.k {
2008                                acc += row.htbeta[[c, a]] * delta_beta[a];
2009                            }
2010                            htbeta_delta[c] = acc;
2011                        }
2012                    }
2013                    for c in 0..di {
2014                        rhs[c] = row.gt[c] + htbeta_delta[c];
2015                    }
2016                    let dt_i = backend.solve_block_vector(factor.view(), rhs.view());
2017                    let row_base = self.row_offsets[row_idx];
2018                    for c in 0..di {
2019                        delta_t[row_base + c] = -dt_i[c];
2020                    }
2021                }
2022            }
2023        }
2024        Ok(delta_t)
2025    }
2026
2027    pub(crate) fn validate_row(
2028        &self,
2029        row_idx: usize,
2030        row: &ArrowRowBlock,
2031    ) -> Result<(), ArrowSchurError> {
2032        let expected_di = if row_idx < self.row_dims.len() {
2033            self.row_dims[row_idx]
2034        } else {
2035            self.d
2036        };
2037        let actual_di = row.htt.nrows();
2038        if actual_di != expected_di || row.htt.ncols() != expected_di {
2039            return Err(ArrowSchurError::PerRowFactorFailed {
2040                row: row_idx,
2041                reason: format!(
2042                    "streaming row H_tt shape {:?} != ({expected_di}, {expected_di})",
2043                    row.htt.dim(),
2044                ),
2045            });
2046        }
2047        // The dense `H_tβ` slab is only validated when no procedural operator is
2048        // installed; with `htbeta_matvec` the slab is intentionally zero-sized
2049        // and the cross-block is probed in `row_htbeta`.
2050        if self.htbeta_matvec.is_none() && row.htbeta.dim() != (expected_di, self.k) {
2051            return Err(ArrowSchurError::SchurFactorFailed {
2052                reason: format!(
2053                    "streaming row H_tβ shape {:?} != ({expected_di}, {})",
2054                    row.htbeta.dim(),
2055                    self.k
2056                ),
2057            });
2058        }
2059        if row.gt.len() != expected_di {
2060            return Err(ArrowSchurError::PerRowFactorFailed {
2061                row: row_idx,
2062                reason: format!("streaming row g_t length {} != {expected_di}", row.gt.len()),
2063            });
2064        }
2065        Ok::<(), _>(())
2066    }
2067}
2068
2069pub(crate) fn apply_analytic_penalty<S, G, D, P, H>(
2070    penalty: &AnalyticPenaltyKind,
2071    target: ArrayView1<'_, f64>,
2072    rho_local: ArrayView1<'_, f64>,
2073    expected_target_len: usize,
2074    hvp_columns: usize,
2075    scatter_target: &mut S,
2076    mut grad_scatter: G,
2077    mut diag_scatter: D,
2078    seed_hvp_probe: P,
2079    mut hvp_column_scatter: H,
2080) where
2081    G: FnMut(&mut S, usize, f64),
2082    D: FnMut(&mut S, usize, f64),
2083    P: Fn(usize, &mut Array1<f64>),
2084    H: for<'a> FnMut(&mut S, usize, ArrayView1<'a, f64>),
2085{
2086    assert_eq!(target.len(), expected_target_len);
2087
2088    let grad = penalty.grad_target(target, rho_local);
2089    for index in 0..expected_target_len {
2090        grad_scatter(scatter_target, index, grad[index]);
2091    }
2092
2093    // The scattered curvature lands in the arrow-Schur `H_tt` / `H_ββ` blocks,
2094    // which are Cholesky-factored (with LM ridge escalation) as the Newton /
2095    // PIRLS curvature operator and must therefore stay PSD. Nonconvex
2096    // sparsifiers (log sparsity, JumpReLU) have an *indefinite* exact Hessian
2097    // that would destroy that positive-definiteness, so we scatter the PSD
2098    // majorizer here — never the exact `hessian_diag` / `hvp`. For convex
2099    // penalties the majorizer equals the exact Hessian (the trait default
2100    // delegates), so this is exact for them. Exact-derivative consumers (the
2101    // outer objective Hessian) use `hessian_diag` / `hvp` directly elsewhere.
2102    if let Some(diag) = penalty.psd_majorizer_diag(target, rho_local) {
2103        assert_eq!(diag.len(), expected_target_len);
2104        for index in 0..expected_target_len {
2105            diag_scatter(scatter_target, index, diag[index]);
2106        }
2107        return;
2108    }
2109
2110    let mut probe = Array1::<f64>::zeros(expected_target_len);
2111    for column in 0..hvp_columns {
2112        probe.fill(0.0);
2113        seed_hvp_probe(column, &mut probe);
2114        let hv = penalty.psd_majorizer_hvp(target, rho_local, probe.view());
2115        hvp_column_scatter(scatter_target, column, hv.view());
2116    }
2117}
2118
2119pub(crate) fn analytic_penalty_is_row_block_diagonal(penalty: &AnalyticPenaltyKind) -> bool {
2120    penalty.is_row_block_diagonal()
2121}
2122
2123/// Per-row + Schur Cholesky factor cache produced by
2124/// [`solve_arrow_newton_step_with_options`]. Consumed downstream by the IFT warm-start
2125/// predictor in `crate::persistent_warm_start`: when the outer
2126/// loop perturbs `(β, ρ)` by a small amount, the new Newton step can be
2127/// predicted by re-using these factors against a refreshed RHS, saving
2128/// the dominant `O(N d³ + K³)` factorization cost.
2129#[derive(Clone)]
2130pub struct ArrowFactorSlab {
2131    pub(crate) data: Arc<[f64]>,
2132    pub(crate) offsets: Arc<[usize]>,
2133    pub(crate) dims: Arc<[usize]>,
2134}
2135
2136impl ArrowFactorSlab {
2137    pub fn from_blocks(blocks: Vec<Array2<f64>>) -> Self {
2138        let mut data = Vec::new();
2139        let mut offsets = Vec::with_capacity(blocks.len() + 1);
2140        let mut dims = Vec::with_capacity(blocks.len());
2141        offsets.push(0);
2142        for block in blocks {
2143            let (rows, cols) = block.dim();
2144            assert_eq!(rows, cols, "ArrowFactorSlab stores square row factors");
2145            dims.push(rows);
2146            data.extend(block.iter().copied());
2147            offsets.push(data.len());
2148        }
2149        Self {
2150            data: data.into(),
2151            offsets: offsets.into(),
2152            dims: dims.into(),
2153        }
2154    }
2155
2156    pub fn len(&self) -> usize {
2157        self.dims.len()
2158    }
2159
2160    pub fn is_empty(&self) -> bool {
2161        self.dims.is_empty()
2162    }
2163
2164    pub fn factor(&self, row: usize) -> ArrayView2<'_, f64> {
2165        let dim = self.dims[row];
2166        let range = self.offsets[row]..self.offsets[row + 1];
2167        ArrayView2::from_shape((dim, dim), &self.data[range])
2168            .expect("ArrowFactorSlab row offset/dim invariant violated")
2169    }
2170
2171    pub fn iter(&self) -> impl Iterator<Item = ArrayView2<'_, f64>> + '_ {
2172        (0..self.len()).map(|row| self.factor(row))
2173    }
2174}
2175
2176impl std::fmt::Debug for ArrowFactorSlab {
2177    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2178        f.debug_struct("ArrowFactorSlab")
2179            .field("rows", &self.len())
2180            .field("values", &self.data.len())
2181            .finish()
2182    }
2183}
2184
2185#[derive(Clone)]
2186pub enum ArrowUndampedFactors {
2187    SameAsDamped,
2188    Owned(ArrowFactorSlab),
2189}
2190
2191impl std::fmt::Debug for ArrowUndampedFactors {
2192    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2193        match self {
2194            Self::SameAsDamped => f.write_str("SameAsDamped"),
2195            Self::Owned(factors) => f.debug_tuple("Owned").field(&factors.len()).finish(),
2196        }
2197    }
2198}
2199
2200/// Apply `H_tβ^(row) · x` for one row, writing into `out` (length `d`).
2201///
2202/// Sums the installed matrix-free operator, when present, and any correctly
2203/// shaped dense `row.htbeta` slab. This lets structured data-fit rows coexist
2204/// with dense analytic-penalty cross blocks on the same row.
2205pub(crate) fn sys_htbeta_apply_row(
2206    sys: &ArrowSchurSystem,
2207    row_idx: usize,
2208    row: &ArrowRowBlock,
2209    x: ArrayView1<'_, f64>,
2210    out: &mut Array1<f64>,
2211) {
2212    out.fill(0.0);
2213    if let Some(op) = sys.htbeta_matvec.as_ref() {
2214        op(row_idx, x, out);
2215    }
2216    if (sys.htbeta_dense_supplement || sys.htbeta_matvec.is_none())
2217        && row.htbeta.dim() == (out.len(), sys.k)
2218    {
2219        let di = row.htbeta.nrows();
2220        for c in 0..di {
2221            let mut acc = 0.0_f64;
2222            for a in 0..sys.k {
2223                acc += row.htbeta[[c, a]] * x[a];
2224            }
2225            out[c] += acc;
2226        }
2227    }
2228}
2229
2230/// Accumulate `H_βt^(row) · v` into `out` (length `k`).
2231///
2232/// `out[a] += Σ_c H_tβ^(row)[c, a] · v[c]`
2233///
2234/// Sums the installed matrix-free operator, when present, and any correctly
2235/// shaped dense `row.htbeta` slab.
2236pub(crate) fn sys_htbeta_accumulate_transpose(
2237    sys: &ArrowSchurSystem,
2238    row_idx: usize,
2239    row: &ArrowRowBlock,
2240    v: ArrayView1<'_, f64>,
2241    out: &mut Array1<f64>,
2242) {
2243    if let Some(op) = sys.htbeta_matvec.as_ref() {
2244        htbeta_probe_transpose(row_idx, op, v, out, v.len(), sys.k);
2245    }
2246    if (sys.htbeta_dense_supplement || sys.htbeta_matvec.is_none())
2247        && row.htbeta.dim() == (v.len(), sys.k)
2248    {
2249        let di = row.htbeta.nrows();
2250        for c in 0..di {
2251            let vc = v[c];
2252            if vc == 0.0 {
2253                continue;
2254            }
2255            for a in 0..sys.k {
2256                out[a] += row.htbeta[[c, a]] * vc;
2257            }
2258        }
2259    }
2260}
2261
2262/// Materialize the dense `(di, k)` cross-block for one row.
2263///
2264/// Materializes the sum of the installed matrix-free operator and any correctly
2265/// shaped dense slab on the row.
2266pub(crate) fn sys_htbeta_materialize_row(
2267    sys: &ArrowSchurSystem,
2268    row_idx: usize,
2269    row: &ArrowRowBlock,
2270) -> Result<Array2<f64>, ArrowSchurError> {
2271    let di = sys.row_dims[row_idx];
2272    let k = sys.k;
2273    let use_dense = sys.htbeta_dense_supplement || sys.htbeta_matvec.is_none();
2274    let mut mat = if use_dense && row.htbeta.dim() == (di, k) {
2275        row.htbeta.clone()
2276    } else {
2277        Array2::<f64>::zeros((di, k))
2278    };
2279    if let Some(op) = sys.htbeta_matvec.as_ref() {
2280        let mut e_a = Array1::<f64>::zeros(k);
2281        let mut col = Array1::<f64>::zeros(di);
2282        for a in 0..k {
2283            e_a.fill(0.0);
2284            e_a[a] = 1.0;
2285            col.fill(0.0);
2286            op(row_idx, e_a.view(), &mut col);
2287            for c in 0..di {
2288                mat[[c, a]] += col[c];
2289            }
2290        }
2291    } else if use_dense && row.htbeta.dim() != (di, k) {
2292        return Err(ArrowSchurError::SchurFactorFailed {
2293            reason: format!(
2294                "row {row_idx}: htbeta shape {:?} != ({di}, {k}) and no htbeta_matvec installed",
2295                row.htbeta.dim()
2296            ),
2297        });
2298    }
2299    Ok(mat)
2300}
2301
2302/// Probe each column of `H_tβ^(row)` by applying the operator to `e_a` and
2303/// dotting the result with `v`.  Accumulates into `out[a]` for all `a in 0..k`.
2304///
2305/// `out[a] += (H_tβ^(row) e_a) · v = H_βt^(row)[a, :] · v`
2306pub(crate) fn htbeta_probe_transpose(
2307    row: usize,
2308    op: &RowHtbetaMatvec,
2309    v: ArrayView1<'_, f64>,
2310    out: &mut Array1<f64>,
2311    d: usize,
2312    k: usize,
2313) {
2314    let mut e_a = Array1::<f64>::zeros(k);
2315    let mut col_a = Array1::<f64>::zeros(d);
2316    for a in 0..k {
2317        e_a.fill(0.0);
2318        e_a[a] = 1.0;
2319        col_a.fill(0.0);
2320        op(row, e_a.view(), &mut col_a);
2321        let mut acc = 0.0_f64;
2322        for c in 0..d {
2323            acc += col_a[c] * v[c];
2324        }
2325        out[a] += acc;
2326    }
2327}
2328
2329#[derive(Clone)]
2330pub enum ArrowHtbetaCache {
2331    Dense {
2332        blocks: Arc<[Array2<f64>]>,
2333        estimated_bytes: usize,
2334    },
2335    Matvec {
2336        op: RowHtbetaMatvec,
2337        estimated_bytes: usize,
2338    },
2339    Disabled {
2340        estimated_bytes: usize,
2341    },
2342}
2343
2344impl std::fmt::Debug for ArrowHtbetaCache {
2345    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2346        match self {
2347            Self::Dense {
2348                blocks,
2349                estimated_bytes,
2350            } => f
2351                .debug_struct("Dense")
2352                .field("blocks", &blocks.len())
2353                .field("estimated_bytes", estimated_bytes)
2354                .finish(),
2355            Self::Matvec {
2356                estimated_bytes, ..
2357            } => f
2358                .debug_struct("Matvec")
2359                .field("estimated_bytes", estimated_bytes)
2360                .finish(),
2361            Self::Disabled { estimated_bytes } => f
2362                .debug_struct("Disabled")
2363                .field("estimated_bytes", estimated_bytes)
2364                .finish(),
2365        }
2366    }
2367}
2368
2369impl ArrowHtbetaCache {
2370    pub(crate) fn is_available(&self) -> bool {
2371        !matches!(self, Self::Disabled { .. })
2372    }
2373
2374    pub(crate) fn apply_row(
2375        &self,
2376        row: usize,
2377        delta_beta: ArrayView1<'_, f64>,
2378        out: &mut Array1<f64>,
2379    ) -> bool {
2380        match self {
2381            Self::Dense { blocks, .. } => {
2382                let Some(block) = blocks.get(row) else {
2383                    return false;
2384                };
2385                if block.ncols() != delta_beta.len() || block.nrows() != out.len() {
2386                    return false;
2387                }
2388                for c in 0..block.nrows() {
2389                    let mut acc = 0.0_f64;
2390                    for a in 0..block.ncols() {
2391                        acc += block[[c, a]] * delta_beta[a];
2392                    }
2393                    out[c] = acc;
2394                }
2395                true
2396            }
2397            Self::Matvec { op, .. } => {
2398                op(row, delta_beta, out);
2399                true
2400            }
2401            Self::Disabled { .. } => false,
2402        }
2403    }
2404
2405    /// Apply the transpose: `out[a] += H_βt^(row)[a, c] · v[c]` for all `a`.
2406    ///
2407    /// `v` has length `d`; `out` has length `k`. Accumulates (does NOT zero
2408    /// `out` first) so callers can sum contributions across rows into a shared
2409    /// accumulator.  Returns `false` when the cache is `Disabled` and no
2410    /// `fallback_op` is provided.
2411    pub(crate) fn apply_row_transpose_accumulate(
2412        &self,
2413        row: usize,
2414        v: ArrayView1<'_, f64>,
2415        out: &mut Array1<f64>,
2416        d: usize,
2417        k: usize,
2418        fallback_op: Option<&RowHtbetaMatvec>,
2419    ) -> bool {
2420        match self {
2421            Self::Dense { blocks, .. } => {
2422                let Some(block) = blocks.get(row) else {
2423                    return false;
2424                };
2425                if block.nrows() != v.len() || block.ncols() != out.len() {
2426                    return false;
2427                }
2428                // H_βt^(i) · v: outer-loop c hoists v[c], inner-loop a is
2429                // contiguous in row-major (d, k) layout.
2430                for c in 0..block.nrows() {
2431                    let vc = v[c];
2432                    if vc == 0.0 {
2433                        continue;
2434                    }
2435                    for a in 0..block.ncols() {
2436                        out[a] += block[[c, a]] * vc;
2437                    }
2438                }
2439                true
2440            }
2441            Self::Matvec { op, .. } => {
2442                // Probe column-by-column: H_tβ^(row) e_a is column a.  dot(col_a, v)
2443                // is entry a of H_βt^(row) v.
2444                htbeta_probe_transpose(row, op, v, out, d, k);
2445                true
2446            }
2447            Self::Disabled { .. } => {
2448                // No cached block.  Use the caller-supplied fallback op if present.
2449                if let Some(op) = fallback_op {
2450                    htbeta_probe_transpose(row, op, v, out, d, k);
2451                    true
2452                } else {
2453                    false
2454                }
2455            }
2456        }
2457    }
2458}
2459
2460/// RAW per-row spectral data of a spectrally-deflated undamped evidence `H_tt`
2461/// block (see [`ArrowFactorCache::deflation_row_spectra`]).
2462///
2463/// `evecs` columns are the RAW symmetric eigenvectors `uₘ` of `H_tt`
2464/// (orthonormal; the deflated directions `vᵢ` are the subset whose eigenvalue
2465/// was pinned). `raw_evals[m]` is the RAW eigenvalue `λₘ` BEFORE the unit-pin /
2466/// floor-clamp. `cond_evals[m]` is the conditioned eigenvalue `λ̃ₘ` the factor
2467/// actually uses (`λ̃ = λ` for an unclamped kept direction, the positive `floor`
2468/// for a clamped kept direction, `1` for a deflated direction). Together they
2469/// give the Daleckii–Krein divided differences the outer-gradient deflation
2470/// correction needs.
2471#[derive(Debug, Clone)]
2472pub struct RowDeflationSpectrum {
2473    pub evecs: Array2<f64>,
2474    pub raw_evals: Array1<f64>,
2475    pub cond_evals: Array1<f64>,
2476}
2477
2478#[derive(Debug, Clone)]
2479pub struct ArrowFactorCache {
2480    /// Per-row lower-triangular Cholesky factors of `H_tt^(i) + ridge_t·I`.
2481    ///
2482    /// These are the *damped* factors used inside the Newton solve. The IFT
2483    /// predictor must NOT use them — see [`Self::htt_factors_undamped`].
2484    pub htt_factors: ArrowFactorSlab,
2485    /// Per-row lower-triangular Cholesky factors of the UNDAMPED
2486    /// `H_tt^(i)` (no `ridge_t` added).
2487    ///
2488    /// The IFT predictor formula
2489    /// `Δt_i = -(H_tt^(i))⁻¹ · (H_tβ^(i) Δβ + δg_t^(i))` is derived from
2490    /// `∂g_t/∂t = H_tt` at the stationary point, with no LM damping term.
2491    /// Reusing the damped factors would bias the predicted shift toward zero
2492    /// in proportion to `ridge_t`. We pay one extra `O(N d³)` Cholesky per
2493    /// Newton solve — the same complexity class as the Newton solve itself —
2494    /// to make the IFT exact.
2495    pub htt_factors_undamped: ArrowUndampedFactors,
2496    /// Lower-triangular Cholesky factor of the Schur complement when the
2497    /// selected BA mode formed/factored dense RCS. `None` for
2498    /// [`ArrowSolverMode::InexactPCG`], where Agarwal-style inexact LM avoids
2499    /// the dense `K × K` factor.
2500    pub schur_factor: Option<Array2<f64>>,
2501    /// Exact undamped joint-Hessian log-determinant produced by the dense
2502    /// factorization path. REML evidence consumes this directly so the Laplace
2503    /// normalizer cannot miss the log-det even when later cache consumers only
2504    /// need solves/traces.
2505    pub joint_hessian_log_det: Option<f64>,
2506    /// BA mode used to create this cache.
2507    pub solver_mode: ArrowSolverMode,
2508    /// Ridge values used to build the cached factors (recorded so the
2509    /// warm-start predictor knows whether the cache is still valid for a
2510    /// requested ridge level).
2511    pub ridge_t: f64,
2512    pub ridge_beta: f64,
2513    /// Per-row cross-block access for `H_tβ^(i) x`.
2514    ///
2515    /// Large caches retain a row matvec callback or disable β-coupled IFT
2516    /// prediction instead of cloning every dense `d × K` slab.
2517    pub htbeta: ArrowHtbetaCache,
2518    /// Maximum per-row latent dim (upper bound; matches `sys.d` at creation).
2519    pub d: usize,
2520    /// Per-row latent dims: `row_dims[i]` is the active dim for row `i`.
2521    pub row_dims: Arc<[usize]>,
2522    /// Flat-buffer row offsets for `delta_t` / IFT output vectors.
2523    /// `row_offsets[i]` is the start of row `i`; `row_offsets[n]` is the
2524    /// total length.
2525    pub row_offsets: Arc<[usize]>,
2526    /// β dimensionality `K`.
2527    pub k: usize,
2528    /// Geometry tag for the row-local factors and cross-blocks.
2529    pub manifold_mode_fingerprint: u64,
2530    /// Row-system tag for the cached per-row factors, cross-blocks, and
2531    /// shared-block diagonal used to build the Schur factor.
2532    pub row_hessian_fingerprint: u64,
2533    /// PCG instrumentation from the solve that produced this cache.
2534    ///
2535    /// Zero-valued (default) when the selected mode did not use PCG
2536    /// (i.e. `Direct` or `SqrtBA`).
2537    pub pcg_diagnostics: PcgDiagnostics,
2538    /// Number of row-local gauge directions stiffened in an undamped evidence
2539    /// factorization.
2540    ///
2541    /// Each direction is stiffened at UNIT stiffness `kappa = 1.0`, so it
2542    /// contributes `log(1) = 0` to the row-block logdet through the returned
2543    /// Cholesky factor: the gauge orbit is a criterion null direction and adds
2544    /// nothing to the Laplace normalizer (the quotient pseudo-determinant
2545    /// convention, cf. `PenaltyPseudologdet`). Zero theta/rho dependence.
2546    pub gauge_deflated_directions: usize,
2547    /// Per-row unit-norm directions `vᵢ` (in each row's `d`-dim latent block
2548    /// coordinates) that an undamped evidence factorization stiffened to UNIT
2549    /// stiffness `λ̃ = 1` (gauge or spectral deflation). Indexed by row; empty
2550    /// for every PD row factored without deflation, and empty overall on the
2551    /// non-deflating solver paths (streaming / cross-row-penalty CG / device).
2552    ///
2553    /// A deflated direction contributes `log(1) = 0` to the row-block log-det
2554    /// and is ρ/θ-INDEPENDENT, so its true contribution to `∂log|H|/∂ρ` is `0`.
2555    /// The analytic outer-gradient traces (`assignment_log_strength_hessian_trace`,
2556    /// `learnable_ibp_data_logdet_alpha_trace`, `logdet_theta_adjoint`) contract
2557    /// `∂H_raw/∂ρ` (the RAW, pre-deflation block derivative) against the DEFLATED
2558    /// inverse, which assigns `1/λ̃ = 1` to each `vᵢ` and therefore spuriously
2559    /// adds `½ vᵢᵀ (∂H_raw/∂ρ) vᵢ`. Those traces subtract this per-row term
2560    /// (kept-subspace restriction) using these directions; without them the
2561    /// REML outer ρ-gradient is biased by `+Σ_deflated ½ vᵢᵀ ∂H_raw/∂ρ vᵢ`.
2562    pub deflated_row_directions: Arc<[Vec<Array1<f64>>]>,
2563    /// Per-row RAW spectral decomposition of an undamped evidence `H_tt` block
2564    /// that underwent SPECTRAL deflation, surfaced so the outer ρ/θ-gradient
2565    /// traces can apply the EXACT deflation-map (Daleckii–Krein) derivative
2566    /// correction, not just the within-row kept-subspace term.
2567    ///
2568    /// The criterion VALUE re-deflates `H_tt` at every ρ, so its gradient is
2569    /// `tr(H_deflated⁻¹ DΦ[∂H_raw/∂ρ])`, where `Φ` is the spectral pin-to-unit
2570    /// map. By Daleckii–Krein `DΦ[Ȧ] = U (F ∘ UᵀȦU) Uᵀ` with the divided-
2571    /// difference matrix `F_{ml} = (λ̃ₘ − λ̃ₗ)/(λₘ − λₗ)` (raw `λ` in the
2572    /// denominator, conditioned `λ̃` in the numerator). The kept×kept block of
2573    /// `F` is `1` (the kept subspace contracts the raw derivative unchanged), the
2574    /// deflated×deflated block is `0`, and the kept(m)×deflated(i) block is
2575    /// `(λₘ − 1)/(λₘ − λᵢ)` — this last, ROTATION, term is what the per-row
2576    /// kept-subspace correction alone misses; it couples to the β-block through
2577    /// the Schur back-substitution carried in `(H⁻¹)_tt`.
2578    ///
2579    /// `Some(spectrum)` only for spectrally-deflated rows; `None` for PD rows,
2580    /// gauge-only deflation (ρ-independent structural null — within-row term
2581    /// suffices), and every non-SAE-evidence solver path (streaming / device /
2582    /// cross-row CG). Empty overall when no row deflated spectrally.
2583    pub deflation_row_spectra: Arc<[Option<RowDeflationSpectrum>]>,
2584    /// Exact cross-row IBP rank-`R` Woodbury correction (#1038), present iff the
2585    /// source system carried an [`IbpCrossRowSource`]. When set, the per-row
2586    /// factors above are of the NO-SELF base `H₀'` (self term `d_k·z'_ik²`
2587    /// downdated from each logit diagonal), and this carrier supplies the exact
2588    /// rank-`R` correction so the value/curvature solve
2589    /// ([`Self::full_inverse_apply`]), the evidence log-determinant
2590    /// ([`Self::arrow_log_det`]), and the θ/ρ-adjoint all describe the same
2591    /// `H_full = H₀' + U D Uᵀ`.
2592    pub cross_row_woodbury: Option<CrossRowWoodbury>,
2593}
2594
2595/// Materialized exact cross-row IBP Woodbury correction (#1038), built against
2596/// an [`ArrowFactorCache`] whose per-row factors are the NO-SELF base `H₀'`.
2597///
2598/// Holds `U` (the `delta_t_len × R` arrow-`t` factor, β-part implicitly zero),
2599/// `D = diag(d_k)`, the projected `M = UᵀH₀'⁻¹U`, the columns `H₀'⁻¹U`, and the
2600/// **LU factorization of the (generally non-symmetric, possibly indefinite)
2601/// capacitance** `C = I_R + D·M`. `d_k = w·s'_k` is not sign-definite, so the
2602/// capacitance is factored by a partial-pivot LU (exact for any sign); the same
2603/// factorization serves the log-determinant `log det C`, the inverse correction
2604/// `H_full⁻¹w = H₀'⁻¹w − H₀'⁻¹U·C⁻¹·(D Uᵀ H₀'⁻¹w)`, and the adjoint's
2605/// selected-inverse (`C⁻¹` and `M`). The full inverse, value/curvature solve,
2606/// log-determinant, and adjoint therefore all describe the SAME
2607/// `H_full = H₀' + U D Uᵀ`.
2608#[derive(Debug, Clone)]
2609pub struct CrossRowWoodbury {
2610    /// `U`: `delta_t_len × R`, column `k` supported on atom-`k` logit slots.
2611    pub u: Array2<f64>,
2612    /// `d_k`, length `R`.
2613    pub d: Array1<f64>,
2614    /// `H₀'⁻¹ U` (the `t`-block), `delta_t_len × R`.
2615    pub h0inv_u: Array2<f64>,
2616    /// `(H₀'⁻¹ U)` β-block, `K × R`. `U` has no β support, but the bordered
2617    /// solve couples the latent columns to `β` through the Schur complement, so
2618    /// this block is generally nonzero and the inverse correction must apply it
2619    /// to the `β` output too.
2620    pub h0inv_u_beta: Array2<f64>,
2621    /// `M = Uᵀ H₀'⁻¹ U`, `R × R` (symmetric). Retained for the θ/ρ-adjoint.
2622    pub m: Array2<f64>,
2623    /// Partial-pivot LU of the capacitance `C = I_R + D·M` (`lu` packs `L`/`U`,
2624    /// `piv` the row swaps), built by [`small_lu_factor`].
2625    pub capacitance_lu: SmallLu,
2626    /// The sparse `U` entries `(global_t_index, atom_k, z'_ik)` — retained so
2627    /// `Uᵀ·v` can be formed over the atom slots without re-deriving them.
2628    pub entries: Vec<(usize, usize, f64)>,
2629}
2630
2631/// Dense partial-pivot LU of a small square matrix. Used for the cross-row IBP
2632/// capacitance `C = I_R + D·M`, which is generally non-symmetric and possibly
2633/// indefinite (`d_k = w·s'_k` is not sign-definite), so a Cholesky/LDLᵀ is
2634/// unavailable. `R` is the atom count, so this is a cheap dense factorization.
2635#[derive(Debug, Clone)]
2636pub struct SmallLu {
2637    /// Packed `L` (unit lower, below diagonal) and `U` (upper, on/above
2638    /// diagonal), `R × R`, in the row-permuted order encoded by `piv`.
2639    pub(crate) lu: Array2<f64>,
2640    /// Row permutation: `piv[i]` is the original row now in position `i`.
2641    pub(crate) piv: Vec<usize>,
2642    /// Sign of the permutation (`±1`), folded into the determinant.
2643    pub(crate) perm_sign: f64,
2644}
2645
2646/// Partial-pivot LU factorization of a small dense square matrix `a` (`R × R`).
2647/// Returns `None` only when a pivot is exactly zero (singular `C`).
2648pub(crate) fn small_lu_factor(a: &Array2<f64>) -> Option<SmallLu> {
2649    let r = a.nrows();
2650    assert_eq!(a.ncols(), r, "small_lu_factor: non-square input");
2651    let mut lu = a.clone();
2652    let mut piv: Vec<usize> = (0..r).collect();
2653    let mut perm_sign = 1.0_f64;
2654    for col in 0..r {
2655        // Partial pivot: pick the largest-magnitude entry on/below the diagonal.
2656        let mut pivot_row = col;
2657        let mut pivot_mag = lu[[col, col]].abs();
2658        for row in (col + 1)..r {
2659            let mag = lu[[row, col]].abs();
2660            if mag > pivot_mag {
2661                pivot_mag = mag;
2662                pivot_row = row;
2663            }
2664        }
2665        // Reject not just an exactly-zero pivot, but any non-finite or
2666        // subnormal magnitude: dividing by a subnormal in the elimination /
2667        // back-solve produces `Inf`/`NaN` that would otherwise flow silently
2668        // into the Woodbury inverse and the evidence log-det (#1038). A
2669        // capacitance this degenerate is a desync the caller must surface
2670        // (→ `Ok(None)` cross-row-absent / `SchurFactorFailed`), not consume.
2671        if !pivot_mag.is_finite() || pivot_mag < f64::MIN_POSITIVE {
2672            return None;
2673        }
2674        if pivot_row != col {
2675            for c in 0..r {
2676                lu.swap((col, c), (pivot_row, c));
2677            }
2678            piv.swap(col, pivot_row);
2679            perm_sign = -perm_sign;
2680        }
2681        let pivot = lu[[col, col]];
2682        for row in (col + 1)..r {
2683            let factor = lu[[row, col]] / pivot;
2684            lu[[row, col]] = factor;
2685            for c in (col + 1)..r {
2686                let v = lu[[col, c]];
2687                lu[[row, c]] -= factor * v;
2688            }
2689        }
2690    }
2691    // Post-elimination invariant: every U diagonal is finite and not subnormal.
2692    // The per-column pivot guard above validates each diagonal as it is chosen,
2693    // but assert it explicitly so `SmallLu::solve` can divide by `lu[[i, i]]`
2694    // without a per-entry guard and so a `SmallLu` value can never carry a
2695    // factor that would silently emit `Inf`/`NaN` into the capacitance solve.
2696    for i in 0..r {
2697        let u = lu[[i, i]];
2698        if !u.is_finite() || u.abs() < f64::MIN_POSITIVE {
2699            return None;
2700        }
2701    }
2702    Some(SmallLu { lu, piv, perm_sign })
2703}
2704
2705impl SmallLu {
2706    pub(crate) fn dim(&self) -> usize {
2707        self.lu.nrows()
2708    }
2709
2710    /// `log|det|` and the determinant sign (`±1`).
2711    pub(crate) fn log_abs_det_and_sign(&self) -> (f64, f64) {
2712        let mut log_abs = 0.0_f64;
2713        let mut sign = self.perm_sign;
2714        for i in 0..self.dim() {
2715            let u = self.lu[[i, i]];
2716            log_abs += u.abs().ln();
2717            if u < 0.0 {
2718                sign = -sign;
2719            }
2720        }
2721        (log_abs, sign)
2722    }
2723
2724    /// Solve `C x = b` reusing the factorization (in place into a fresh vector).
2725    ///
2726    /// Returns `None` when the solve cannot produce a finite result — either a
2727    /// `U` diagonal is non-finite/subnormal (defensive: `small_lu_factor`
2728    /// already rejects such factors, but a future construction path might not)
2729    /// or the back-substitution overflows to `Inf`/`NaN` for an extreme RHS on
2730    /// an ill-conditioned (yet validly factored) capacitance. Surfacing `None`
2731    /// lets the Woodbury / evidence consumers fail loudly (#1038) instead of
2732    /// flowing a silent `NaN` into the log-det and outer gradient.
2733    pub(crate) fn solve(&self, b: &Array1<f64>) -> Option<Array1<f64>> {
2734        let r = self.dim();
2735        // Apply the row permutation: y = P b.
2736        let mut y = Array1::<f64>::zeros(r);
2737        for i in 0..r {
2738            y[i] = b[self.piv[i]];
2739        }
2740        // Forward solve L y' = P b (L unit-lower).
2741        for i in 0..r {
2742            let mut sum = y[i];
2743            for j in 0..i {
2744                sum -= self.lu[[i, j]] * y[j];
2745            }
2746            y[i] = sum;
2747        }
2748        // Back solve U x = y' (U upper, explicit diagonal).
2749        let mut x = Array1::<f64>::zeros(r);
2750        for i in (0..r).rev() {
2751            let mut sum = y[i];
2752            for j in (i + 1)..r {
2753                sum -= self.lu[[i, j]] * x[j];
2754            }
2755            let pivot = self.lu[[i, i]];
2756            if !pivot.is_finite() || pivot.abs() < f64::MIN_POSITIVE {
2757                return None;
2758            }
2759            x[i] = sum / pivot;
2760        }
2761        if x.iter().all(|v| v.is_finite()) {
2762            Some(x)
2763        } else {
2764            None
2765        }
2766    }
2767}
2768
2769/// Close the streaming cross-row IBP Woodbury (#1038) into its exact evidence
2770/// log-determinant correction.
2771///
2772/// Given the chunk-summed `M0 = Uᵀ A⁻¹ U` (`R×R`), `W = Bᵀ A⁻¹ U` (`k×R`), the
2773/// GLOBAL reduced Schur `schur` (`S`, `k×k`, PD), and the per-atom `D`, this
2774/// forms the EXACT projected `M = Uᵀ H₀'⁻¹ U = M0 + Wᵀ S⁻¹ W` (the bordered
2775/// inverse `t`-block `(H₀'⁻¹)_tt = A⁻¹ + A⁻¹ B S⁻¹ Bᵀ A⁻¹` contracted by `U`),
2776/// then the capacitance `C = I_R + D·M` and returns `log det C`.
2777///
2778/// This is byte-for-byte the same quantity, and the same sign convention, that
2779/// the dense [`CrossRowWoodbury::log_det`] returns (symmetrize `M`, scale rows
2780/// by `D`, partial-pivot LU, reject a non-positive determinant). Returns
2781/// `Ok(None)` when the capacitance is non-PD / singular — the recoverable
2782/// "cross-row IBP joint Hessian is non-PD at this ρ" infeasible-probe refusal,
2783/// NOT a silent wrong value.
2784pub fn streaming_cross_row_woodbury_log_det(
2785    schur: &Array2<f64>,
2786    m0: &Array2<f64>,
2787    w: &Array2<f64>,
2788    d: &Array1<f64>,
2789) -> Result<Option<f64>, ArrowSchurError> {
2790    let r = d.len();
2791    let factor =
2792        cholesky_lower(schur).map_err(|reason| ArrowSchurError::SchurFactorFailed { reason })?;
2793    // M = M0 + Wᵀ S⁻¹ W. With S⁻¹ symmetric, `(Wᵀ S⁻¹ W)[a, b] = (S⁻¹ w_a)·w_b`.
2794    let mut m = m0.clone();
2795    for a in 0..r {
2796        let w_a = w.column(a).to_owned();
2797        let sinv_w_a = cholesky_solve_vector(&factor, &w_a);
2798        for b in 0..r {
2799            m[[a, b]] += sinv_w_a.dot(&w.column(b));
2800        }
2801    }
2802    // Symmetrize to clear back-substitution rounding asymmetry (matches
2803    // `CrossRowWoodbury::build`).
2804    for a in 0..r {
2805        for b in (a + 1)..r {
2806            let avg = 0.5 * (m[[a, b]] + m[[b, a]]);
2807            m[[a, b]] = avg;
2808            m[[b, a]] = avg;
2809        }
2810    }
2811    // Capacitance C = I_R + D·M (row k scaled by d_k), factored by partial-pivot
2812    // LU (`d_k = w·s'_k` is not sign-definite, so C is generally non-symmetric /
2813    // indefinite — exactly the dense carrier's factorization).
2814    let mut c = Array2::<f64>::zeros((r, r));
2815    for a in 0..r {
2816        for b in 0..r {
2817            c[[a, b]] = d[a] * m[[a, b]];
2818        }
2819        c[[a, a]] += 1.0;
2820    }
2821    match small_lu_factor(&c) {
2822        Some(lu) => {
2823            let (log_abs, sign) = lu.log_abs_det_and_sign();
2824            Ok((sign > 0.0).then_some(log_abs))
2825        }
2826        // Exactly-singular capacitance: `H_full` det is 0, the Laplace log-det is
2827        // undefined — surface as the recoverable non-PD probe refusal.
2828        None => Ok(None),
2829    }
2830}
2831
2832impl CrossRowWoodbury {
2833    /// Build the exact rank-`R` cross-row Woodbury carrier from the IBP source
2834    /// and a cache whose per-row factors are the NO-SELF base `H₀'`.
2835    ///
2836    /// Computes `H₀'⁻¹U` (one [`ArrowFactorCache::full_inverse_apply`] back-solve
2837    /// per column, β-RHS zero — the `t`-block of the result is `H₀'⁻¹U`'s
2838    /// column), `M = UᵀH₀'⁻¹U`, and the LU of `C = I_R + D·M`. Returns `None`
2839    /// when the capacitance is exactly singular (the only un-representable case;
2840    /// the caller then proceeds with the bare `H₀'` cache and the cross-row term
2841    /// is absent — never silently inconsistent, since logdet/inverse/adjoint all
2842    /// key off the presence of this carrier).
2843    pub(crate) fn build(
2844        cache: &ArrowFactorCache,
2845        source: &IbpCrossRowSource,
2846    ) -> Result<Option<Self>, ArrowSchurError> {
2847        let r = source.r;
2848        let total_len = cache.delta_t_len();
2849        let u = source.dense_u(total_len);
2850        let d = source.d.clone();
2851        let zero_beta = Array1::<f64>::zeros(cache.k);
2852        // h0inv_u[:, k] = (H₀'⁻¹ U)_t for column k; h0inv_u_beta[:, k] its β-block.
2853        let mut h0inv_u = Array2::<f64>::zeros((total_len, r));
2854        let mut h0inv_u_beta = Array2::<f64>::zeros((cache.k, r));
2855        for k in 0..r {
2856            let col = u.column(k).to_owned();
2857            let (sol_t, sol_beta) = cache.full_inverse_apply(col.view(), zero_beta.view())?;
2858            for g in 0..total_len {
2859                h0inv_u[[g, k]] = sol_t[g];
2860            }
2861            for c in 0..cache.k {
2862                h0inv_u_beta[[c, k]] = sol_beta[c];
2863            }
2864        }
2865        // M = Uᵀ (H₀'⁻¹ U), symmetric R×R. U is sparse (atom-slot supported), so
2866        // contract over the listed entries.
2867        let mut m = Array2::<f64>::zeros((r, r));
2868        for a in 0..r {
2869            for b in 0..r {
2870                let mut acc = 0.0_f64;
2871                for &(g, k, z) in &source.entries {
2872                    if k == a {
2873                        acc += z * h0inv_u[[g, b]];
2874                    }
2875                }
2876                m[[a, b]] = acc;
2877            }
2878        }
2879        // Symmetrize M to clear back-substitution rounding asymmetry.
2880        for a in 0..r {
2881            for b in (a + 1)..r {
2882                let avg = 0.5 * (m[[a, b]] + m[[b, a]]);
2883                m[[a, b]] = avg;
2884                m[[b, a]] = avg;
2885            }
2886        }
2887        // Capacitance C = I_R + D·M (row k scaled by d_k).
2888        let mut c = Array2::<f64>::zeros((r, r));
2889        for a in 0..r {
2890            for b in 0..r {
2891                c[[a, b]] = d[a] * m[[a, b]];
2892            }
2893            c[[a, a]] += 1.0;
2894        }
2895        let Some(capacitance_lu) = small_lu_factor(&c) else {
2896            return Ok(None);
2897        };
2898        Ok(Some(Self {
2899            u,
2900            d,
2901            h0inv_u,
2902            h0inv_u_beta,
2903            m,
2904            capacitance_lu,
2905            entries: source.entries.clone(),
2906        }))
2907    }
2908
2909    /// The sparse `U` entry list `(global_t_index, atom_k, z'_ik)`.
2910    pub(crate) fn source_entries(&self) -> &[(usize, usize, f64)] {
2911        &self.entries
2912    }
2913
2914    /// `C⁻¹ D` as a dense `R × R` matrix (`R` capacitance solves; column `l` is
2915    /// `d_l · C⁻¹ e_l`). Shared by the inverse-diagonal correction and any
2916    /// adjoint trace that needs the selected inverse of the capacitance.
2917    ///
2918    /// Returns `None` when any capacitance solve fails to produce a finite
2919    /// result (#1038); the consumer must surface this as a loud failure rather
2920    /// than propagate a `NaN` into the evidence/gradient.
2921    pub fn capacitance_inv_times_d(&self) -> Option<Array2<f64>> {
2922        let r = self.d.len();
2923        let mut out = Array2::<f64>::zeros((r, r));
2924        let mut e_l = Array1::<f64>::zeros(r);
2925        for l in 0..r {
2926            e_l.fill(0.0);
2927            e_l[l] = 1.0;
2928            let col = self.capacitance_lu.solve(&e_l)?;
2929            for k in 0..r {
2930                out[[k, l]] = col[k] * self.d[l];
2931            }
2932        }
2933        Some(out)
2934    }
2935
2936    /// Subtract the rank-`R` Woodbury term from the latent inverse diagonal:
2937    /// `diag ← diag − diag(H₀'⁻¹U C⁻¹ D Uᵀ H₀'⁻¹)`. With `G = h0inv_u` and
2938    /// `(C⁻¹D) = capacitance_inv_times_d()`, the entry at global index `g` is
2939    /// `Σ_{k,l} G[g,k] (C⁻¹D)[k,l] G[g,l]`.
2940    pub(crate) fn subtract_inverse_diagonal(
2941        &self,
2942        diag: &mut Array1<f64>,
2943    ) -> Result<(), ArrowSchurError> {
2944        let r = self.d.len();
2945        let cinv_d =
2946            self.capacitance_inv_times_d()
2947                .ok_or_else(|| ArrowSchurError::SchurFactorFailed {
2948                    reason: "cross-row Woodbury capacitance solve produced a non-finite \
2949                         C⁻¹D for the inverse-diagonal correction (#1038): \
2950                         singular/ill-conditioned cross-row capacitance"
2951                        .to_string(),
2952                })?;
2953        let total_len = self.h0inv_u.nrows();
2954        for g in 0..total_len {
2955            let mut acc = 0.0_f64;
2956            for k in 0..r {
2957                let gk = self.h0inv_u[[g, k]];
2958                if gk == 0.0 {
2959                    continue;
2960                }
2961                for l in 0..r {
2962                    acc += gk * cinv_d[[k, l]] * self.h0inv_u[[g, l]];
2963                }
2964            }
2965            diag[g] -= acc;
2966        }
2967        Ok(())
2968    }
2969
2970    /// `log det(I_R + D·M)` (the matrix-determinant-lemma correction). Returns
2971    /// `None` when the capacitance LU has a negative determinant — i.e. the
2972    /// implied `H_full` is non-PD, which is a desync the evidence must reject
2973    /// loudly rather than return a complex/`NaN` log-det.
2974    pub fn log_det(&self) -> Option<f64> {
2975        let (log_abs, sign) = self.log_det_correction();
2976        if sign > 0.0 { Some(log_abs) } else { None }
2977    }
2978
2979    /// `log det(I_R + D·M)`: the exact additive correction
2980    /// `log det H_full − log det H₀'` (matrix-determinant lemma). For a genuine
2981    /// PD `H_full` this is real; the LU sign is returned for the caller to
2982    /// surface a non-PD capacitance as an error rather than a silent `NaN`.
2983    pub(crate) fn log_det_correction(&self) -> (f64, f64) {
2984        self.capacitance_lu.log_abs_det_and_sign()
2985    }
2986
2987    /// Apply the rank-`R` inverse correction in place on BOTH arrow blocks:
2988    /// `u ← u − (H₀'⁻¹U) · C⁻¹ · (D Uᵀ (H₀'⁻¹ rhs)_t)`, where `h0inv_rhs_t` is
2989    /// the `t`-block of `H₀'⁻¹ rhs` already computed by the base
2990    /// [`ArrowFactorCache::full_inverse_apply`]. Implements the Woodbury
2991    /// identity `H_full⁻¹ = H₀'⁻¹ − H₀'⁻¹U C⁻¹ D Uᵀ H₀'⁻¹`. `U` has no `β`
2992    /// support so `Uᵀ·v` reads only the `t`-block, but `H₀'⁻¹U` couples to `β`
2993    /// through the Schur complement, so the correction touches `u_beta` too.
2994    ///
2995    /// `entries` lets `Uᵀ·v` be formed over the sparse atom slots.
2996    pub(crate) fn apply_inverse_correction(
2997        &self,
2998        h0inv_rhs_t: ArrayView1<'_, f64>,
2999        entries: &[(usize, usize, f64)],
3000        u_t: &mut Array1<f64>,
3001        u_beta: &mut Array1<f64>,
3002    ) -> Result<(), ArrowSchurError> {
3003        let r = self.d.len();
3004        // p = D Uᵀ (H₀'⁻¹ rhs)_t.
3005        let mut p = Array1::<f64>::zeros(r);
3006        for &(g, k, z) in entries {
3007            p[k] += z * h0inv_rhs_t[g];
3008        }
3009        for k in 0..r {
3010            p[k] *= self.d[k];
3011        }
3012        // q = C⁻¹ p. A non-finite solve is a singular/ill-conditioned cross-row
3013        // capacitance (#1038): fail loudly rather than write `NaN` into the
3014        // Newton step / adjoint solve.
3015        let q =
3016            self.capacitance_lu
3017                .solve(&p)
3018                .ok_or_else(|| ArrowSchurError::SchurFactorFailed {
3019                    reason: "cross-row Woodbury capacitance solve produced a non-finite \
3020                         C⁻¹p for the inverse correction (#1038): \
3021                         singular/ill-conditioned cross-row capacitance"
3022                        .to_string(),
3023                })?;
3024        // u_t -= (H₀'⁻¹U)_t · q.
3025        for g in 0..u_t.len() {
3026            let mut acc = 0.0_f64;
3027            for k in 0..r {
3028                acc += self.h0inv_u[[g, k]] * q[k];
3029            }
3030            u_t[g] -= acc;
3031        }
3032        // u_beta -= (H₀'⁻¹U)_β · q.
3033        for c in 0..u_beta.len() {
3034            let mut acc = 0.0_f64;
3035            for k in 0..r {
3036                acc += self.h0inv_u_beta[[c, k]] * q[k];
3037            }
3038            u_beta[c] -= acc;
3039        }
3040        Ok(())
3041    }
3042
3043    /// Forward apply of the rank-`R` cross-row curvature on the latent (`t`)
3044    /// block: `out_t += U D Uᵀ v_t`. This is the EXACT forward of the same
3045    /// `H_full = H₀' + U D Uᵀ` whose inverse [`Self::apply_inverse_correction`]
3046    /// applies and whose log-determinant [`Self::log_det_correction`] reports, so
3047    /// a forward Hessian apply that adds this term stays consistent (operator and
3048    /// preconditioner describe the SAME operator). `U` has no `β` support, so the
3049    /// forward term touches only the `t` block; the `β` coupling is purely an
3050    /// inverse-side Schur artifact (`h0inv_u_beta`) and must NOT appear here.
3051    ///
3052    /// Formed over the sparse `entries` `(global_t_index, atom_k, z'_ik)` so it
3053    /// matches `Uᵀ·v` in `apply_inverse_correction` bit-for-bit:
3054    /// `p_k = d_k · Σ_{(g,k,z)} z·v_t[g]`, then `out_t[g] += Σ_{(g,k,z)} z·p_k`.
3055    pub fn apply_forward_t(&self, v_t: ArrayView1<'_, f64>, out_t: &mut Array1<f64>) {
3056        let r = self.d.len();
3057        // p = D Uᵀ v_t.
3058        let mut p = Array1::<f64>::zeros(r);
3059        for &(g, k, z) in &self.entries {
3060            p[k] += z * v_t[g];
3061        }
3062        for k in 0..r {
3063            p[k] *= self.d[k];
3064        }
3065        // out_t += U p.
3066        for &(g, k, z) in &self.entries {
3067            out_t[g] += z * p[k];
3068        }
3069    }
3070}
3071
3072#[derive(Debug, Clone, Copy, PartialEq)]
3073pub struct ArrowFactorMinPivot {
3074    pub min_row_pivot: Option<f64>,
3075    pub min_schur_pivot: Option<f64>,
3076    pub min_pivot: Option<f64>,
3077}
3078
3079impl ArrowFactorMinPivot {
3080    pub(crate) fn combine(row: Option<f64>, schur: Option<f64>) -> Self {
3081        let min_pivot = match (row, schur) {
3082            (Some(a), Some(b)) => Some(a.min(b)),
3083            (Some(a), None) => Some(a),
3084            (None, Some(b)) => Some(b),
3085            (None, None) => None,
3086        };
3087        Self {
3088            min_row_pivot: row,
3089            min_schur_pivot: schur,
3090            min_pivot,
3091        }
3092    }
3093}
3094
3095pub(crate) fn lower_cholesky_min_pivot(factor: ArrayView2<'_, f64>) -> Option<f64> {
3096    let width = factor.nrows().min(factor.ncols());
3097    let mut out = None;
3098    for idx in 0..width {
3099        let pivot = factor[[idx, idx]] * factor[[idx, idx]];
3100        out = Some(match out {
3101            Some(current) => f64::min(current, pivot),
3102            None => pivot,
3103        });
3104    }
3105    out
3106}
3107
3108pub(crate) fn lower_cholesky_max_pivot(factor: ArrayView2<'_, f64>) -> Option<f64> {
3109    let width = factor.nrows().min(factor.ncols());
3110    let mut out = None;
3111    for idx in 0..width {
3112        let pivot = factor[[idx, idx]] * factor[[idx, idx]];
3113        out = Some(match out {
3114            Some(current) => f64::max(current, pivot),
3115            None => pivot,
3116        });
3117    }
3118    out
3119}
3120
3121/// Smallest cached Cholesky pivot for row blocks and the dense Schur factor.
3122///
3123/// Pivots are returned as squared lower-factor diagonals, matching the Hessian
3124/// scale rather than the Cholesky-factor scale. In inexact PCG mode the dense
3125/// Schur factor is absent, so `min_schur_pivot` is `None`.
3126pub fn arrow_factor_min_pivot(cache: &ArrowFactorCache) -> ArrowFactorMinPivot {
3127    let mut min_row_pivot = None;
3128    for factor in cache.htt_factors.iter() {
3129        if let Some(pivot) = lower_cholesky_min_pivot(factor) {
3130            min_row_pivot = Some(match min_row_pivot {
3131                Some(current) => f64::min(current, pivot),
3132                None => pivot,
3133            });
3134        }
3135    }
3136    let min_schur_pivot = cache
3137        .schur_factor
3138        .as_ref()
3139        .and_then(|factor| lower_cholesky_min_pivot(factor.view()));
3140    ArrowFactorMinPivot::combine(min_row_pivot, min_schur_pivot)
3141}
3142
3143/// Largest cached Cholesky pivot across the row blocks and the dense Schur
3144/// factor (Hessian scale, i.e. squared lower-factor diagonal). This is the
3145/// diagonal magnitude scale a safe-SPD pivot floor is measured against: the
3146/// curvature-homotopy tracker (#1007) compares the min pivot against
3147/// `√eps · max(this, 1)`, the same floor the inner solver's
3148/// [`safe_spd_pivot_min`] uses. `None` only for an empty cache.
3149pub fn arrow_factor_max_pivot(cache: &ArrowFactorCache) -> Option<f64> {
3150    let mut max_pivot: Option<f64> = None;
3151    for factor in cache.htt_factors.iter() {
3152        if let Some(pivot) = lower_cholesky_max_pivot(factor) {
3153            max_pivot = Some(match max_pivot {
3154                Some(current) => f64::max(current, pivot),
3155                None => pivot,
3156            });
3157        }
3158    }
3159    if let Some(factor) = cache.schur_factor.as_ref()
3160        && let Some(pivot) = lower_cholesky_max_pivot(factor.view())
3161    {
3162        max_pivot = Some(match max_pivot {
3163            Some(current) => f64::max(current, pivot),
3164            None => pivot,
3165        });
3166    }
3167    max_pivot
3168}
3169
3170impl ArrowFactorCache {
3171    pub fn n_rows(&self) -> usize {
3172        self.htt_factors.len()
3173    }
3174
3175    pub fn htbeta_available(&self) -> bool {
3176        self.htbeta.is_available()
3177    }
3178
3179    /// Whether the Newton solve that produced this cache actually executed on
3180    /// the device: the device-resident Direct dense solve or the device-resident
3181    /// matrix-free SAE PCG (whose matvec runs in CUDA kernels). This does NOT
3182    /// include the injected host-procedural reduced-Schur matvec, whose
3183    /// arithmetic runs on the CPU even when a CUDA context was opened to build
3184    /// per-row factors (#1209) — that path sets
3185    /// `PcgDiagnostics::injected_host_procedural_matvec` instead. Read-only
3186    /// routing provenance: lets a fit result record device-vs-CPU as ground
3187    /// truth instead of inferring it from the runtime probe. Mirrors
3188    /// `PcgDiagnostics::used_device_arrow`.
3189    #[must_use]
3190    pub fn used_device(&self) -> bool {
3191        self.pcg_diagnostics.used_device_arrow
3192    }
3193
3194    pub fn undamped_factor(&self, row: usize) -> ArrayView2<'_, f64> {
3195        match &self.htt_factors_undamped {
3196            ArrowUndampedFactors::SameAsDamped => self.htt_factors.factor(row),
3197            ArrowUndampedFactors::Owned(factors) => factors.factor(row),
3198        }
3199    }
3200
3201    pub fn undamped_factor_count(&self) -> usize {
3202        match &self.htt_factors_undamped {
3203            ArrowUndampedFactors::SameAsDamped => self.htt_factors.len(),
3204            ArrowUndampedFactors::Owned(factors) => factors.len(),
3205        }
3206    }
3207
3208    pub fn undamped_factors_iter(&self) -> impl Iterator<Item = ArrayView2<'_, f64>> + '_ {
3209        (0..self.undamped_factor_count()).map(|row| self.undamped_factor(row))
3210    }
3211
3212    pub fn compute_undamped_arrow_log_det(&self) -> Option<f64> {
3213        if self.ridge_t != 0.0 || self.ridge_beta != 0.0 {
3214            return None;
3215        }
3216        // When the shared β block is empty (`k == 0`) the joint Hessian is
3217        // exactly the block diagonal of the per-row latent blocks: there is no
3218        // reduced Schur complement to form, so the dense Direct path leaves
3219        // `schur_factor = None` legitimately (not the InexactPCG "never formed
3220        // the dense K×K factor" case, which has `k > 0`). The log-det is then
3221        // the per-row sum with a zero (empty `0×0`) Schur contribution. Without
3222        // this the `schur_factor.as_ref()?` below would return `None` for a
3223        // β-profiled atom (#1132 euclidean K=4) and starve the REML Laplace
3224        // normaliser of the joint Hessian log-det it requires.
3225        let schur = match self.schur_factor.as_ref() {
3226            Some(schur) => Some(schur),
3227            None if self.k == 0 => None,
3228            None => return None,
3229        };
3230
3231        let mut acc = 0.0_f64;
3232        for l in self.undamped_factors_iter() {
3233            for i in 0..l.nrows() {
3234                let d = l[[i, i]];
3235                if d <= 0.0 || !d.is_finite() {
3236                    return None;
3237                }
3238                acc += 2.0 * d.ln();
3239            }
3240        }
3241        if let Some(schur) = schur {
3242            for i in 0..schur.nrows() {
3243                let d = schur[[i, i]];
3244                if d <= 0.0 || !d.is_finite() {
3245                    return None;
3246                }
3247                acc += 2.0 * d.ln();
3248            }
3249        }
3250        let woodbury_correction = self.cross_row_woodbury_log_det();
3251        if !woodbury_correction.is_finite() {
3252            return None;
3253        }
3254        Some(acc + woodbury_correction)
3255    }
3256
3257    /// The total length of `delta_t` / IFT output vectors for this cache.
3258    pub fn delta_t_len(&self) -> usize {
3259        self.row_offsets[self.n_rows()]
3260    }
3261
3262    pub fn apply_htbeta_row(
3263        &self,
3264        row: usize,
3265        delta_beta: ArrayView1<'_, f64>,
3266        out: &mut Array1<f64>,
3267    ) -> bool {
3268        let di = if row < self.row_dims.len() {
3269            self.row_dims[row]
3270        } else {
3271            self.d
3272        };
3273        if out.len() != di || delta_beta.len() != self.k {
3274            return false;
3275        }
3276        self.htbeta.apply_row(row, delta_beta, out)
3277    }
3278
3279    /// Accumulate `out[a] += H_βt^(row)[a, :] · v` for all `a in 0..k`.
3280    ///
3281    /// `v` has length `row_dims[row]`; `out` has length `k`. The caller must
3282    /// zero `out` before the first call if it needs a fresh result.  Returns
3283    /// `false` when the cache is `Disabled` and no `fallback_op` is provided;
3284    /// callers must treat the accumulator as invalid in that case.
3285    pub fn apply_htbeta_row_transpose(
3286        &self,
3287        row: usize,
3288        v: ArrayView1<'_, f64>,
3289        out: &mut Array1<f64>,
3290        fallback_op: Option<&RowHtbetaMatvec>,
3291    ) -> bool {
3292        let di = if row < self.row_dims.len() {
3293            self.row_dims[row]
3294        } else {
3295            self.d
3296        };
3297        if v.len() != di || out.len() != self.k {
3298            return false;
3299        }
3300        self.htbeta
3301            .apply_row_transpose_accumulate(row, v, out, di, self.k, fallback_op)
3302    }
3303
3304    /// Arrow log-determinant
3305    /// `log|H| = Σ_i log|H_{t_i t_i}| + log|Schur_β|`
3306    /// using the cached (damped) factors.
3307    ///
3308    /// Returns `(log_det_tt_sum, log_det_schur)` so the caller can decide
3309    /// what to do with the Schur piece (e.g. REML evidence wants both;
3310    /// some diagnostics want only the per-row sum). `None` for the Schur
3311    /// piece signals that the cache was produced by an InexactPCG solve
3312    /// and never formed/factored the dense `K × K` reduced system.
3313    ///
3314    /// The log-determinant of a Cholesky factor `L` of `M` is
3315    /// `2 Σ log L_ii`.
3316    pub fn arrow_log_det(&self) -> (f64, Option<f64>) {
3317        let mut log_det_tt = 0.0_f64;
3318        for l in self.htt_factors.iter() {
3319            for i in 0..l.nrows() {
3320                log_det_tt += l[[i, i]].ln();
3321            }
3322        }
3323        log_det_tt *= 2.0;
3324        let log_det_schur = self.schur_factor.as_ref().map(|l| {
3325            let mut s = 0.0_f64;
3326            for i in 0..l.nrows() {
3327                s += l[[i, i]].ln();
3328            }
3329            2.0 * s + self.cross_row_woodbury_log_det()
3330        });
3331        (log_det_tt, log_det_schur)
3332    }
3333
3334    /// The exact cross-row IBP correction `log det(I_R + D·M)` to add to the
3335    /// base `log det H₀'` (#1038). Zero when no [`CrossRowWoodbury`] is present,
3336    /// so non-IBP caches are unaffected. The determinant lemma gives
3337    /// `log det H_full = log det H₀' + log det(I_R + D Uᵀ H₀'⁻¹ U)`; this is the
3338    /// second term, the only piece beyond the bare arrow log-determinant.
3339    ///
3340    /// Panics-free: a negative capacitance determinant (non-PD `H_full`) yields
3341    /// `NaN` here so the evidence surfaces the desync rather than silently
3342    /// dropping the imaginary part. Callers that must reject it should check
3343    /// [`CrossRowWoodbury::log_det`] directly.
3344    pub fn cross_row_woodbury_log_det(&self) -> f64 {
3345        match self.cross_row_woodbury.as_ref() {
3346            Some(w) => w.log_det().unwrap_or(f64::NAN),
3347            None => 0.0,
3348        }
3349    }
3350
3351    /// Diagonal of the latent (`t`-block) of the *full* bordered-arrow
3352    /// inverse `(H⁻¹)_tt`, in `delta_t` layout (length [`Self::delta_t_len`]).
3353    ///
3354    /// For the bordered arrow Hessian
3355    /// `H = [[A, B], [Bᵀ, H_ββ]]` with `A = H_tt` (block-diagonal per row,
3356    /// `A_i = H_tt^(i)`) and `B = H_tβ`, the standard block-inverse identity
3357    /// gives the `t`-block
3358    /// `(H⁻¹)_tt = A⁻¹ + A⁻¹ B S⁻¹ Bᵀ A⁻¹`, where
3359    /// `S = H_ββ − Bᵀ A⁻¹ B` is the Schur complement on `β`. Because `A` is
3360    /// block-diagonal, the `(i, j)` diagonal entry of `(H⁻¹)_tt` is computed
3361    /// purely from row `i`'s factor and cross-block:
3362    ///
3363    /// ```text
3364    /// a    = A_i⁻¹ e_j                       (chol_solve on the per-row factor)
3365    /// [A_i⁻¹]_{jj} = a[j]
3366    /// w    = B_iᵀ a = H_βt^(i) a             (a K-vector)
3367    /// z    = S⁻¹ w                           (chol_solve on the Schur factor)
3368    /// diag = a[j] + w · z
3369    /// ```
3370    ///
3371    /// The UNDAMPED per-row factors ([`Self::undamped_factor`]) are used so
3372    /// the result is the inverse of the *true* `H_tt`, not the LM-damped
3373    /// `H_tt + ridge_t·I` — same rationale the IFT predictor docstring gives
3374    /// at the top of this struct.
3375    ///
3376    /// # Consuming the diagonal as a per-(atom, axis) trace
3377    ///
3378    /// `(H⁻¹)_tt` is the latent covariance block. The selected-inverse trace
3379    /// for a contiguous group of latent coordinates (e.g. one atom's rows, or
3380    /// one axis across rows) is simply the sum of the returned diagonal entries
3381    /// over those `row_offsets[i] + j` indices — no off-diagonal terms are
3382    /// needed for the trace `tr[(H⁻¹)_tt · D]` against a diagonal selector `D`.
3383    ///
3384    /// # Errors
3385    ///
3386    /// Returns [`ArrowSchurError::SchurFactorFailed`] when this cache has no
3387    /// dense Schur factor or no usable `H_βt` coupling — i.e. it was produced
3388    /// by an [`ArrowSolverMode::InexactPCG`] solve (no dense `K × K` factor) or
3389    /// by a `Disabled` `htbeta` cache. The selected-inverse block-trace is not
3390    /// yet supported for the matrix-free PCG mode; that branch needs a separate
3391    /// Lanczos/Hutchinson estimator.
3392    pub fn latent_block_inverse_diagonal(&self) -> Result<Array1<f64>, ArrowSchurError> {
3393        let Some(schur_factor) = self.schur_factor.as_ref() else {
3394            return Err(ArrowSchurError::SchurFactorFailed {
3395                reason: "latent_block_inverse_diagonal requires a dense Schur factor; \
3396                         the InexactPCG mode does not form one"
3397                    .to_string(),
3398            });
3399        };
3400        if !self.htbeta_available() {
3401            return Err(ArrowSchurError::SchurFactorFailed {
3402                reason: "latent_block_inverse_diagonal requires the H_tβ coupling, \
3403                         but this cache's htbeta is Disabled"
3404                    .to_string(),
3405            });
3406        }
3407        let n = self.undamped_factor_count();
3408        let total_len = self.delta_t_len();
3409        let mut out = Array1::<f64>::zeros(total_len);
3410        // Per-row scratch, sized to the max latent dim / K.
3411        let mut e_j = Array1::<f64>::zeros(self.d);
3412        let mut w = Array1::<f64>::zeros(self.k);
3413        for i in 0..n {
3414            let di = self.row_dims[i];
3415            let row_base = self.row_offsets[i];
3416            let factor = self.undamped_factor(i);
3417            for j in 0..di {
3418                // a = A_i⁻¹ e_j.
3419                for c in 0..di {
3420                    e_j[c] = 0.0;
3421                }
3422                e_j[j] = 1.0;
3423                let e_j_slice = e_j.slice(ndarray::s![..di]).to_owned();
3424                let a = cholesky_solve_vector(factor, &e_j_slice);
3425                // w = H_βt^(i) a (a K-vector); accumulator must start zeroed.
3426                w.fill(0.0);
3427                if !self.apply_htbeta_row_transpose(i, a.view(), &mut w, None) {
3428                    return Err(ArrowSchurError::SchurFactorFailed {
3429                        reason: format!(
3430                            "latent_block_inverse_diagonal: H_βt^({i}) apply failed \
3431                             (htbeta cache could not supply row {i})"
3432                        ),
3433                    });
3434                }
3435                // z = S⁻¹ w; correction = w · z.
3436                let z = cholesky_solve_vector(schur_factor, &w);
3437                let mut corr = 0.0_f64;
3438                for c in 0..self.k {
3439                    corr += w[c] * z[c];
3440                }
3441                out[row_base + j] = a[j] + corr;
3442            }
3443        }
3444        if let Some(woodbury) = self.cross_row_woodbury.as_ref() {
3445            // #1038: the factors above are `H₀'`, so `out` is diag((H₀'⁻¹)_tt).
3446            // The full inverse diagonal subtracts the rank-`R` Woodbury term
3447            // diag(H₀'⁻¹U C⁻¹ D Uᵀ H₀'⁻¹). With `G = h0inv_u = (H₀'⁻¹U)_t` and
3448            // (by symmetry of `H₀'⁻¹`) `(Uᵀ H₀'⁻¹)_t = Gᵀ`, the diagonal entry at
3449            // global index `g` is `Σ_{k,l} G[g,k] (C⁻¹D)[k,l] G[g,l]`. Form the
3450            // `R×R` matrix `C⁻¹D` once (R solves), then contract per row index.
3451            woodbury.subtract_inverse_diagonal(&mut out)?;
3452        }
3453        Ok(out)
3454    }
3455
3456    /// Solve the full bordered-arrow system `H·u = w` on the cached factor
3457    /// (#1006): `w` arrives in arrow layout — `w_t` flat per
3458    /// [`Self::delta_t_len`] / `row_offsets`, `w_beta` of length `K` — and the
3459    /// solution comes back in the same layout. Standard block elimination on
3460    /// the SAME factors whose log-determinant the evidence reports:
3461    ///
3462    /// ```text
3463    ///   y_i      = H_tt^(i)⁻¹ · w_t^(i)
3464    ///   r_β      = w_β − Σ_i H_βt^(i) · y_i
3465    ///   u_β      = Schur⁻¹ · r_β
3466    ///   u_t^(i)  = y_i − H_tt^(i)⁻¹ · (H_tβ^(i) · u_β)
3467    /// ```
3468    ///
3469    /// This is the IFT / adjoint back-solve the analytic outer ρ-gradient
3470    /// consumes: `u_j = H⁻¹ (∂g/∂ρ_j)` per outer coordinate and the
3471    /// `H⁻¹`-side of the third-order correction `−½·Γᵀ·H⁻¹·(∂g/∂ρ_j)`.
3472    /// Contract: the cache must be the ridge-0 Direct evidence factor
3473    /// (undamped per-row factors + dense Schur), so the solve is against the
3474    /// criterion's own `H` — never a damped surrogate (that would desync the
3475    /// gradient from the reported evidence).
3476    ///
3477    /// When the cache carries an exact cross-row IBP
3478    /// [`CrossRowWoodbury`] (#1038), the per-row factors are the NO-SELF base
3479    /// `H₀'` and this method layers the rank-`R` Woodbury correction so the
3480    /// returned solve is against the FULL `H_full = H₀' + U D Uᵀ` — the same
3481    /// operator whose log-determinant [`Self::arrow_log_det`] reports. The
3482    /// θ/ρ-adjoint that consumes this therefore sees the cross-row curvature.
3483    pub fn full_inverse_apply(
3484        &self,
3485        w_t: ArrayView1<'_, f64>,
3486        w_beta: ArrayView1<'_, f64>,
3487    ) -> Result<(Array1<f64>, Array1<f64>), ArrowSchurError> {
3488        let (mut u_t, mut u_beta) = self.full_inverse_apply_base(w_t, w_beta)?;
3489        if let Some(woodbury) = self.cross_row_woodbury.as_ref() {
3490            // u ← u − H₀'⁻¹U C⁻¹ D Uᵀ u. `u_t` is the `t`-block of `H₀'⁻¹ w`.
3491            let h0inv_w_t = u_t.clone();
3492            woodbury.apply_inverse_correction(
3493                h0inv_w_t.view(),
3494                woodbury.source_entries(),
3495                &mut u_t,
3496                &mut u_beta,
3497            )?;
3498        }
3499        Ok((u_t, u_beta))
3500    }
3501
3502    /// Bare bordered-arrow inverse solve against the cached per-row factors and
3503    /// Schur factor (the NO-SELF base `H₀'` when a cross-row Woodbury is
3504    /// present). [`Self::full_inverse_apply`] wraps this with the rank-`R`
3505    /// correction; [`CrossRowWoodbury::build`] calls this directly (before the
3506    /// carrier exists) to form `H₀'⁻¹U`.
3507    pub(crate) fn full_inverse_apply_base(
3508        &self,
3509        w_t: ArrayView1<'_, f64>,
3510        w_beta: ArrayView1<'_, f64>,
3511    ) -> Result<(Array1<f64>, Array1<f64>), ArrowSchurError> {
3512        let total_len = self.delta_t_len();
3513        if w_t.len() != total_len || w_beta.len() != self.k {
3514            return Err(ArrowSchurError::SchurFactorFailed {
3515                reason: format!(
3516                    "full_inverse_apply: rhs shapes (w_t={}, w_beta={}) != (delta_t_len={}, K={})",
3517                    w_t.len(),
3518                    w_beta.len(),
3519                    total_len,
3520                    self.k
3521                ),
3522            });
3523        }
3524        let n = self.undamped_factor_count();
3525        // Forward pass: y_i = H_tt^(i)⁻¹ w_t^(i), accumulating the border RHS.
3526        let mut y = Array1::<f64>::zeros(total_len);
3527        let mut r_beta = w_beta.to_owned();
3528        for i in 0..n {
3529            let di = self.row_dims[i];
3530            let base = self.row_offsets[i];
3531            let factor = self.undamped_factor(i);
3532            let w_row = w_t.slice(ndarray::s![base..base + di]).to_owned();
3533            let y_row = cholesky_solve_vector(factor, &w_row);
3534            if self.k > 0 {
3535                // r_β −= H_βt^(i) y_i: accumulate into a scratch then subtract,
3536                // because the helper ACCUMULATES (+=) into its output.
3537                let mut acc = Array1::<f64>::zeros(self.k);
3538                if !self.apply_htbeta_row_transpose(i, y_row.view(), &mut acc, None) {
3539                    return Err(ArrowSchurError::SchurFactorFailed {
3540                        reason: format!(
3541                            "full_inverse_apply: H_βt^({i}) apply failed (htbeta cache \
3542                             could not supply row {i}; htbeta={:?}, di={}, k={})",
3543                            self.htbeta,
3544                            self.row_dims.get(i).copied().unwrap_or(self.d),
3545                            self.k
3546                        ),
3547                    });
3548                }
3549                for c in 0..self.k {
3550                    r_beta[c] -= acc[c];
3551                }
3552            }
3553            for j in 0..di {
3554                y[base + j] = y_row[j];
3555            }
3556        }
3557        // Border solve + back-substitution.
3558        let u_beta = if self.k > 0 {
3559            self.schur_inverse_apply(r_beta.view())?
3560        } else {
3561            Array1::<f64>::zeros(0)
3562        };
3563        let mut u_t = y;
3564        if self.k > 0 {
3565            let mut cross = Array1::<f64>::zeros(self.d);
3566            for i in 0..n {
3567                let di = self.row_dims[i];
3568                let base = self.row_offsets[i];
3569                let mut cross_row = cross.slice_mut(ndarray::s![..di]);
3570                cross_row.fill(0.0);
3571                let mut cross_owned = cross_row.to_owned();
3572                if !self.apply_htbeta_row(i, u_beta.view(), &mut cross_owned) {
3573                    return Err(ArrowSchurError::SchurFactorFailed {
3574                        reason: format!(
3575                            "full_inverse_apply: H_tβ^({i}) apply failed (htbeta cache \
3576                             could not supply row {i})"
3577                        ),
3578                    });
3579                }
3580                let factor = self.undamped_factor(i);
3581                let corr = cholesky_solve_vector(factor, &cross_owned);
3582                for j in 0..di {
3583                    u_t[base + j] -= corr[j];
3584                }
3585            }
3586        }
3587        Ok((u_t, u_beta))
3588    }
3589
3590    /// Apply the β-block of the full inverse, `(H⁻¹)_ββ · rhs = S_β⁻¹ · rhs`,
3591    /// where `S_β` is the Schur complement on β whose Cholesky factor this
3592    /// cache holds in [`Self::schur_factor`].
3593    ///
3594    /// For the bordered arrow Hessian `H = [[A, B], [Bᵀ, H_ββ]]`, the
3595    /// β-block of `H⁻¹` is exactly the inverse of the Schur complement
3596    /// `S_β = H_ββ − Bᵀ A⁻¹ B`. One Cholesky back-substitution per call,
3597    /// reusing the cached factor; `rhs` and the returned vector both have
3598    /// length `K`.
3599    ///
3600    /// This is the general single-solve primitive for the β border. Callers
3601    /// that need a Schur-inverse trace `tr(S_β⁻¹ M)` against a structured
3602    /// penalty `M` (e.g. the SAE λ_smooth Fellner-Schall step, where
3603    /// `M = blockdiag_k(λ_k S_k ⊗ I_p)`) build it as
3604    /// `Σ_col e_colᵀ S_β⁻¹ M e_col` — apply this to each column of `M`
3605    /// (exploiting whatever sparsity `M` has) and read off `result[col]`.
3606    /// Keeping `M`'s layout on the caller side avoids coupling this solver
3607    /// to penalty-op types.
3608    ///
3609    /// # Errors
3610    ///
3611    /// Returns [`ArrowSchurError::SchurFactorFailed`] when this cache has no
3612    /// dense Schur factor (an [`ArrowSolverMode::InexactPCG`] solve) — the
3613    /// same not-yet-supported branch as [`Self::latent_block_inverse_diagonal`]
3614    /// — or when `rhs.len() != k`.
3615    ///
3616    /// Cross-row IBP (#1038) note: this is the β-block primitive of the
3617    /// factored base `S_β` (`H₀'` when a [`CrossRowWoodbury`] is present), used
3618    /// internally by [`Self::full_inverse_apply_base`]; it is deliberately NOT
3619    /// Woodbury-corrected so the base solve stays bare. The cross-row term has
3620    /// no `β` support, so `(H_full⁻¹)_ββ = S_β⁻¹` exactly on the directions any
3621    /// IBP ρ-trace contracts. A consumer needing the full `(H_full⁻¹)_ββ` for a
3622    /// β-supported direction should call [`Self::full_inverse_apply`] with a
3623    /// unit `β`-RHS (which applies the rank-`R` correction).
3624    pub fn schur_inverse_apply(
3625        &self,
3626        rhs: ArrayView1<'_, f64>,
3627    ) -> Result<Array1<f64>, ArrowSchurError> {
3628        let Some(schur_factor) = self.schur_factor.as_ref() else {
3629            return Err(ArrowSchurError::SchurFactorFailed {
3630                reason: "schur_inverse_apply requires a dense Schur factor; \
3631                         the InexactPCG mode does not form one"
3632                    .to_string(),
3633            });
3634        };
3635        if rhs.len() != self.k {
3636            return Err(ArrowSchurError::SchurFactorFailed {
3637                reason: format!(
3638                    "schur_inverse_apply: rhs length {} != K {}",
3639                    rhs.len(),
3640                    self.k
3641                ),
3642            });
3643        }
3644        let rhs_owned = rhs.to_owned();
3645        Ok(cholesky_solve_vector(schur_factor, &rhs_owned))
3646    }
3647
3648    /// Dense principal sub-block of the β-block of the full inverse,
3649    /// `(H⁻¹)_ββ[block, block] = S_β⁻¹[block, block]`, shape `(W, W)` with
3650    /// `W = block.len()`.
3651    ///
3652    /// For the bordered arrow Hessian `H = [[A, B], [Bᵀ, H_ββ]]`, the β-block
3653    /// of `H⁻¹` is exactly `S_β⁻¹` (the inverse of the Schur complement whose
3654    /// Cholesky factor this cache holds). This returns the contiguous
3655    /// `block × block` sub-block — e.g. one SAE atom's decoder coefficients via
3656    /// [`gam_terms::sae::manifold::SaeManifoldTerm::beta_block_offsets`] — by
3657    /// solving `S_β x = e_j` for each `j ∈ block` (reusing the cached factor)
3658    /// and gathering the `block` rows of each solution column. `W`
3659    /// back-substitutions of size `K`; the result is symmetrized to clear
3660    /// back-substitution rounding asymmetry. Up to a dispersion scale `φ`, this
3661    /// block is the joint posterior covariance `Cov(β_block)` of those
3662    /// coefficients with the latent coordinates already marginalized out (that
3663    /// is precisely what Schur-eliminating the per-row `t`-blocks does).
3664    ///
3665    /// Same dense-Schur requirement / error contract as
3666    /// [`Self::schur_inverse_apply`]; additionally errors when `block` runs past
3667    /// `K`.
3668    pub fn schur_inverse_block(
3669        &self,
3670        block: std::ops::Range<usize>,
3671    ) -> Result<Array2<f64>, ArrowSchurError> {
3672        let Some(schur_factor) = self.schur_factor.as_ref() else {
3673            return Err(ArrowSchurError::SchurFactorFailed {
3674                reason: "schur_inverse_block requires a dense Schur factor; \
3675                         the InexactPCG mode does not form one"
3676                    .to_string(),
3677            });
3678        };
3679        if block.end > self.k {
3680            return Err(ArrowSchurError::SchurFactorFailed {
3681                reason: format!(
3682                    "schur_inverse_block: block end {} exceeds K {}",
3683                    block.end, self.k
3684                ),
3685            });
3686        }
3687        let w = block.len();
3688        let mut out = Array2::<f64>::zeros((w, w));
3689        let mut e_j = Array1::<f64>::zeros(self.k);
3690        for (jc, j) in block.clone().enumerate() {
3691            e_j.fill(0.0);
3692            e_j[j] = 1.0;
3693            let col = cholesky_solve_vector(schur_factor, &e_j);
3694            for (ic, i) in block.clone().enumerate() {
3695                out[[ic, jc]] = col[i];
3696            }
3697        }
3698        // S_β⁻¹ is symmetric; symmetrize to clear back-substitution rounding.
3699        for ic in 0..w {
3700            for jc in (ic + 1)..w {
3701                let avg = 0.5 * (out[[ic, jc]] + out[[jc, ic]]);
3702                out[[ic, jc]] = avg;
3703                out[[jc, ic]] = avg;
3704            }
3705        }
3706        Ok(out)
3707    }
3708}