gam_solve/arrow_schur/system.rs
1//! The bordered arrow-Schur system itself: [`ArrowRowBlock`], the
2//! [`ArrowSchurSystem`] container and its assembly impl, cross-row latent
3//! penalties, the streaming builder, and the per-row factor caches.
4
5use super::*;
6
7/// Per-row block data for the arrow-Schur system.
8///
9/// `htt` holds the `d × d` Gauss–Newton block for row `i` (including any
10/// analytic-penalty contributions on that row); `htbeta` holds the
11/// `d × K` cross-block `H_tβ^(i)`; `gt` is the `d`-length latent
12/// gradient for row `i`.
13#[derive(Debug, Clone)]
14pub struct ArrowRowBlock {
15 /// `H_tt^(i)`, shape `(d, d)`.
16 pub htt: Array2<f64>,
17 /// `H_tβ^(i)`, shape `(d, K)`.
18 pub htbeta: Array2<f64>,
19 /// `g_t^(i)`, shape `(d,)`.
20 pub gt: Array1<f64>,
21}
22
23impl ArrowRowBlock {
24 /// Allocate one BA point-block row: local latent Hessian, point-camera
25 /// cross block, and point gradient.
26 pub fn new(d: usize, k: usize) -> Self {
27 Self::new_with_htbeta_cols(d, k)
28 }
29
30 /// Allocate one BA row whose dense cross-block slab has `htbeta_cols`
31 /// columns. This is used by matrix-free assemblers that keep the shared
32 /// beta tier at one width while dense row supplements live in another
33 /// coordinate system.
34 pub fn new_with_htbeta_cols(d: usize, htbeta_cols: usize) -> Self {
35 Self {
36 htt: Array2::<f64>::zeros((d, d)),
37 htbeta: Array2::<f64>::zeros((d, htbeta_cols)),
38 gt: Array1::<f64>::zeros(d),
39 }
40 }
41}
42
43/// Bordered (t, β) Newton system with arrow structure.
44///
45/// The β-block is held as a dense `K × K` Hessian `H_ββ` plus a `K`-length
46/// gradient `g_β` for direct BA modes. Large-scale inexact BA callers may
47/// additionally install a matrix-free `H_ββ x` operator and diagonal via
48/// [`ArrowSchurSystem::set_shared_beta_operator`]; the InexactPCG mode then
49/// avoids dense Schur formation/factorization.
50/// The t-block is a `Vec<ArrowRowBlock>` of length `N`.
51///
52/// Construction is the driver's responsibility: the driver
53///
54/// 1. evaluates Φ(t) and the radial jet `∂Φ/∂t` (the latter via
55/// [`gam_terms::latent::LatentCoordValues::design_gradient_wrt_t`]);
56/// 2. forms the working-weighted Gauss–Newton blocks
57/// `H_tt^(i) += (g_i β)(g_i β)^T`, `H_tβ^(i) += (g_i β) ⊗ Φ_i`,
58/// `H_ββ += Φ^T W Φ + Σ_k λ_k S_k`;
59/// 3. calls [`ArrowSchurSystem::add_analytic_penalty_contributions`] to
60/// fold row-block Psi-tier analytic penalties (`ARDPenalty`,
61/// `SparsityPenalty`) into `H_tt^(i)` and Beta-tier penalties into `H_ββ`;
62/// 4. calls [`ArrowSchurSystem::solve`] to obtain `(Δt, Δβ)`.
63pub struct ArrowSchurSystem {
64 /// Per-row latent block (length `N`, each row `d × d` / `d × K` / `d`).
65 pub rows: Vec<ArrowRowBlock>,
66 /// `H_ββ`, shape `(K, K)` for direct BA modes; empty when constructed
67 /// by [`ArrowSchurSystem::new_matrix_free_shared`] for PCG-only use.
68 pub hbb: Array2<f64>,
69 /// Optional matrix-free `H_ββ x` operator for large BA Schur PCG.
70 ///
71 /// Direct and Square-Root BA modes still require `hbb`; InexactPCG uses
72 /// this operator when present, avoiding dense shared-block storage for
73 /// SAE-manifold scale `K`.
74 pub hbb_matvec: Option<SharedBetaMatvec>,
75 /// Optional row-local matrix-free multiply for `H_tβ^(i) x`.
76 ///
77 /// When present, all inner-Schur paths route through this operator instead
78 /// of indexing the per-row `htbeta` dense slabs: `reduced_rhs_beta`,
79 /// `schur_matvec` (PCG hot loop), back-substitution,
80 /// `JacobiPreconditioner` construction, `build_dense_schur_direct`, and
81 /// `build_dense_schur_sqrt_ba` all call `sys_htbeta_apply_row` or
82 /// `sys_htbeta_materialize_row`. Factor caches retain the operator for
83 /// IFT/evidence consumers as before.
84 pub htbeta_matvec: Option<RowHtbetaMatvec>,
85 /// Optional row-local matrix-free transpose multiply `out += H_βt^(i) · v`.
86 ///
87 /// The sparse adjoint of [`Self::htbeta_matvec`]. When present, the
88 /// reduced-Schur matvec applies `H_βt^(i)` directly (sparse `scatter`)
89 /// instead of probing the forward operator against `K` basis vectors. This
90 /// is the per-row sparse apply that lifts the `O(K)` column-probe in the
91 /// GPU PCG and streaming Schur paths to `O(m_i · p)` per row. Installed in
92 /// lock-step with `htbeta_matvec` by [`Self::set_row_htbeta_operator`].
93 pub htbeta_transpose_matvec: Option<RowHtbetaTransposeMatvec>,
94 /// Whether `rows[*].htbeta` contains a dense contribution that must be added
95 /// on top of the matrix-free row operator.
96 pub htbeta_dense_supplement: bool,
97 /// Optional diagonal of the matrix-free shared block, used by the
98 /// Schur-Jacobi preconditioner in the Agarwal-style PCG path.
99 pub hbb_diag: Option<Array1<f64>>,
100 /// `g_β`, shape `(K,)`.
101 pub gb: Array1<f64>,
102 /// Maximum per-row latent dimensionality across all rows.
103 ///
104 /// For homogeneous systems (all rows have the same dim) this equals the
105 /// common per-row `d`. For heterogeneous systems (e.g. sparse SAE rows
106 /// where JumpReLU / TopK / sparsemax active sets vary per observation)
107 /// this is `max_i row_dims[i]`. Per-row code should use
108 /// `row.htt.nrows()` or `row_dims[i]`; `d` is an upper bound for
109 /// scratch-buffer sizing.
110 pub d: usize,
111 /// Per-row latent dimensionality: `row_dims[i] == rows[i].htt.nrows()`.
112 ///
113 /// For homogeneous systems `row_dims[i] == d` for all `i`.
114 pub row_dims: Arc<[usize]>,
115 /// Flat-buffer row offsets for the `delta_t` vector produced by
116 /// [`Self::solve`] / [`solve_arrow_newton_step_core`].
117 ///
118 /// `row_offsets[i]` is the start index for row `i`'s slice in `delta_t`;
119 /// `row_offsets[n]` is the total `delta_t` length. For homogeneous
120 /// systems `row_offsets[i] == i * d`.
121 pub row_offsets: Arc<[usize]>,
122 /// β dimensionality `K`.
123 pub k: usize,
124 /// Geometry tag for the row-local latent blocks after optional
125 /// Riemannian projection. Euclidean/no-op geometry uses the sentinel.
126 pub manifold_mode_fingerprint: u64,
127 /// Structural/value tag for row-local Hessian factors and their Schur
128 /// inputs. Stale caches must be rejected when row-dependent Hessian
129 /// penalties or cross-blocks change.
130 pub row_hessian_fingerprint: u64,
131 /// Registry-side tag for row-dependent analytic-penalty Hessian inputs.
132 /// Combined with the materialized row blocks in
133 /// [`Self::current_row_hessian_fingerprint`].
134 pub analytic_row_hessian_fingerprint: u64,
135 /// Term-block column ranges for the block-Jacobi Schur preconditioner.
136 ///
137 /// Each entry `r` means that indices `r.start..r.end` belong to one
138 /// coefficient block (a GAM term or a custom parameter family from
139 /// `ParameterBlockSpec`). When populated via
140 /// [`Self::set_block_offsets`], the Jacobi preconditioner inverts the
141 /// full `b × b` Schur block for each term instead of only its diagonal.
142 ///
143 /// The default (empty slice) causes `JacobiPreconditioner` to fall back
144 /// to pure scalar diagonal inversion, preserving the pre-#283 behaviour.
145 pub block_offsets: Arc<[Range<usize>]>,
146 /// Optional matrix-free penalty-side `H_ββ` operator (#296).
147 ///
148 /// When set, all hot paths (`schur_matvec`, `build_dense_schur_*`,
149 /// `JacobiPreconditioner`, quadratic-form reduction) route through this
150 /// operator instead of the dense `hbb` accumulator, enabling
151 /// `BlockPenaltyOp` / `KroneckerPenaltyOp` to skip the `O(K²)` dense
152 /// materialisation for structured smoothness penalties.
153 ///
154 /// When `None`, those paths fall back to wrapping `hbb` in a transient
155 /// `DensePenaltyOp` — identical observable behaviour, no new allocation
156 /// hot-path cost for callers that have not opted in.
157 pub penalty_op: Option<Arc<dyn BetaPenaltyOp>>,
158 /// Device-uploadable SAE Kronecker data for CUDA-resident reduced PCG.
159 ///
160 /// The generic matrix-free closures remain the authoritative CPU path. This
161 /// descriptor is installed only when SAE assembly has a matching CUDA sparse
162 /// representation for both `H_tβ` and `H_ββ`.
163 pub device_sae_pcg: Option<Arc<DeviceSaePcgData>>,
164 /// Registered Psi-tier analytic penalties whose Hessian couples *distinct*
165 /// latent rows (non-row-block-diagonal), captured by
166 /// [`Self::add_analytic_penalty_contributions`].
167 ///
168 /// These penalties (`TotalVariationPenalty`, `SheafConsistencyPenalty`,
169 /// block-orthogonality, …) produce off-row Hessian blocks `∂²P/∂t_i∂t_j`
170 /// (`i ≠ j`) that the arrow elimination — which assumes each `H_tt^(i)` is
171 /// independent of every other row — cannot represent. Their *gradient* is
172 /// still folded into `g_t` exactly like every other Psi penalty; only their
173 /// curvature is held here, applied during the solve as a full-latent
174 /// Hessian-vector product `P_cross · Δt` against the penalty's
175 /// `psd_majorizer_hvp`. When this vector is non-empty,
176 /// [`solve_arrow_newton_step_artifacts`] auto-selects the matrix-free
177 /// full-system PCG path (arrow block-diagonal inverse as preconditioner)
178 /// instead of the exact one-shot Schur elimination. When empty, the system
179 /// is purely row-block-diagonal and the exact Schur path is unchanged.
180 pub cross_row_penalties: Vec<CrossRowLatentPenalty>,
181 /// Optional row-local gauge directions for evidence-only Faddeev-Popov
182 /// deflation of an otherwise non-PD `H_tt` row block.
183 ///
184 /// These vectors live in each row's actual chart block, so compact SAE rows
185 /// and dense rows share the same factorization path. Ordinary Newton solves
186 /// ignore them; only undamped evidence factors with
187 /// `tolerate_ill_conditioning` set may stiffen a gauge-explained row
188 /// direction.
189 pub row_gauge_deflation: Option<ArrowRowGaugeDeflation>,
190 /// Optional exact cross-row IBP low-rank source (#1038). When set, the
191 /// factorization downdates the per-row logit-slot self term and layers the
192 /// exact rank-`R` Woodbury correction onto the evidence cache (value,
193 /// log-determinant, and θ/ρ-adjoint together). `None` for all non-IBP
194 /// systems — the row-block-diagonal arrow path is then unchanged.
195 pub ibp_cross_row: Option<IbpCrossRowSource>,
196}
197
198impl Clone for ArrowSchurSystem {
199 fn clone(&self) -> Self {
200 Self {
201 rows: self.rows.clone(),
202 hbb: self.hbb.clone(),
203 hbb_matvec: self.hbb_matvec.clone(),
204 htbeta_matvec: self.htbeta_matvec.clone(),
205 htbeta_transpose_matvec: self.htbeta_transpose_matvec.clone(),
206 htbeta_dense_supplement: self.htbeta_dense_supplement,
207 hbb_diag: self.hbb_diag.clone(),
208 gb: self.gb.clone(),
209 d: self.d,
210 row_dims: Arc::clone(&self.row_dims),
211 row_offsets: Arc::clone(&self.row_offsets),
212 k: self.k,
213 manifold_mode_fingerprint: self.manifold_mode_fingerprint,
214 row_hessian_fingerprint: self.row_hessian_fingerprint,
215 analytic_row_hessian_fingerprint: self.analytic_row_hessian_fingerprint,
216 block_offsets: Arc::clone(&self.block_offsets),
217 penalty_op: self.penalty_op.clone(),
218 device_sae_pcg: self.device_sae_pcg.clone(),
219 cross_row_penalties: self.cross_row_penalties.clone(),
220 row_gauge_deflation: self.row_gauge_deflation.clone(),
221 ibp_cross_row: self.ibp_cross_row.clone(),
222 }
223 }
224}
225
226/// A captured cross-row Psi-tier analytic penalty: the penalty kind plus the
227/// global-ρ slice (`rho_local`) it was registered with.
228///
229/// Holds an owned copy of the local ρ-axes so the penalty's
230/// [`AnalyticPenaltyKind::psd_majorizer_hvp`] can be evaluated during the
231/// matrix-free full-system solve without re-deriving the ρ layout. The penalty
232/// itself is an `Arc`-backed clone (cheap), so capturing it does not copy the
233/// penalty payload.
234#[derive(Clone)]
235pub struct CrossRowLatentPenalty {
236 /// The non-row-block-diagonal Psi penalty (e.g. `TotalVariationPenalty`).
237 pub penalty: AnalyticPenaltyKind,
238 /// The penalty's local ρ-axes (its slice of the global ρ vector).
239 pub rho_local: Array1<f64>,
240 /// The flat latent vector (`N·d`, row-major) the penalty's curvature was
241 /// linearized at — i.e. the `target_t` passed to
242 /// [`ArrowSchurSystem::add_analytic_penalty_contributions`]. The Hessian of
243 /// a nonlinear penalty (the smoothed-TV curvature weights `φ''(D t)`,
244 /// etc.) depends on this point, so `psd_majorizer_hvp` must be evaluated
245 /// against it for the Newton operator to be the true Hessian at the
246 /// current iterate.
247 pub target_t: Array1<f64>,
248}
249
250/// Exact cross-row low-rank IBP source (#1038): the per-column rank-one Hessian
251/// terms `H_(i,k),(j,k) = d_k·z'_ik·z'_jk` (for ALL `i,j`, including the `i=j`
252/// self term) that couple DISTINCT latent rows through a shared atom column `k`.
253///
254/// Stacking over rows, this is `H_full = H₀' + U D Uᵀ`, where:
255/// * `U` is `delta_t_len × R` with `U[g, k] = z'_ik` at the global latent index
256/// `g` of row `i`'s logit slot for atom `k` (zero elsewhere) — i.e. column `k`
257/// is supported on the atom-`k` logit slot of every row;
258/// * `D = diag(d_k)`, `d_k = w·s'_k` ([`gam_terms::analytic_penalties::IbpHessianDiagThirdChannels::cross_row_d`]);
259/// * `H₀'` is the assembled latent block-diagonal `H₀` with the per-row self
260/// term `d_k·z'_ik²` REMOVED from each logit-slot diagonal (the assembled
261/// `H₀` already carries it, so the FULL rank-one outer product `U D Uᵀ` —
262/// which re-adds the `i=j` diagonal — would double-count without this
263/// downdate). The determinant lemma `log det(I_R + D UᵀH₀'⁻¹U)` is only the
264/// exact rank-`R` correction against this no-self base.
265///
266/// The arrow elimination assumes each row's `H_tt^(i)` is independent of every
267/// other row, so it structurally cannot hold this coupling block-locally. The
268/// factorization owner (`solver::arrow_schur`) consumes this source to (a)
269/// downdate the per-row logit diagonal before factoring, (b) build `U`/`D` onto
270/// the resulting [`ArrowFactorCache`] as a [`CrossRowWoodbury`], and (c) apply
271/// the exact Woodbury correction to the value/curvature solve, the evidence
272/// log-determinant, and the θ/ρ-adjoint TOGETHER (they all describe the SAME
273/// `H_full`).
274#[derive(Clone, Debug, Default)]
275pub struct IbpCrossRowSource {
276 /// Number of atom columns `R` (the rank of the cross-row update).
277 pub r: usize,
278 /// `d_k = w·s'_k`, the scalar `D`-coefficient of column `k`. Length `R`.
279 pub d: Array1<f64>,
280 /// Per-row column entries `(global_t_index, atom_k, z'_ik)`: each tuple
281 /// places `z'_ik` at `U[global_t_index, atom_k]`. The `global_t_index` is
282 /// `row_offsets[i] + local_slot` for the row's logit slot of atom `k`. Only
283 /// nonzero entries are listed (one per active (row, atom) pair).
284 pub entries: Vec<(usize, usize, f64)>,
285}
286
287impl IbpCrossRowSource {
288 /// Build the dense `delta_t_len × R` factor `U` (each column supported on
289 /// its atom's per-row logit slots) from the sparse entry list.
290 pub(crate) fn dense_u(&self, delta_t_len: usize) -> Array2<f64> {
291 let mut u = Array2::<f64>::zeros((delta_t_len, self.r));
292 for &(g, k, z) in &self.entries {
293 u[[g, k]] += z;
294 }
295 u
296 }
297
298 /// Per-row-slot self-term downdate: returns, for each global latent index,
299 /// the scalar `Σ_k d_k·z'_ik²` to subtract from that logit slot's diagonal
300 /// so the factored base is `H₀'` (self term removed). Indexed by global
301 /// `delta_t` position.
302 pub(crate) fn self_term_downdate(&self, delta_t_len: usize) -> Array1<f64> {
303 let mut down = Array1::<f64>::zeros(delta_t_len);
304 for &(g, k, z) in &self.entries {
305 down[g] += self.d[k] * z * z;
306 }
307 down
308 }
309}
310
311impl ArrowSchurSystem {
312 /// Allocate an empty BA reduced-camera-system instance sized
313 /// `(N point/latent rows × d, K shared decoder parameters)`.
314 pub fn new(n: usize, d: usize, k: usize) -> Self {
315 Self::new_with_hbb(n, d, k, Array2::<f64>::zeros((k, k)))
316 }
317
318 /// Allocate an arrow system with no dense shared `H_ββ` block and with
319 /// per-row dense `H_tβ` slabs allocated at `htbeta_cols` columns.
320 pub fn new_with_empty_hbb_and_htbeta_cols(
321 n: usize,
322 d: usize,
323 k: usize,
324 htbeta_cols: usize,
325 ) -> Self {
326 let rows = (0..n)
327 .map(|_| ArrowRowBlock::new_with_htbeta_cols(d, htbeta_cols))
328 .collect();
329 let row_dims: Arc<[usize]> = (0..n).map(|_| d).collect::<Vec<_>>().into();
330 let row_offsets: Arc<[usize]> = (0..=n).map(|i| i * d).collect::<Vec<_>>().into();
331 Self {
332 rows,
333 hbb: Array2::<f64>::zeros((0, 0)),
334 hbb_matvec: None,
335 htbeta_matvec: None,
336 htbeta_transpose_matvec: None,
337 htbeta_dense_supplement: false,
338 hbb_diag: None,
339 gb: Array1::<f64>::zeros(k),
340 d,
341 row_dims,
342 row_offsets,
343 k,
344 manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
345 row_hessian_fingerprint: 0,
346 analytic_row_hessian_fingerprint: 0,
347 block_offsets: Arc::from([] as [Range<usize>; 0]),
348 penalty_op: None,
349 device_sae_pcg: None,
350 cross_row_penalties: Vec::new(),
351 row_gauge_deflation: None,
352 ibp_cross_row: None,
353 }
354 }
355
356 /// Allocate an arrow system using a caller-owned dense shared-block buffer.
357 /// The buffer must already have shape `(k, k)` and is zeroed in place before
358 /// use so callers can recycle it across assemblies without changing
359 /// numerics.
360 pub fn new_with_hbb(n: usize, d: usize, k: usize, hbb: Array2<f64>) -> Self {
361 Self::new_with_hbb_and_htbeta_cols(n, d, k, hbb, k)
362 }
363
364 /// Allocate an arrow system with a caller-owned dense shared-block buffer and
365 /// per-row dense `H_tβ` slabs allocated at `htbeta_cols` columns.
366 pub fn new_with_hbb_and_htbeta_cols(
367 n: usize,
368 d: usize,
369 k: usize,
370 mut hbb: Array2<f64>,
371 htbeta_cols: usize,
372 ) -> Self {
373 assert_eq!(hbb.dim(), (k, k));
374 hbb.fill(0.0);
375 let rows = (0..n)
376 .map(|_| ArrowRowBlock::new_with_htbeta_cols(d, htbeta_cols))
377 .collect();
378 let row_dims: Arc<[usize]> = (0..n).map(|_| d).collect::<Vec<_>>().into();
379 let row_offsets: Arc<[usize]> = (0..=n).map(|i| i * d).collect::<Vec<_>>().into();
380 Self {
381 rows,
382 hbb,
383 hbb_matvec: None,
384 htbeta_matvec: None,
385 htbeta_transpose_matvec: None,
386 htbeta_dense_supplement: false,
387 hbb_diag: None,
388 gb: Array1::<f64>::zeros(k),
389 d,
390 row_dims,
391 row_offsets,
392 k,
393 manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
394 row_hessian_fingerprint: 0,
395 analytic_row_hessian_fingerprint: 0,
396 block_offsets: Arc::from([] as [Range<usize>; 0]),
397 penalty_op: None,
398 device_sae_pcg: None,
399 cross_row_penalties: Vec::new(),
400 row_gauge_deflation: None,
401 ibp_cross_row: None,
402 }
403 }
404
405 /// Allocate an arrow system whose shared `H_ββ` block is supplied only as
406 /// a matrix-free operator for large BA InexactPCG.
407 ///
408 /// Direct and Square-Root BA modes require dense `hbb` and must not be
409 /// used with this constructor. The row-local `H_tβ` slabs remain explicit;
410 /// a future MegBA backend can replace those slab operations behind
411 /// [`BatchedBlockSolver`].
412 pub fn new_matrix_free_shared<F>(
413 n: usize,
414 d: usize,
415 k: usize,
416 matvec: F,
417 diag: Array1<f64>,
418 ) -> Self
419 where
420 F: for<'a> Fn(ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
421 {
422 assert_eq!(diag.len(), k);
423 let rows = (0..n).map(|_| ArrowRowBlock::new(d, k)).collect();
424 let row_dims: Arc<[usize]> = (0..n).map(|_| d).collect::<Vec<_>>().into();
425 let row_offsets: Arc<[usize]> = (0..=n).map(|i| i * d).collect::<Vec<_>>().into();
426 let matvec_arc: SharedBetaMatvec = Arc::new(matvec);
427 // Mirror the closure into a BetaPenaltyOp so all hot paths (#296)
428 // route through the trait while preserving hbb_matvec + hbb_diag for
429 // code that inspects them directly.
430 let penalty_op: Option<Arc<dyn BetaPenaltyOp>> = Some(Arc::new(MatvecDiagPenaltyOp::new(
431 k,
432 Arc::clone(&matvec_arc),
433 diag.clone(),
434 )));
435 Self {
436 rows,
437 hbb: Array2::<f64>::zeros((0, 0)),
438 hbb_matvec: Some(matvec_arc),
439 htbeta_matvec: None,
440 htbeta_transpose_matvec: None,
441 htbeta_dense_supplement: false,
442 hbb_diag: Some(diag),
443 gb: Array1::<f64>::zeros(k),
444 d,
445 row_dims,
446 row_offsets,
447 k,
448 manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
449 row_hessian_fingerprint: 0,
450 analytic_row_hessian_fingerprint: 0,
451 block_offsets: Arc::from([] as [Range<usize>; 0]),
452 penalty_op,
453 device_sae_pcg: None,
454 cross_row_penalties: Vec::new(),
455 row_gauge_deflation: None,
456 ibp_cross_row: None,
457 }
458 }
459
460
461
462 /// Allocate a heterogeneous-row arrow system with no dense shared `H_ββ`
463 /// block and with row `H_tβ` slabs allocated at `htbeta_cols` columns.
464 pub fn new_with_per_row_dims_empty_hbb_and_htbeta_cols(
465 per_row_dims: Vec<usize>,
466 k: usize,
467 htbeta_cols: usize,
468 ) -> Self {
469 let n = per_row_dims.len();
470 let d = per_row_dims.iter().copied().max().unwrap_or(0);
471 let mut offsets = Vec::with_capacity(n + 1);
472 let mut cursor = 0usize;
473 offsets.push(cursor);
474 for &dim in &per_row_dims {
475 cursor += dim;
476 offsets.push(cursor);
477 }
478 let rows = per_row_dims
479 .iter()
480 .map(|&dim| ArrowRowBlock::new_with_htbeta_cols(dim, htbeta_cols))
481 .collect();
482 Self {
483 rows,
484 hbb: Array2::<f64>::zeros((0, 0)),
485 hbb_matvec: None,
486 htbeta_matvec: None,
487 htbeta_transpose_matvec: None,
488 htbeta_dense_supplement: false,
489 hbb_diag: None,
490 gb: Array1::<f64>::zeros(k),
491 d,
492 row_dims: Arc::from(per_row_dims.into_boxed_slice()),
493 row_offsets: Arc::from(offsets.into_boxed_slice()),
494 k,
495 manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
496 row_hessian_fingerprint: 0,
497 analytic_row_hessian_fingerprint: 0,
498 block_offsets: Arc::from([] as [Range<usize>; 0]),
499 penalty_op: None,
500 device_sae_pcg: None,
501 cross_row_penalties: Vec::new(),
502 row_gauge_deflation: None,
503 ibp_cross_row: None,
504 }
505 }
506
507 /// Allocate a heterogeneous-row system using a caller-owned dense shared
508 /// block and row `H_tβ` slabs allocated at `htbeta_cols` columns.
509 pub fn new_with_per_row_dims_and_hbb_and_htbeta_cols(
510 per_row_dims: Vec<usize>,
511 k: usize,
512 mut hbb: Array2<f64>,
513 htbeta_cols: usize,
514 ) -> Self {
515 assert_eq!(hbb.dim(), (k, k));
516 hbb.fill(0.0);
517 let n = per_row_dims.len();
518 let max_d = per_row_dims.iter().copied().max().unwrap_or(0);
519 let row_dims: Arc<[usize]> = per_row_dims.iter().copied().collect::<Vec<_>>().into();
520 let mut off_vec = Vec::with_capacity(n + 1);
521 let mut cursor = 0usize;
522 for &di in &per_row_dims {
523 off_vec.push(cursor);
524 cursor += di;
525 }
526 off_vec.push(cursor);
527 let row_offsets: Arc<[usize]> = off_vec.into();
528 let rows = per_row_dims
529 .iter()
530 .map(|&di| ArrowRowBlock::new_with_htbeta_cols(di, htbeta_cols))
531 .collect();
532 Self {
533 rows,
534 hbb,
535 hbb_matvec: None,
536 htbeta_matvec: None,
537 htbeta_transpose_matvec: None,
538 htbeta_dense_supplement: false,
539 hbb_diag: None,
540 gb: Array1::<f64>::zeros(k),
541 d: max_d,
542 row_dims,
543 row_offsets,
544 k,
545 manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
546 row_hessian_fingerprint: 0,
547 analytic_row_hessian_fingerprint: 0,
548 block_offsets: Arc::from([] as [Range<usize>; 0]),
549 penalty_op: None,
550 device_sae_pcg: None,
551 cross_row_penalties: Vec::new(),
552 row_gauge_deflation: None,
553 ibp_cross_row: None,
554 }
555 }
556
557 pub fn set_row_gauge_deflation(&mut self, deflation: ArrowRowGaugeDeflation) {
558 self.row_gauge_deflation = Some(deflation);
559 }
560
561 /// Register the exact cross-row IBP low-rank source (#1038). The assembly
562 /// passes the per-column `D`-coefficients (`cross_row_d`) and the `(global
563 /// latent index, atom, z'_ik)` entries built from `z_jac`; the factorization
564 /// then carries the exact rank-`R` Woodbury (value + log-determinant +
565 /// θ/ρ-adjoint) on the evidence cache. An empty source (`r == 0` or no
566 /// entries) is treated as absent so the row-block-diagonal path is unchanged.
567 pub fn set_ibp_cross_row_source(&mut self, source: IbpCrossRowSource) {
568 if source.r == 0 || source.entries.is_empty() {
569 self.ibp_cross_row = None;
570 } else {
571 self.ibp_cross_row = Some(source);
572 }
573 }
574
575 /// Number of BA point/latent rows `N`.
576 pub fn n(&self) -> usize {
577 self.rows.len()
578 }
579
580 /// Recompute the row-system fingerprint from the currently materialized
581 /// row blocks, cross-blocks, and shared-block diagonal.
582 pub fn compute_row_hessian_fingerprint(&self) -> u64 {
583 row_hessian_fingerprint_for_system(self)
584 }
585
586 /// Current effective row-system fingerprint, including the materialized
587 /// row blocks and any registry metadata captured while folding analytic
588 /// penalties into the system.
589 pub fn current_row_hessian_fingerprint(&self) -> u64 {
590 combine_row_and_registry_fingerprints(
591 self.compute_row_hessian_fingerprint(),
592 self.analytic_row_hessian_fingerprint,
593 )
594 }
595
596 /// Store the current row-system fingerprint on the system.
597 ///
598 /// This is intentionally explicit and expensive. Cache and evidence callers
599 /// use [`Self::current_row_hessian_fingerprint`] at the point they need the
600 /// value, after assembly has populated the system, instead of hashing each
601 /// intermediate construction/mutation step.
602 pub fn refresh_row_hessian_fingerprint(&mut self) {
603 self.row_hessian_fingerprint = self.current_row_hessian_fingerprint();
604 }
605
606 /// Install a matrix-free shared-block operator for Agarwal-style
607 /// inexact Schur PCG.
608 ///
609 /// `diag` must be the diagonal of the same `H_ββ` operator and is used
610 /// for the Schur-Jacobi preconditioner. This is the BA "large camera
611 /// system" path mapped to large decoder coefficient blocks.
612 pub fn set_shared_beta_operator<F>(&mut self, matvec: F, diag: Array1<f64>)
613 where
614 F: for<'a> Fn(ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
615 {
616 assert_eq!(diag.len(), self.k);
617 let matvec_arc: SharedBetaMatvec = Arc::new(matvec);
618 // Mirror the closure into a BetaPenaltyOp so all hot paths (#296)
619 // route through the trait, preserving the existing hbb_matvec +
620 // hbb_diag fields for code that inspects them directly.
621 self.penalty_op = Some(Arc::new(MatvecDiagPenaltyOp::new(
622 self.k,
623 Arc::clone(&matvec_arc),
624 diag.clone(),
625 )));
626 self.hbb_matvec = Some(matvec_arc);
627 self.hbb_diag = Some(diag);
628 }
629
630 /// Mark the dense per-row cross-block slabs as active supplements to the
631 /// installed matrix-free row operator.
632 pub fn activate_dense_htbeta_supplement(&mut self) {
633 self.htbeta_dense_supplement = true;
634 }
635
636 /// Install a matrix-free per-row cross-block operator and its sparse
637 /// adjoint.
638 ///
639 /// `forward` must write `out = H_tβ^(row) x` for `out.len() == d` and
640 /// `x.len() == K`. `transpose` must **add** `H_βt^(row) v` into `out` for
641 /// `out.len() == K` and `v.len() == d` (the sparse `scatter` adjoint).
642 ///
643 /// When installed, the forward operator is used during the Newton solve
644 /// (inside `reduced_rhs_beta`, `schur_matvec`, back-substitution, and
645 /// `JacobiPreconditioner` construction) and afterwards by IFT/evidence
646 /// predictors. Per-row `htbeta` slabs in `ArrowRowBlock` may be left
647 /// zero-sized when this operator is installed — all inner-Schur paths route
648 /// through the matvec instead of indexing the dense block. The transpose
649 /// operator lets the reduced-Schur matvec apply `H_βt^(row)` directly
650 /// (`O(m_i · p)`) instead of probing `forward` against `K` basis vectors.
651 pub fn set_row_htbeta_operator<F, T>(&mut self, forward: F, transpose: T)
652 where
653 F: for<'a> Fn(usize, ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
654 T: for<'a> Fn(usize, ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
655 {
656 self.htbeta_matvec = Some(Arc::new(forward));
657 self.htbeta_transpose_matvec = Some(Arc::new(transpose));
658 }
659
660 /// Register term-block column ranges for the block-Jacobi Schur preconditioner.
661 ///
662 /// Each `Range<usize>` covers the columns of one GAM term (or custom
663 /// parameter family) in the shared `β` vector. The ranges must be
664 /// non-overlapping, sorted, and their union must cover `0..k`.
665 ///
666 /// Call this after building the system and before [`Self::solve`] /
667 /// [`Self::solve_with_options`] whenever the solver will use
668 /// [`ArrowSolverMode::InexactPCG`]. Absent a call, the preconditioner
669 /// falls back to scalar diagonal Jacobi (the pre-#283 behaviour).
670 ///
671 /// The same plumbing is compatible with #287 (custom `ParameterBlockSpec`
672 /// families): callers from that path simply supply ranges derived from
673 /// their own block layout.
674 pub fn set_block_offsets(&mut self, offsets: Arc<[Range<usize>]>) {
675 self.block_offsets = offsets;
676 }
677
678 /// Install a matrix-free penalty-side `H_ββ` operator (#296).
679 ///
680 /// When set, all hot paths (`schur_matvec`, `build_dense_schur_*`,
681 /// `JacobiPreconditioner`, quadratic-form reduction) route through this
682 /// operator instead of the dense `hbb` accumulator, enabling
683 /// `BlockPenaltyOp` / `KroneckerPenaltyOp` to avoid `O(K²)` allocation
684 /// for structured smoothness penalties.
685 pub fn set_penalty_op(&mut self, op: Arc<dyn BetaPenaltyOp>) {
686 self.penalty_op = Some(op);
687 }
688
689 pub fn set_device_sae_pcg_data(&mut self, data: DeviceSaePcgData) {
690 assert_eq!(data.beta_dim, self.k);
691 // The frames-engaged builder (`build_framed_device_sae_data`) carries the
692 // per-row cross block through `frame.frame_blocks` and intentionally leaves
693 // the full-`B` `a_phi`/`local_jac` slabs EMPTY (#1033). Only the non-framed
694 // full-`B` path populates those per-row slabs, so the length contract
695 // applies only when there is no frame.
696 if data.frame.is_none() {
697 assert_eq!(data.a_phi.len(), self.rows.len());
698 assert_eq!(data.local_jac.len(), self.rows.len());
699 }
700 self.device_sae_pcg = Some(Arc::new(data));
701 }
702
703 /// Return the effective penalty operator: the installed `penalty_op` if
704 /// present, otherwise a `DensePenaltyOp` wrapping the current `hbb`.
705 ///
706 /// Note: when `penalty_op` is `None`, this clones `hbb` into a new
707 /// `DensePenaltyOp`. Callers in hot loops should call this once and
708 /// store the result, not call it per-iteration.
709 pub fn effective_penalty_op(&self) -> Arc<dyn BetaPenaltyOp> {
710 match self.penalty_op.as_ref() {
711 Some(op) => Arc::clone(op),
712 None => Arc::new(DensePenaltyOp(self.hbb.clone())),
713 }
714 }
715
716 /// `y += P x` without allocating a new Arc; dispatches to `penalty_op`
717 /// or falls back to `hbb` inline, avoiding the K×K clone hot-path cost.
718 #[inline]
719 pub(crate) fn penalty_matvec_add(&self, x: &[f64], y: &mut [f64]) {
720 if let Some(op) = self.penalty_op.as_ref() {
721 op.matvec(x, y);
722 } else {
723 let k = self.hbb.nrows();
724 for a in 0..k {
725 let mut acc = 0.0_f64;
726 for b in 0..k {
727 acc += self.hbb[[a, b]] * x[b];
728 }
729 y[a] += acc;
730 }
731 }
732 }
733
734 /// Reduced-Schur matvec prologue `y = (P + ridge·I) x` written fresh into a
735 /// zeroed `y` (the caller clears `out` first; this is the first writer).
736 ///
737 /// At the SAE LLM border width (#1017) the dense `H_ββ` fallback is a `k×k`
738 /// GEMV whose `O(k²)` cost (≈4M flops at k=2048) runs once per CG iteration
739 /// and was the serial Amdahl ceiling on the per-row-parallel matvec: while
740 /// the `n`-row point-elimination term fans across all cores, this prologue
741 /// pinned one core and grows as `k²`. The dense GEMV is embarrassingly
742 /// parallel over output rows `a` — each `y[a] = Σ_b hbb[a,b]·x[b] + ridge·x[a]`
743 /// is independent and its inner sum order is identical whether one thread or
744 /// many compute it. Here parallelism is over independent output rows (NOT a
745 /// reassociated reduction), so each `y[a]` accumulates in the SAME order as
746 /// serial — the result is bit-identical to serial, not merely deterministic
747 /// run-to-run (the #1017 determinism gate). On THIS exact-order path the
748 /// criterion ranking is invariant; that no-move guarantee holds because the
749 /// order matches serial, and does NOT generalise to chunk-reassociated
750 /// reductions, where a near-tie winner can flip within the f64 margin
751 /// (#1211). The `penalty_op` path stays serial — it is an opaque operator
752 /// with its own structure (SAE uses the dense `hbb`), and small `k` stays
753 /// serial to avoid rayon overhead on a trivial GEMV.
754 ///
755 /// `parallel` is the caller's top-level / not-nested-in-rayon decision (the
756 /// same guard the row loop uses), so this never oversubscribes inside the
757 /// topology race.
758 pub(crate) fn penalty_ridge_prologue_into(
759 &self,
760 x: &[f64],
761 ridge: f64,
762 y: &mut [f64],
763 parallel: bool,
764 ) {
765 let k = self.hbb.nrows();
766 let dense_parallel = parallel
767 && self.penalty_op.is_none()
768 && self.hbb.dim() == (k, k)
769 && k >= SCHUR_PROLOGUE_PARALLEL_K_MIN;
770 if dense_parallel {
771 use rayon::prelude::*;
772 let hbb = &self.hbb;
773 y.par_iter_mut().enumerate().for_each(|(a, ya)| {
774 let mut acc = 0.0_f64;
775 for b in 0..k {
776 acc += hbb[[a, b]] * x[b];
777 }
778 *ya = acc + ridge * x[a];
779 });
780 } else {
781 self.penalty_matvec_add(x, y);
782 for a in 0..k {
783 y[a] += ridge * x[a];
784 }
785 }
786 }
787
788 /// `diag += diag(P)` without allocating; dispatches to `penalty_op`
789 /// or falls back to `hbb` diagonal / `hbb_diag` inline.
790 #[inline]
791 pub(crate) fn penalty_diagonal_add(&self, diag: &mut [f64]) {
792 if let Some(op) = self.penalty_op.as_ref() {
793 op.diagonal(diag);
794 } else if let Some(hbb_diag) = self.hbb_diag.as_ref() {
795 let k = hbb_diag.len().min(diag.len());
796 for j in 0..k {
797 diag[j] += hbb_diag[j];
798 }
799 } else {
800 let k = self.hbb.nrows().min(diag.len());
801 for j in 0..k {
802 diag[j] += self.hbb[[j, j]];
803 }
804 }
805 }
806
807 /// Add the `b×b` penalty sub-block for `id` to `out`, routing through
808 /// `penalty_op` or falling back to `hbb` / `hbb_diag` inline.
809 #[inline]
810 pub(crate) fn penalty_block_add(
811 &self,
812 id: BetaBlockId,
813 offsets: &[Range<usize>],
814 out: &mut Array2<f64>,
815 ) {
816 if let Some(op) = self.penalty_op.as_ref() {
817 op.block(id, offsets, out);
818 } else {
819 let range = &offsets[id.0];
820 let b = range.end - range.start;
821 if self.hbb.dim() == (self.k, self.k) {
822 for bi in 0..b {
823 for bj in 0..b {
824 out[[bi, bj]] += self.hbb[[range.start + bi, range.start + bj]];
825 }
826 }
827 } else if let Some(hbb_diag) = self.hbb_diag.as_ref() {
828 for bi in 0..b {
829 out[[bi, bi]] += hbb_diag[range.start + bi];
830 }
831 }
832 }
833 }
834
835 /// Fill a `b×b` penalty sub-block for a set of arbitrary (possibly
836 /// non-contiguous) global column indices `cols`, routing through
837 /// `penalty_op` or falling back to `hbb` / `hbb_diag` inline.
838 ///
839 /// Used by the cluster-Jacobi preconditioner (#299) which groups columns
840 /// by spectral adjacency rather than contiguous block ranges.
841 #[inline]
842 pub(crate) fn penalty_subblock_add(&self, cols: &[usize], out: &mut Array2<f64>) {
843 let b = cols.len();
844 if let Some(op) = self.penalty_op.as_ref() {
845 // Probe each column basis vector and extract the sub-block entries.
846 let mut probe = Array1::<f64>::zeros(self.k);
847 let mut result = Array1::<f64>::zeros(self.k);
848 for bj in 0..b {
849 probe.fill(0.0);
850 probe[cols[bj]] = 1.0;
851 result.fill(0.0);
852 {
853 let p_slice = probe.as_slice().expect("probe contiguous");
854 let r_slice = result.as_slice_mut().expect("result contiguous");
855 op.matvec(p_slice, r_slice);
856 }
857 for bi in 0..b {
858 out[[bi, bj]] += result[cols[bi]];
859 }
860 }
861 } else if self.hbb.dim() == (self.k, self.k) {
862 for bi in 0..b {
863 for bj in 0..b {
864 out[[bi, bj]] += self.hbb[[cols[bi], cols[bj]]];
865 }
866 }
867 } else if let Some(hbb_diag) = self.hbb_diag.as_ref() {
868 for bi in 0..b {
869 out[[bi, bi]] += hbb_diag[cols[bi]];
870 }
871 }
872 }
873
874 /// Fold analytic-penalty contributions into the appropriate blocks.
875 ///
876 /// BA source mapping: these are extra prior/regularization normal-equation
877 /// terms before point elimination, the same place Ceres/g2o attach robust
878 /// priors or gauge-fixing constraints.
879 ///
880 /// **Composition path.** Each registered [`AnalyticPenaltyKind`] is
881 /// queried for `grad_target` (added to `g_t` or `g_β`) and then for
882 /// `hessian_diag` first. Diagonal penalties (ARD and the shipped
883 /// sparsity kernels) are injected directly. The row-block-only Psi-tier
884 /// penalties are `ARDPenalty`, `SparsityPenalty`,
885 /// `SoftmaxAssignmentSparsity`, `IBPAssignment`,
886 /// `RowPrecisionPrior`, `ParametricRowPrecisionPrior`, and
887 /// `ScadMcpPenalty`. Their `d × d` per-row Hessian folds into
888 /// `rows[i].htt`, so the exact arrow Schur elimination (`N` independent
889 /// `d × d` row solves) represents them exactly. Dense Beta-tier penalties
890 /// still fall back to `hvp` probes against the canonical basis vectors for
891 /// `β`.
892 ///
893 /// **Cross-row Psi penalties.** Penalties whose Hessian couples *distinct*
894 /// latent rows — `TotalVariationPenalty`, `SheafConsistencyPenalty`,
895 /// block-orthogonality, … — produce off-row blocks `∂²P/∂t_i∂t_j`
896 /// (`i ≠ j`) that the arrow elimination cannot store, since it assumes each
897 /// `H_tt^(i)` is independent of every other row. These are handled without
898 /// any approximation: their **gradient** is folded into `g_t` exactly as
899 /// for every other Psi penalty (`grad_target → g_t`), and their full
900 /// **curvature** is captured into [`Self::cross_row_penalties`] as a
901 /// matrix-free operator. At solve time, `K = K0 + P_cross` where `K0` is
902 /// the block-diagonal arrow operator and `P_cross · Δt = Σ_p ρ_p ·
903 /// psd_majorizer_hvp_p(t, Δt)` is the cross-row penalty Hessian applied to
904 /// the full flat latent vector. The presence of any captured cross-row
905 /// penalty auto-routes [`Self::solve`] through the matrix-free full-system
906 /// PCG path (the exact arrow block-diagonal inverse `K0⁻¹` is the
907 /// preconditioner `M⁻¹`); a purely row-block-diagonal system keeps the
908 /// exact one-shot Schur path unchanged. No new flag is involved — the route
909 /// is selected from the captured penalty set alone (magic by default).
910 ///
911 /// `target_t` is the full flat latent-coordinate vector (row-major, `N·d` entries)
912 /// at the current iterate; `target_beta` is the current `β`. `rho`
913 /// is the global ρ vector restricted to each penalty's local slice
914 /// by [`AnalyticPenaltyRegistry::rho_layout`].
915 pub fn add_analytic_penalty_contributions(
916 &mut self,
917 registry: &AnalyticPenaltyRegistry,
918 target_t: ArrayView1<'_, f64>,
919 target_beta: ArrayView1<'_, f64>,
920 rho_global: ArrayView1<'_, f64>,
921 ) -> Result<(), ArrowSchurError> {
922 let layout = registry.rho_layout();
923 let mut penalty_fingerprints = Vec::new();
924 self.cross_row_penalties.clear();
925 for (penalty, (rho_slice, tier, _name)) in registry.penalties.iter().zip(layout.iter()) {
926 let rho_local = rho_global.slice(ndarray::s![rho_slice.clone()]);
927 match tier {
928 PenaltyTier::Psi => {
929 if analytic_penalty_is_row_block_diagonal(penalty) {
930 // Row-block-diagonal: fold gradient + per-row d×d
931 // curvature into rows[i].htt, exactly representable by
932 // the arrow Schur elimination.
933 self.add_ext_coord_penalty(penalty, target_t, rho_local);
934 if let Some(fingerprint) =
935 analytic_penalty_row_hessian_fingerprint(penalty, target_t, rho_local)
936 {
937 penalty_fingerprints.push(fingerprint);
938 }
939 } else {
940 // Cross-row: fold the gradient into g_t (exact, like
941 // every Psi penalty), but DO NOT fold any curvature into
942 // the row blocks — its off-row coupling cannot be stored
943 // there. Capture the penalty so the solve applies its
944 // full Hessian-vector product P_cross·Δt over the flat
945 // latent vector. This auto-selects the matrix-free
946 // full-system PCG path.
947 self.add_ext_coord_penalty_gradient_only(penalty, target_t, rho_local);
948 self.cross_row_penalties.push(CrossRowLatentPenalty {
949 penalty: penalty.clone(),
950 rho_local: rho_local.to_owned(),
951 target_t: target_t.to_owned(),
952 });
953 }
954 }
955 PenaltyTier::Beta => {
956 self.add_beta_penalty(penalty, target_beta, rho_local);
957 }
958 PenaltyTier::Rho => {
959 // Rho-tier hyperpriors do not contribute to the inner
960 // (t, β) Newton step; they enter only at the REML
961 // outer level.
962 }
963 }
964 }
965 // Cross-row penalties contribute to the Newton Hessian operator, not
966 // the stored row blocks, so they must still invalidate the row-Hessian
967 // cache when their curvature changes. Probe each captured penalty's PSD
968 // majorizer against the current latent vector (a deterministic, generic
969 // probe) and fold the resulting signature in.
970 for cross in &self.cross_row_penalties {
971 penalty_fingerprints.push(cross_row_penalty_fingerprint(
972 &cross.penalty,
973 target_t,
974 cross.rho_local.view(),
975 ));
976 }
977 self.analytic_row_hessian_fingerprint = if penalty_fingerprints.is_empty() {
978 0
979 } else {
980 let mut hasher = Fingerprinter::new();
981 hasher.write_str("arrow-schur-row-hessian-registry-v1");
982 hasher.write_usize(penalty_fingerprints.len());
983 for fingerprint in penalty_fingerprints {
984 hasher.write_u64(fingerprint);
985 }
986 hasher.finish_u64()
987 };
988 Ok(())
989 }
990
991 /// Convert row-local Euclidean latent blocks to Riemannian tangent blocks.
992 ///
993 /// This is the only arrow-Schur algebra change needed for manifold
994 /// latents: `g_t`, `H_tt`, and each `H_tβ` column are projected to
995 /// `T_{t_i}M`, while the shared β block and Schur structure remain
996 /// untouched. Embedded constrained manifolds carry a pinned normal block
997 /// so the existing ambient Cholesky factorization still works; all RHS
998 /// terms live in the tangent space, so the solved update retracts cleanly.
999 pub fn apply_riemannian_latent_geometry(&mut self, latent: &LatentCoordValues) {
1000 let manifold = latent.manifold();
1001 self.manifold_mode_fingerprint = manifold_mode_fingerprint(latent);
1002 if manifold.is_euclidean() {
1003 return;
1004 }
1005 assert_eq!(latent.n_obs(), self.rows.len());
1006 assert_eq!(latent.latent_dim(), self.d);
1007 for (i, row) in self.rows.iter_mut().enumerate() {
1008 let t_i = ArrayView1::from(latent.row(i));
1009 let gt_e = row.gt.clone();
1010 let htt_e = row.htt.clone();
1011 let htbeta_e = row.htbeta.clone();
1012 row.gt = manifold.project_gradient_to_tangent(t_i, gt_e.view());
1013 row.htt = manifold.riemannian_hessian_matrix(t_i, gt_e.view(), htt_e.view());
1014 row.htbeta = manifold.project_matrix_columns_to_gradient_tangent(
1015 t_i,
1016 gt_e.view(),
1017 htbeta_e.view(),
1018 );
1019 }
1020 }
1021
1022 pub(crate) fn add_ext_coord_penalty(
1023 &mut self,
1024 penalty: &AnalyticPenaltyKind,
1025 target_t: ArrayView1<'_, f64>,
1026 rho_local: ArrayView1<'_, f64>,
1027 ) {
1028 let d = self.d;
1029 let n = self.rows.len();
1030 apply_analytic_penalty(
1031 penalty,
1032 target_t,
1033 rho_local,
1034 n * d,
1035 d,
1036 self,
1037 |sys, flat, value| sys.rows[flat / d].gt[flat % d] += value,
1038 |sys, flat, value| sys.rows[flat / d].htt[[flat % d, flat % d]] += value,
1039 |a, probe| {
1040 for i in 0..n {
1041 probe[i * d + a] = 1.0;
1042 }
1043 },
1044 |sys, a, hv| {
1045 for i in 0..n {
1046 for b in 0..d {
1047 sys.rows[i].htt[[b, a]] += hv[i * d + b];
1048 }
1049 }
1050 },
1051 );
1052 }
1053
1054 /// Fold ONLY the latent gradient `grad_target → g_t` of an analytic
1055 /// penalty, leaving the row-block Hessian untouched.
1056 ///
1057 /// Used for cross-row Psi penalties: their gradient enters `g_t` exactly
1058 /// like every other Psi penalty, but their curvature must NOT be scattered
1059 /// into the per-row `H_tt^(i)` blocks (the diagonal piece would be
1060 /// double-counted and the off-row coupling cannot be stored there). The
1061 /// full curvature is instead applied as a matrix-free `P_cross · Δt`
1062 /// during the solve, via [`Self::cross_row_penalties`].
1063 pub(crate) fn add_ext_coord_penalty_gradient_only(
1064 &mut self,
1065 penalty: &AnalyticPenaltyKind,
1066 target_t: ArrayView1<'_, f64>,
1067 rho_local: ArrayView1<'_, f64>,
1068 ) {
1069 let d = self.d;
1070 let n = self.rows.len();
1071 assert_eq!(target_t.len(), n * d);
1072 let grad = penalty.grad_target(target_t, rho_local);
1073 for flat in 0..n * d {
1074 self.rows[flat / d].gt[flat % d] += grad[flat];
1075 }
1076 }
1077
1078 /// Apply the aggregate cross-row penalty Hessian `P_cross · v` over the
1079 /// full flat latent vector `v` (length `Σ_i row_dims[i]`), accumulating
1080 /// into `out`.
1081 ///
1082 /// `P_cross = Σ_p psd_majorizer_hvp_p(target_t, ·; ρ_p)` summed over every
1083 /// captured cross-row penalty. Each penalty's `psd_majorizer_hvp` is its
1084 /// exact (PSD) Hessian-vector product over the `N·d` flat latent vector —
1085 /// for `TotalVariationPenalty` this is `Dᵀ diag(φ''(D t)) D · v`, the
1086 /// graph/forward-difference Laplacian-style coupling that links distinct
1087 /// rows. The ρ scaling is already baked into each penalty's resolved
1088 /// weight, so no extra factor is applied here.
1089 ///
1090 /// This is only valid for homogeneous systems (every row of dimension
1091 /// `d`), the only shape cross-row latent penalties are defined on; the
1092 /// flat-index convention `flat = i·d + j` matches every penalty's
1093 /// `latent_dim`/row-major contract.
1094 pub(crate) fn apply_cross_row_penalty_hessian(
1095 &self,
1096 v: ArrayView1<'_, f64>,
1097 out: &mut Array1<f64>,
1098 ) {
1099 for cross in &self.cross_row_penalties {
1100 assert_eq!(cross.target_t.len(), v.len());
1101 let hv =
1102 cross
1103 .penalty
1104 .psd_majorizer_hvp(cross.target_t.view(), cross.rho_local.view(), v);
1105 assert_eq!(hv.len(), out.len());
1106 for i in 0..out.len() {
1107 out[i] += hv[i];
1108 }
1109 }
1110 }
1111
1112 pub(crate) fn add_beta_penalty(
1113 &mut self,
1114 penalty: &AnalyticPenaltyKind,
1115 target_beta: ArrayView1<'_, f64>,
1116 rho_local: ArrayView1<'_, f64>,
1117 ) {
1118 let k = self.k;
1119 let hvp_columns = if self.hbb.dim() == (k, k) { k } else { 0 };
1120 apply_analytic_penalty(
1121 penalty,
1122 target_beta,
1123 rho_local,
1124 k,
1125 hvp_columns,
1126 self,
1127 |sys, j, value| sys.gb[j] += value,
1128 |sys, j, value| {
1129 if sys.hbb.dim() == (k, k) {
1130 sys.hbb[[j, j]] += value;
1131 }
1132 if let Some(hbb_diag) = sys.hbb_diag.as_mut() {
1133 hbb_diag[j] += value;
1134 }
1135 },
1136 |j, probe| probe[j] = 1.0,
1137 |sys, j, hv| {
1138 for i in 0..k {
1139 sys.hbb[[i, j]] += hv[i];
1140 }
1141 // Keep `hbb_diag` consistent with the dense `hbb` Hessian when
1142 // both are populated (the dense-allocated path + a later
1143 // `set_shared_beta_operator` install). The HVP probe for
1144 // column `j` returns the full Hessian column, whose `j`-th
1145 // entry is the diagonal contribution of this penalty. Without
1146 // this mirror, the Jacobi Schur preconditioner — which prefers
1147 // `hbb_diag` over `hbb`'s diagonal — would silently use a
1148 // stale diagonal for any Beta-tier analytic penalty that
1149 // exposes only an HVP (no `hessian_diag`).
1150 if let Some(hbb_diag) = sys.hbb_diag.as_mut() {
1151 hbb_diag[j] += hv[j];
1152 }
1153 },
1154 );
1155 }
1156
1157 /// Schur-eliminate the per-row latent block and solve for `(Δt, Δβ, diag)`.
1158 ///
1159 /// This uses [`ArrowSolveOptions::automatic`]: BA dense RCS for
1160 /// `K <= 2000`, and Agarwal-style inexact Schur PCG above that size.
1161 /// Call [`ArrowSchurSystem::solve_with_options`] to force Square-Root BA
1162 /// or a specific inexact solve policy.
1163 ///
1164 /// Returns `(delta_t, delta_beta, PcgDiagnostics)` with `delta_t` flat
1165 /// row-major of length `N · d` and `delta_beta` of length `K`. The sign
1166 /// convention matches `solve_newton_direction_dense`: the returned
1167 /// increments satisfy the bordered system with RHS `[-g_t; -g_β]`, i.e.
1168 /// they are the *negated* solutions of the standard Newton-direction
1169 /// formulation. `PcgDiagnostics` is zero-valued for the Direct path and
1170 /// carries live counters (PCG iters, ridge escalations, residual) for
1171 /// InexactPCG.
1172 ///
1173 /// `ridge_t` and `ridge_beta` are nonnegative diagonal regularizers
1174 /// added to the latent and β blocks respectively before factorization
1175 /// — used by the LM damping outer wrapper to recover from near-singular
1176 /// inner steps. Pass `0.0` for both to obtain the unregularized
1177 /// Newton direction.
1178 pub fn solve(
1179 &self,
1180 ridge_t: f64,
1181 ridge_beta: f64,
1182 ) -> Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError> {
1183 let options = ArrowSolveOptions::automatic(self.k);
1184 solve_arrow_newton_step_core(self, ridge_t, ridge_beta, &options)
1185 }
1186
1187 /// Solve with the standard LM-style ridge escalation: if a per-row
1188 /// `H_tt + ridge_t·I` Cholesky pivot is non-PD, or the reduced Schur
1189 /// factor fails, geometrically grow both ridges and retry. This is the
1190 /// same Ceres-style proximal correction the Newton driver in
1191 /// `run_joint_fit_arrow_schur` performs around `solve`, lifted into the
1192 /// system itself so every entry point (predict OOS reconstruction,
1193 /// single-shot Newton refinement, …) is self-healing against the
1194 /// pathological per-row blocks produced by PCA-seeded latent
1195 /// coordinates on subset / new data — see #163 and #175.
1196 ///
1197 /// `ridge_t` / `ridge_beta` are the caller-nominal Tikhonov ridges; the
1198 /// escalation only adds extra damping on top of them when the factor
1199 /// fails. PCG / AdaptiveCorrection failures are left untouched because
1200 /// they are not factorization-recoverable.
1201 pub fn solve_with_lm_escalation(
1202 &self,
1203 ridge_t: f64,
1204 ridge_beta: f64,
1205 ) -> Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError> {
1206 let options = ArrowSolveOptions::automatic(self.k);
1207 solve_with_lm_escalation_inner(self, ridge_t, ridge_beta, &options)
1208 }
1209
1210 /// Solve with an explicit BA Schur mode, returning `(Δt, Δβ, PcgDiagnostics)`.
1211 ///
1212 /// [`ArrowSolverMode::Direct`] is the classic dense reduced-camera-system
1213 /// Cholesky path; [`ArrowSolverMode::SqrtBA`] forms the same dense system
1214 /// through Square-Root BA factors; [`ArrowSolverMode::InexactPCG`] runs
1215 /// inexact-step LM on the reduced system with Jacobi-preconditioned
1216 /// Steihaug-CG. `PcgDiagnostics` is zero-valued for Direct/SqrtBA and
1217 /// carries live counters for InexactPCG (iterations, matvec calls,
1218 /// preconditioner escalations, final relative residual, stopping reason).
1219 pub fn solve_with_options(
1220 &self,
1221 ridge_t: f64,
1222 ridge_beta: f64,
1223 options: &ArrowSolveOptions,
1224 ) -> Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError> {
1225 solve_arrow_newton_step_core(self, ridge_t, ridge_beta, options)
1226 }
1227}
1228
1229/// Chunked Schur assembler that never retains all row cross-blocks.
1230pub struct StreamingArrowSchur {
1231 pub n_rows: usize,
1232 /// Maximum per-row latent dim (upper bound for scratch buffers).
1233 pub d: usize,
1234 /// Per-row latent dims `row_dims[i] == rows[i].htt.nrows()`.
1235 pub row_dims: Arc<[usize]>,
1236 /// Flat-buffer row offsets: `row_offsets[i]` is the start of row `i` in
1237 /// `delta_t`; `row_offsets[n_rows]` is the total `delta_t` length.
1238 pub row_offsets: Arc<[usize]>,
1239 pub k: usize,
1240 pub chunk_size: usize,
1241 pub s_acc: Array2<f64>,
1242 pub(crate) rhs_acc: Array1<f64>,
1243 pub(crate) hbb: Array2<f64>,
1244 pub(crate) gb: Array1<f64>,
1245 pub(crate) row_builder: StreamingArrowRowBuilder,
1246 /// Procedural cross-block operator `H_tβ^(i) x`. When present, the dense
1247 /// per-row `H_tβ` slabs are never materialized: `accumulate_chunk` and
1248 /// `back_substitute` probe this operator column-by-column to apply the
1249 /// cross-block, matching the Kronecker / matrix-free assembly path. When
1250 /// `None` (legacy dense BA callers), the per-row `row.htbeta` slab is used.
1251 pub(crate) htbeta_matvec: Option<RowHtbetaMatvec>,
1252 /// Sparse adjoint of `htbeta_matvec`. When present, `row_htbeta` rebuilds
1253 /// the dense `(d_i × K)` cross-block by probing the transpose with `d_i`
1254 /// basis vectors — `O(d_i · m_i · p)` total, vs the `O(K · m_i · p)` cost
1255 /// of probing the forward operator with `K` basis vectors. Since
1256 /// `d_i ≪ K`, this is the per-row sparse apply that replaces the `O(K)`
1257 /// column-probe in the streaming reduced-Schur accumulation.
1258 pub(crate) htbeta_transpose_matvec: Option<RowHtbetaTransposeMatvec>,
1259 /// Lift the per-row κ rejection for evidence/log-det-only solves; see
1260 /// [`ArrowSolveOptions::tolerate_ill_conditioning`]. Set by [`Self::solve`]
1261 /// from the options; defaults to `false` so direct callers of
1262 /// [`Self::accumulate_chunk`] keep the full guard.
1263 pub(crate) tolerate_ill_conditioning: bool,
1264 /// Set when the source system carried an exact cross-row IBP source
1265 /// ([`IbpCrossRowSource`], #1038). The streaming chunked accumulator cannot
1266 /// hold the rank-`R` Woodbury correction chunk-locally — `U`'s columns span
1267 /// ALL rows, so the capacitance `I_R + D Uᵀ H₀'⁻¹ U` needs the per-row
1268 /// factors retained for a global `H₀'⁻¹U` back-solve, which is exactly the
1269 /// `(N·K)`-scale residency the streaming path exists to avoid. Rather than
1270 /// silently DROP the cross-row term (an inexact logdet that would desync
1271 /// from the dense-resident gradient), the streaming log-determinant errors
1272 /// loudly when this is set, forcing IBP-active fits onto the dense resident
1273 /// [`ArrowFactorCache::arrow_log_det`] path (which carries the exact
1274 /// Woodbury). See the #1038 streaming note.
1275 pub(crate) ibp_cross_row_active: bool,
1276 /// SAE manifold evidence-path per-row gauge deflation, copied from the
1277 /// source [`ArrowSchurSystem::row_gauge_deflation`] (#1273/#1377). When
1278 /// present, the streaming per-row factor MUST apply the SAME spectral
1279 /// discovery-and-deflation of an intrinsic-dimension-flat `H_tt^(i)`
1280 /// direction (eigenvalue → +1, ρ-independent `log 1 = 0` evidence) that the
1281 /// dense [`factor_blocks_for_system`] path applies, or the two routes report
1282 /// different log-determinants for the SAME system — the cross-route
1283 /// invariant `streaming_logdet == full_logdet` would break (the #1377
1284 /// regression: #1273 wired the deflation into the dense path only). `None`
1285 /// for every non-evidence caller, which keeps the strict non-PD refusal.
1286 pub(crate) row_gauge_deflation: Option<ArrowRowGaugeDeflation>,
1287 /// The exact cross-row IBP source ([`IbpCrossRowSource`], #1038), cloned from
1288 /// the assembled [`ArrowSchurSystem`]. The bare streaming paths
1289 /// ([`Self::solve`] / [`Self::reduced_schur_and_log_det_tt`]) still REFUSE
1290 /// when this is present (they cannot carry the rank-`R` Woodbury), but the
1291 /// evidence-only [`Self::reduced_schur_log_det_tt_woodbury`] consumes it to
1292 /// accumulate the chunk-additive Woodbury building blocks `M0 = Uᵀ A⁻¹ U`
1293 /// and `W = Bᵀ A⁻¹ U` against the NO-SELF base `H₀'` (self term downdated),
1294 /// which the caller closes into the exact `log det(I_R + D Uᵀ H₀'⁻¹ U)` via
1295 /// [`streaming_cross_row_woodbury_log_det`].
1296 pub(crate) ibp_cross_row: Option<IbpCrossRowSource>,
1297}
1298
1299/// One chunk's contribution to the streaming cross-row IBP Woodbury (#1038).
1300///
1301/// The capacitance `C = I_R + D·M` with `M = Uᵀ H₀'⁻¹ U` is chunk-additive in
1302/// its two ingredients (`A = H₀'` is block-diagonal over rows, `U` is supported
1303/// per-row): `M = M0 + Wᵀ S⁻¹ W` where `M0 = Σ_i Uᵢᵀ Aᵢ⁻¹ Uᵢ` (`R×R`) and
1304/// `W = Σ_i Bᵢᵀ Aᵢ⁻¹ Uᵢ` (`k×R`) accumulate row-by-row, and `S` is the final
1305/// GLOBAL reduced Schur. This carries one chunk's `M0`/`W` plus the per-atom
1306/// `D`; the caller sums them and closes the capacitance after the chunk loop.
1307#[derive(Debug, Clone)]
1308pub struct StreamingWoodburyChunk {
1309 /// `Σ_{i∈chunk} Uᵢᵀ Aᵢ⁻¹ Uᵢ`, `R×R`.
1310 pub m0: Array2<f64>,
1311 /// `Σ_{i∈chunk} Bᵢᵀ Aᵢ⁻¹ Uᵢ`, `k×R` (`k` = reduced β border).
1312 pub w: Array2<f64>,
1313 /// Per-atom `D`-coefficients `d_k = w·s'_k`, length `R`.
1314 pub d: Array1<f64>,
1315}
1316
1317impl std::fmt::Debug for StreamingArrowSchur {
1318 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1319 f.debug_struct("StreamingArrowSchur")
1320 .field("n_rows", &self.n_rows)
1321 .field("d", &self.d)
1322 .field("k", &self.k)
1323 .field("chunk_size", &self.chunk_size)
1324 .finish_non_exhaustive()
1325 }
1326}
1327
1328impl StreamingArrowSchur {
1329 #[must_use]
1330 pub fn new(
1331 n_rows: usize,
1332 d: usize,
1333 row_dims: Arc<[usize]>,
1334 row_offsets: Arc<[usize]>,
1335 k: usize,
1336 hbb: Array2<f64>,
1337 gb: Array1<f64>,
1338 row_builder: StreamingArrowRowBuilder,
1339 chunk_size: usize,
1340 ) -> Self {
1341 assert_eq!(hbb.dim(), (k, k));
1342 assert_eq!(gb.len(), k);
1343 Self {
1344 n_rows,
1345 d,
1346 row_dims,
1347 row_offsets,
1348 k,
1349 chunk_size: chunk_size.max(1),
1350 s_acc: Array2::<f64>::zeros((k, k)),
1351 rhs_acc: Array1::<f64>::zeros(k),
1352 hbb,
1353 gb,
1354 row_builder,
1355 htbeta_matvec: None,
1356 htbeta_transpose_matvec: None,
1357 tolerate_ill_conditioning: false,
1358 ibp_cross_row_active: false,
1359 row_gauge_deflation: None,
1360 ibp_cross_row: None,
1361 }
1362 }
1363
1364 #[must_use]
1365 pub fn from_system(sys: &ArrowSchurSystem, chunk_size: usize) -> Self {
1366 // When a Kronecker / matrix-free htbeta_matvec is installed, the dense
1367 // row.htbeta slabs may be zero-sized. Rather than materialize every
1368 // `(d × K)` slab (the very `(N·K)`-scale buffer the streaming path
1369 // exists to avoid), retain the procedural operator and probe it per row
1370 // inside `accumulate_chunk` / `back_substitute`. The row builder then
1371 // only carries the small `H_tt` / `g_t` blocks.
1372 let htbeta_matvec = sys.htbeta_matvec.clone();
1373 let rows: Vec<ArrowRowBlock> = if htbeta_matvec.is_some() {
1374 sys.rows
1375 .iter()
1376 .map(|row| ArrowRowBlock {
1377 htt: row.htt.clone(),
1378 htbeta: Array2::<f64>::zeros((0, 0)),
1379 gt: row.gt.clone(),
1380 })
1381 .collect()
1382 } else {
1383 sys.rows.clone()
1384 };
1385 let rows = Arc::new(rows);
1386 let row_builder: StreamingArrowRowBuilder = Arc::new(move |row| {
1387 rows.get(row)
1388 .cloned()
1389 .ok_or_else(|| ArrowSchurError::SchurFactorFailed {
1390 reason: format!("streaming row {row} out of bounds"),
1391 })
1392 });
1393 // Materialize the dense β-block from the effective penalty operator so
1394 // the streaming accumulator stays correct when contributions live in a
1395 // structured `BetaPenaltyOp` (e.g. the SAE data-fit Gauss-Newton block,
1396 // represented as `G ⊗ I_p`) rather than the dense `hbb` accumulator.
1397 // When no `penalty_op` is installed this reduces to `hbb.clone()`.
1398 let hbb_dense = sys.effective_penalty_op().to_dense();
1399 let mut streaming = Self::new(
1400 sys.rows.len(),
1401 sys.d,
1402 Arc::clone(&sys.row_dims),
1403 Arc::clone(&sys.row_offsets),
1404 sys.k,
1405 hbb_dense,
1406 sys.gb.clone(),
1407 row_builder,
1408 chunk_size,
1409 );
1410 streaming.htbeta_matvec = htbeta_matvec;
1411 streaming.htbeta_transpose_matvec = sys.htbeta_transpose_matvec.clone();
1412 streaming.ibp_cross_row_active = sys.ibp_cross_row.is_some();
1413 streaming.ibp_cross_row = sys.ibp_cross_row.clone();
1414 // Carry the SAE evidence-path per-row gauge deflation so the streaming
1415 // per-row factor matches the dense `factor_blocks_for_system` exactly
1416 // (#1377): without it, a row with an intrinsic-dimension-flat `H_tt`
1417 // would deflate on the dense path but be refused / log-det-divergent on
1418 // the streaming path, breaking `streaming_logdet == full_logdet`.
1419 streaming.row_gauge_deflation = sys.row_gauge_deflation.clone();
1420 streaming
1421 }
1422
1423 /// Factor one streaming row's `H_tt^(i)`, applying the SAME per-row gauge /
1424 /// spectral deflation the dense [`factor_blocks_for_system`] path applies
1425 /// when this is the SAE manifold evidence path (an installed
1426 /// `row_gauge_deflation`). For every non-evidence caller this is exactly the
1427 /// generic [`factor_one_row`] (strict non-PD refusal), so PD blocks are
1428 /// bit-for-bit unchanged. Routing both the dense and streaming per-row
1429 /// factors through the identical recovery is what keeps their
1430 /// log-determinants identical (#1273/#1377).
1431 fn factor_row(
1432 &self,
1433 row: &ArrowRowBlock,
1434 ridge_t: f64,
1435 di: usize,
1436 row_idx: usize,
1437 ) -> Result<Array2<f64>, ArrowSchurError> {
1438 match self.row_gauge_deflation.as_ref() {
1439 Some(deflation) => factor_one_row_result(
1440 row,
1441 ridge_t,
1442 di,
1443 row_idx,
1444 self.tolerate_ill_conditioning,
1445 deflation.row(row_idx),
1446 // Evidence path: opt into spectral discovery of an
1447 // intrinsic-dimension-flat direction even when this row's
1448 // supplied gauge list is empty/non-spanning — matching the
1449 // `allow_spectral_deflation = true` the dense path passes.
1450 true,
1451 )
1452 .map(|result| result.factor),
1453 None => factor_one_row(row, ridge_t, di, row_idx, self.tolerate_ill_conditioning),
1454 }
1455 }
1456
1457 /// Build the `(di × k)` cross-block for `row_idx` on demand.
1458 ///
1459 /// When the sparse transpose adjoint is installed, probes it with `di`
1460 /// standard basis vectors — each yields a full `K`-row of `H_βt^(i)`
1461 /// (i.e. a row of the `(di × k)` block) via the sparse scatter, for
1462 /// `O(di · m_i · p)` total, far below the `O(K · m_i · p)` cost of probing
1463 /// the forward operator with `K` basis vectors when `di ≪ K`.
1464 ///
1465 /// When only the forward operator is installed (no adjoint), falls back to
1466 /// the `k`-column forward probe. Otherwise clones the dense `row.htbeta`
1467 /// slab.
1468 pub(crate) fn row_htbeta(&self, row_idx: usize, row: &ArrowRowBlock, di: usize) -> Array2<f64> {
1469 if let Some(op_t) = self.htbeta_transpose_matvec.as_ref() {
1470 // Probe the adjoint: for each latent index c, scatter e_c to obtain
1471 // row c of the (di × k) block.
1472 let mut mat = Array2::<f64>::zeros((di, self.k));
1473 let mut e_c = Array1::<f64>::zeros(di);
1474 let mut beta_row = Array1::<f64>::zeros(self.k);
1475 for c in 0..di {
1476 e_c.fill(0.0);
1477 e_c[c] = 1.0;
1478 beta_row.fill(0.0);
1479 op_t(row_idx, e_c.view(), &mut beta_row);
1480 for a in 0..self.k {
1481 mat[[c, a]] = beta_row[a];
1482 }
1483 }
1484 return mat;
1485 }
1486 match self.htbeta_matvec.as_ref() {
1487 Some(op) => {
1488 let mut mat = Array2::<f64>::zeros((di, self.k));
1489 let mut e_a = Array1::<f64>::zeros(self.k);
1490 let mut col = Array1::<f64>::zeros(di);
1491 for a in 0..self.k {
1492 e_a.fill(0.0);
1493 e_a[a] = 1.0;
1494 col.fill(0.0);
1495 op(row_idx, e_a.view(), &mut col);
1496 for c in 0..di {
1497 mat[[c, a]] = col[c];
1498 }
1499 }
1500 mat
1501 }
1502 None => row.htbeta.clone(),
1503 }
1504 }
1505
1506 /// Move out the accumulated reduced Schur block `s_acc` and reduced RHS
1507 /// `rhs_acc`, leaving fresh zero buffers in their place.
1508 ///
1509 /// The reduced contribution is `s_acc = hbb − Σ_i H_βt^(i)(H_tt^(i))⁻¹H_tβ^(i)`
1510 /// (the β-block `hbb` seeded by `reset_accumulator`, minus the per-row
1511 /// reduction summed by `accumulate_chunk`) and
1512 /// `rhs_acc = +Σ_i H_βt^(i)(H_tt^(i))⁻¹g_t^(i)`. Used by external online
1513 /// drivers (e.g. the SAE streaming joint fit) that accumulate the reduced
1514 /// system across re-materialized chunk systems.
1515 #[must_use]
1516 pub fn take_accumulators(&mut self) -> (Array2<f64>, Array1<f64>) {
1517 let s = std::mem::replace(&mut self.s_acc, Array2::<f64>::zeros((self.k, self.k)));
1518 let rhs = std::mem::replace(&mut self.rhs_acc, Array1::<f64>::zeros(self.k));
1519 (s, rhs)
1520 }
1521
1522 /// Reset the dense shared accumulator to `H_ββ + ridge_beta I`.
1523 pub fn reset_accumulator(&mut self, ridge_beta: f64) -> Result<(), ArrowSchurError> {
1524 if self.hbb.dim() != (self.k, self.k) {
1525 return Err(ArrowSchurError::SchurFactorFailed {
1526 reason: "streaming Arrow-Schur requires a dense beta block accumulator".to_string(),
1527 });
1528 }
1529 self.s_acc.assign(&self.hbb);
1530 for j in 0..self.k {
1531 self.s_acc[[j, j]] += ridge_beta;
1532 self.rhs_acc[j] = 0.0;
1533 }
1534 Ok(())
1535 }
1536
1537 /// Accumulate rows `[start, end)` into the reduced RHS and Schur block.
1538 pub fn accumulate_chunk(
1539 &mut self,
1540 start: usize,
1541 end: usize,
1542 ridge_t: f64,
1543 mode: ArrowSolverMode,
1544 ) -> Result<(), ArrowSchurError> {
1545 if start > end || end > self.n_rows {
1546 return Err(ArrowSchurError::SchurFactorFailed {
1547 reason: format!(
1548 "streaming Arrow-Schur chunk [{start}, {end}) outside 0..{}",
1549 self.n_rows
1550 ),
1551 });
1552 }
1553 let backend = CpuBatchedBlockSolver;
1554 let k = self.k;
1555 // Per-row factor + two block solves + a `k×k` GEMM subtract is the whole
1556 // assembly cost at the SAE LLM shape (#1017); the rows are independent so
1557 // the chunk fans across cores. Stay sequential for the handful-of-rows
1558 // non-SAE callers, or when already inside a rayon worker (the topology
1559 // race fans candidates with `run_topology_race_parallel`) to avoid
1560 // nested-rayon oversubscription — the same gate `schur_matvec` uses.
1561 let parallel = (end - start) >= SCHUR_MATVEC_PARALLEL_ROW_MIN
1562 && rayon::current_thread_index().is_none();
1563 if parallel {
1564 use rayon::prelude::*;
1565 const CHUNK: usize = 64;
1566 // Bind `&self` so the per-row body borrows only the immutable
1567 // streaming state. Each row contributes `+H_βt^(i)(H_tt^(i))⁻¹ g_t^(i)`
1568 // (length `k`) to the reduced RHS and `−H_βt^(i)(H_tt^(i))⁻¹ H_tβ^(i)`
1569 // (`k×k`) to the reduced Schur complement; both are written INTO a
1570 // worker-private `(rhs_part, s_part)` pair so the chunk partials fold
1571 // back in chunk order — bit-identical run-to-run regardless of thread
1572 // scheduling (the #1017 verification gate). The chunk fold reassociates
1573 // the row sum relative to serial, so the criterion ranking is stable
1574 // only up to that f64 margin; a near-tie winner inside the margin can
1575 // flip — not an exact no-move guarantee (#1211).
1576 let this: &Self = self;
1577 let row_into = |row_idx: usize,
1578 rhs_part: &mut Array1<f64>,
1579 s_part: &mut Array2<f64>|
1580 -> Result<(), ArrowSchurError> {
1581 let row = (this.row_builder)(row_idx)?;
1582 let di = row.htt.nrows();
1583 this.validate_row(row_idx, &row)?;
1584 let htbeta = this.row_htbeta(row_idx, &row, di);
1585 let factor = this.factor_row(&row, ridge_t, di, row_idx)?;
1586 let v = backend.solve_block_vector(factor.view(), row.gt.view());
1587 for c in 0..di {
1588 let vc = v[c];
1589 if vc == 0.0 {
1590 continue;
1591 }
1592 for a in 0..k {
1593 rhs_part[a] += htbeta[[c, a]] * vc;
1594 }
1595 }
1596 match mode {
1597 // InexactPCG differs from Direct only in how the *reduced*
1598 // system is solved, not in how it is assembled, so it shares
1599 // the dense Schur subtraction here (see the serial branch).
1600 ArrowSolverMode::Direct | ArrowSolverMode::InexactPCG => {
1601 let solved = backend.solve_block_matrix(factor.view(), htbeta.view());
1602 backend.block_gemm_subtract(s_part, &htbeta, &solved);
1603 }
1604 ArrowSolverMode::SqrtBA => {
1605 let whitened =
1606 backend.sqrt_solve_block_matrix(factor.view(), htbeta.view());
1607 backend.block_gemm_subtract(s_part, &whitened, &whitened);
1608 }
1609 }
1610 Ok(())
1611 };
1612 let partials: Vec<(Array1<f64>, Array2<f64>)> = (start..end)
1613 .into_par_iter()
1614 .chunks(CHUNK)
1615 .map(|idxs| {
1616 let mut rhs_part = Array1::<f64>::zeros(k);
1617 let mut s_part = Array2::<f64>::zeros((k, k));
1618 for i in idxs {
1619 row_into(i, &mut rhs_part, &mut s_part)?;
1620 }
1621 Ok::<_, ArrowSchurError>((rhs_part, s_part))
1622 })
1623 .collect::<Result<Vec<_>, _>>()?;
1624 // Deterministic ordered reduction: fold chunk partials left-to-right.
1625 // `block_gemm_subtract` already subtracted into each `s_part`, so the
1626 // partials carry the negative Schur contribution; add them in.
1627 for (rhs_part, s_part) in &partials {
1628 for a in 0..k {
1629 self.rhs_acc[a] += rhs_part[a];
1630 }
1631 self.s_acc += s_part;
1632 }
1633 } else {
1634 // Serial path accumulates DIRECTLY into the running `self.{rhs,s}_acc`
1635 // (which carry the `reset_accumulator` seed `H_ββ + ridge·I`), exactly
1636 // as before — bit-for-bit unchanged for the handful-of-rows callers.
1637 for row_idx in start..end {
1638 let row = (self.row_builder)(row_idx)?;
1639 let di = row.htt.nrows();
1640 self.validate_row(row_idx, &row)?;
1641 let htbeta = self.row_htbeta(row_idx, &row, di);
1642 let factor = self.factor_row(&row, ridge_t, di, row_idx)?;
1643 let v = backend.solve_block_vector(factor.view(), row.gt.view());
1644 for c in 0..di {
1645 let vc = v[c];
1646 if vc == 0.0 {
1647 continue;
1648 }
1649 for a in 0..k {
1650 self.rhs_acc[a] += htbeta[[c, a]] * vc;
1651 }
1652 }
1653 match mode {
1654 ArrowSolverMode::Direct | ArrowSolverMode::InexactPCG => {
1655 let solved = backend.solve_block_matrix(factor.view(), htbeta.view());
1656 backend.block_gemm_subtract(&mut self.s_acc, &htbeta, &solved);
1657 }
1658 ArrowSolverMode::SqrtBA => {
1659 let whitened =
1660 backend.sqrt_solve_block_matrix(factor.view(), htbeta.view());
1661 backend.block_gemm_subtract(&mut self.s_acc, &whitened, &whitened);
1662 }
1663 }
1664 }
1665 }
1666 Ok(())
1667 }
1668
1669 /// Compute the exact arrow Hessian log-determinant by accumulating the
1670 /// reduced Schur complement in row chunks, without retaining the full set
1671 /// of per-row Cholesky factors.
1672 ///
1673 /// This is the streaming analogue of [`ArrowFactorCache::arrow_log_det`]:
1674 ///
1675 /// ```text
1676 /// log|H| = Σ_i log|H_tt^(i)| + log|H_ββ - Σ_i H_βt^(i) H_tt^(i)⁻¹ H_tβ^(i)|.
1677 /// ```
1678 ///
1679 /// The same row builder and procedural `H_tβ` callbacks used by the
1680 /// streaming Newton solve are consumed here, so callers can score REML
1681 /// evidence without materialising the full `(N × q × K)` cross block or
1682 /// the full list of row factors.
1683 pub fn reduced_schur_and_log_det_tt(
1684 &mut self,
1685 ridge_t: f64,
1686 ridge_beta: f64,
1687 options: &ArrowSolveOptions,
1688 ) -> Result<(f64, Array2<f64>), ArrowSchurError> {
1689 if self.ibp_cross_row_active {
1690 return Err(ArrowSchurError::SchurFactorFailed {
1691 reason: "streaming arrow log-det cannot carry the exact cross-row IBP \
1692 Woodbury correction (#1038): U's columns span all rows, so the \
1693 rank-R capacitance needs the per-row factors retained — the very \
1694 (N·K) residency the streaming path avoids. Route IBP-active fits \
1695 through the dense resident ArrowFactorCache::arrow_log_det instead."
1696 .to_string(),
1697 });
1698 }
1699 self.tolerate_ill_conditioning = options.tolerate_ill_conditioning;
1700 self.reset_accumulator(ridge_beta)?;
1701 let backend = CpuBatchedBlockSolver;
1702 let mut log_det_tt = 0.0_f64;
1703 for start in (0..self.n_rows).step_by(self.chunk_size) {
1704 let end = (start + self.chunk_size).min(self.n_rows);
1705 for row_idx in start..end {
1706 let row = (self.row_builder)(row_idx)?;
1707 let di = row.htt.nrows();
1708 self.validate_row(row_idx, &row)?;
1709 let htbeta = self.row_htbeta(row_idx, &row, di);
1710 let factor = self.factor_row(&row, ridge_t, di, row_idx)?;
1711 for axis in 0..di {
1712 log_det_tt += 2.0 * factor[[axis, axis]].ln();
1713 }
1714 match options.mode {
1715 ArrowSolverMode::Direct | ArrowSolverMode::InexactPCG => {
1716 let solved = backend.solve_block_matrix(factor.view(), htbeta.view());
1717 backend.block_gemm_subtract(&mut self.s_acc, &htbeta, &solved);
1718 }
1719 ArrowSolverMode::SqrtBA => {
1720 let whitened =
1721 backend.sqrt_solve_block_matrix(factor.view(), htbeta.view());
1722 backend.block_gemm_subtract(&mut self.s_acc, &whitened, &whitened);
1723 }
1724 }
1725 }
1726 }
1727 symmetrize_upper_from_lower(&mut self.s_acc);
1728 let schur = std::mem::replace(&mut self.s_acc, Array2::<f64>::zeros((self.k, self.k)));
1729 Ok((log_det_tt, schur))
1730 }
1731
1732 /// As [`Self::reduced_schur_and_log_det_tt`], but ALSO accumulates the exact
1733 /// cross-row IBP Woodbury building blocks (#1038) when this streaming system
1734 /// carried an [`IbpCrossRowSource`].
1735 ///
1736 /// Mirrors the dense `factor_blocks_for_system` + [`CrossRowWoodbury::build`]
1737 /// path exactly:
1738 ///
1739 /// * the per-row logit self term `Σ_k d_k·z'_ik²`
1740 /// ([`IbpCrossRowSource::self_term_downdate`]) is subtracted from each
1741 /// `H_tt^(i)` BEFORE factoring, so the factored base — and therefore the
1742 /// returned `log_det_tt`/Schur — is the NO-SELF `H₀'` (the dense path
1743 /// factors `H₀'` too, then layers the rank-`R` Woodbury);
1744 /// * for each row it forms `Aᵢ⁻¹ Uᵢ` (one block solve against the row's
1745 /// `H₀'` factor, `Uᵢ` the J-weighted atom-column indicator) and adds the
1746 /// row's contributions to `M0 = Σ Uᵢᵀ Aᵢ⁻¹ Uᵢ` (`R×R`) and
1747 /// `W = Σ Bᵢᵀ Aᵢ⁻¹ Uᵢ` (`k×R`, `Bᵢ = H_tβ^(i)`).
1748 ///
1749 /// The caller sums `M0`/`W`/`log_det_tt`/Schur over chunks and closes the
1750 /// capacitance `M = M0 + Wᵀ S⁻¹ W`, `log det(I_R + D·M)` against the GLOBAL
1751 /// reduced Schur `S` via [`streaming_cross_row_woodbury_log_det`]. This is
1752 /// the exact `log det H_full = log det H₀' + log det(I_R + D Uᵀ H₀'⁻¹ U)`
1753 /// the dense [`ArrowFactorCache::arrow_log_det`] returns — the streaming and
1754 /// dense evidence then optimize the SAME REML objective (#1225).
1755 ///
1756 /// When no IBP source is present this delegates to
1757 /// [`Self::reduced_schur_and_log_det_tt`] and returns `None` woodbury, so
1758 /// every non-IBP (softmax / JumpReLU) caller is bit-for-bit unchanged.
1759 pub fn reduced_schur_log_det_tt_woodbury(
1760 &mut self,
1761 ridge_t: f64,
1762 ridge_beta: f64,
1763 options: &ArrowSolveOptions,
1764 ) -> Result<(f64, Array2<f64>, Option<StreamingWoodburyChunk>), ArrowSchurError> {
1765 let Some(source) = self.ibp_cross_row.clone() else {
1766 // Temporarily clear the refusal flag so the shared bare path runs;
1767 // it is `false` whenever the source is absent, so this is a no-op.
1768 let (log_det_tt, schur) =
1769 self.reduced_schur_and_log_det_tt(ridge_t, ridge_beta, options)?;
1770 return Ok((log_det_tt, schur, None));
1771 };
1772 let r = source.r;
1773 let total_len = self.row_offsets[self.n_rows];
1774 let down = source.self_term_downdate(total_len);
1775 // Group the sparse `U` entries `(global_t_index, atom, z'_ik)` by row as
1776 // `(local_slot, atom, z)`. `row_offsets` is strictly ascending, so the
1777 // owning row of a global index is located by binary search.
1778 let mut row_entries: Vec<Vec<(usize, usize, f64)>> = vec![Vec::new(); self.n_rows];
1779 for &(g, atom, z) in &source.entries {
1780 let i = match self.row_offsets.binary_search(&g) {
1781 Ok(idx) => idx,
1782 Err(idx) => idx - 1,
1783 };
1784 let slot = g - self.row_offsets[i];
1785 row_entries[i].push((slot, atom, z));
1786 }
1787 self.tolerate_ill_conditioning = options.tolerate_ill_conditioning;
1788 self.reset_accumulator(ridge_beta)?;
1789 let backend = CpuBatchedBlockSolver;
1790 let mut log_det_tt = 0.0_f64;
1791 let mut m0 = Array2::<f64>::zeros((r, r));
1792 let mut w = Array2::<f64>::zeros((self.k, r));
1793 for start in (0..self.n_rows).step_by(self.chunk_size) {
1794 let end = (start + self.chunk_size).min(self.n_rows);
1795 for row_idx in start..end {
1796 let mut row = (self.row_builder)(row_idx)?;
1797 let di = row.htt.nrows();
1798 self.validate_row(row_idx, &row)?;
1799 // Downdate the per-row logit self term so the factored base is
1800 // `H₀'` (the dense path downdates the SAME `down[base + j]`).
1801 let base = self.row_offsets[row_idx];
1802 for j in 0..di {
1803 row.htt[[j, j]] -= down[base + j];
1804 }
1805 let htbeta = self.row_htbeta(row_idx, &row, di);
1806 let factor = self.factor_row(&row, ridge_t, di, row_idx)?;
1807 for axis in 0..di {
1808 log_det_tt += 2.0 * factor[[axis, axis]].ln();
1809 }
1810 match options.mode {
1811 ArrowSolverMode::Direct | ArrowSolverMode::InexactPCG => {
1812 let solved = backend.solve_block_matrix(factor.view(), htbeta.view());
1813 backend.block_gemm_subtract(&mut self.s_acc, &htbeta, &solved);
1814 }
1815 ArrowSolverMode::SqrtBA => {
1816 let whitened =
1817 backend.sqrt_solve_block_matrix(factor.view(), htbeta.view());
1818 backend.block_gemm_subtract(&mut self.s_acc, &whitened, &whitened);
1819 }
1820 }
1821 let entries = &row_entries[row_idx];
1822 if !entries.is_empty() {
1823 // `Uᵢ`: `di × R`, column `atom` carries `z'_ik` at its logit
1824 // slot. `Aᵢ⁻¹ Uᵢ` via the row's `H₀'` Cholesky.
1825 let mut u_local = Array2::<f64>::zeros((di, r));
1826 for &(slot, atom, z) in entries {
1827 u_local[[slot, atom]] += z;
1828 }
1829 let ainv_u = backend.solve_block_matrix(factor.view(), u_local.view());
1830 // `M0 += Uᵢᵀ Aᵢ⁻¹ Uᵢ` (`R×R`); `W += Bᵢᵀ Aᵢ⁻¹ Uᵢ` (`k×R`).
1831 m0 += &u_local.t().dot(&ainv_u);
1832 w += &htbeta.t().dot(&ainv_u);
1833 }
1834 }
1835 }
1836 symmetrize_upper_from_lower(&mut self.s_acc);
1837 let schur = std::mem::replace(&mut self.s_acc, Array2::<f64>::zeros((self.k, self.k)));
1838 Ok((
1839 log_det_tt,
1840 schur,
1841 Some(StreamingWoodburyChunk {
1842 m0,
1843 w,
1844 d: source.d.clone(),
1845 }),
1846 ))
1847 }
1848
1849 pub fn reduced_schur_log_det(
1850 schur: &Array2<f64>,
1851 options: &ArrowSolveOptions,
1852 ) -> Result<f64, ArrowSchurError> {
1853 let rhs = Array1::<f64>::zeros(schur.nrows());
1854 let trust_metric_weights = None;
1855 let (delta, schur_factor, diag) =
1856 solve_dense_reduced_system(schur, &rhs, options, trust_metric_weights)?;
1857 if delta.len() != schur.nrows() || diag.iterations != 0 {
1858 return Err(ArrowSchurError::SchurFactorFailed {
1859 reason: "streaming log-det reduced solve returned incoherent diagnostics"
1860 .to_string(),
1861 });
1862 }
1863 let schur_factor = schur_factor.ok_or_else(|| ArrowSchurError::SchurFactorFailed {
1864 reason: "streaming log-det requires a dense reduced Schur factor".to_string(),
1865 })?;
1866 let mut log_det_schur = 0.0_f64;
1867 for axis in 0..schur_factor.nrows() {
1868 log_det_schur += 2.0 * schur_factor[[axis, axis]].ln();
1869 }
1870 Ok(log_det_schur)
1871 }
1872
1873 pub fn exact_arrow_log_det(
1874 &mut self,
1875 ridge_t: f64,
1876 ridge_beta: f64,
1877 options: &ArrowSolveOptions,
1878 ) -> Result<f64, ArrowSchurError> {
1879 let (log_det_tt, schur) =
1880 self.reduced_schur_and_log_det_tt(ridge_t, ridge_beta, options)?;
1881 Ok(log_det_tt + Self::reduced_schur_log_det(&schur, options)?)
1882 }
1883
1884 pub fn solve(
1885 &mut self,
1886 ridge_t: f64,
1887 ridge_beta: f64,
1888 options: &ArrowSolveOptions,
1889 ) -> Result<(Array1<f64>, Array1<f64>, Option<Array2<f64>>), ArrowSchurError> {
1890 if self.ibp_cross_row_active {
1891 return Err(ArrowSchurError::SchurFactorFailed {
1892 reason: "streaming arrow solve cannot carry the exact cross-row IBP \
1893 Woodbury correction (#1038); route IBP-active fits through the \
1894 dense resident solve_arrow_newton_step_with_options instead."
1895 .to_string(),
1896 });
1897 }
1898 // Propagate the evidence/log-det ill-conditioning tolerance to the
1899 // per-row factor calls inside `accumulate_chunk` / `back_substitute`,
1900 // which take their stable public signatures. Direct callers of
1901 // `accumulate_chunk` keep the conservative default (`false`, full guard).
1902 self.tolerate_ill_conditioning = options.tolerate_ill_conditioning;
1903 self.reset_accumulator(ridge_beta)?;
1904 for start in (0..self.n_rows).step_by(self.chunk_size) {
1905 let end = (start + self.chunk_size).min(self.n_rows);
1906 self.accumulate_chunk(start, end, ridge_t, options.mode)?;
1907 }
1908 for j in 0..self.k {
1909 self.rhs_acc[j] -= self.gb[j];
1910 }
1911 symmetrize_upper_from_lower(&mut self.s_acc);
1912 let trust_metric_weights = None;
1913 let (delta_beta, schur_factor, _diag) =
1914 solve_dense_reduced_system(&self.s_acc, &self.rhs_acc, options, trust_metric_weights)?;
1915 let delta_t = self.back_substitute(ridge_t, delta_beta.view())?;
1916 Ok((delta_t, delta_beta, schur_factor))
1917 }
1918
1919 pub(crate) fn back_substitute(
1920 &self,
1921 ridge_t: f64,
1922 delta_beta: ArrayView1<'_, f64>,
1923 ) -> Result<Array1<f64>, ArrowSchurError> {
1924 let backend = CpuBatchedBlockSolver;
1925 // Total delta_t length = row_offsets[n_rows].
1926 let total_len = self.row_offsets[self.n_rows];
1927 let mut delta_t = Array1::<f64>::zeros(total_len);
1928 // Each row's back-solve `Δt_i = -(H_tt^(i))⁻¹(g_t^(i) + H_tβ^(i)Δβ)`
1929 // writes a DISJOINT segment `delta_t[row_base .. row_base+di]` — no
1930 // cross-row reduction, so this is embarrassingly parallel and the scatter
1931 // is bit-identical regardless of which thread produced each segment (the
1932 // #1017 verification gate). At the SAE LLM shape (`n` in the thousands)
1933 // the per-row factor + solve is the whole cost; below the threshold, or
1934 // when already inside a rayon worker (the topology race fans candidates
1935 // with `run_topology_race_parallel`), stay sequential to avoid
1936 // nested-rayon oversubscription — the same guard `schur_matvec` uses.
1937 let parallel =
1938 self.n_rows >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
1939 if parallel {
1940 use rayon::prelude::*;
1941 const CHUNK: usize = 64;
1942 // Per-row body: factor, form the RHS, solve, return `-(dt_i)`.
1943 let row_solve = |row_idx: usize| -> Result<(usize, Array1<f64>), ArrowSchurError> {
1944 let row = (self.row_builder)(row_idx)?;
1945 let di = row.htt.nrows();
1946 self.validate_row(row_idx, &row)?;
1947 let factor = self.factor_row(&row, ridge_t, di, row_idx)?;
1948 let mut htbeta_delta = Array1::<f64>::zeros(di);
1949 if let Some(op) = self.htbeta_matvec.as_ref() {
1950 op(row_idx, delta_beta, &mut htbeta_delta);
1951 } else {
1952 for c in 0..di {
1953 let mut acc = 0.0_f64;
1954 for a in 0..self.k {
1955 acc += row.htbeta[[c, a]] * delta_beta[a];
1956 }
1957 htbeta_delta[c] = acc;
1958 }
1959 }
1960 let mut rhs = Array1::<f64>::zeros(di);
1961 for c in 0..di {
1962 rhs[c] = row.gt[c] + htbeta_delta[c];
1963 }
1964 let dt_i = backend.solve_block_vector(factor.view(), rhs.view());
1965 let mut neg = Array1::<f64>::zeros(di);
1966 for c in 0..di {
1967 neg[c] = -dt_i[c];
1968 }
1969 Ok((self.row_offsets[row_idx], neg))
1970 };
1971 // Collect per-row segments under rayon, then scatter into the disjoint
1972 // slices. Errors are surfaced via `collect::<Result<…>>`.
1973 let segments: Vec<(usize, Array1<f64>)> = (0..self.n_rows)
1974 .into_par_iter()
1975 .chunks(CHUNK)
1976 .map(|idxs| {
1977 idxs.into_iter()
1978 .map(&row_solve)
1979 .collect::<Result<Vec<_>, _>>()
1980 })
1981 .collect::<Result<Vec<_>, _>>()?
1982 .into_iter()
1983 .flatten()
1984 .collect();
1985 for (base, seg) in &segments {
1986 for (c, &v) in seg.iter().enumerate() {
1987 delta_t[base + c] = v;
1988 }
1989 }
1990 } else {
1991 let mut rhs = Array1::<f64>::zeros(self.d);
1992 for start in (0..self.n_rows).step_by(self.chunk_size) {
1993 let end = (start + self.chunk_size).min(self.n_rows);
1994 for row_idx in start..end {
1995 let row = (self.row_builder)(row_idx)?;
1996 let di = row.htt.nrows();
1997 self.validate_row(row_idx, &row)?;
1998 let factor = self.factor_row(&row, ridge_t, di, row_idx)?;
1999 // `H_tβ^(i) Δβ`: route through the procedural operator when
2000 // present (no dense slab), else through the dense slab.
2001 let mut htbeta_delta = Array1::<f64>::zeros(di);
2002 if let Some(op) = self.htbeta_matvec.as_ref() {
2003 op(row_idx, delta_beta, &mut htbeta_delta);
2004 } else {
2005 for c in 0..di {
2006 let mut acc = 0.0_f64;
2007 for a in 0..self.k {
2008 acc += row.htbeta[[c, a]] * delta_beta[a];
2009 }
2010 htbeta_delta[c] = acc;
2011 }
2012 }
2013 for c in 0..di {
2014 rhs[c] = row.gt[c] + htbeta_delta[c];
2015 }
2016 let dt_i = backend.solve_block_vector(factor.view(), rhs.view());
2017 let row_base = self.row_offsets[row_idx];
2018 for c in 0..di {
2019 delta_t[row_base + c] = -dt_i[c];
2020 }
2021 }
2022 }
2023 }
2024 Ok(delta_t)
2025 }
2026
2027 pub(crate) fn validate_row(
2028 &self,
2029 row_idx: usize,
2030 row: &ArrowRowBlock,
2031 ) -> Result<(), ArrowSchurError> {
2032 let expected_di = if row_idx < self.row_dims.len() {
2033 self.row_dims[row_idx]
2034 } else {
2035 self.d
2036 };
2037 let actual_di = row.htt.nrows();
2038 if actual_di != expected_di || row.htt.ncols() != expected_di {
2039 return Err(ArrowSchurError::PerRowFactorFailed {
2040 row: row_idx,
2041 reason: format!(
2042 "streaming row H_tt shape {:?} != ({expected_di}, {expected_di})",
2043 row.htt.dim(),
2044 ),
2045 });
2046 }
2047 // The dense `H_tβ` slab is only validated when no procedural operator is
2048 // installed; with `htbeta_matvec` the slab is intentionally zero-sized
2049 // and the cross-block is probed in `row_htbeta`.
2050 if self.htbeta_matvec.is_none() && row.htbeta.dim() != (expected_di, self.k) {
2051 return Err(ArrowSchurError::SchurFactorFailed {
2052 reason: format!(
2053 "streaming row H_tβ shape {:?} != ({expected_di}, {})",
2054 row.htbeta.dim(),
2055 self.k
2056 ),
2057 });
2058 }
2059 if row.gt.len() != expected_di {
2060 return Err(ArrowSchurError::PerRowFactorFailed {
2061 row: row_idx,
2062 reason: format!("streaming row g_t length {} != {expected_di}", row.gt.len()),
2063 });
2064 }
2065 Ok::<(), _>(())
2066 }
2067}
2068
2069pub(crate) fn apply_analytic_penalty<S, G, D, P, H>(
2070 penalty: &AnalyticPenaltyKind,
2071 target: ArrayView1<'_, f64>,
2072 rho_local: ArrayView1<'_, f64>,
2073 expected_target_len: usize,
2074 hvp_columns: usize,
2075 scatter_target: &mut S,
2076 mut grad_scatter: G,
2077 mut diag_scatter: D,
2078 seed_hvp_probe: P,
2079 mut hvp_column_scatter: H,
2080) where
2081 G: FnMut(&mut S, usize, f64),
2082 D: FnMut(&mut S, usize, f64),
2083 P: Fn(usize, &mut Array1<f64>),
2084 H: for<'a> FnMut(&mut S, usize, ArrayView1<'a, f64>),
2085{
2086 assert_eq!(target.len(), expected_target_len);
2087
2088 let grad = penalty.grad_target(target, rho_local);
2089 for index in 0..expected_target_len {
2090 grad_scatter(scatter_target, index, grad[index]);
2091 }
2092
2093 // The scattered curvature lands in the arrow-Schur `H_tt` / `H_ββ` blocks,
2094 // which are Cholesky-factored (with LM ridge escalation) as the Newton /
2095 // PIRLS curvature operator and must therefore stay PSD. Nonconvex
2096 // sparsifiers (log sparsity, JumpReLU) have an *indefinite* exact Hessian
2097 // that would destroy that positive-definiteness, so we scatter the PSD
2098 // majorizer here — never the exact `hessian_diag` / `hvp`. For convex
2099 // penalties the majorizer equals the exact Hessian (the trait default
2100 // delegates), so this is exact for them. Exact-derivative consumers (the
2101 // outer objective Hessian) use `hessian_diag` / `hvp` directly elsewhere.
2102 if let Some(diag) = penalty.psd_majorizer_diag(target, rho_local) {
2103 assert_eq!(diag.len(), expected_target_len);
2104 for index in 0..expected_target_len {
2105 diag_scatter(scatter_target, index, diag[index]);
2106 }
2107 return;
2108 }
2109
2110 let mut probe = Array1::<f64>::zeros(expected_target_len);
2111 for column in 0..hvp_columns {
2112 probe.fill(0.0);
2113 seed_hvp_probe(column, &mut probe);
2114 let hv = penalty.psd_majorizer_hvp(target, rho_local, probe.view());
2115 hvp_column_scatter(scatter_target, column, hv.view());
2116 }
2117}
2118
2119pub(crate) fn analytic_penalty_is_row_block_diagonal(penalty: &AnalyticPenaltyKind) -> bool {
2120 penalty.is_row_block_diagonal()
2121}
2122
2123/// Per-row + Schur Cholesky factor cache produced by
2124/// [`solve_arrow_newton_step_with_options`]. Consumed downstream by the IFT warm-start
2125/// predictor in `crate::persistent_warm_start`: when the outer
2126/// loop perturbs `(β, ρ)` by a small amount, the new Newton step can be
2127/// predicted by re-using these factors against a refreshed RHS, saving
2128/// the dominant `O(N d³ + K³)` factorization cost.
2129#[derive(Clone)]
2130pub struct ArrowFactorSlab {
2131 pub(crate) data: Arc<[f64]>,
2132 pub(crate) offsets: Arc<[usize]>,
2133 pub(crate) dims: Arc<[usize]>,
2134}
2135
2136impl ArrowFactorSlab {
2137 pub fn from_blocks(blocks: Vec<Array2<f64>>) -> Self {
2138 let mut data = Vec::new();
2139 let mut offsets = Vec::with_capacity(blocks.len() + 1);
2140 let mut dims = Vec::with_capacity(blocks.len());
2141 offsets.push(0);
2142 for block in blocks {
2143 let (rows, cols) = block.dim();
2144 assert_eq!(rows, cols, "ArrowFactorSlab stores square row factors");
2145 dims.push(rows);
2146 data.extend(block.iter().copied());
2147 offsets.push(data.len());
2148 }
2149 Self {
2150 data: data.into(),
2151 offsets: offsets.into(),
2152 dims: dims.into(),
2153 }
2154 }
2155
2156 pub fn len(&self) -> usize {
2157 self.dims.len()
2158 }
2159
2160 pub fn is_empty(&self) -> bool {
2161 self.dims.is_empty()
2162 }
2163
2164 pub fn factor(&self, row: usize) -> ArrayView2<'_, f64> {
2165 let dim = self.dims[row];
2166 let range = self.offsets[row]..self.offsets[row + 1];
2167 ArrayView2::from_shape((dim, dim), &self.data[range])
2168 .expect("ArrowFactorSlab row offset/dim invariant violated")
2169 }
2170
2171 pub fn iter(&self) -> impl Iterator<Item = ArrayView2<'_, f64>> + '_ {
2172 (0..self.len()).map(|row| self.factor(row))
2173 }
2174}
2175
2176impl std::fmt::Debug for ArrowFactorSlab {
2177 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2178 f.debug_struct("ArrowFactorSlab")
2179 .field("rows", &self.len())
2180 .field("values", &self.data.len())
2181 .finish()
2182 }
2183}
2184
2185#[derive(Clone)]
2186pub enum ArrowUndampedFactors {
2187 SameAsDamped,
2188 Owned(ArrowFactorSlab),
2189}
2190
2191impl std::fmt::Debug for ArrowUndampedFactors {
2192 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2193 match self {
2194 Self::SameAsDamped => f.write_str("SameAsDamped"),
2195 Self::Owned(factors) => f.debug_tuple("Owned").field(&factors.len()).finish(),
2196 }
2197 }
2198}
2199
2200/// Apply `H_tβ^(row) · x` for one row, writing into `out` (length `d`).
2201///
2202/// Sums the installed matrix-free operator, when present, and any correctly
2203/// shaped dense `row.htbeta` slab. This lets structured data-fit rows coexist
2204/// with dense analytic-penalty cross blocks on the same row.
2205pub(crate) fn sys_htbeta_apply_row(
2206 sys: &ArrowSchurSystem,
2207 row_idx: usize,
2208 row: &ArrowRowBlock,
2209 x: ArrayView1<'_, f64>,
2210 out: &mut Array1<f64>,
2211) {
2212 out.fill(0.0);
2213 if let Some(op) = sys.htbeta_matvec.as_ref() {
2214 op(row_idx, x, out);
2215 }
2216 if (sys.htbeta_dense_supplement || sys.htbeta_matvec.is_none())
2217 && row.htbeta.dim() == (out.len(), sys.k)
2218 {
2219 let di = row.htbeta.nrows();
2220 for c in 0..di {
2221 let mut acc = 0.0_f64;
2222 for a in 0..sys.k {
2223 acc += row.htbeta[[c, a]] * x[a];
2224 }
2225 out[c] += acc;
2226 }
2227 }
2228}
2229
2230/// Accumulate `H_βt^(row) · v` into `out` (length `k`).
2231///
2232/// `out[a] += Σ_c H_tβ^(row)[c, a] · v[c]`
2233///
2234/// Sums the installed matrix-free operator, when present, and any correctly
2235/// shaped dense `row.htbeta` slab.
2236pub(crate) fn sys_htbeta_accumulate_transpose(
2237 sys: &ArrowSchurSystem,
2238 row_idx: usize,
2239 row: &ArrowRowBlock,
2240 v: ArrayView1<'_, f64>,
2241 out: &mut Array1<f64>,
2242) {
2243 if let Some(op) = sys.htbeta_matvec.as_ref() {
2244 htbeta_probe_transpose(row_idx, op, v, out, v.len(), sys.k);
2245 }
2246 if (sys.htbeta_dense_supplement || sys.htbeta_matvec.is_none())
2247 && row.htbeta.dim() == (v.len(), sys.k)
2248 {
2249 let di = row.htbeta.nrows();
2250 for c in 0..di {
2251 let vc = v[c];
2252 if vc == 0.0 {
2253 continue;
2254 }
2255 for a in 0..sys.k {
2256 out[a] += row.htbeta[[c, a]] * vc;
2257 }
2258 }
2259 }
2260}
2261
2262/// Materialize the dense `(di, k)` cross-block for one row.
2263///
2264/// Materializes the sum of the installed matrix-free operator and any correctly
2265/// shaped dense slab on the row.
2266pub(crate) fn sys_htbeta_materialize_row(
2267 sys: &ArrowSchurSystem,
2268 row_idx: usize,
2269 row: &ArrowRowBlock,
2270) -> Result<Array2<f64>, ArrowSchurError> {
2271 let di = sys.row_dims[row_idx];
2272 let k = sys.k;
2273 let use_dense = sys.htbeta_dense_supplement || sys.htbeta_matvec.is_none();
2274 let mut mat = if use_dense && row.htbeta.dim() == (di, k) {
2275 row.htbeta.clone()
2276 } else {
2277 Array2::<f64>::zeros((di, k))
2278 };
2279 if let Some(op) = sys.htbeta_matvec.as_ref() {
2280 let mut e_a = Array1::<f64>::zeros(k);
2281 let mut col = Array1::<f64>::zeros(di);
2282 for a in 0..k {
2283 e_a.fill(0.0);
2284 e_a[a] = 1.0;
2285 col.fill(0.0);
2286 op(row_idx, e_a.view(), &mut col);
2287 for c in 0..di {
2288 mat[[c, a]] += col[c];
2289 }
2290 }
2291 } else if use_dense && row.htbeta.dim() != (di, k) {
2292 return Err(ArrowSchurError::SchurFactorFailed {
2293 reason: format!(
2294 "row {row_idx}: htbeta shape {:?} != ({di}, {k}) and no htbeta_matvec installed",
2295 row.htbeta.dim()
2296 ),
2297 });
2298 }
2299 Ok(mat)
2300}
2301
2302/// Probe each column of `H_tβ^(row)` by applying the operator to `e_a` and
2303/// dotting the result with `v`. Accumulates into `out[a]` for all `a in 0..k`.
2304///
2305/// `out[a] += (H_tβ^(row) e_a) · v = H_βt^(row)[a, :] · v`
2306pub(crate) fn htbeta_probe_transpose(
2307 row: usize,
2308 op: &RowHtbetaMatvec,
2309 v: ArrayView1<'_, f64>,
2310 out: &mut Array1<f64>,
2311 d: usize,
2312 k: usize,
2313) {
2314 let mut e_a = Array1::<f64>::zeros(k);
2315 let mut col_a = Array1::<f64>::zeros(d);
2316 for a in 0..k {
2317 e_a.fill(0.0);
2318 e_a[a] = 1.0;
2319 col_a.fill(0.0);
2320 op(row, e_a.view(), &mut col_a);
2321 let mut acc = 0.0_f64;
2322 for c in 0..d {
2323 acc += col_a[c] * v[c];
2324 }
2325 out[a] += acc;
2326 }
2327}
2328
2329#[derive(Clone)]
2330pub enum ArrowHtbetaCache {
2331 Dense {
2332 blocks: Arc<[Array2<f64>]>,
2333 estimated_bytes: usize,
2334 },
2335 Matvec {
2336 op: RowHtbetaMatvec,
2337 estimated_bytes: usize,
2338 },
2339 Disabled {
2340 estimated_bytes: usize,
2341 },
2342}
2343
2344impl std::fmt::Debug for ArrowHtbetaCache {
2345 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2346 match self {
2347 Self::Dense {
2348 blocks,
2349 estimated_bytes,
2350 } => f
2351 .debug_struct("Dense")
2352 .field("blocks", &blocks.len())
2353 .field("estimated_bytes", estimated_bytes)
2354 .finish(),
2355 Self::Matvec {
2356 estimated_bytes, ..
2357 } => f
2358 .debug_struct("Matvec")
2359 .field("estimated_bytes", estimated_bytes)
2360 .finish(),
2361 Self::Disabled { estimated_bytes } => f
2362 .debug_struct("Disabled")
2363 .field("estimated_bytes", estimated_bytes)
2364 .finish(),
2365 }
2366 }
2367}
2368
2369impl ArrowHtbetaCache {
2370 pub(crate) fn is_available(&self) -> bool {
2371 !matches!(self, Self::Disabled { .. })
2372 }
2373
2374 pub(crate) fn apply_row(
2375 &self,
2376 row: usize,
2377 delta_beta: ArrayView1<'_, f64>,
2378 out: &mut Array1<f64>,
2379 ) -> bool {
2380 match self {
2381 Self::Dense { blocks, .. } => {
2382 let Some(block) = blocks.get(row) else {
2383 return false;
2384 };
2385 if block.ncols() != delta_beta.len() || block.nrows() != out.len() {
2386 return false;
2387 }
2388 for c in 0..block.nrows() {
2389 let mut acc = 0.0_f64;
2390 for a in 0..block.ncols() {
2391 acc += block[[c, a]] * delta_beta[a];
2392 }
2393 out[c] = acc;
2394 }
2395 true
2396 }
2397 Self::Matvec { op, .. } => {
2398 op(row, delta_beta, out);
2399 true
2400 }
2401 Self::Disabled { .. } => false,
2402 }
2403 }
2404
2405 /// Apply the transpose: `out[a] += H_βt^(row)[a, c] · v[c]` for all `a`.
2406 ///
2407 /// `v` has length `d`; `out` has length `k`. Accumulates (does NOT zero
2408 /// `out` first) so callers can sum contributions across rows into a shared
2409 /// accumulator. Returns `false` when the cache is `Disabled` and no
2410 /// `fallback_op` is provided.
2411 pub(crate) fn apply_row_transpose_accumulate(
2412 &self,
2413 row: usize,
2414 v: ArrayView1<'_, f64>,
2415 out: &mut Array1<f64>,
2416 d: usize,
2417 k: usize,
2418 fallback_op: Option<&RowHtbetaMatvec>,
2419 ) -> bool {
2420 match self {
2421 Self::Dense { blocks, .. } => {
2422 let Some(block) = blocks.get(row) else {
2423 return false;
2424 };
2425 if block.nrows() != v.len() || block.ncols() != out.len() {
2426 return false;
2427 }
2428 // H_βt^(i) · v: outer-loop c hoists v[c], inner-loop a is
2429 // contiguous in row-major (d, k) layout.
2430 for c in 0..block.nrows() {
2431 let vc = v[c];
2432 if vc == 0.0 {
2433 continue;
2434 }
2435 for a in 0..block.ncols() {
2436 out[a] += block[[c, a]] * vc;
2437 }
2438 }
2439 true
2440 }
2441 Self::Matvec { op, .. } => {
2442 // Probe column-by-column: H_tβ^(row) e_a is column a. dot(col_a, v)
2443 // is entry a of H_βt^(row) v.
2444 htbeta_probe_transpose(row, op, v, out, d, k);
2445 true
2446 }
2447 Self::Disabled { .. } => {
2448 // No cached block. Use the caller-supplied fallback op if present.
2449 if let Some(op) = fallback_op {
2450 htbeta_probe_transpose(row, op, v, out, d, k);
2451 true
2452 } else {
2453 false
2454 }
2455 }
2456 }
2457 }
2458}
2459
2460/// RAW per-row spectral data of a spectrally-deflated undamped evidence `H_tt`
2461/// block (see [`ArrowFactorCache::deflation_row_spectra`]).
2462///
2463/// `evecs` columns are the RAW symmetric eigenvectors `uₘ` of `H_tt`
2464/// (orthonormal; the deflated directions `vᵢ` are the subset whose eigenvalue
2465/// was pinned). `raw_evals[m]` is the RAW eigenvalue `λₘ` BEFORE the unit-pin /
2466/// floor-clamp. `cond_evals[m]` is the conditioned eigenvalue `λ̃ₘ` the factor
2467/// actually uses (`λ̃ = λ` for an unclamped kept direction, the positive `floor`
2468/// for a clamped kept direction, `1` for a deflated direction). Together they
2469/// give the Daleckii–Krein divided differences the outer-gradient deflation
2470/// correction needs.
2471#[derive(Debug, Clone)]
2472pub struct RowDeflationSpectrum {
2473 pub evecs: Array2<f64>,
2474 pub raw_evals: Array1<f64>,
2475 pub cond_evals: Array1<f64>,
2476}
2477
2478#[derive(Debug, Clone)]
2479pub struct ArrowFactorCache {
2480 /// Per-row lower-triangular Cholesky factors of `H_tt^(i) + ridge_t·I`.
2481 ///
2482 /// These are the *damped* factors used inside the Newton solve. The IFT
2483 /// predictor must NOT use them — see [`Self::htt_factors_undamped`].
2484 pub htt_factors: ArrowFactorSlab,
2485 /// Per-row lower-triangular Cholesky factors of the UNDAMPED
2486 /// `H_tt^(i)` (no `ridge_t` added).
2487 ///
2488 /// The IFT predictor formula
2489 /// `Δt_i = -(H_tt^(i))⁻¹ · (H_tβ^(i) Δβ + δg_t^(i))` is derived from
2490 /// `∂g_t/∂t = H_tt` at the stationary point, with no LM damping term.
2491 /// Reusing the damped factors would bias the predicted shift toward zero
2492 /// in proportion to `ridge_t`. We pay one extra `O(N d³)` Cholesky per
2493 /// Newton solve — the same complexity class as the Newton solve itself —
2494 /// to make the IFT exact.
2495 pub htt_factors_undamped: ArrowUndampedFactors,
2496 /// Lower-triangular Cholesky factor of the Schur complement when the
2497 /// selected BA mode formed/factored dense RCS. `None` for
2498 /// [`ArrowSolverMode::InexactPCG`], where Agarwal-style inexact LM avoids
2499 /// the dense `K × K` factor.
2500 pub schur_factor: Option<Array2<f64>>,
2501 /// Exact undamped joint-Hessian log-determinant produced by the dense
2502 /// factorization path. REML evidence consumes this directly so the Laplace
2503 /// normalizer cannot miss the log-det even when later cache consumers only
2504 /// need solves/traces.
2505 pub joint_hessian_log_det: Option<f64>,
2506 /// BA mode used to create this cache.
2507 pub solver_mode: ArrowSolverMode,
2508 /// Ridge values used to build the cached factors (recorded so the
2509 /// warm-start predictor knows whether the cache is still valid for a
2510 /// requested ridge level).
2511 pub ridge_t: f64,
2512 pub ridge_beta: f64,
2513 /// Per-row cross-block access for `H_tβ^(i) x`.
2514 ///
2515 /// Large caches retain a row matvec callback or disable β-coupled IFT
2516 /// prediction instead of cloning every dense `d × K` slab.
2517 pub htbeta: ArrowHtbetaCache,
2518 /// Maximum per-row latent dim (upper bound; matches `sys.d` at creation).
2519 pub d: usize,
2520 /// Per-row latent dims: `row_dims[i]` is the active dim for row `i`.
2521 pub row_dims: Arc<[usize]>,
2522 /// Flat-buffer row offsets for `delta_t` / IFT output vectors.
2523 /// `row_offsets[i]` is the start of row `i`; `row_offsets[n]` is the
2524 /// total length.
2525 pub row_offsets: Arc<[usize]>,
2526 /// β dimensionality `K`.
2527 pub k: usize,
2528 /// Geometry tag for the row-local factors and cross-blocks.
2529 pub manifold_mode_fingerprint: u64,
2530 /// Row-system tag for the cached per-row factors, cross-blocks, and
2531 /// shared-block diagonal used to build the Schur factor.
2532 pub row_hessian_fingerprint: u64,
2533 /// PCG instrumentation from the solve that produced this cache.
2534 ///
2535 /// Zero-valued (default) when the selected mode did not use PCG
2536 /// (i.e. `Direct` or `SqrtBA`).
2537 pub pcg_diagnostics: PcgDiagnostics,
2538 /// Number of row-local gauge directions stiffened in an undamped evidence
2539 /// factorization.
2540 ///
2541 /// Each direction is stiffened at UNIT stiffness `kappa = 1.0`, so it
2542 /// contributes `log(1) = 0` to the row-block logdet through the returned
2543 /// Cholesky factor: the gauge orbit is a criterion null direction and adds
2544 /// nothing to the Laplace normalizer (the quotient pseudo-determinant
2545 /// convention, cf. `PenaltyPseudologdet`). Zero theta/rho dependence.
2546 pub gauge_deflated_directions: usize,
2547 /// Per-row unit-norm directions `vᵢ` (in each row's `d`-dim latent block
2548 /// coordinates) that an undamped evidence factorization stiffened to UNIT
2549 /// stiffness `λ̃ = 1` (gauge or spectral deflation). Indexed by row; empty
2550 /// for every PD row factored without deflation, and empty overall on the
2551 /// non-deflating solver paths (streaming / cross-row-penalty CG / device).
2552 ///
2553 /// A deflated direction contributes `log(1) = 0` to the row-block log-det
2554 /// and is ρ/θ-INDEPENDENT, so its true contribution to `∂log|H|/∂ρ` is `0`.
2555 /// The analytic outer-gradient traces (`assignment_log_strength_hessian_trace`,
2556 /// `learnable_ibp_data_logdet_alpha_trace`, `logdet_theta_adjoint`) contract
2557 /// `∂H_raw/∂ρ` (the RAW, pre-deflation block derivative) against the DEFLATED
2558 /// inverse, which assigns `1/λ̃ = 1` to each `vᵢ` and therefore spuriously
2559 /// adds `½ vᵢᵀ (∂H_raw/∂ρ) vᵢ`. Those traces subtract this per-row term
2560 /// (kept-subspace restriction) using these directions; without them the
2561 /// REML outer ρ-gradient is biased by `+Σ_deflated ½ vᵢᵀ ∂H_raw/∂ρ vᵢ`.
2562 pub deflated_row_directions: Arc<[Vec<Array1<f64>>]>,
2563 /// Per-row RAW spectral decomposition of an undamped evidence `H_tt` block
2564 /// that underwent SPECTRAL deflation, surfaced so the outer ρ/θ-gradient
2565 /// traces can apply the EXACT deflation-map (Daleckii–Krein) derivative
2566 /// correction, not just the within-row kept-subspace term.
2567 ///
2568 /// The criterion VALUE re-deflates `H_tt` at every ρ, so its gradient is
2569 /// `tr(H_deflated⁻¹ DΦ[∂H_raw/∂ρ])`, where `Φ` is the spectral pin-to-unit
2570 /// map. By Daleckii–Krein `DΦ[Ȧ] = U (F ∘ UᵀȦU) Uᵀ` with the divided-
2571 /// difference matrix `F_{ml} = (λ̃ₘ − λ̃ₗ)/(λₘ − λₗ)` (raw `λ` in the
2572 /// denominator, conditioned `λ̃` in the numerator). The kept×kept block of
2573 /// `F` is `1` (the kept subspace contracts the raw derivative unchanged), the
2574 /// deflated×deflated block is `0`, and the kept(m)×deflated(i) block is
2575 /// `(λₘ − 1)/(λₘ − λᵢ)` — this last, ROTATION, term is what the per-row
2576 /// kept-subspace correction alone misses; it couples to the β-block through
2577 /// the Schur back-substitution carried in `(H⁻¹)_tt`.
2578 ///
2579 /// `Some(spectrum)` only for spectrally-deflated rows; `None` for PD rows,
2580 /// gauge-only deflation (ρ-independent structural null — within-row term
2581 /// suffices), and every non-SAE-evidence solver path (streaming / device /
2582 /// cross-row CG). Empty overall when no row deflated spectrally.
2583 pub deflation_row_spectra: Arc<[Option<RowDeflationSpectrum>]>,
2584 /// Exact cross-row IBP rank-`R` Woodbury correction (#1038), present iff the
2585 /// source system carried an [`IbpCrossRowSource`]. When set, the per-row
2586 /// factors above are of the NO-SELF base `H₀'` (self term `d_k·z'_ik²`
2587 /// downdated from each logit diagonal), and this carrier supplies the exact
2588 /// rank-`R` correction so the value/curvature solve
2589 /// ([`Self::full_inverse_apply`]), the evidence log-determinant
2590 /// ([`Self::arrow_log_det`]), and the θ/ρ-adjoint all describe the same
2591 /// `H_full = H₀' + U D Uᵀ`.
2592 pub cross_row_woodbury: Option<CrossRowWoodbury>,
2593}
2594
2595/// Materialized exact cross-row IBP Woodbury correction (#1038), built against
2596/// an [`ArrowFactorCache`] whose per-row factors are the NO-SELF base `H₀'`.
2597///
2598/// Holds `U` (the `delta_t_len × R` arrow-`t` factor, β-part implicitly zero),
2599/// `D = diag(d_k)`, the projected `M = UᵀH₀'⁻¹U`, the columns `H₀'⁻¹U`, and the
2600/// **LU factorization of the (generally non-symmetric, possibly indefinite)
2601/// capacitance** `C = I_R + D·M`. `d_k = w·s'_k` is not sign-definite, so the
2602/// capacitance is factored by a partial-pivot LU (exact for any sign); the same
2603/// factorization serves the log-determinant `log det C`, the inverse correction
2604/// `H_full⁻¹w = H₀'⁻¹w − H₀'⁻¹U·C⁻¹·(D Uᵀ H₀'⁻¹w)`, and the adjoint's
2605/// selected-inverse (`C⁻¹` and `M`). The full inverse, value/curvature solve,
2606/// log-determinant, and adjoint therefore all describe the SAME
2607/// `H_full = H₀' + U D Uᵀ`.
2608#[derive(Debug, Clone)]
2609pub struct CrossRowWoodbury {
2610 /// `U`: `delta_t_len × R`, column `k` supported on atom-`k` logit slots.
2611 pub u: Array2<f64>,
2612 /// `d_k`, length `R`.
2613 pub d: Array1<f64>,
2614 /// `H₀'⁻¹ U` (the `t`-block), `delta_t_len × R`.
2615 pub h0inv_u: Array2<f64>,
2616 /// `(H₀'⁻¹ U)` β-block, `K × R`. `U` has no β support, but the bordered
2617 /// solve couples the latent columns to `β` through the Schur complement, so
2618 /// this block is generally nonzero and the inverse correction must apply it
2619 /// to the `β` output too.
2620 pub h0inv_u_beta: Array2<f64>,
2621 /// `M = Uᵀ H₀'⁻¹ U`, `R × R` (symmetric). Retained for the θ/ρ-adjoint.
2622 pub m: Array2<f64>,
2623 /// Partial-pivot LU of the capacitance `C = I_R + D·M` (`lu` packs `L`/`U`,
2624 /// `piv` the row swaps), built by [`small_lu_factor`].
2625 pub capacitance_lu: SmallLu,
2626 /// The sparse `U` entries `(global_t_index, atom_k, z'_ik)` — retained so
2627 /// `Uᵀ·v` can be formed over the atom slots without re-deriving them.
2628 pub entries: Vec<(usize, usize, f64)>,
2629}
2630
2631/// Dense partial-pivot LU of a small square matrix. Used for the cross-row IBP
2632/// capacitance `C = I_R + D·M`, which is generally non-symmetric and possibly
2633/// indefinite (`d_k = w·s'_k` is not sign-definite), so a Cholesky/LDLᵀ is
2634/// unavailable. `R` is the atom count, so this is a cheap dense factorization.
2635#[derive(Debug, Clone)]
2636pub struct SmallLu {
2637 /// Packed `L` (unit lower, below diagonal) and `U` (upper, on/above
2638 /// diagonal), `R × R`, in the row-permuted order encoded by `piv`.
2639 pub(crate) lu: Array2<f64>,
2640 /// Row permutation: `piv[i]` is the original row now in position `i`.
2641 pub(crate) piv: Vec<usize>,
2642 /// Sign of the permutation (`±1`), folded into the determinant.
2643 pub(crate) perm_sign: f64,
2644}
2645
2646/// Partial-pivot LU factorization of a small dense square matrix `a` (`R × R`).
2647/// Returns `None` only when a pivot is exactly zero (singular `C`).
2648pub(crate) fn small_lu_factor(a: &Array2<f64>) -> Option<SmallLu> {
2649 let r = a.nrows();
2650 assert_eq!(a.ncols(), r, "small_lu_factor: non-square input");
2651 let mut lu = a.clone();
2652 let mut piv: Vec<usize> = (0..r).collect();
2653 let mut perm_sign = 1.0_f64;
2654 for col in 0..r {
2655 // Partial pivot: pick the largest-magnitude entry on/below the diagonal.
2656 let mut pivot_row = col;
2657 let mut pivot_mag = lu[[col, col]].abs();
2658 for row in (col + 1)..r {
2659 let mag = lu[[row, col]].abs();
2660 if mag > pivot_mag {
2661 pivot_mag = mag;
2662 pivot_row = row;
2663 }
2664 }
2665 // Reject not just an exactly-zero pivot, but any non-finite or
2666 // subnormal magnitude: dividing by a subnormal in the elimination /
2667 // back-solve produces `Inf`/`NaN` that would otherwise flow silently
2668 // into the Woodbury inverse and the evidence log-det (#1038). A
2669 // capacitance this degenerate is a desync the caller must surface
2670 // (→ `Ok(None)` cross-row-absent / `SchurFactorFailed`), not consume.
2671 if !pivot_mag.is_finite() || pivot_mag < f64::MIN_POSITIVE {
2672 return None;
2673 }
2674 if pivot_row != col {
2675 for c in 0..r {
2676 lu.swap((col, c), (pivot_row, c));
2677 }
2678 piv.swap(col, pivot_row);
2679 perm_sign = -perm_sign;
2680 }
2681 let pivot = lu[[col, col]];
2682 for row in (col + 1)..r {
2683 let factor = lu[[row, col]] / pivot;
2684 lu[[row, col]] = factor;
2685 for c in (col + 1)..r {
2686 let v = lu[[col, c]];
2687 lu[[row, c]] -= factor * v;
2688 }
2689 }
2690 }
2691 // Post-elimination invariant: every U diagonal is finite and not subnormal.
2692 // The per-column pivot guard above validates each diagonal as it is chosen,
2693 // but assert it explicitly so `SmallLu::solve` can divide by `lu[[i, i]]`
2694 // without a per-entry guard and so a `SmallLu` value can never carry a
2695 // factor that would silently emit `Inf`/`NaN` into the capacitance solve.
2696 for i in 0..r {
2697 let u = lu[[i, i]];
2698 if !u.is_finite() || u.abs() < f64::MIN_POSITIVE {
2699 return None;
2700 }
2701 }
2702 Some(SmallLu { lu, piv, perm_sign })
2703}
2704
2705impl SmallLu {
2706 pub(crate) fn dim(&self) -> usize {
2707 self.lu.nrows()
2708 }
2709
2710 /// `log|det|` and the determinant sign (`±1`).
2711 pub(crate) fn log_abs_det_and_sign(&self) -> (f64, f64) {
2712 let mut log_abs = 0.0_f64;
2713 let mut sign = self.perm_sign;
2714 for i in 0..self.dim() {
2715 let u = self.lu[[i, i]];
2716 log_abs += u.abs().ln();
2717 if u < 0.0 {
2718 sign = -sign;
2719 }
2720 }
2721 (log_abs, sign)
2722 }
2723
2724 /// Solve `C x = b` reusing the factorization (in place into a fresh vector).
2725 ///
2726 /// Returns `None` when the solve cannot produce a finite result — either a
2727 /// `U` diagonal is non-finite/subnormal (defensive: `small_lu_factor`
2728 /// already rejects such factors, but a future construction path might not)
2729 /// or the back-substitution overflows to `Inf`/`NaN` for an extreme RHS on
2730 /// an ill-conditioned (yet validly factored) capacitance. Surfacing `None`
2731 /// lets the Woodbury / evidence consumers fail loudly (#1038) instead of
2732 /// flowing a silent `NaN` into the log-det and outer gradient.
2733 pub(crate) fn solve(&self, b: &Array1<f64>) -> Option<Array1<f64>> {
2734 let r = self.dim();
2735 // Apply the row permutation: y = P b.
2736 let mut y = Array1::<f64>::zeros(r);
2737 for i in 0..r {
2738 y[i] = b[self.piv[i]];
2739 }
2740 // Forward solve L y' = P b (L unit-lower).
2741 for i in 0..r {
2742 let mut sum = y[i];
2743 for j in 0..i {
2744 sum -= self.lu[[i, j]] * y[j];
2745 }
2746 y[i] = sum;
2747 }
2748 // Back solve U x = y' (U upper, explicit diagonal).
2749 let mut x = Array1::<f64>::zeros(r);
2750 for i in (0..r).rev() {
2751 let mut sum = y[i];
2752 for j in (i + 1)..r {
2753 sum -= self.lu[[i, j]] * x[j];
2754 }
2755 let pivot = self.lu[[i, i]];
2756 if !pivot.is_finite() || pivot.abs() < f64::MIN_POSITIVE {
2757 return None;
2758 }
2759 x[i] = sum / pivot;
2760 }
2761 if x.iter().all(|v| v.is_finite()) {
2762 Some(x)
2763 } else {
2764 None
2765 }
2766 }
2767}
2768
2769/// Close the streaming cross-row IBP Woodbury (#1038) into its exact evidence
2770/// log-determinant correction.
2771///
2772/// Given the chunk-summed `M0 = Uᵀ A⁻¹ U` (`R×R`), `W = Bᵀ A⁻¹ U` (`k×R`), the
2773/// GLOBAL reduced Schur `schur` (`S`, `k×k`, PD), and the per-atom `D`, this
2774/// forms the EXACT projected `M = Uᵀ H₀'⁻¹ U = M0 + Wᵀ S⁻¹ W` (the bordered
2775/// inverse `t`-block `(H₀'⁻¹)_tt = A⁻¹ + A⁻¹ B S⁻¹ Bᵀ A⁻¹` contracted by `U`),
2776/// then the capacitance `C = I_R + D·M` and returns `log det C`.
2777///
2778/// This is byte-for-byte the same quantity, and the same sign convention, that
2779/// the dense [`CrossRowWoodbury::log_det`] returns (symmetrize `M`, scale rows
2780/// by `D`, partial-pivot LU, reject a non-positive determinant). Returns
2781/// `Ok(None)` when the capacitance is non-PD / singular — the recoverable
2782/// "cross-row IBP joint Hessian is non-PD at this ρ" infeasible-probe refusal,
2783/// NOT a silent wrong value.
2784pub fn streaming_cross_row_woodbury_log_det(
2785 schur: &Array2<f64>,
2786 m0: &Array2<f64>,
2787 w: &Array2<f64>,
2788 d: &Array1<f64>,
2789) -> Result<Option<f64>, ArrowSchurError> {
2790 let r = d.len();
2791 let factor =
2792 cholesky_lower(schur).map_err(|reason| ArrowSchurError::SchurFactorFailed { reason })?;
2793 // M = M0 + Wᵀ S⁻¹ W. With S⁻¹ symmetric, `(Wᵀ S⁻¹ W)[a, b] = (S⁻¹ w_a)·w_b`.
2794 let mut m = m0.clone();
2795 for a in 0..r {
2796 let w_a = w.column(a).to_owned();
2797 let sinv_w_a = cholesky_solve_vector(&factor, &w_a);
2798 for b in 0..r {
2799 m[[a, b]] += sinv_w_a.dot(&w.column(b));
2800 }
2801 }
2802 // Symmetrize to clear back-substitution rounding asymmetry (matches
2803 // `CrossRowWoodbury::build`).
2804 for a in 0..r {
2805 for b in (a + 1)..r {
2806 let avg = 0.5 * (m[[a, b]] + m[[b, a]]);
2807 m[[a, b]] = avg;
2808 m[[b, a]] = avg;
2809 }
2810 }
2811 // Capacitance C = I_R + D·M (row k scaled by d_k), factored by partial-pivot
2812 // LU (`d_k = w·s'_k` is not sign-definite, so C is generally non-symmetric /
2813 // indefinite — exactly the dense carrier's factorization).
2814 let mut c = Array2::<f64>::zeros((r, r));
2815 for a in 0..r {
2816 for b in 0..r {
2817 c[[a, b]] = d[a] * m[[a, b]];
2818 }
2819 c[[a, a]] += 1.0;
2820 }
2821 match small_lu_factor(&c) {
2822 Some(lu) => {
2823 let (log_abs, sign) = lu.log_abs_det_and_sign();
2824 Ok((sign > 0.0).then_some(log_abs))
2825 }
2826 // Exactly-singular capacitance: `H_full` det is 0, the Laplace log-det is
2827 // undefined — surface as the recoverable non-PD probe refusal.
2828 None => Ok(None),
2829 }
2830}
2831
2832impl CrossRowWoodbury {
2833 /// Build the exact rank-`R` cross-row Woodbury carrier from the IBP source
2834 /// and a cache whose per-row factors are the NO-SELF base `H₀'`.
2835 ///
2836 /// Computes `H₀'⁻¹U` (one [`ArrowFactorCache::full_inverse_apply`] back-solve
2837 /// per column, β-RHS zero — the `t`-block of the result is `H₀'⁻¹U`'s
2838 /// column), `M = UᵀH₀'⁻¹U`, and the LU of `C = I_R + D·M`. Returns `None`
2839 /// when the capacitance is exactly singular (the only un-representable case;
2840 /// the caller then proceeds with the bare `H₀'` cache and the cross-row term
2841 /// is absent — never silently inconsistent, since logdet/inverse/adjoint all
2842 /// key off the presence of this carrier).
2843 pub(crate) fn build(
2844 cache: &ArrowFactorCache,
2845 source: &IbpCrossRowSource,
2846 ) -> Result<Option<Self>, ArrowSchurError> {
2847 let r = source.r;
2848 let total_len = cache.delta_t_len();
2849 let u = source.dense_u(total_len);
2850 let d = source.d.clone();
2851 let zero_beta = Array1::<f64>::zeros(cache.k);
2852 // h0inv_u[:, k] = (H₀'⁻¹ U)_t for column k; h0inv_u_beta[:, k] its β-block.
2853 let mut h0inv_u = Array2::<f64>::zeros((total_len, r));
2854 let mut h0inv_u_beta = Array2::<f64>::zeros((cache.k, r));
2855 for k in 0..r {
2856 let col = u.column(k).to_owned();
2857 let (sol_t, sol_beta) = cache.full_inverse_apply(col.view(), zero_beta.view())?;
2858 for g in 0..total_len {
2859 h0inv_u[[g, k]] = sol_t[g];
2860 }
2861 for c in 0..cache.k {
2862 h0inv_u_beta[[c, k]] = sol_beta[c];
2863 }
2864 }
2865 // M = Uᵀ (H₀'⁻¹ U), symmetric R×R. U is sparse (atom-slot supported), so
2866 // contract over the listed entries.
2867 let mut m = Array2::<f64>::zeros((r, r));
2868 for a in 0..r {
2869 for b in 0..r {
2870 let mut acc = 0.0_f64;
2871 for &(g, k, z) in &source.entries {
2872 if k == a {
2873 acc += z * h0inv_u[[g, b]];
2874 }
2875 }
2876 m[[a, b]] = acc;
2877 }
2878 }
2879 // Symmetrize M to clear back-substitution rounding asymmetry.
2880 for a in 0..r {
2881 for b in (a + 1)..r {
2882 let avg = 0.5 * (m[[a, b]] + m[[b, a]]);
2883 m[[a, b]] = avg;
2884 m[[b, a]] = avg;
2885 }
2886 }
2887 // Capacitance C = I_R + D·M (row k scaled by d_k).
2888 let mut c = Array2::<f64>::zeros((r, r));
2889 for a in 0..r {
2890 for b in 0..r {
2891 c[[a, b]] = d[a] * m[[a, b]];
2892 }
2893 c[[a, a]] += 1.0;
2894 }
2895 let Some(capacitance_lu) = small_lu_factor(&c) else {
2896 return Ok(None);
2897 };
2898 Ok(Some(Self {
2899 u,
2900 d,
2901 h0inv_u,
2902 h0inv_u_beta,
2903 m,
2904 capacitance_lu,
2905 entries: source.entries.clone(),
2906 }))
2907 }
2908
2909 /// The sparse `U` entry list `(global_t_index, atom_k, z'_ik)`.
2910 pub(crate) fn source_entries(&self) -> &[(usize, usize, f64)] {
2911 &self.entries
2912 }
2913
2914 /// `C⁻¹ D` as a dense `R × R` matrix (`R` capacitance solves; column `l` is
2915 /// `d_l · C⁻¹ e_l`). Shared by the inverse-diagonal correction and any
2916 /// adjoint trace that needs the selected inverse of the capacitance.
2917 ///
2918 /// Returns `None` when any capacitance solve fails to produce a finite
2919 /// result (#1038); the consumer must surface this as a loud failure rather
2920 /// than propagate a `NaN` into the evidence/gradient.
2921 pub fn capacitance_inv_times_d(&self) -> Option<Array2<f64>> {
2922 let r = self.d.len();
2923 let mut out = Array2::<f64>::zeros((r, r));
2924 let mut e_l = Array1::<f64>::zeros(r);
2925 for l in 0..r {
2926 e_l.fill(0.0);
2927 e_l[l] = 1.0;
2928 let col = self.capacitance_lu.solve(&e_l)?;
2929 for k in 0..r {
2930 out[[k, l]] = col[k] * self.d[l];
2931 }
2932 }
2933 Some(out)
2934 }
2935
2936 /// Subtract the rank-`R` Woodbury term from the latent inverse diagonal:
2937 /// `diag ← diag − diag(H₀'⁻¹U C⁻¹ D Uᵀ H₀'⁻¹)`. With `G = h0inv_u` and
2938 /// `(C⁻¹D) = capacitance_inv_times_d()`, the entry at global index `g` is
2939 /// `Σ_{k,l} G[g,k] (C⁻¹D)[k,l] G[g,l]`.
2940 pub(crate) fn subtract_inverse_diagonal(
2941 &self,
2942 diag: &mut Array1<f64>,
2943 ) -> Result<(), ArrowSchurError> {
2944 let r = self.d.len();
2945 let cinv_d =
2946 self.capacitance_inv_times_d()
2947 .ok_or_else(|| ArrowSchurError::SchurFactorFailed {
2948 reason: "cross-row Woodbury capacitance solve produced a non-finite \
2949 C⁻¹D for the inverse-diagonal correction (#1038): \
2950 singular/ill-conditioned cross-row capacitance"
2951 .to_string(),
2952 })?;
2953 let total_len = self.h0inv_u.nrows();
2954 for g in 0..total_len {
2955 let mut acc = 0.0_f64;
2956 for k in 0..r {
2957 let gk = self.h0inv_u[[g, k]];
2958 if gk == 0.0 {
2959 continue;
2960 }
2961 for l in 0..r {
2962 acc += gk * cinv_d[[k, l]] * self.h0inv_u[[g, l]];
2963 }
2964 }
2965 diag[g] -= acc;
2966 }
2967 Ok(())
2968 }
2969
2970 /// `log det(I_R + D·M)` (the matrix-determinant-lemma correction). Returns
2971 /// `None` when the capacitance LU has a negative determinant — i.e. the
2972 /// implied `H_full` is non-PD, which is a desync the evidence must reject
2973 /// loudly rather than return a complex/`NaN` log-det.
2974 pub fn log_det(&self) -> Option<f64> {
2975 let (log_abs, sign) = self.log_det_correction();
2976 if sign > 0.0 { Some(log_abs) } else { None }
2977 }
2978
2979 /// `log det(I_R + D·M)`: the exact additive correction
2980 /// `log det H_full − log det H₀'` (matrix-determinant lemma). For a genuine
2981 /// PD `H_full` this is real; the LU sign is returned for the caller to
2982 /// surface a non-PD capacitance as an error rather than a silent `NaN`.
2983 pub(crate) fn log_det_correction(&self) -> (f64, f64) {
2984 self.capacitance_lu.log_abs_det_and_sign()
2985 }
2986
2987 /// Apply the rank-`R` inverse correction in place on BOTH arrow blocks:
2988 /// `u ← u − (H₀'⁻¹U) · C⁻¹ · (D Uᵀ (H₀'⁻¹ rhs)_t)`, where `h0inv_rhs_t` is
2989 /// the `t`-block of `H₀'⁻¹ rhs` already computed by the base
2990 /// [`ArrowFactorCache::full_inverse_apply`]. Implements the Woodbury
2991 /// identity `H_full⁻¹ = H₀'⁻¹ − H₀'⁻¹U C⁻¹ D Uᵀ H₀'⁻¹`. `U` has no `β`
2992 /// support so `Uᵀ·v` reads only the `t`-block, but `H₀'⁻¹U` couples to `β`
2993 /// through the Schur complement, so the correction touches `u_beta` too.
2994 ///
2995 /// `entries` lets `Uᵀ·v` be formed over the sparse atom slots.
2996 pub(crate) fn apply_inverse_correction(
2997 &self,
2998 h0inv_rhs_t: ArrayView1<'_, f64>,
2999 entries: &[(usize, usize, f64)],
3000 u_t: &mut Array1<f64>,
3001 u_beta: &mut Array1<f64>,
3002 ) -> Result<(), ArrowSchurError> {
3003 let r = self.d.len();
3004 // p = D Uᵀ (H₀'⁻¹ rhs)_t.
3005 let mut p = Array1::<f64>::zeros(r);
3006 for &(g, k, z) in entries {
3007 p[k] += z * h0inv_rhs_t[g];
3008 }
3009 for k in 0..r {
3010 p[k] *= self.d[k];
3011 }
3012 // q = C⁻¹ p. A non-finite solve is a singular/ill-conditioned cross-row
3013 // capacitance (#1038): fail loudly rather than write `NaN` into the
3014 // Newton step / adjoint solve.
3015 let q =
3016 self.capacitance_lu
3017 .solve(&p)
3018 .ok_or_else(|| ArrowSchurError::SchurFactorFailed {
3019 reason: "cross-row Woodbury capacitance solve produced a non-finite \
3020 C⁻¹p for the inverse correction (#1038): \
3021 singular/ill-conditioned cross-row capacitance"
3022 .to_string(),
3023 })?;
3024 // u_t -= (H₀'⁻¹U)_t · q.
3025 for g in 0..u_t.len() {
3026 let mut acc = 0.0_f64;
3027 for k in 0..r {
3028 acc += self.h0inv_u[[g, k]] * q[k];
3029 }
3030 u_t[g] -= acc;
3031 }
3032 // u_beta -= (H₀'⁻¹U)_β · q.
3033 for c in 0..u_beta.len() {
3034 let mut acc = 0.0_f64;
3035 for k in 0..r {
3036 acc += self.h0inv_u_beta[[c, k]] * q[k];
3037 }
3038 u_beta[c] -= acc;
3039 }
3040 Ok(())
3041 }
3042
3043 /// Forward apply of the rank-`R` cross-row curvature on the latent (`t`)
3044 /// block: `out_t += U D Uᵀ v_t`. This is the EXACT forward of the same
3045 /// `H_full = H₀' + U D Uᵀ` whose inverse [`Self::apply_inverse_correction`]
3046 /// applies and whose log-determinant [`Self::log_det_correction`] reports, so
3047 /// a forward Hessian apply that adds this term stays consistent (operator and
3048 /// preconditioner describe the SAME operator). `U` has no `β` support, so the
3049 /// forward term touches only the `t` block; the `β` coupling is purely an
3050 /// inverse-side Schur artifact (`h0inv_u_beta`) and must NOT appear here.
3051 ///
3052 /// Formed over the sparse `entries` `(global_t_index, atom_k, z'_ik)` so it
3053 /// matches `Uᵀ·v` in `apply_inverse_correction` bit-for-bit:
3054 /// `p_k = d_k · Σ_{(g,k,z)} z·v_t[g]`, then `out_t[g] += Σ_{(g,k,z)} z·p_k`.
3055 pub fn apply_forward_t(&self, v_t: ArrayView1<'_, f64>, out_t: &mut Array1<f64>) {
3056 let r = self.d.len();
3057 // p = D Uᵀ v_t.
3058 let mut p = Array1::<f64>::zeros(r);
3059 for &(g, k, z) in &self.entries {
3060 p[k] += z * v_t[g];
3061 }
3062 for k in 0..r {
3063 p[k] *= self.d[k];
3064 }
3065 // out_t += U p.
3066 for &(g, k, z) in &self.entries {
3067 out_t[g] += z * p[k];
3068 }
3069 }
3070}
3071
3072#[derive(Debug, Clone, Copy, PartialEq)]
3073pub struct ArrowFactorMinPivot {
3074 pub min_row_pivot: Option<f64>,
3075 pub min_schur_pivot: Option<f64>,
3076 pub min_pivot: Option<f64>,
3077}
3078
3079impl ArrowFactorMinPivot {
3080 pub(crate) fn combine(row: Option<f64>, schur: Option<f64>) -> Self {
3081 let min_pivot = match (row, schur) {
3082 (Some(a), Some(b)) => Some(a.min(b)),
3083 (Some(a), None) => Some(a),
3084 (None, Some(b)) => Some(b),
3085 (None, None) => None,
3086 };
3087 Self {
3088 min_row_pivot: row,
3089 min_schur_pivot: schur,
3090 min_pivot,
3091 }
3092 }
3093}
3094
3095pub(crate) fn lower_cholesky_min_pivot(factor: ArrayView2<'_, f64>) -> Option<f64> {
3096 let width = factor.nrows().min(factor.ncols());
3097 let mut out = None;
3098 for idx in 0..width {
3099 let pivot = factor[[idx, idx]] * factor[[idx, idx]];
3100 out = Some(match out {
3101 Some(current) => f64::min(current, pivot),
3102 None => pivot,
3103 });
3104 }
3105 out
3106}
3107
3108pub(crate) fn lower_cholesky_max_pivot(factor: ArrayView2<'_, f64>) -> Option<f64> {
3109 let width = factor.nrows().min(factor.ncols());
3110 let mut out = None;
3111 for idx in 0..width {
3112 let pivot = factor[[idx, idx]] * factor[[idx, idx]];
3113 out = Some(match out {
3114 Some(current) => f64::max(current, pivot),
3115 None => pivot,
3116 });
3117 }
3118 out
3119}
3120
3121/// Smallest cached Cholesky pivot for row blocks and the dense Schur factor.
3122///
3123/// Pivots are returned as squared lower-factor diagonals, matching the Hessian
3124/// scale rather than the Cholesky-factor scale. In inexact PCG mode the dense
3125/// Schur factor is absent, so `min_schur_pivot` is `None`.
3126pub fn arrow_factor_min_pivot(cache: &ArrowFactorCache) -> ArrowFactorMinPivot {
3127 let mut min_row_pivot = None;
3128 for factor in cache.htt_factors.iter() {
3129 if let Some(pivot) = lower_cholesky_min_pivot(factor) {
3130 min_row_pivot = Some(match min_row_pivot {
3131 Some(current) => f64::min(current, pivot),
3132 None => pivot,
3133 });
3134 }
3135 }
3136 let min_schur_pivot = cache
3137 .schur_factor
3138 .as_ref()
3139 .and_then(|factor| lower_cholesky_min_pivot(factor.view()));
3140 ArrowFactorMinPivot::combine(min_row_pivot, min_schur_pivot)
3141}
3142
3143/// Largest cached Cholesky pivot across the row blocks and the dense Schur
3144/// factor (Hessian scale, i.e. squared lower-factor diagonal). This is the
3145/// diagonal magnitude scale a safe-SPD pivot floor is measured against: the
3146/// curvature-homotopy tracker (#1007) compares the min pivot against
3147/// `√eps · max(this, 1)`, the same floor the inner solver's
3148/// [`safe_spd_pivot_min`] uses. `None` only for an empty cache.
3149pub fn arrow_factor_max_pivot(cache: &ArrowFactorCache) -> Option<f64> {
3150 let mut max_pivot: Option<f64> = None;
3151 for factor in cache.htt_factors.iter() {
3152 if let Some(pivot) = lower_cholesky_max_pivot(factor) {
3153 max_pivot = Some(match max_pivot {
3154 Some(current) => f64::max(current, pivot),
3155 None => pivot,
3156 });
3157 }
3158 }
3159 if let Some(factor) = cache.schur_factor.as_ref()
3160 && let Some(pivot) = lower_cholesky_max_pivot(factor.view())
3161 {
3162 max_pivot = Some(match max_pivot {
3163 Some(current) => f64::max(current, pivot),
3164 None => pivot,
3165 });
3166 }
3167 max_pivot
3168}
3169
3170impl ArrowFactorCache {
3171 pub fn n_rows(&self) -> usize {
3172 self.htt_factors.len()
3173 }
3174
3175 pub fn htbeta_available(&self) -> bool {
3176 self.htbeta.is_available()
3177 }
3178
3179 /// Whether the Newton solve that produced this cache actually executed on
3180 /// the device: the device-resident Direct dense solve or the device-resident
3181 /// matrix-free SAE PCG (whose matvec runs in CUDA kernels). This does NOT
3182 /// include the injected host-procedural reduced-Schur matvec, whose
3183 /// arithmetic runs on the CPU even when a CUDA context was opened to build
3184 /// per-row factors (#1209) — that path sets
3185 /// `PcgDiagnostics::injected_host_procedural_matvec` instead. Read-only
3186 /// routing provenance: lets a fit result record device-vs-CPU as ground
3187 /// truth instead of inferring it from the runtime probe. Mirrors
3188 /// `PcgDiagnostics::used_device_arrow`.
3189 #[must_use]
3190 pub fn used_device(&self) -> bool {
3191 self.pcg_diagnostics.used_device_arrow
3192 }
3193
3194 pub fn undamped_factor(&self, row: usize) -> ArrayView2<'_, f64> {
3195 match &self.htt_factors_undamped {
3196 ArrowUndampedFactors::SameAsDamped => self.htt_factors.factor(row),
3197 ArrowUndampedFactors::Owned(factors) => factors.factor(row),
3198 }
3199 }
3200
3201 pub fn undamped_factor_count(&self) -> usize {
3202 match &self.htt_factors_undamped {
3203 ArrowUndampedFactors::SameAsDamped => self.htt_factors.len(),
3204 ArrowUndampedFactors::Owned(factors) => factors.len(),
3205 }
3206 }
3207
3208 pub fn undamped_factors_iter(&self) -> impl Iterator<Item = ArrayView2<'_, f64>> + '_ {
3209 (0..self.undamped_factor_count()).map(|row| self.undamped_factor(row))
3210 }
3211
3212 pub fn compute_undamped_arrow_log_det(&self) -> Option<f64> {
3213 if self.ridge_t != 0.0 || self.ridge_beta != 0.0 {
3214 return None;
3215 }
3216 // When the shared β block is empty (`k == 0`) the joint Hessian is
3217 // exactly the block diagonal of the per-row latent blocks: there is no
3218 // reduced Schur complement to form, so the dense Direct path leaves
3219 // `schur_factor = None` legitimately (not the InexactPCG "never formed
3220 // the dense K×K factor" case, which has `k > 0`). The log-det is then
3221 // the per-row sum with a zero (empty `0×0`) Schur contribution. Without
3222 // this the `schur_factor.as_ref()?` below would return `None` for a
3223 // β-profiled atom (#1132 euclidean K=4) and starve the REML Laplace
3224 // normaliser of the joint Hessian log-det it requires.
3225 let schur = match self.schur_factor.as_ref() {
3226 Some(schur) => Some(schur),
3227 None if self.k == 0 => None,
3228 None => return None,
3229 };
3230
3231 let mut acc = 0.0_f64;
3232 for l in self.undamped_factors_iter() {
3233 for i in 0..l.nrows() {
3234 let d = l[[i, i]];
3235 if d <= 0.0 || !d.is_finite() {
3236 return None;
3237 }
3238 acc += 2.0 * d.ln();
3239 }
3240 }
3241 if let Some(schur) = schur {
3242 for i in 0..schur.nrows() {
3243 let d = schur[[i, i]];
3244 if d <= 0.0 || !d.is_finite() {
3245 return None;
3246 }
3247 acc += 2.0 * d.ln();
3248 }
3249 }
3250 let woodbury_correction = self.cross_row_woodbury_log_det();
3251 if !woodbury_correction.is_finite() {
3252 return None;
3253 }
3254 Some(acc + woodbury_correction)
3255 }
3256
3257 /// The total length of `delta_t` / IFT output vectors for this cache.
3258 pub fn delta_t_len(&self) -> usize {
3259 self.row_offsets[self.n_rows()]
3260 }
3261
3262 pub fn apply_htbeta_row(
3263 &self,
3264 row: usize,
3265 delta_beta: ArrayView1<'_, f64>,
3266 out: &mut Array1<f64>,
3267 ) -> bool {
3268 let di = if row < self.row_dims.len() {
3269 self.row_dims[row]
3270 } else {
3271 self.d
3272 };
3273 if out.len() != di || delta_beta.len() != self.k {
3274 return false;
3275 }
3276 self.htbeta.apply_row(row, delta_beta, out)
3277 }
3278
3279 /// Accumulate `out[a] += H_βt^(row)[a, :] · v` for all `a in 0..k`.
3280 ///
3281 /// `v` has length `row_dims[row]`; `out` has length `k`. The caller must
3282 /// zero `out` before the first call if it needs a fresh result. Returns
3283 /// `false` when the cache is `Disabled` and no `fallback_op` is provided;
3284 /// callers must treat the accumulator as invalid in that case.
3285 pub fn apply_htbeta_row_transpose(
3286 &self,
3287 row: usize,
3288 v: ArrayView1<'_, f64>,
3289 out: &mut Array1<f64>,
3290 fallback_op: Option<&RowHtbetaMatvec>,
3291 ) -> bool {
3292 let di = if row < self.row_dims.len() {
3293 self.row_dims[row]
3294 } else {
3295 self.d
3296 };
3297 if v.len() != di || out.len() != self.k {
3298 return false;
3299 }
3300 self.htbeta
3301 .apply_row_transpose_accumulate(row, v, out, di, self.k, fallback_op)
3302 }
3303
3304 /// Arrow log-determinant
3305 /// `log|H| = Σ_i log|H_{t_i t_i}| + log|Schur_β|`
3306 /// using the cached (damped) factors.
3307 ///
3308 /// Returns `(log_det_tt_sum, log_det_schur)` so the caller can decide
3309 /// what to do with the Schur piece (e.g. REML evidence wants both;
3310 /// some diagnostics want only the per-row sum). `None` for the Schur
3311 /// piece signals that the cache was produced by an InexactPCG solve
3312 /// and never formed/factored the dense `K × K` reduced system.
3313 ///
3314 /// The log-determinant of a Cholesky factor `L` of `M` is
3315 /// `2 Σ log L_ii`.
3316 pub fn arrow_log_det(&self) -> (f64, Option<f64>) {
3317 let mut log_det_tt = 0.0_f64;
3318 for l in self.htt_factors.iter() {
3319 for i in 0..l.nrows() {
3320 log_det_tt += l[[i, i]].ln();
3321 }
3322 }
3323 log_det_tt *= 2.0;
3324 let log_det_schur = self.schur_factor.as_ref().map(|l| {
3325 let mut s = 0.0_f64;
3326 for i in 0..l.nrows() {
3327 s += l[[i, i]].ln();
3328 }
3329 2.0 * s + self.cross_row_woodbury_log_det()
3330 });
3331 (log_det_tt, log_det_schur)
3332 }
3333
3334 /// The exact cross-row IBP correction `log det(I_R + D·M)` to add to the
3335 /// base `log det H₀'` (#1038). Zero when no [`CrossRowWoodbury`] is present,
3336 /// so non-IBP caches are unaffected. The determinant lemma gives
3337 /// `log det H_full = log det H₀' + log det(I_R + D Uᵀ H₀'⁻¹ U)`; this is the
3338 /// second term, the only piece beyond the bare arrow log-determinant.
3339 ///
3340 /// Panics-free: a negative capacitance determinant (non-PD `H_full`) yields
3341 /// `NaN` here so the evidence surfaces the desync rather than silently
3342 /// dropping the imaginary part. Callers that must reject it should check
3343 /// [`CrossRowWoodbury::log_det`] directly.
3344 pub fn cross_row_woodbury_log_det(&self) -> f64 {
3345 match self.cross_row_woodbury.as_ref() {
3346 Some(w) => w.log_det().unwrap_or(f64::NAN),
3347 None => 0.0,
3348 }
3349 }
3350
3351 /// Diagonal of the latent (`t`-block) of the *full* bordered-arrow
3352 /// inverse `(H⁻¹)_tt`, in `delta_t` layout (length [`Self::delta_t_len`]).
3353 ///
3354 /// For the bordered arrow Hessian
3355 /// `H = [[A, B], [Bᵀ, H_ββ]]` with `A = H_tt` (block-diagonal per row,
3356 /// `A_i = H_tt^(i)`) and `B = H_tβ`, the standard block-inverse identity
3357 /// gives the `t`-block
3358 /// `(H⁻¹)_tt = A⁻¹ + A⁻¹ B S⁻¹ Bᵀ A⁻¹`, where
3359 /// `S = H_ββ − Bᵀ A⁻¹ B` is the Schur complement on `β`. Because `A` is
3360 /// block-diagonal, the `(i, j)` diagonal entry of `(H⁻¹)_tt` is computed
3361 /// purely from row `i`'s factor and cross-block:
3362 ///
3363 /// ```text
3364 /// a = A_i⁻¹ e_j (chol_solve on the per-row factor)
3365 /// [A_i⁻¹]_{jj} = a[j]
3366 /// w = B_iᵀ a = H_βt^(i) a (a K-vector)
3367 /// z = S⁻¹ w (chol_solve on the Schur factor)
3368 /// diag = a[j] + w · z
3369 /// ```
3370 ///
3371 /// The UNDAMPED per-row factors ([`Self::undamped_factor`]) are used so
3372 /// the result is the inverse of the *true* `H_tt`, not the LM-damped
3373 /// `H_tt + ridge_t·I` — same rationale the IFT predictor docstring gives
3374 /// at the top of this struct.
3375 ///
3376 /// # Consuming the diagonal as a per-(atom, axis) trace
3377 ///
3378 /// `(H⁻¹)_tt` is the latent covariance block. The selected-inverse trace
3379 /// for a contiguous group of latent coordinates (e.g. one atom's rows, or
3380 /// one axis across rows) is simply the sum of the returned diagonal entries
3381 /// over those `row_offsets[i] + j` indices — no off-diagonal terms are
3382 /// needed for the trace `tr[(H⁻¹)_tt · D]` against a diagonal selector `D`.
3383 ///
3384 /// # Errors
3385 ///
3386 /// Returns [`ArrowSchurError::SchurFactorFailed`] when this cache has no
3387 /// dense Schur factor or no usable `H_βt` coupling — i.e. it was produced
3388 /// by an [`ArrowSolverMode::InexactPCG`] solve (no dense `K × K` factor) or
3389 /// by a `Disabled` `htbeta` cache. The selected-inverse block-trace is not
3390 /// yet supported for the matrix-free PCG mode; that branch needs a separate
3391 /// Lanczos/Hutchinson estimator.
3392 pub fn latent_block_inverse_diagonal(&self) -> Result<Array1<f64>, ArrowSchurError> {
3393 let Some(schur_factor) = self.schur_factor.as_ref() else {
3394 return Err(ArrowSchurError::SchurFactorFailed {
3395 reason: "latent_block_inverse_diagonal requires a dense Schur factor; \
3396 the InexactPCG mode does not form one"
3397 .to_string(),
3398 });
3399 };
3400 if !self.htbeta_available() {
3401 return Err(ArrowSchurError::SchurFactorFailed {
3402 reason: "latent_block_inverse_diagonal requires the H_tβ coupling, \
3403 but this cache's htbeta is Disabled"
3404 .to_string(),
3405 });
3406 }
3407 let n = self.undamped_factor_count();
3408 let total_len = self.delta_t_len();
3409 let mut out = Array1::<f64>::zeros(total_len);
3410 // Per-row scratch, sized to the max latent dim / K.
3411 let mut e_j = Array1::<f64>::zeros(self.d);
3412 let mut w = Array1::<f64>::zeros(self.k);
3413 for i in 0..n {
3414 let di = self.row_dims[i];
3415 let row_base = self.row_offsets[i];
3416 let factor = self.undamped_factor(i);
3417 for j in 0..di {
3418 // a = A_i⁻¹ e_j.
3419 for c in 0..di {
3420 e_j[c] = 0.0;
3421 }
3422 e_j[j] = 1.0;
3423 let e_j_slice = e_j.slice(ndarray::s![..di]).to_owned();
3424 let a = cholesky_solve_vector(factor, &e_j_slice);
3425 // w = H_βt^(i) a (a K-vector); accumulator must start zeroed.
3426 w.fill(0.0);
3427 if !self.apply_htbeta_row_transpose(i, a.view(), &mut w, None) {
3428 return Err(ArrowSchurError::SchurFactorFailed {
3429 reason: format!(
3430 "latent_block_inverse_diagonal: H_βt^({i}) apply failed \
3431 (htbeta cache could not supply row {i})"
3432 ),
3433 });
3434 }
3435 // z = S⁻¹ w; correction = w · z.
3436 let z = cholesky_solve_vector(schur_factor, &w);
3437 let mut corr = 0.0_f64;
3438 for c in 0..self.k {
3439 corr += w[c] * z[c];
3440 }
3441 out[row_base + j] = a[j] + corr;
3442 }
3443 }
3444 if let Some(woodbury) = self.cross_row_woodbury.as_ref() {
3445 // #1038: the factors above are `H₀'`, so `out` is diag((H₀'⁻¹)_tt).
3446 // The full inverse diagonal subtracts the rank-`R` Woodbury term
3447 // diag(H₀'⁻¹U C⁻¹ D Uᵀ H₀'⁻¹). With `G = h0inv_u = (H₀'⁻¹U)_t` and
3448 // (by symmetry of `H₀'⁻¹`) `(Uᵀ H₀'⁻¹)_t = Gᵀ`, the diagonal entry at
3449 // global index `g` is `Σ_{k,l} G[g,k] (C⁻¹D)[k,l] G[g,l]`. Form the
3450 // `R×R` matrix `C⁻¹D` once (R solves), then contract per row index.
3451 woodbury.subtract_inverse_diagonal(&mut out)?;
3452 }
3453 Ok(out)
3454 }
3455
3456 /// Solve the full bordered-arrow system `H·u = w` on the cached factor
3457 /// (#1006): `w` arrives in arrow layout — `w_t` flat per
3458 /// [`Self::delta_t_len`] / `row_offsets`, `w_beta` of length `K` — and the
3459 /// solution comes back in the same layout. Standard block elimination on
3460 /// the SAME factors whose log-determinant the evidence reports:
3461 ///
3462 /// ```text
3463 /// y_i = H_tt^(i)⁻¹ · w_t^(i)
3464 /// r_β = w_β − Σ_i H_βt^(i) · y_i
3465 /// u_β = Schur⁻¹ · r_β
3466 /// u_t^(i) = y_i − H_tt^(i)⁻¹ · (H_tβ^(i) · u_β)
3467 /// ```
3468 ///
3469 /// This is the IFT / adjoint back-solve the analytic outer ρ-gradient
3470 /// consumes: `u_j = H⁻¹ (∂g/∂ρ_j)` per outer coordinate and the
3471 /// `H⁻¹`-side of the third-order correction `−½·Γᵀ·H⁻¹·(∂g/∂ρ_j)`.
3472 /// Contract: the cache must be the ridge-0 Direct evidence factor
3473 /// (undamped per-row factors + dense Schur), so the solve is against the
3474 /// criterion's own `H` — never a damped surrogate (that would desync the
3475 /// gradient from the reported evidence).
3476 ///
3477 /// When the cache carries an exact cross-row IBP
3478 /// [`CrossRowWoodbury`] (#1038), the per-row factors are the NO-SELF base
3479 /// `H₀'` and this method layers the rank-`R` Woodbury correction so the
3480 /// returned solve is against the FULL `H_full = H₀' + U D Uᵀ` — the same
3481 /// operator whose log-determinant [`Self::arrow_log_det`] reports. The
3482 /// θ/ρ-adjoint that consumes this therefore sees the cross-row curvature.
3483 pub fn full_inverse_apply(
3484 &self,
3485 w_t: ArrayView1<'_, f64>,
3486 w_beta: ArrayView1<'_, f64>,
3487 ) -> Result<(Array1<f64>, Array1<f64>), ArrowSchurError> {
3488 let (mut u_t, mut u_beta) = self.full_inverse_apply_base(w_t, w_beta)?;
3489 if let Some(woodbury) = self.cross_row_woodbury.as_ref() {
3490 // u ← u − H₀'⁻¹U C⁻¹ D Uᵀ u. `u_t` is the `t`-block of `H₀'⁻¹ w`.
3491 let h0inv_w_t = u_t.clone();
3492 woodbury.apply_inverse_correction(
3493 h0inv_w_t.view(),
3494 woodbury.source_entries(),
3495 &mut u_t,
3496 &mut u_beta,
3497 )?;
3498 }
3499 Ok((u_t, u_beta))
3500 }
3501
3502 /// Bare bordered-arrow inverse solve against the cached per-row factors and
3503 /// Schur factor (the NO-SELF base `H₀'` when a cross-row Woodbury is
3504 /// present). [`Self::full_inverse_apply`] wraps this with the rank-`R`
3505 /// correction; [`CrossRowWoodbury::build`] calls this directly (before the
3506 /// carrier exists) to form `H₀'⁻¹U`.
3507 pub(crate) fn full_inverse_apply_base(
3508 &self,
3509 w_t: ArrayView1<'_, f64>,
3510 w_beta: ArrayView1<'_, f64>,
3511 ) -> Result<(Array1<f64>, Array1<f64>), ArrowSchurError> {
3512 let total_len = self.delta_t_len();
3513 if w_t.len() != total_len || w_beta.len() != self.k {
3514 return Err(ArrowSchurError::SchurFactorFailed {
3515 reason: format!(
3516 "full_inverse_apply: rhs shapes (w_t={}, w_beta={}) != (delta_t_len={}, K={})",
3517 w_t.len(),
3518 w_beta.len(),
3519 total_len,
3520 self.k
3521 ),
3522 });
3523 }
3524 let n = self.undamped_factor_count();
3525 // Forward pass: y_i = H_tt^(i)⁻¹ w_t^(i), accumulating the border RHS.
3526 let mut y = Array1::<f64>::zeros(total_len);
3527 let mut r_beta = w_beta.to_owned();
3528 for i in 0..n {
3529 let di = self.row_dims[i];
3530 let base = self.row_offsets[i];
3531 let factor = self.undamped_factor(i);
3532 let w_row = w_t.slice(ndarray::s![base..base + di]).to_owned();
3533 let y_row = cholesky_solve_vector(factor, &w_row);
3534 if self.k > 0 {
3535 // r_β −= H_βt^(i) y_i: accumulate into a scratch then subtract,
3536 // because the helper ACCUMULATES (+=) into its output.
3537 let mut acc = Array1::<f64>::zeros(self.k);
3538 if !self.apply_htbeta_row_transpose(i, y_row.view(), &mut acc, None) {
3539 return Err(ArrowSchurError::SchurFactorFailed {
3540 reason: format!(
3541 "full_inverse_apply: H_βt^({i}) apply failed (htbeta cache \
3542 could not supply row {i}; htbeta={:?}, di={}, k={})",
3543 self.htbeta,
3544 self.row_dims.get(i).copied().unwrap_or(self.d),
3545 self.k
3546 ),
3547 });
3548 }
3549 for c in 0..self.k {
3550 r_beta[c] -= acc[c];
3551 }
3552 }
3553 for j in 0..di {
3554 y[base + j] = y_row[j];
3555 }
3556 }
3557 // Border solve + back-substitution.
3558 let u_beta = if self.k > 0 {
3559 self.schur_inverse_apply(r_beta.view())?
3560 } else {
3561 Array1::<f64>::zeros(0)
3562 };
3563 let mut u_t = y;
3564 if self.k > 0 {
3565 let mut cross = Array1::<f64>::zeros(self.d);
3566 for i in 0..n {
3567 let di = self.row_dims[i];
3568 let base = self.row_offsets[i];
3569 let mut cross_row = cross.slice_mut(ndarray::s![..di]);
3570 cross_row.fill(0.0);
3571 let mut cross_owned = cross_row.to_owned();
3572 if !self.apply_htbeta_row(i, u_beta.view(), &mut cross_owned) {
3573 return Err(ArrowSchurError::SchurFactorFailed {
3574 reason: format!(
3575 "full_inverse_apply: H_tβ^({i}) apply failed (htbeta cache \
3576 could not supply row {i})"
3577 ),
3578 });
3579 }
3580 let factor = self.undamped_factor(i);
3581 let corr = cholesky_solve_vector(factor, &cross_owned);
3582 for j in 0..di {
3583 u_t[base + j] -= corr[j];
3584 }
3585 }
3586 }
3587 Ok((u_t, u_beta))
3588 }
3589
3590 /// Apply the β-block of the full inverse, `(H⁻¹)_ββ · rhs = S_β⁻¹ · rhs`,
3591 /// where `S_β` is the Schur complement on β whose Cholesky factor this
3592 /// cache holds in [`Self::schur_factor`].
3593 ///
3594 /// For the bordered arrow Hessian `H = [[A, B], [Bᵀ, H_ββ]]`, the
3595 /// β-block of `H⁻¹` is exactly the inverse of the Schur complement
3596 /// `S_β = H_ββ − Bᵀ A⁻¹ B`. One Cholesky back-substitution per call,
3597 /// reusing the cached factor; `rhs` and the returned vector both have
3598 /// length `K`.
3599 ///
3600 /// This is the general single-solve primitive for the β border. Callers
3601 /// that need a Schur-inverse trace `tr(S_β⁻¹ M)` against a structured
3602 /// penalty `M` (e.g. the SAE λ_smooth Fellner-Schall step, where
3603 /// `M = blockdiag_k(λ_k S_k ⊗ I_p)`) build it as
3604 /// `Σ_col e_colᵀ S_β⁻¹ M e_col` — apply this to each column of `M`
3605 /// (exploiting whatever sparsity `M` has) and read off `result[col]`.
3606 /// Keeping `M`'s layout on the caller side avoids coupling this solver
3607 /// to penalty-op types.
3608 ///
3609 /// # Errors
3610 ///
3611 /// Returns [`ArrowSchurError::SchurFactorFailed`] when this cache has no
3612 /// dense Schur factor (an [`ArrowSolverMode::InexactPCG`] solve) — the
3613 /// same not-yet-supported branch as [`Self::latent_block_inverse_diagonal`]
3614 /// — or when `rhs.len() != k`.
3615 ///
3616 /// Cross-row IBP (#1038) note: this is the β-block primitive of the
3617 /// factored base `S_β` (`H₀'` when a [`CrossRowWoodbury`] is present), used
3618 /// internally by [`Self::full_inverse_apply_base`]; it is deliberately NOT
3619 /// Woodbury-corrected so the base solve stays bare. The cross-row term has
3620 /// no `β` support, so `(H_full⁻¹)_ββ = S_β⁻¹` exactly on the directions any
3621 /// IBP ρ-trace contracts. A consumer needing the full `(H_full⁻¹)_ββ` for a
3622 /// β-supported direction should call [`Self::full_inverse_apply`] with a
3623 /// unit `β`-RHS (which applies the rank-`R` correction).
3624 pub fn schur_inverse_apply(
3625 &self,
3626 rhs: ArrayView1<'_, f64>,
3627 ) -> Result<Array1<f64>, ArrowSchurError> {
3628 let Some(schur_factor) = self.schur_factor.as_ref() else {
3629 return Err(ArrowSchurError::SchurFactorFailed {
3630 reason: "schur_inverse_apply requires a dense Schur factor; \
3631 the InexactPCG mode does not form one"
3632 .to_string(),
3633 });
3634 };
3635 if rhs.len() != self.k {
3636 return Err(ArrowSchurError::SchurFactorFailed {
3637 reason: format!(
3638 "schur_inverse_apply: rhs length {} != K {}",
3639 rhs.len(),
3640 self.k
3641 ),
3642 });
3643 }
3644 let rhs_owned = rhs.to_owned();
3645 Ok(cholesky_solve_vector(schur_factor, &rhs_owned))
3646 }
3647
3648 /// Dense principal sub-block of the β-block of the full inverse,
3649 /// `(H⁻¹)_ββ[block, block] = S_β⁻¹[block, block]`, shape `(W, W)` with
3650 /// `W = block.len()`.
3651 ///
3652 /// For the bordered arrow Hessian `H = [[A, B], [Bᵀ, H_ββ]]`, the β-block
3653 /// of `H⁻¹` is exactly `S_β⁻¹` (the inverse of the Schur complement whose
3654 /// Cholesky factor this cache holds). This returns the contiguous
3655 /// `block × block` sub-block — e.g. one SAE atom's decoder coefficients via
3656 /// [`gam_terms::sae::manifold::SaeManifoldTerm::beta_block_offsets`] — by
3657 /// solving `S_β x = e_j` for each `j ∈ block` (reusing the cached factor)
3658 /// and gathering the `block` rows of each solution column. `W`
3659 /// back-substitutions of size `K`; the result is symmetrized to clear
3660 /// back-substitution rounding asymmetry. Up to a dispersion scale `φ`, this
3661 /// block is the joint posterior covariance `Cov(β_block)` of those
3662 /// coefficients with the latent coordinates already marginalized out (that
3663 /// is precisely what Schur-eliminating the per-row `t`-blocks does).
3664 ///
3665 /// Same dense-Schur requirement / error contract as
3666 /// [`Self::schur_inverse_apply`]; additionally errors when `block` runs past
3667 /// `K`.
3668 pub fn schur_inverse_block(
3669 &self,
3670 block: std::ops::Range<usize>,
3671 ) -> Result<Array2<f64>, ArrowSchurError> {
3672 let Some(schur_factor) = self.schur_factor.as_ref() else {
3673 return Err(ArrowSchurError::SchurFactorFailed {
3674 reason: "schur_inverse_block requires a dense Schur factor; \
3675 the InexactPCG mode does not form one"
3676 .to_string(),
3677 });
3678 };
3679 if block.end > self.k {
3680 return Err(ArrowSchurError::SchurFactorFailed {
3681 reason: format!(
3682 "schur_inverse_block: block end {} exceeds K {}",
3683 block.end, self.k
3684 ),
3685 });
3686 }
3687 let w = block.len();
3688 let mut out = Array2::<f64>::zeros((w, w));
3689 let mut e_j = Array1::<f64>::zeros(self.k);
3690 for (jc, j) in block.clone().enumerate() {
3691 e_j.fill(0.0);
3692 e_j[j] = 1.0;
3693 let col = cholesky_solve_vector(schur_factor, &e_j);
3694 for (ic, i) in block.clone().enumerate() {
3695 out[[ic, jc]] = col[i];
3696 }
3697 }
3698 // S_β⁻¹ is symmetric; symmetrize to clear back-substitution rounding.
3699 for ic in 0..w {
3700 for jc in (ic + 1)..w {
3701 let avg = 0.5 * (out[[ic, jc]] + out[[jc, ic]]);
3702 out[[ic, jc]] = avg;
3703 out[[jc, ic]] = avg;
3704 }
3705 }
3706 Ok(out)
3707 }
3708}