gam_solve/arrow_schur/system.rs
1//! The bordered arrow-Schur system itself: [`ArrowRowBlock`], the
2//! [`ArrowSchurSystem`] container and its assembly impl, cross-row latent
3//! penalties, the streaming builder, and the per-row factor caches.
4
5use super::*;
6
7/// Per-row block data for the arrow-Schur system.
8///
9/// `htt` holds the `d × d` Gauss–Newton block for row `i` (including any
10/// analytic-penalty contributions on that row); `htbeta` holds the
11/// `d × K` cross-block `H_tβ^(i)`; `gt` is the `d`-length latent
12/// gradient for row `i`.
13#[derive(Debug, Clone)]
14pub struct ArrowRowBlock {
15 /// `H_tt^(i)`, shape `(d, d)`.
16 pub htt: Array2<f64>,
17 /// `H_tβ^(i)`, shape `(d, K)`.
18 pub htbeta: Array2<f64>,
19 /// `g_t^(i)`, shape `(d,)`.
20 pub gt: Array1<f64>,
21}
22
23impl ArrowRowBlock {
24 /// Allocate one BA point-block row: local latent Hessian, point-camera
25 /// cross block, and point gradient.
26 pub fn new(d: usize, k: usize) -> Self {
27 Self::new_with_htbeta_cols(d, k)
28 }
29
30 /// Allocate one BA row whose dense cross-block slab has `htbeta_cols`
31 /// columns. This is used by matrix-free assemblers that keep the shared
32 /// beta tier at one width while dense row supplements live in another
33 /// coordinate system.
34 pub fn new_with_htbeta_cols(d: usize, htbeta_cols: usize) -> Self {
35 Self {
36 htt: Array2::<f64>::zeros((d, d)),
37 htbeta: Array2::<f64>::zeros((d, htbeta_cols)),
38 gt: Array1::<f64>::zeros(d),
39 }
40 }
41}
42
43/// Bordered (t, β) Newton system with arrow structure.
44///
45/// The β-block is held as a dense `K × K` Hessian `H_ββ` plus a `K`-length
46/// gradient `g_β` for direct BA modes. Large-scale inexact BA callers may
47/// additionally install a matrix-free `H_ββ x` operator and diagonal via
48/// [`ArrowSchurSystem::set_shared_beta_operator`]; the InexactPCG mode then
49/// avoids dense Schur formation/factorization.
50/// The t-block is a `Vec<ArrowRowBlock>` of length `N`.
51///
52/// Construction is the driver's responsibility: the driver
53///
54/// 1. evaluates Φ(t) and the radial jet `∂Φ/∂t` (the latter via
55/// [`gam_terms::latent::LatentCoordValues::design_gradient_wrt_t`]);
56/// 2. forms the working-weighted Gauss–Newton blocks
57/// `H_tt^(i) += (g_i β)(g_i β)^T`, `H_tβ^(i) += (g_i β) ⊗ Φ_i`,
58/// `H_ββ += Φ^T W Φ + Σ_k λ_k S_k`;
59/// 3. calls [`ArrowSchurSystem::add_analytic_penalty_contributions`] to
60/// fold row-block Psi-tier analytic penalties (`ARDPenalty`,
61/// `SparsityPenalty`) into `H_tt^(i)` and Beta-tier penalties into `H_ββ`;
62/// 4. calls [`ArrowSchurSystem::solve`] to obtain `(Δt, Δβ)`.
63pub struct ArrowSchurSystem {
64 /// Per-row latent block (length `N`, each row `d × d` / `d × K` / `d`).
65 pub rows: Vec<ArrowRowBlock>,
66 /// `H_ββ`, shape `(K, K)` for direct BA modes; empty when constructed
67 /// by [`ArrowSchurSystem::new_matrix_free_shared`] for PCG-only use.
68 pub hbb: Array2<f64>,
69 /// Optional matrix-free `H_ββ x` operator for large BA Schur PCG.
70 ///
71 /// Direct and Square-Root BA modes still require `hbb`; InexactPCG uses
72 /// this operator when present, avoiding dense shared-block storage for
73 /// SAE-manifold scale `K`.
74 pub hbb_matvec: Option<SharedBetaMatvec>,
75 /// Optional row-local matrix-free multiply for `H_tβ^(i) x`.
76 ///
77 /// When present, all inner-Schur paths route through this operator instead
78 /// of indexing the per-row `htbeta` dense slabs: `reduced_rhs_beta`,
79 /// `schur_matvec` (PCG hot loop), back-substitution,
80 /// `JacobiPreconditioner` construction, `build_dense_schur_direct`, and
81 /// `build_dense_schur_sqrt_ba` all call `sys_htbeta_apply_row` or
82 /// `sys_htbeta_materialize_row`. Factor caches retain the operator for
83 /// IFT/evidence consumers as before.
84 pub htbeta_matvec: Option<RowHtbetaMatvec>,
85 /// Optional row-local matrix-free transpose multiply `out += H_βt^(i) · v`.
86 ///
87 /// The sparse adjoint of [`Self::htbeta_matvec`]. When present, the
88 /// reduced-Schur matvec applies `H_βt^(i)` directly (sparse `scatter`)
89 /// instead of probing the forward operator against `K` basis vectors. This
90 /// is the per-row sparse apply that lifts the `O(K)` column-probe in the
91 /// GPU PCG and streaming Schur paths to `O(m_i · p)` per row. Installed in
92 /// lock-step with `htbeta_matvec` by [`Self::set_row_htbeta_operator`].
93 pub htbeta_transpose_matvec: Option<RowHtbetaTransposeMatvec>,
94 /// Whether `rows[*].htbeta` contains a dense contribution that must be added
95 /// on top of the matrix-free row operator.
96 pub htbeta_dense_supplement: bool,
97 /// Optional diagonal of the matrix-free shared block, used by the
98 /// Schur-Jacobi preconditioner in the Agarwal-style PCG path.
99 pub hbb_diag: Option<Array1<f64>>,
100 /// `g_β`, shape `(K,)`.
101 pub gb: Array1<f64>,
102 /// Maximum per-row latent dimensionality across all rows.
103 ///
104 /// For homogeneous systems (all rows have the same dim) this equals the
105 /// common per-row `d`. For heterogeneous systems (e.g. sparse SAE rows
106 /// where JumpReLU / TopK / sparsemax active sets vary per observation)
107 /// this is `max_i row_dims[i]`. Per-row code should use
108 /// `row.htt.nrows()` or `row_dims[i]`; `d` is an upper bound for
109 /// scratch-buffer sizing.
110 pub d: usize,
111 /// Per-row latent dimensionality: `row_dims[i] == rows[i].htt.nrows()`.
112 ///
113 /// For homogeneous systems `row_dims[i] == d` for all `i`.
114 pub row_dims: Arc<[usize]>,
115 /// Flat-buffer row offsets for the `delta_t` vector produced by
116 /// [`Self::solve`] / [`solve_arrow_newton_step_core`].
117 ///
118 /// `row_offsets[i]` is the start index for row `i`'s slice in `delta_t`;
119 /// `row_offsets[n]` is the total `delta_t` length. For homogeneous
120 /// systems `row_offsets[i] == i * d`.
121 pub row_offsets: Arc<[usize]>,
122 /// β dimensionality `K`.
123 pub k: usize,
124 /// Geometry tag for the row-local latent blocks after optional
125 /// Riemannian projection. Euclidean/no-op geometry uses the sentinel.
126 pub manifold_mode_fingerprint: u64,
127 /// Structural/value tag for row-local Hessian factors and their Schur
128 /// inputs. Stale caches must be rejected when row-dependent Hessian
129 /// penalties or cross-blocks change.
130 pub row_hessian_fingerprint: u64,
131 /// Registry-side tag for row-dependent analytic-penalty Hessian inputs.
132 /// Combined with the materialized row blocks in
133 /// [`Self::current_row_hessian_fingerprint`].
134 pub analytic_row_hessian_fingerprint: u64,
135 /// Term-block column ranges for the block-Jacobi Schur preconditioner.
136 ///
137 /// Each entry `r` means that indices `r.start..r.end` belong to one
138 /// coefficient block (a GAM term or a custom parameter family from
139 /// `ParameterBlockSpec`). When populated via
140 /// [`Self::set_block_offsets`], the Jacobi preconditioner inverts the
141 /// full `b × b` Schur block for each term instead of only its diagonal.
142 ///
143 /// The default (empty slice) causes `JacobiPreconditioner` to fall back
144 /// to pure scalar diagonal inversion, preserving the pre-#283 behaviour.
145 pub block_offsets: Arc<[Range<usize>]>,
146 /// Optional matrix-free penalty-side `H_ββ` operator (#296).
147 ///
148 /// When set, all hot paths (`schur_matvec`, `build_dense_schur_*`,
149 /// `JacobiPreconditioner`, quadratic-form reduction) route through this
150 /// operator instead of the dense `hbb` accumulator, enabling
151 /// `BlockPenaltyOp` / `KroneckerPenaltyOp` to skip the `O(K²)` dense
152 /// materialisation for structured smoothness penalties.
153 ///
154 /// When `None`, those paths fall back to wrapping `hbb` in a transient
155 /// `DensePenaltyOp` — identical observable behaviour, no new allocation
156 /// hot-path cost for callers that have not opted in.
157 pub penalty_op: Option<Arc<dyn BetaPenaltyOp>>,
158 /// Device-uploadable SAE Kronecker data for CUDA-resident reduced PCG.
159 ///
160 /// The generic matrix-free closures remain the authoritative CPU path. This
161 /// descriptor is installed only when SAE assembly has a matching CUDA sparse
162 /// representation for both `H_tβ` and `H_ββ`.
163 pub device_sae_pcg: Option<Arc<DeviceSaePcgData>>,
164 /// Registered Psi-tier analytic penalties whose Hessian couples *distinct*
165 /// latent rows (non-row-block-diagonal), captured by
166 /// [`Self::add_analytic_penalty_contributions`].
167 ///
168 /// These penalties (`TotalVariationPenalty`, `SheafConsistencyPenalty`,
169 /// block-orthogonality, …) produce off-row Hessian blocks `∂²P/∂t_i∂t_j`
170 /// (`i ≠ j`) that the arrow elimination — which assumes each `H_tt^(i)` is
171 /// independent of every other row — cannot represent. Their *gradient* is
172 /// still folded into `g_t` exactly like every other Psi penalty; only their
173 /// curvature is held here, applied during the solve as a full-latent
174 /// Hessian-vector product `P_cross · Δt` against the penalty's
175 /// `psd_majorizer_hvp`. When this vector is non-empty,
176 /// [`solve_arrow_newton_step_artifacts`] auto-selects the matrix-free
177 /// full-system PCG path (arrow block-diagonal inverse as preconditioner)
178 /// instead of the exact one-shot Schur elimination. When empty, the system
179 /// is purely row-block-diagonal and the exact Schur path is unchanged.
180 pub cross_row_penalties: Vec<CrossRowLatentPenalty>,
181 /// Optional row-local gauge directions for evidence-only Faddeev-Popov
182 /// deflation of an otherwise non-PD `H_tt` row block.
183 ///
184 /// These vectors live in each row's actual chart block, so compact SAE rows
185 /// and dense rows share the same factorization path. Ordinary Newton solves
186 /// ignore them; only undamped evidence factors with
187 /// `tolerate_ill_conditioning` set may stiffen a gauge-explained row
188 /// direction.
189 pub row_gauge_deflation: Option<ArrowRowGaugeDeflation>,
190 /// Optional exact cross-row IBP low-rank source (#1038). When set, the
191 /// factorization downdates the per-row logit-slot self term and layers the
192 /// exact rank-`R` Woodbury correction onto the evidence cache (value,
193 /// log-determinant, and θ/ρ-adjoint together). `None` for all non-IBP
194 /// systems — the row-block-diagonal arrow path is then unchanged.
195 pub ibp_cross_row: Option<IbpCrossRowSource>,
196}
197
198impl Clone for ArrowSchurSystem {
199 fn clone(&self) -> Self {
200 Self {
201 rows: self.rows.clone(),
202 hbb: self.hbb.clone(),
203 hbb_matvec: self.hbb_matvec.clone(),
204 htbeta_matvec: self.htbeta_matvec.clone(),
205 htbeta_transpose_matvec: self.htbeta_transpose_matvec.clone(),
206 htbeta_dense_supplement: self.htbeta_dense_supplement,
207 hbb_diag: self.hbb_diag.clone(),
208 gb: self.gb.clone(),
209 d: self.d,
210 row_dims: Arc::clone(&self.row_dims),
211 row_offsets: Arc::clone(&self.row_offsets),
212 k: self.k,
213 manifold_mode_fingerprint: self.manifold_mode_fingerprint,
214 row_hessian_fingerprint: self.row_hessian_fingerprint,
215 analytic_row_hessian_fingerprint: self.analytic_row_hessian_fingerprint,
216 block_offsets: Arc::clone(&self.block_offsets),
217 penalty_op: self.penalty_op.clone(),
218 device_sae_pcg: self.device_sae_pcg.clone(),
219 cross_row_penalties: self.cross_row_penalties.clone(),
220 row_gauge_deflation: self.row_gauge_deflation.clone(),
221 ibp_cross_row: self.ibp_cross_row.clone(),
222 }
223 }
224}
225
226/// A captured cross-row Psi-tier analytic penalty: the penalty kind plus the
227/// global-ρ slice (`rho_local`) it was registered with.
228///
229/// Holds an owned copy of the local ρ-axes so the penalty's
230/// [`AnalyticPenaltyKind::psd_majorizer_hvp`] can be evaluated during the
231/// matrix-free full-system solve without re-deriving the ρ layout. The penalty
232/// itself is an `Arc`-backed clone (cheap), so capturing it does not copy the
233/// penalty payload.
234#[derive(Clone)]
235pub struct CrossRowLatentPenalty {
236 /// The non-row-block-diagonal Psi penalty (e.g. `TotalVariationPenalty`).
237 pub penalty: AnalyticPenaltyKind,
238 /// The penalty's local ρ-axes (its slice of the global ρ vector).
239 pub rho_local: Array1<f64>,
240 /// The flat latent vector (`N·d`, row-major) the penalty's curvature was
241 /// linearized at — i.e. the `target_t` passed to
242 /// [`ArrowSchurSystem::add_analytic_penalty_contributions`]. The Hessian of
243 /// a nonlinear penalty (the smoothed-TV curvature weights `φ''(D t)`,
244 /// etc.) depends on this point, so `psd_majorizer_hvp` must be evaluated
245 /// against it for the Newton operator to be the true Hessian at the
246 /// current iterate.
247 pub target_t: Array1<f64>,
248}
249
250/// Exact cross-row low-rank IBP source (#1038): the per-column rank-one Hessian
251/// terms `H_(i,k),(j,k) = d_k·z'_ik·z'_jk` (for ALL `i,j`, including the `i=j`
252/// self term) that couple DISTINCT latent rows through a shared atom column `k`.
253///
254/// Stacking over rows, this is `H_full = H₀' + U D Uᵀ`, where:
255/// * `U` is `delta_t_len × R` with `U[g, k] = z'_ik` at the global latent index
256/// `g` of row `i`'s logit slot for atom `k` (zero elsewhere) — i.e. column `k`
257/// is supported on the atom-`k` logit slot of every row;
258/// * `D = diag(d_k)`, `d_k = w·s'_k` ([`gam_terms::analytic_penalties::IbpHessianDiagThirdChannels::cross_row_d`]);
259/// * `H₀'` is the assembled latent block-diagonal `H₀` with the per-row self
260/// term `d_k·z'_ik²` REMOVED from each logit-slot diagonal (the assembled
261/// `H₀` already carries it, so the FULL rank-one outer product `U D Uᵀ` —
262/// which re-adds the `i=j` diagonal — would double-count without this
263/// downdate). The determinant lemma `log det(I_R + D UᵀH₀'⁻¹U)` is only the
264/// exact rank-`R` correction against this no-self base.
265///
266/// The arrow elimination assumes each row's `H_tt^(i)` is independent of every
267/// other row, so it structurally cannot hold this coupling block-locally. The
268/// factorization owner (`solver::arrow_schur`) consumes this source to (a)
269/// downdate the per-row logit diagonal before factoring, (b) build `U`/`D` onto
270/// the resulting [`ArrowFactorCache`] as a [`CrossRowWoodbury`], and (c) apply
271/// the exact Woodbury correction to the value/curvature solve, the evidence
272/// log-determinant, and the θ/ρ-adjoint TOGETHER (they all describe the SAME
273/// `H_full`).
274#[derive(Clone, Debug, Default)]
275pub struct IbpCrossRowSource {
276 /// Number of atom columns `R` (the rank of the cross-row update).
277 pub r: usize,
278 /// `d_k = w·s'_k`, the scalar `D`-coefficient of column `k`. Length `R`.
279 pub d: Array1<f64>,
280 /// Per-row column entries `(global_t_index, atom_k, z'_ik)`: each tuple
281 /// places `z'_ik` at `U[global_t_index, atom_k]`. The `global_t_index` is
282 /// `row_offsets[i] + local_slot` for the row's logit slot of atom `k`. Only
283 /// nonzero entries are listed (one per active (row, atom) pair).
284 pub entries: Vec<(usize, usize, f64)>,
285}
286
287impl IbpCrossRowSource {
288 /// Build the dense `delta_t_len × R` factor `U` (each column supported on
289 /// its atom's per-row logit slots) from the sparse entry list.
290 pub(crate) fn dense_u(&self, delta_t_len: usize) -> Array2<f64> {
291 let mut u = Array2::<f64>::zeros((delta_t_len, self.r));
292 for &(g, k, z) in &self.entries {
293 u[[g, k]] += z;
294 }
295 u
296 }
297
298 /// Per-row-slot self-term downdate: returns, for each global latent index,
299 /// the scalar `Σ_k d_k·z'_ik²` to subtract from that logit slot's diagonal
300 /// so the factored base is `H₀'` (self term removed). Indexed by global
301 /// `delta_t` position.
302 pub(crate) fn self_term_downdate(&self, delta_t_len: usize) -> Array1<f64> {
303 let mut down = Array1::<f64>::zeros(delta_t_len);
304 for &(g, k, z) in &self.entries {
305 down[g] += self.d[k] * z * z;
306 }
307 down
308 }
309}
310
311impl ArrowSchurSystem {
312 /// Allocate an empty BA reduced-camera-system instance sized
313 /// `(N point/latent rows × d, K shared decoder parameters)`.
314 pub fn new(n: usize, d: usize, k: usize) -> Self {
315 Self::new_with_hbb(n, d, k, Array2::<f64>::zeros((k, k)))
316 }
317
318 /// Allocate an arrow system with no dense shared `H_ββ` block and with
319 /// per-row dense `H_tβ` slabs allocated at `htbeta_cols` columns.
320 pub fn new_with_empty_hbb_and_htbeta_cols(
321 n: usize,
322 d: usize,
323 k: usize,
324 htbeta_cols: usize,
325 ) -> Self {
326 let rows = (0..n)
327 .map(|_| ArrowRowBlock::new_with_htbeta_cols(d, htbeta_cols))
328 .collect();
329 let row_dims: Arc<[usize]> = (0..n).map(|_| d).collect::<Vec<_>>().into();
330 let row_offsets: Arc<[usize]> = (0..=n).map(|i| i * d).collect::<Vec<_>>().into();
331 Self {
332 rows,
333 hbb: Array2::<f64>::zeros((0, 0)),
334 hbb_matvec: None,
335 htbeta_matvec: None,
336 htbeta_transpose_matvec: None,
337 htbeta_dense_supplement: false,
338 hbb_diag: None,
339 gb: Array1::<f64>::zeros(k),
340 d,
341 row_dims,
342 row_offsets,
343 k,
344 manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
345 row_hessian_fingerprint: 0,
346 analytic_row_hessian_fingerprint: 0,
347 block_offsets: Arc::from([] as [Range<usize>; 0]),
348 penalty_op: None,
349 device_sae_pcg: None,
350 cross_row_penalties: Vec::new(),
351 row_gauge_deflation: None,
352 ibp_cross_row: None,
353 }
354 }
355
356 /// Allocate an arrow system using a caller-owned dense shared-block buffer.
357 /// The buffer must already have shape `(k, k)` and is zeroed in place before
358 /// use so callers can recycle it across assemblies without changing
359 /// numerics.
360 pub fn new_with_hbb(n: usize, d: usize, k: usize, hbb: Array2<f64>) -> Self {
361 Self::new_with_hbb_and_htbeta_cols(n, d, k, hbb, k)
362 }
363
364 /// Allocate an arrow system with a caller-owned dense shared-block buffer and
365 /// per-row dense `H_tβ` slabs allocated at `htbeta_cols` columns.
366 pub fn new_with_hbb_and_htbeta_cols(
367 n: usize,
368 d: usize,
369 k: usize,
370 mut hbb: Array2<f64>,
371 htbeta_cols: usize,
372 ) -> Self {
373 assert_eq!(hbb.dim(), (k, k));
374 hbb.fill(0.0);
375 let rows = (0..n)
376 .map(|_| ArrowRowBlock::new_with_htbeta_cols(d, htbeta_cols))
377 .collect();
378 let row_dims: Arc<[usize]> = (0..n).map(|_| d).collect::<Vec<_>>().into();
379 let row_offsets: Arc<[usize]> = (0..=n).map(|i| i * d).collect::<Vec<_>>().into();
380 Self {
381 rows,
382 hbb,
383 hbb_matvec: None,
384 htbeta_matvec: None,
385 htbeta_transpose_matvec: None,
386 htbeta_dense_supplement: false,
387 hbb_diag: None,
388 gb: Array1::<f64>::zeros(k),
389 d,
390 row_dims,
391 row_offsets,
392 k,
393 manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
394 row_hessian_fingerprint: 0,
395 analytic_row_hessian_fingerprint: 0,
396 block_offsets: Arc::from([] as [Range<usize>; 0]),
397 penalty_op: None,
398 device_sae_pcg: None,
399 cross_row_penalties: Vec::new(),
400 row_gauge_deflation: None,
401 ibp_cross_row: None,
402 }
403 }
404
405 /// Allocate an arrow system whose shared `H_ββ` block is supplied only as
406 /// a matrix-free operator for large BA InexactPCG.
407 ///
408 /// Direct and Square-Root BA modes require dense `hbb` and must not be
409 /// used with this constructor. The row-local `H_tβ` slabs remain explicit;
410 /// a future MegBA backend can replace those slab operations behind
411 /// [`BatchedBlockSolver`].
412 pub fn new_matrix_free_shared<F>(
413 n: usize,
414 d: usize,
415 k: usize,
416 matvec: F,
417 diag: Array1<f64>,
418 ) -> Self
419 where
420 F: for<'a> Fn(ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
421 {
422 assert_eq!(diag.len(), k);
423 let rows = (0..n).map(|_| ArrowRowBlock::new(d, k)).collect();
424 let row_dims: Arc<[usize]> = (0..n).map(|_| d).collect::<Vec<_>>().into();
425 let row_offsets: Arc<[usize]> = (0..=n).map(|i| i * d).collect::<Vec<_>>().into();
426 let matvec_arc: SharedBetaMatvec = Arc::new(matvec);
427 // Mirror the closure into a BetaPenaltyOp so all hot paths (#296)
428 // route through the trait while preserving hbb_matvec + hbb_diag for
429 // code that inspects them directly.
430 let penalty_op: Option<Arc<dyn BetaPenaltyOp>> = Some(Arc::new(MatvecDiagPenaltyOp::new(
431 k,
432 Arc::clone(&matvec_arc),
433 diag.clone(),
434 )));
435 Self {
436 rows,
437 hbb: Array2::<f64>::zeros((0, 0)),
438 hbb_matvec: Some(matvec_arc),
439 htbeta_matvec: None,
440 htbeta_transpose_matvec: None,
441 htbeta_dense_supplement: false,
442 hbb_diag: Some(diag),
443 gb: Array1::<f64>::zeros(k),
444 d,
445 row_dims,
446 row_offsets,
447 k,
448 manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
449 row_hessian_fingerprint: 0,
450 analytic_row_hessian_fingerprint: 0,
451 block_offsets: Arc::from([] as [Range<usize>; 0]),
452 penalty_op,
453 device_sae_pcg: None,
454 cross_row_penalties: Vec::new(),
455 row_gauge_deflation: None,
456 ibp_cross_row: None,
457 }
458 }
459
460
461
462 /// Allocate a heterogeneous-row arrow system with no dense shared `H_ββ`
463 /// block and with row `H_tβ` slabs allocated at `htbeta_cols` columns.
464 pub fn new_with_per_row_dims_empty_hbb_and_htbeta_cols(
465 per_row_dims: Vec<usize>,
466 k: usize,
467 htbeta_cols: usize,
468 ) -> Self {
469 let n = per_row_dims.len();
470 let d = per_row_dims.iter().copied().max().unwrap_or(0);
471 let mut offsets = Vec::with_capacity(n + 1);
472 let mut cursor = 0usize;
473 offsets.push(cursor);
474 for &dim in &per_row_dims {
475 cursor += dim;
476 offsets.push(cursor);
477 }
478 let rows = per_row_dims
479 .iter()
480 .map(|&dim| ArrowRowBlock::new_with_htbeta_cols(dim, htbeta_cols))
481 .collect();
482 Self {
483 rows,
484 hbb: Array2::<f64>::zeros((0, 0)),
485 hbb_matvec: None,
486 htbeta_matvec: None,
487 htbeta_transpose_matvec: None,
488 htbeta_dense_supplement: false,
489 hbb_diag: None,
490 gb: Array1::<f64>::zeros(k),
491 d,
492 row_dims: Arc::from(per_row_dims.into_boxed_slice()),
493 row_offsets: Arc::from(offsets.into_boxed_slice()),
494 k,
495 manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
496 row_hessian_fingerprint: 0,
497 analytic_row_hessian_fingerprint: 0,
498 block_offsets: Arc::from([] as [Range<usize>; 0]),
499 penalty_op: None,
500 device_sae_pcg: None,
501 cross_row_penalties: Vec::new(),
502 row_gauge_deflation: None,
503 ibp_cross_row: None,
504 }
505 }
506
507 /// Allocate a heterogeneous-row system using a caller-owned dense shared
508 /// block and row `H_tβ` slabs allocated at `htbeta_cols` columns.
509 pub fn new_with_per_row_dims_and_hbb_and_htbeta_cols(
510 per_row_dims: Vec<usize>,
511 k: usize,
512 mut hbb: Array2<f64>,
513 htbeta_cols: usize,
514 ) -> Self {
515 assert_eq!(hbb.dim(), (k, k));
516 hbb.fill(0.0);
517 let n = per_row_dims.len();
518 let max_d = per_row_dims.iter().copied().max().unwrap_or(0);
519 let row_dims: Arc<[usize]> = per_row_dims.iter().copied().collect::<Vec<_>>().into();
520 let mut off_vec = Vec::with_capacity(n + 1);
521 let mut cursor = 0usize;
522 for &di in &per_row_dims {
523 off_vec.push(cursor);
524 cursor += di;
525 }
526 off_vec.push(cursor);
527 let row_offsets: Arc<[usize]> = off_vec.into();
528 let rows = per_row_dims
529 .iter()
530 .map(|&di| ArrowRowBlock::new_with_htbeta_cols(di, htbeta_cols))
531 .collect();
532 Self {
533 rows,
534 hbb,
535 hbb_matvec: None,
536 htbeta_matvec: None,
537 htbeta_transpose_matvec: None,
538 htbeta_dense_supplement: false,
539 hbb_diag: None,
540 gb: Array1::<f64>::zeros(k),
541 d: max_d,
542 row_dims,
543 row_offsets,
544 k,
545 manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
546 row_hessian_fingerprint: 0,
547 analytic_row_hessian_fingerprint: 0,
548 block_offsets: Arc::from([] as [Range<usize>; 0]),
549 penalty_op: None,
550 device_sae_pcg: None,
551 cross_row_penalties: Vec::new(),
552 row_gauge_deflation: None,
553 ibp_cross_row: None,
554 }
555 }
556
557 pub fn set_row_gauge_deflation(&mut self, deflation: ArrowRowGaugeDeflation) {
558 self.row_gauge_deflation = Some(deflation);
559 }
560
561 /// Register the exact cross-row IBP low-rank source (#1038). The assembly
562 /// passes the per-column `D`-coefficients (`cross_row_d`) and the `(global
563 /// latent index, atom, z'_ik)` entries built from `z_jac`; the factorization
564 /// then carries the exact rank-`R` Woodbury (value + log-determinant +
565 /// θ/ρ-adjoint) on the evidence cache. An empty source (`r == 0` or no
566 /// entries) is treated as absent so the row-block-diagonal path is unchanged.
567 pub fn set_ibp_cross_row_source(&mut self, source: IbpCrossRowSource) {
568 if source.r == 0 || source.entries.is_empty() {
569 self.ibp_cross_row = None;
570 } else {
571 self.ibp_cross_row = Some(source);
572 }
573 }
574
575 /// Number of BA point/latent rows `N`.
576 pub fn n(&self) -> usize {
577 self.rows.len()
578 }
579
580 /// Recompute the row-system fingerprint from the currently materialized
581 /// row blocks, cross-blocks, and shared-block diagonal.
582 pub fn compute_row_hessian_fingerprint(&self) -> u64 {
583 row_hessian_fingerprint_for_system(self)
584 }
585
586 /// Current effective row-system fingerprint, including the materialized
587 /// row blocks and any registry metadata captured while folding analytic
588 /// penalties into the system.
589 pub fn current_row_hessian_fingerprint(&self) -> u64 {
590 combine_row_and_registry_fingerprints(
591 self.compute_row_hessian_fingerprint(),
592 self.analytic_row_hessian_fingerprint,
593 )
594 }
595
596 /// Store the current row-system fingerprint on the system.
597 ///
598 /// This is intentionally explicit and expensive. Cache and evidence callers
599 /// use [`Self::current_row_hessian_fingerprint`] at the point they need the
600 /// value, after assembly has populated the system, instead of hashing each
601 /// intermediate construction/mutation step.
602 pub fn refresh_row_hessian_fingerprint(&mut self) {
603 self.row_hessian_fingerprint = self.current_row_hessian_fingerprint();
604 }
605
606 /// Install a matrix-free shared-block operator for Agarwal-style
607 /// inexact Schur PCG.
608 ///
609 /// `diag` must be the diagonal of the same `H_ββ` operator and is used
610 /// for the Schur-Jacobi preconditioner. This is the BA "large camera
611 /// system" path mapped to large decoder coefficient blocks.
612 pub fn set_shared_beta_operator<F>(&mut self, matvec: F, diag: Array1<f64>)
613 where
614 F: for<'a> Fn(ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
615 {
616 assert_eq!(diag.len(), self.k);
617 let matvec_arc: SharedBetaMatvec = Arc::new(matvec);
618 // Mirror the closure into a BetaPenaltyOp so all hot paths (#296)
619 // route through the trait, preserving the existing hbb_matvec +
620 // hbb_diag fields for code that inspects them directly.
621 self.penalty_op = Some(Arc::new(MatvecDiagPenaltyOp::new(
622 self.k,
623 Arc::clone(&matvec_arc),
624 diag.clone(),
625 )));
626 self.hbb_matvec = Some(matvec_arc);
627 self.hbb_diag = Some(diag);
628 }
629
630 /// Mark the dense per-row cross-block slabs as active supplements to the
631 /// installed matrix-free row operator.
632 pub fn activate_dense_htbeta_supplement(&mut self) {
633 self.htbeta_dense_supplement = true;
634 }
635
636 /// Install a matrix-free per-row cross-block operator and its sparse
637 /// adjoint.
638 ///
639 /// `forward` must write `out = H_tβ^(row) x` for `out.len() == d` and
640 /// `x.len() == K`. `transpose` must **add** `H_βt^(row) v` into `out` for
641 /// `out.len() == K` and `v.len() == d` (the sparse `scatter` adjoint).
642 ///
643 /// When installed, the forward operator is used during the Newton solve
644 /// (inside `reduced_rhs_beta`, `schur_matvec`, back-substitution, and
645 /// `JacobiPreconditioner` construction) and afterwards by IFT/evidence
646 /// predictors. Per-row `htbeta` slabs in `ArrowRowBlock` may be left
647 /// zero-sized when this operator is installed — all inner-Schur paths route
648 /// through the matvec instead of indexing the dense block. The transpose
649 /// operator lets the reduced-Schur matvec apply `H_βt^(row)` directly
650 /// (`O(m_i · p)`) instead of probing `forward` against `K` basis vectors.
651 pub fn set_row_htbeta_operator<F, T>(&mut self, forward: F, transpose: T)
652 where
653 F: for<'a> Fn(usize, ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
654 T: for<'a> Fn(usize, ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
655 {
656 self.htbeta_matvec = Some(Arc::new(forward));
657 self.htbeta_transpose_matvec = Some(Arc::new(transpose));
658 }
659
660 /// Register term-block column ranges for the block-Jacobi Schur preconditioner.
661 ///
662 /// Each `Range<usize>` covers the columns of one GAM term (or custom
663 /// parameter family) in the shared `β` vector. The ranges must be
664 /// non-overlapping, sorted, and their union must cover `0..k`.
665 ///
666 /// Call this after building the system and before [`Self::solve`] /
667 /// [`Self::solve_with_options`] whenever the solver will use
668 /// [`ArrowSolverMode::InexactPCG`]. Absent a call, the preconditioner
669 /// falls back to scalar diagonal Jacobi (the pre-#283 behaviour).
670 ///
671 /// The same plumbing is compatible with #287 (custom `ParameterBlockSpec`
672 /// families): callers from that path simply supply ranges derived from
673 /// their own block layout.
674 pub fn set_block_offsets(&mut self, offsets: Arc<[Range<usize>]>) {
675 self.block_offsets = offsets;
676 }
677
678 /// Install a matrix-free penalty-side `H_ββ` operator (#296).
679 ///
680 /// When set, all hot paths (`schur_matvec`, `build_dense_schur_*`,
681 /// `JacobiPreconditioner`, quadratic-form reduction) route through this
682 /// operator instead of the dense `hbb` accumulator, enabling
683 /// `BlockPenaltyOp` / `KroneckerPenaltyOp` to avoid `O(K²)` allocation
684 /// for structured smoothness penalties.
685 pub fn set_penalty_op(&mut self, op: Arc<dyn BetaPenaltyOp>) {
686 self.penalty_op = Some(op);
687 }
688
689 pub fn set_device_sae_pcg_data(&mut self, data: DeviceSaePcgData) {
690 assert_eq!(data.beta_dim, self.k);
691 assert_eq!(data.a_phi.len(), self.rows.len());
692 assert_eq!(data.local_jac.len(), self.rows.len());
693 self.device_sae_pcg = Some(Arc::new(data));
694 }
695
696 /// Return the effective penalty operator: the installed `penalty_op` if
697 /// present, otherwise a `DensePenaltyOp` wrapping the current `hbb`.
698 ///
699 /// Note: when `penalty_op` is `None`, this clones `hbb` into a new
700 /// `DensePenaltyOp`. Callers in hot loops should call this once and
701 /// store the result, not call it per-iteration.
702 pub fn effective_penalty_op(&self) -> Arc<dyn BetaPenaltyOp> {
703 match self.penalty_op.as_ref() {
704 Some(op) => Arc::clone(op),
705 None => Arc::new(DensePenaltyOp(self.hbb.clone())),
706 }
707 }
708
709 /// `y += P x` without allocating a new Arc; dispatches to `penalty_op`
710 /// or falls back to `hbb` inline, avoiding the K×K clone hot-path cost.
711 #[inline]
712 pub(crate) fn penalty_matvec_add(&self, x: &[f64], y: &mut [f64]) {
713 if let Some(op) = self.penalty_op.as_ref() {
714 op.matvec(x, y);
715 } else {
716 let k = self.hbb.nrows();
717 for a in 0..k {
718 let mut acc = 0.0_f64;
719 for b in 0..k {
720 acc += self.hbb[[a, b]] * x[b];
721 }
722 y[a] += acc;
723 }
724 }
725 }
726
727 /// Reduced-Schur matvec prologue `y = (P + ridge·I) x` written fresh into a
728 /// zeroed `y` (the caller clears `out` first; this is the first writer).
729 ///
730 /// At the SAE LLM border width (#1017) the dense `H_ββ` fallback is a `k×k`
731 /// GEMV whose `O(k²)` cost (≈4M flops at k=2048) runs once per CG iteration
732 /// and was the serial Amdahl ceiling on the per-row-parallel matvec: while
733 /// the `n`-row point-elimination term fans across all cores, this prologue
734 /// pinned one core and grows as `k²`. The dense GEMV is embarrassingly
735 /// parallel over output rows `a` — each `y[a] = Σ_b hbb[a,b]·x[b] + ridge·x[a]`
736 /// is independent and its inner sum order is identical whether one thread or
737 /// many compute it. Here parallelism is over independent output rows (NOT a
738 /// reassociated reduction), so each `y[a]` accumulates in the SAME order as
739 /// serial — the result is bit-identical to serial, not merely deterministic
740 /// run-to-run (the #1017 determinism gate). On THIS exact-order path the
741 /// criterion ranking is invariant; that no-move guarantee holds because the
742 /// order matches serial, and does NOT generalise to chunk-reassociated
743 /// reductions, where a near-tie winner can flip within the f64 margin
744 /// (#1211). The `penalty_op` path stays serial — it is an opaque operator
745 /// with its own structure (SAE uses the dense `hbb`), and small `k` stays
746 /// serial to avoid rayon overhead on a trivial GEMV.
747 ///
748 /// `parallel` is the caller's top-level / not-nested-in-rayon decision (the
749 /// same guard the row loop uses), so this never oversubscribes inside the
750 /// topology race.
751 pub(crate) fn penalty_ridge_prologue_into(
752 &self,
753 x: &[f64],
754 ridge: f64,
755 y: &mut [f64],
756 parallel: bool,
757 ) {
758 let k = self.hbb.nrows();
759 let dense_parallel = parallel
760 && self.penalty_op.is_none()
761 && self.hbb.dim() == (k, k)
762 && k >= SCHUR_PROLOGUE_PARALLEL_K_MIN;
763 if dense_parallel {
764 use rayon::prelude::*;
765 let hbb = &self.hbb;
766 y.par_iter_mut().enumerate().for_each(|(a, ya)| {
767 let mut acc = 0.0_f64;
768 for b in 0..k {
769 acc += hbb[[a, b]] * x[b];
770 }
771 *ya = acc + ridge * x[a];
772 });
773 } else {
774 self.penalty_matvec_add(x, y);
775 for a in 0..k {
776 y[a] += ridge * x[a];
777 }
778 }
779 }
780
781 /// `diag += diag(P)` without allocating; dispatches to `penalty_op`
782 /// or falls back to `hbb` diagonal / `hbb_diag` inline.
783 #[inline]
784 pub(crate) fn penalty_diagonal_add(&self, diag: &mut [f64]) {
785 if let Some(op) = self.penalty_op.as_ref() {
786 op.diagonal(diag);
787 } else if let Some(hbb_diag) = self.hbb_diag.as_ref() {
788 let k = hbb_diag.len().min(diag.len());
789 for j in 0..k {
790 diag[j] += hbb_diag[j];
791 }
792 } else {
793 let k = self.hbb.nrows().min(diag.len());
794 for j in 0..k {
795 diag[j] += self.hbb[[j, j]];
796 }
797 }
798 }
799
800 /// Add the `b×b` penalty sub-block for `id` to `out`, routing through
801 /// `penalty_op` or falling back to `hbb` / `hbb_diag` inline.
802 #[inline]
803 pub(crate) fn penalty_block_add(
804 &self,
805 id: BetaBlockId,
806 offsets: &[Range<usize>],
807 out: &mut Array2<f64>,
808 ) {
809 if let Some(op) = self.penalty_op.as_ref() {
810 op.block(id, offsets, out);
811 } else {
812 let range = &offsets[id.0];
813 let b = range.end - range.start;
814 if self.hbb.dim() == (self.k, self.k) {
815 for bi in 0..b {
816 for bj in 0..b {
817 out[[bi, bj]] += self.hbb[[range.start + bi, range.start + bj]];
818 }
819 }
820 } else if let Some(hbb_diag) = self.hbb_diag.as_ref() {
821 for bi in 0..b {
822 out[[bi, bi]] += hbb_diag[range.start + bi];
823 }
824 }
825 }
826 }
827
828 /// Fill a `b×b` penalty sub-block for a set of arbitrary (possibly
829 /// non-contiguous) global column indices `cols`, routing through
830 /// `penalty_op` or falling back to `hbb` / `hbb_diag` inline.
831 ///
832 /// Used by the cluster-Jacobi preconditioner (#299) which groups columns
833 /// by spectral adjacency rather than contiguous block ranges.
834 #[inline]
835 pub(crate) fn penalty_subblock_add(&self, cols: &[usize], out: &mut Array2<f64>) {
836 let b = cols.len();
837 if let Some(op) = self.penalty_op.as_ref() {
838 // Probe each column basis vector and extract the sub-block entries.
839 let mut probe = Array1::<f64>::zeros(self.k);
840 let mut result = Array1::<f64>::zeros(self.k);
841 for bj in 0..b {
842 probe.fill(0.0);
843 probe[cols[bj]] = 1.0;
844 result.fill(0.0);
845 {
846 let p_slice = probe.as_slice().expect("probe contiguous");
847 let r_slice = result.as_slice_mut().expect("result contiguous");
848 op.matvec(p_slice, r_slice);
849 }
850 for bi in 0..b {
851 out[[bi, bj]] += result[cols[bi]];
852 }
853 }
854 } else if self.hbb.dim() == (self.k, self.k) {
855 for bi in 0..b {
856 for bj in 0..b {
857 out[[bi, bj]] += self.hbb[[cols[bi], cols[bj]]];
858 }
859 }
860 } else if let Some(hbb_diag) = self.hbb_diag.as_ref() {
861 for bi in 0..b {
862 out[[bi, bi]] += hbb_diag[cols[bi]];
863 }
864 }
865 }
866
867 /// Fold analytic-penalty contributions into the appropriate blocks.
868 ///
869 /// BA source mapping: these are extra prior/regularization normal-equation
870 /// terms before point elimination, the same place Ceres/g2o attach robust
871 /// priors or gauge-fixing constraints.
872 ///
873 /// **Composition path.** Each registered [`AnalyticPenaltyKind`] is
874 /// queried for `grad_target` (added to `g_t` or `g_β`) and then for
875 /// `hessian_diag` first. Diagonal penalties (ARD and the shipped
876 /// sparsity kernels) are injected directly. The row-block-only Psi-tier
877 /// penalties are `ARDPenalty`, `SparsityPenalty`,
878 /// `SoftmaxAssignmentSparsity`, `IBPAssignment`,
879 /// `RowPrecisionPrior`, `ParametricRowPrecisionPrior`, and
880 /// `ScadMcpPenalty`. Their `d × d` per-row Hessian folds into
881 /// `rows[i].htt`, so the exact arrow Schur elimination (`N` independent
882 /// `d × d` row solves) represents them exactly. Dense Beta-tier penalties
883 /// still fall back to `hvp` probes against the canonical basis vectors for
884 /// `β`.
885 ///
886 /// **Cross-row Psi penalties.** Penalties whose Hessian couples *distinct*
887 /// latent rows — `TotalVariationPenalty`, `SheafConsistencyPenalty`,
888 /// block-orthogonality, … — produce off-row blocks `∂²P/∂t_i∂t_j`
889 /// (`i ≠ j`) that the arrow elimination cannot store, since it assumes each
890 /// `H_tt^(i)` is independent of every other row. These are handled without
891 /// any approximation: their **gradient** is folded into `g_t` exactly as
892 /// for every other Psi penalty (`grad_target → g_t`), and their full
893 /// **curvature** is captured into [`Self::cross_row_penalties`] as a
894 /// matrix-free operator. At solve time, `K = K0 + P_cross` where `K0` is
895 /// the block-diagonal arrow operator and `P_cross · Δt = Σ_p ρ_p ·
896 /// psd_majorizer_hvp_p(t, Δt)` is the cross-row penalty Hessian applied to
897 /// the full flat latent vector. The presence of any captured cross-row
898 /// penalty auto-routes [`Self::solve`] through the matrix-free full-system
899 /// PCG path (the exact arrow block-diagonal inverse `K0⁻¹` is the
900 /// preconditioner `M⁻¹`); a purely row-block-diagonal system keeps the
901 /// exact one-shot Schur path unchanged. No new flag is involved — the route
902 /// is selected from the captured penalty set alone (magic by default).
903 ///
904 /// `target_t` is the full flat latent-coordinate vector (row-major, `N·d` entries)
905 /// at the current iterate; `target_beta` is the current `β`. `rho`
906 /// is the global ρ vector restricted to each penalty's local slice
907 /// by [`AnalyticPenaltyRegistry::rho_layout`].
908 pub fn add_analytic_penalty_contributions(
909 &mut self,
910 registry: &AnalyticPenaltyRegistry,
911 target_t: ArrayView1<'_, f64>,
912 target_beta: ArrayView1<'_, f64>,
913 rho_global: ArrayView1<'_, f64>,
914 ) -> Result<(), ArrowSchurError> {
915 let layout = registry.rho_layout();
916 let mut penalty_fingerprints = Vec::new();
917 self.cross_row_penalties.clear();
918 for (penalty, (rho_slice, tier, _name)) in registry.penalties.iter().zip(layout.iter()) {
919 let rho_local = rho_global.slice(ndarray::s![rho_slice.clone()]);
920 match tier {
921 PenaltyTier::Psi => {
922 if analytic_penalty_is_row_block_diagonal(penalty) {
923 // Row-block-diagonal: fold gradient + per-row d×d
924 // curvature into rows[i].htt, exactly representable by
925 // the arrow Schur elimination.
926 self.add_ext_coord_penalty(penalty, target_t, rho_local);
927 if let Some(fingerprint) =
928 analytic_penalty_row_hessian_fingerprint(penalty, target_t, rho_local)
929 {
930 penalty_fingerprints.push(fingerprint);
931 }
932 } else {
933 // Cross-row: fold the gradient into g_t (exact, like
934 // every Psi penalty), but DO NOT fold any curvature into
935 // the row blocks — its off-row coupling cannot be stored
936 // there. Capture the penalty so the solve applies its
937 // full Hessian-vector product P_cross·Δt over the flat
938 // latent vector. This auto-selects the matrix-free
939 // full-system PCG path.
940 self.add_ext_coord_penalty_gradient_only(penalty, target_t, rho_local);
941 self.cross_row_penalties.push(CrossRowLatentPenalty {
942 penalty: penalty.clone(),
943 rho_local: rho_local.to_owned(),
944 target_t: target_t.to_owned(),
945 });
946 }
947 }
948 PenaltyTier::Beta => {
949 self.add_beta_penalty(penalty, target_beta, rho_local);
950 }
951 PenaltyTier::Rho => {
952 // Rho-tier hyperpriors do not contribute to the inner
953 // (t, β) Newton step; they enter only at the REML
954 // outer level.
955 }
956 }
957 }
958 // Cross-row penalties contribute to the Newton Hessian operator, not
959 // the stored row blocks, so they must still invalidate the row-Hessian
960 // cache when their curvature changes. Probe each captured penalty's PSD
961 // majorizer against the current latent vector (a deterministic, generic
962 // probe) and fold the resulting signature in.
963 for cross in &self.cross_row_penalties {
964 penalty_fingerprints.push(cross_row_penalty_fingerprint(
965 &cross.penalty,
966 target_t,
967 cross.rho_local.view(),
968 ));
969 }
970 self.analytic_row_hessian_fingerprint = if penalty_fingerprints.is_empty() {
971 0
972 } else {
973 let mut hasher = Fingerprinter::new();
974 hasher.write_str("arrow-schur-row-hessian-registry-v1");
975 hasher.write_usize(penalty_fingerprints.len());
976 for fingerprint in penalty_fingerprints {
977 hasher.write_u64(fingerprint);
978 }
979 hasher.finish_u64()
980 };
981 Ok(())
982 }
983
984 /// Convert row-local Euclidean latent blocks to Riemannian tangent blocks.
985 ///
986 /// This is the only arrow-Schur algebra change needed for manifold
987 /// latents: `g_t`, `H_tt`, and each `H_tβ` column are projected to
988 /// `T_{t_i}M`, while the shared β block and Schur structure remain
989 /// untouched. Embedded constrained manifolds carry a pinned normal block
990 /// so the existing ambient Cholesky factorization still works; all RHS
991 /// terms live in the tangent space, so the solved update retracts cleanly.
992 pub fn apply_riemannian_latent_geometry(&mut self, latent: &LatentCoordValues) {
993 let manifold = latent.manifold();
994 self.manifold_mode_fingerprint = manifold_mode_fingerprint(latent);
995 if manifold.is_euclidean() {
996 return;
997 }
998 assert_eq!(latent.n_obs(), self.rows.len());
999 assert_eq!(latent.latent_dim(), self.d);
1000 for (i, row) in self.rows.iter_mut().enumerate() {
1001 let t_i = ArrayView1::from(latent.row(i));
1002 let gt_e = row.gt.clone();
1003 let htt_e = row.htt.clone();
1004 let htbeta_e = row.htbeta.clone();
1005 row.gt = manifold.project_gradient_to_tangent(t_i, gt_e.view());
1006 row.htt = manifold.riemannian_hessian_matrix(t_i, gt_e.view(), htt_e.view());
1007 row.htbeta = manifold.project_matrix_columns_to_gradient_tangent(
1008 t_i,
1009 gt_e.view(),
1010 htbeta_e.view(),
1011 );
1012 }
1013 }
1014
1015 pub(crate) fn add_ext_coord_penalty(
1016 &mut self,
1017 penalty: &AnalyticPenaltyKind,
1018 target_t: ArrayView1<'_, f64>,
1019 rho_local: ArrayView1<'_, f64>,
1020 ) {
1021 let d = self.d;
1022 let n = self.rows.len();
1023 apply_analytic_penalty(
1024 penalty,
1025 target_t,
1026 rho_local,
1027 n * d,
1028 d,
1029 self,
1030 |sys, flat, value| sys.rows[flat / d].gt[flat % d] += value,
1031 |sys, flat, value| sys.rows[flat / d].htt[[flat % d, flat % d]] += value,
1032 |a, probe| {
1033 for i in 0..n {
1034 probe[i * d + a] = 1.0;
1035 }
1036 },
1037 |sys, a, hv| {
1038 for i in 0..n {
1039 for b in 0..d {
1040 sys.rows[i].htt[[b, a]] += hv[i * d + b];
1041 }
1042 }
1043 },
1044 );
1045 }
1046
1047 /// Fold ONLY the latent gradient `grad_target → g_t` of an analytic
1048 /// penalty, leaving the row-block Hessian untouched.
1049 ///
1050 /// Used for cross-row Psi penalties: their gradient enters `g_t` exactly
1051 /// like every other Psi penalty, but their curvature must NOT be scattered
1052 /// into the per-row `H_tt^(i)` blocks (the diagonal piece would be
1053 /// double-counted and the off-row coupling cannot be stored there). The
1054 /// full curvature is instead applied as a matrix-free `P_cross · Δt`
1055 /// during the solve, via [`Self::cross_row_penalties`].
1056 pub(crate) fn add_ext_coord_penalty_gradient_only(
1057 &mut self,
1058 penalty: &AnalyticPenaltyKind,
1059 target_t: ArrayView1<'_, f64>,
1060 rho_local: ArrayView1<'_, f64>,
1061 ) {
1062 let d = self.d;
1063 let n = self.rows.len();
1064 assert_eq!(target_t.len(), n * d);
1065 let grad = penalty.grad_target(target_t, rho_local);
1066 for flat in 0..n * d {
1067 self.rows[flat / d].gt[flat % d] += grad[flat];
1068 }
1069 }
1070
1071 /// Apply the aggregate cross-row penalty Hessian `P_cross · v` over the
1072 /// full flat latent vector `v` (length `Σ_i row_dims[i]`), accumulating
1073 /// into `out`.
1074 ///
1075 /// `P_cross = Σ_p psd_majorizer_hvp_p(target_t, ·; ρ_p)` summed over every
1076 /// captured cross-row penalty. Each penalty's `psd_majorizer_hvp` is its
1077 /// exact (PSD) Hessian-vector product over the `N·d` flat latent vector —
1078 /// for `TotalVariationPenalty` this is `Dᵀ diag(φ''(D t)) D · v`, the
1079 /// graph/forward-difference Laplacian-style coupling that links distinct
1080 /// rows. The ρ scaling is already baked into each penalty's resolved
1081 /// weight, so no extra factor is applied here.
1082 ///
1083 /// This is only valid for homogeneous systems (every row of dimension
1084 /// `d`), the only shape cross-row latent penalties are defined on; the
1085 /// flat-index convention `flat = i·d + j` matches every penalty's
1086 /// `latent_dim`/row-major contract.
1087 pub(crate) fn apply_cross_row_penalty_hessian(
1088 &self,
1089 v: ArrayView1<'_, f64>,
1090 out: &mut Array1<f64>,
1091 ) {
1092 for cross in &self.cross_row_penalties {
1093 assert_eq!(cross.target_t.len(), v.len());
1094 let hv =
1095 cross
1096 .penalty
1097 .psd_majorizer_hvp(cross.target_t.view(), cross.rho_local.view(), v);
1098 assert_eq!(hv.len(), out.len());
1099 for i in 0..out.len() {
1100 out[i] += hv[i];
1101 }
1102 }
1103 }
1104
1105 pub(crate) fn add_beta_penalty(
1106 &mut self,
1107 penalty: &AnalyticPenaltyKind,
1108 target_beta: ArrayView1<'_, f64>,
1109 rho_local: ArrayView1<'_, f64>,
1110 ) {
1111 let k = self.k;
1112 let hvp_columns = if self.hbb.dim() == (k, k) { k } else { 0 };
1113 apply_analytic_penalty(
1114 penalty,
1115 target_beta,
1116 rho_local,
1117 k,
1118 hvp_columns,
1119 self,
1120 |sys, j, value| sys.gb[j] += value,
1121 |sys, j, value| {
1122 if sys.hbb.dim() == (k, k) {
1123 sys.hbb[[j, j]] += value;
1124 }
1125 if let Some(hbb_diag) = sys.hbb_diag.as_mut() {
1126 hbb_diag[j] += value;
1127 }
1128 },
1129 |j, probe| probe[j] = 1.0,
1130 |sys, j, hv| {
1131 for i in 0..k {
1132 sys.hbb[[i, j]] += hv[i];
1133 }
1134 // Keep `hbb_diag` consistent with the dense `hbb` Hessian when
1135 // both are populated (the dense-allocated path + a later
1136 // `set_shared_beta_operator` install). The HVP probe for
1137 // column `j` returns the full Hessian column, whose `j`-th
1138 // entry is the diagonal contribution of this penalty. Without
1139 // this mirror, the Jacobi Schur preconditioner — which prefers
1140 // `hbb_diag` over `hbb`'s diagonal — would silently use a
1141 // stale diagonal for any Beta-tier analytic penalty that
1142 // exposes only an HVP (no `hessian_diag`).
1143 if let Some(hbb_diag) = sys.hbb_diag.as_mut() {
1144 hbb_diag[j] += hv[j];
1145 }
1146 },
1147 );
1148 }
1149
1150 /// Schur-eliminate the per-row latent block and solve for `(Δt, Δβ, diag)`.
1151 ///
1152 /// This uses [`ArrowSolveOptions::automatic`]: BA dense RCS for
1153 /// `K <= 2000`, and Agarwal-style inexact Schur PCG above that size.
1154 /// Call [`ArrowSchurSystem::solve_with_options`] to force Square-Root BA
1155 /// or a specific inexact solve policy.
1156 ///
1157 /// Returns `(delta_t, delta_beta, PcgDiagnostics)` with `delta_t` flat
1158 /// row-major of length `N · d` and `delta_beta` of length `K`. The sign
1159 /// convention matches `solve_newton_direction_dense`: the returned
1160 /// increments satisfy the bordered system with RHS `[-g_t; -g_β]`, i.e.
1161 /// they are the *negated* solutions of the standard Newton-direction
1162 /// formulation. `PcgDiagnostics` is zero-valued for the Direct path and
1163 /// carries live counters (PCG iters, ridge escalations, residual) for
1164 /// InexactPCG.
1165 ///
1166 /// `ridge_t` and `ridge_beta` are nonnegative diagonal regularizers
1167 /// added to the latent and β blocks respectively before factorization
1168 /// — used by the LM damping outer wrapper to recover from near-singular
1169 /// inner steps. Pass `0.0` for both to obtain the unregularized
1170 /// Newton direction.
1171 pub fn solve(
1172 &self,
1173 ridge_t: f64,
1174 ridge_beta: f64,
1175 ) -> Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError> {
1176 let options = ArrowSolveOptions::automatic(self.k);
1177 solve_arrow_newton_step_core(self, ridge_t, ridge_beta, &options)
1178 }
1179
1180 /// Solve with the standard LM-style ridge escalation: if a per-row
1181 /// `H_tt + ridge_t·I` Cholesky pivot is non-PD, or the reduced Schur
1182 /// factor fails, geometrically grow both ridges and retry. This is the
1183 /// same Ceres-style proximal correction the Newton driver in
1184 /// `run_joint_fit_arrow_schur` performs around `solve`, lifted into the
1185 /// system itself so every entry point (predict OOS reconstruction,
1186 /// single-shot Newton refinement, …) is self-healing against the
1187 /// pathological per-row blocks produced by PCA-seeded latent
1188 /// coordinates on subset / new data — see #163 and #175.
1189 ///
1190 /// `ridge_t` / `ridge_beta` are the caller-nominal Tikhonov ridges; the
1191 /// escalation only adds extra damping on top of them when the factor
1192 /// fails. PCG / AdaptiveCorrection failures are left untouched because
1193 /// they are not factorization-recoverable.
1194 pub fn solve_with_lm_escalation(
1195 &self,
1196 ridge_t: f64,
1197 ridge_beta: f64,
1198 ) -> Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError> {
1199 let options = ArrowSolveOptions::automatic(self.k);
1200 solve_with_lm_escalation_inner(self, ridge_t, ridge_beta, &options)
1201 }
1202
1203 /// Solve with an explicit BA Schur mode, returning `(Δt, Δβ, PcgDiagnostics)`.
1204 ///
1205 /// [`ArrowSolverMode::Direct`] is the classic dense reduced-camera-system
1206 /// Cholesky path; [`ArrowSolverMode::SqrtBA`] forms the same dense system
1207 /// through Square-Root BA factors; [`ArrowSolverMode::InexactPCG`] runs
1208 /// inexact-step LM on the reduced system with Jacobi-preconditioned
1209 /// Steihaug-CG. `PcgDiagnostics` is zero-valued for Direct/SqrtBA and
1210 /// carries live counters for InexactPCG (iterations, matvec calls,
1211 /// preconditioner escalations, final relative residual, stopping reason).
1212 pub fn solve_with_options(
1213 &self,
1214 ridge_t: f64,
1215 ridge_beta: f64,
1216 options: &ArrowSolveOptions,
1217 ) -> Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError> {
1218 solve_arrow_newton_step_core(self, ridge_t, ridge_beta, options)
1219 }
1220}
1221
1222/// Chunked Schur assembler that never retains all row cross-blocks.
1223pub struct StreamingArrowSchur {
1224 pub n_rows: usize,
1225 /// Maximum per-row latent dim (upper bound for scratch buffers).
1226 pub d: usize,
1227 /// Per-row latent dims `row_dims[i] == rows[i].htt.nrows()`.
1228 pub row_dims: Arc<[usize]>,
1229 /// Flat-buffer row offsets: `row_offsets[i]` is the start of row `i` in
1230 /// `delta_t`; `row_offsets[n_rows]` is the total `delta_t` length.
1231 pub row_offsets: Arc<[usize]>,
1232 pub k: usize,
1233 pub chunk_size: usize,
1234 pub s_acc: Array2<f64>,
1235 pub(crate) rhs_acc: Array1<f64>,
1236 pub(crate) hbb: Array2<f64>,
1237 pub(crate) gb: Array1<f64>,
1238 pub(crate) row_builder: StreamingArrowRowBuilder,
1239 /// Procedural cross-block operator `H_tβ^(i) x`. When present, the dense
1240 /// per-row `H_tβ` slabs are never materialized: `accumulate_chunk` and
1241 /// `back_substitute` probe this operator column-by-column to apply the
1242 /// cross-block, matching the Kronecker / matrix-free assembly path. When
1243 /// `None` (legacy dense BA callers), the per-row `row.htbeta` slab is used.
1244 pub(crate) htbeta_matvec: Option<RowHtbetaMatvec>,
1245 /// Sparse adjoint of `htbeta_matvec`. When present, `row_htbeta` rebuilds
1246 /// the dense `(d_i × K)` cross-block by probing the transpose with `d_i`
1247 /// basis vectors — `O(d_i · m_i · p)` total, vs the `O(K · m_i · p)` cost
1248 /// of probing the forward operator with `K` basis vectors. Since
1249 /// `d_i ≪ K`, this is the per-row sparse apply that replaces the `O(K)`
1250 /// column-probe in the streaming reduced-Schur accumulation.
1251 pub(crate) htbeta_transpose_matvec: Option<RowHtbetaTransposeMatvec>,
1252 /// Lift the per-row κ rejection for evidence/log-det-only solves; see
1253 /// [`ArrowSolveOptions::tolerate_ill_conditioning`]. Set by [`Self::solve`]
1254 /// from the options; defaults to `false` so direct callers of
1255 /// [`Self::accumulate_chunk`] keep the full guard.
1256 pub(crate) tolerate_ill_conditioning: bool,
1257 /// Set when the source system carried an exact cross-row IBP source
1258 /// ([`IbpCrossRowSource`], #1038). The streaming chunked accumulator cannot
1259 /// hold the rank-`R` Woodbury correction chunk-locally — `U`'s columns span
1260 /// ALL rows, so the capacitance `I_R + D Uᵀ H₀'⁻¹ U` needs the per-row
1261 /// factors retained for a global `H₀'⁻¹U` back-solve, which is exactly the
1262 /// `(N·K)`-scale residency the streaming path exists to avoid. Rather than
1263 /// silently DROP the cross-row term (an inexact logdet that would desync
1264 /// from the dense-resident gradient), the streaming log-determinant errors
1265 /// loudly when this is set, forcing IBP-active fits onto the dense resident
1266 /// [`ArrowFactorCache::arrow_log_det`] path (which carries the exact
1267 /// Woodbury). See the #1038 streaming note.
1268 pub(crate) ibp_cross_row_active: bool,
1269 /// SAE manifold evidence-path per-row gauge deflation, copied from the
1270 /// source [`ArrowSchurSystem::row_gauge_deflation`] (#1273/#1377). When
1271 /// present, the streaming per-row factor MUST apply the SAME spectral
1272 /// discovery-and-deflation of an intrinsic-dimension-flat `H_tt^(i)`
1273 /// direction (eigenvalue → +1, ρ-independent `log 1 = 0` evidence) that the
1274 /// dense [`factor_blocks_for_system`] path applies, or the two routes report
1275 /// different log-determinants for the SAME system — the cross-route
1276 /// invariant `streaming_logdet == full_logdet` would break (the #1377
1277 /// regression: #1273 wired the deflation into the dense path only). `None`
1278 /// for every non-evidence caller, which keeps the strict non-PD refusal.
1279 pub(crate) row_gauge_deflation: Option<ArrowRowGaugeDeflation>,
1280 /// The exact cross-row IBP source ([`IbpCrossRowSource`], #1038), cloned from
1281 /// the assembled [`ArrowSchurSystem`]. The bare streaming paths
1282 /// ([`Self::solve`] / [`Self::reduced_schur_and_log_det_tt`]) still REFUSE
1283 /// when this is present (they cannot carry the rank-`R` Woodbury), but the
1284 /// evidence-only [`Self::reduced_schur_log_det_tt_woodbury`] consumes it to
1285 /// accumulate the chunk-additive Woodbury building blocks `M0 = Uᵀ A⁻¹ U`
1286 /// and `W = Bᵀ A⁻¹ U` against the NO-SELF base `H₀'` (self term downdated),
1287 /// which the caller closes into the exact `log det(I_R + D Uᵀ H₀'⁻¹ U)` via
1288 /// [`streaming_cross_row_woodbury_log_det`].
1289 pub(crate) ibp_cross_row: Option<IbpCrossRowSource>,
1290}
1291
1292/// One chunk's contribution to the streaming cross-row IBP Woodbury (#1038).
1293///
1294/// The capacitance `C = I_R + D·M` with `M = Uᵀ H₀'⁻¹ U` is chunk-additive in
1295/// its two ingredients (`A = H₀'` is block-diagonal over rows, `U` is supported
1296/// per-row): `M = M0 + Wᵀ S⁻¹ W` where `M0 = Σ_i Uᵢᵀ Aᵢ⁻¹ Uᵢ` (`R×R`) and
1297/// `W = Σ_i Bᵢᵀ Aᵢ⁻¹ Uᵢ` (`k×R`) accumulate row-by-row, and `S` is the final
1298/// GLOBAL reduced Schur. This carries one chunk's `M0`/`W` plus the per-atom
1299/// `D`; the caller sums them and closes the capacitance after the chunk loop.
1300#[derive(Debug, Clone)]
1301pub struct StreamingWoodburyChunk {
1302 /// `Σ_{i∈chunk} Uᵢᵀ Aᵢ⁻¹ Uᵢ`, `R×R`.
1303 pub m0: Array2<f64>,
1304 /// `Σ_{i∈chunk} Bᵢᵀ Aᵢ⁻¹ Uᵢ`, `k×R` (`k` = reduced β border).
1305 pub w: Array2<f64>,
1306 /// Per-atom `D`-coefficients `d_k = w·s'_k`, length `R`.
1307 pub d: Array1<f64>,
1308}
1309
1310impl std::fmt::Debug for StreamingArrowSchur {
1311 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1312 f.debug_struct("StreamingArrowSchur")
1313 .field("n_rows", &self.n_rows)
1314 .field("d", &self.d)
1315 .field("k", &self.k)
1316 .field("chunk_size", &self.chunk_size)
1317 .finish_non_exhaustive()
1318 }
1319}
1320
1321impl StreamingArrowSchur {
1322 #[must_use]
1323 pub fn new(
1324 n_rows: usize,
1325 d: usize,
1326 row_dims: Arc<[usize]>,
1327 row_offsets: Arc<[usize]>,
1328 k: usize,
1329 hbb: Array2<f64>,
1330 gb: Array1<f64>,
1331 row_builder: StreamingArrowRowBuilder,
1332 chunk_size: usize,
1333 ) -> Self {
1334 assert_eq!(hbb.dim(), (k, k));
1335 assert_eq!(gb.len(), k);
1336 Self {
1337 n_rows,
1338 d,
1339 row_dims,
1340 row_offsets,
1341 k,
1342 chunk_size: chunk_size.max(1),
1343 s_acc: Array2::<f64>::zeros((k, k)),
1344 rhs_acc: Array1::<f64>::zeros(k),
1345 hbb,
1346 gb,
1347 row_builder,
1348 htbeta_matvec: None,
1349 htbeta_transpose_matvec: None,
1350 tolerate_ill_conditioning: false,
1351 ibp_cross_row_active: false,
1352 row_gauge_deflation: None,
1353 ibp_cross_row: None,
1354 }
1355 }
1356
1357 #[must_use]
1358 pub fn from_system(sys: &ArrowSchurSystem, chunk_size: usize) -> Self {
1359 // When a Kronecker / matrix-free htbeta_matvec is installed, the dense
1360 // row.htbeta slabs may be zero-sized. Rather than materialize every
1361 // `(d × K)` slab (the very `(N·K)`-scale buffer the streaming path
1362 // exists to avoid), retain the procedural operator and probe it per row
1363 // inside `accumulate_chunk` / `back_substitute`. The row builder then
1364 // only carries the small `H_tt` / `g_t` blocks.
1365 let htbeta_matvec = sys.htbeta_matvec.clone();
1366 let rows: Vec<ArrowRowBlock> = if htbeta_matvec.is_some() {
1367 sys.rows
1368 .iter()
1369 .map(|row| ArrowRowBlock {
1370 htt: row.htt.clone(),
1371 htbeta: Array2::<f64>::zeros((0, 0)),
1372 gt: row.gt.clone(),
1373 })
1374 .collect()
1375 } else {
1376 sys.rows.clone()
1377 };
1378 let rows = Arc::new(rows);
1379 let row_builder: StreamingArrowRowBuilder = Arc::new(move |row| {
1380 rows.get(row)
1381 .cloned()
1382 .ok_or_else(|| ArrowSchurError::SchurFactorFailed {
1383 reason: format!("streaming row {row} out of bounds"),
1384 })
1385 });
1386 // Materialize the dense β-block from the effective penalty operator so
1387 // the streaming accumulator stays correct when contributions live in a
1388 // structured `BetaPenaltyOp` (e.g. the SAE data-fit Gauss-Newton block,
1389 // represented as `G ⊗ I_p`) rather than the dense `hbb` accumulator.
1390 // When no `penalty_op` is installed this reduces to `hbb.clone()`.
1391 let hbb_dense = sys.effective_penalty_op().to_dense();
1392 let mut streaming = Self::new(
1393 sys.rows.len(),
1394 sys.d,
1395 Arc::clone(&sys.row_dims),
1396 Arc::clone(&sys.row_offsets),
1397 sys.k,
1398 hbb_dense,
1399 sys.gb.clone(),
1400 row_builder,
1401 chunk_size,
1402 );
1403 streaming.htbeta_matvec = htbeta_matvec;
1404 streaming.htbeta_transpose_matvec = sys.htbeta_transpose_matvec.clone();
1405 streaming.ibp_cross_row_active = sys.ibp_cross_row.is_some();
1406 streaming.ibp_cross_row = sys.ibp_cross_row.clone();
1407 // Carry the SAE evidence-path per-row gauge deflation so the streaming
1408 // per-row factor matches the dense `factor_blocks_for_system` exactly
1409 // (#1377): without it, a row with an intrinsic-dimension-flat `H_tt`
1410 // would deflate on the dense path but be refused / log-det-divergent on
1411 // the streaming path, breaking `streaming_logdet == full_logdet`.
1412 streaming.row_gauge_deflation = sys.row_gauge_deflation.clone();
1413 streaming
1414 }
1415
1416 /// Factor one streaming row's `H_tt^(i)`, applying the SAME per-row gauge /
1417 /// spectral deflation the dense [`factor_blocks_for_system`] path applies
1418 /// when this is the SAE manifold evidence path (an installed
1419 /// `row_gauge_deflation`). For every non-evidence caller this is exactly the
1420 /// generic [`factor_one_row`] (strict non-PD refusal), so PD blocks are
1421 /// bit-for-bit unchanged. Routing both the dense and streaming per-row
1422 /// factors through the identical recovery is what keeps their
1423 /// log-determinants identical (#1273/#1377).
1424 fn factor_row(
1425 &self,
1426 row: &ArrowRowBlock,
1427 ridge_t: f64,
1428 di: usize,
1429 row_idx: usize,
1430 ) -> Result<Array2<f64>, ArrowSchurError> {
1431 match self.row_gauge_deflation.as_ref() {
1432 Some(deflation) => factor_one_row_result(
1433 row,
1434 ridge_t,
1435 di,
1436 row_idx,
1437 self.tolerate_ill_conditioning,
1438 deflation.row(row_idx),
1439 // Evidence path: opt into spectral discovery of an
1440 // intrinsic-dimension-flat direction even when this row's
1441 // supplied gauge list is empty/non-spanning — matching the
1442 // `allow_spectral_deflation = true` the dense path passes.
1443 true,
1444 )
1445 .map(|result| result.factor),
1446 None => factor_one_row(row, ridge_t, di, row_idx, self.tolerate_ill_conditioning),
1447 }
1448 }
1449
1450 /// Build the `(di × k)` cross-block for `row_idx` on demand.
1451 ///
1452 /// When the sparse transpose adjoint is installed, probes it with `di`
1453 /// standard basis vectors — each yields a full `K`-row of `H_βt^(i)`
1454 /// (i.e. a row of the `(di × k)` block) via the sparse scatter, for
1455 /// `O(di · m_i · p)` total, far below the `O(K · m_i · p)` cost of probing
1456 /// the forward operator with `K` basis vectors when `di ≪ K`.
1457 ///
1458 /// When only the forward operator is installed (no adjoint), falls back to
1459 /// the `k`-column forward probe. Otherwise clones the dense `row.htbeta`
1460 /// slab.
1461 pub(crate) fn row_htbeta(&self, row_idx: usize, row: &ArrowRowBlock, di: usize) -> Array2<f64> {
1462 if let Some(op_t) = self.htbeta_transpose_matvec.as_ref() {
1463 // Probe the adjoint: for each latent index c, scatter e_c to obtain
1464 // row c of the (di × k) block.
1465 let mut mat = Array2::<f64>::zeros((di, self.k));
1466 let mut e_c = Array1::<f64>::zeros(di);
1467 let mut beta_row = Array1::<f64>::zeros(self.k);
1468 for c in 0..di {
1469 e_c.fill(0.0);
1470 e_c[c] = 1.0;
1471 beta_row.fill(0.0);
1472 op_t(row_idx, e_c.view(), &mut beta_row);
1473 for a in 0..self.k {
1474 mat[[c, a]] = beta_row[a];
1475 }
1476 }
1477 return mat;
1478 }
1479 match self.htbeta_matvec.as_ref() {
1480 Some(op) => {
1481 let mut mat = Array2::<f64>::zeros((di, self.k));
1482 let mut e_a = Array1::<f64>::zeros(self.k);
1483 let mut col = Array1::<f64>::zeros(di);
1484 for a in 0..self.k {
1485 e_a.fill(0.0);
1486 e_a[a] = 1.0;
1487 col.fill(0.0);
1488 op(row_idx, e_a.view(), &mut col);
1489 for c in 0..di {
1490 mat[[c, a]] = col[c];
1491 }
1492 }
1493 mat
1494 }
1495 None => row.htbeta.clone(),
1496 }
1497 }
1498
1499 /// Move out the accumulated reduced Schur block `s_acc` and reduced RHS
1500 /// `rhs_acc`, leaving fresh zero buffers in their place.
1501 ///
1502 /// The reduced contribution is `s_acc = hbb − Σ_i H_βt^(i)(H_tt^(i))⁻¹H_tβ^(i)`
1503 /// (the β-block `hbb` seeded by `reset_accumulator`, minus the per-row
1504 /// reduction summed by `accumulate_chunk`) and
1505 /// `rhs_acc = +Σ_i H_βt^(i)(H_tt^(i))⁻¹g_t^(i)`. Used by external online
1506 /// drivers (e.g. the SAE streaming joint fit) that accumulate the reduced
1507 /// system across re-materialized chunk systems.
1508 #[must_use]
1509 pub fn take_accumulators(&mut self) -> (Array2<f64>, Array1<f64>) {
1510 let s = std::mem::replace(&mut self.s_acc, Array2::<f64>::zeros((self.k, self.k)));
1511 let rhs = std::mem::replace(&mut self.rhs_acc, Array1::<f64>::zeros(self.k));
1512 (s, rhs)
1513 }
1514
1515 /// Reset the dense shared accumulator to `H_ββ + ridge_beta I`.
1516 pub fn reset_accumulator(&mut self, ridge_beta: f64) -> Result<(), ArrowSchurError> {
1517 if self.hbb.dim() != (self.k, self.k) {
1518 return Err(ArrowSchurError::SchurFactorFailed {
1519 reason: "streaming Arrow-Schur requires a dense beta block accumulator".to_string(),
1520 });
1521 }
1522 self.s_acc.assign(&self.hbb);
1523 for j in 0..self.k {
1524 self.s_acc[[j, j]] += ridge_beta;
1525 self.rhs_acc[j] = 0.0;
1526 }
1527 Ok(())
1528 }
1529
1530 /// Accumulate rows `[start, end)` into the reduced RHS and Schur block.
1531 pub fn accumulate_chunk(
1532 &mut self,
1533 start: usize,
1534 end: usize,
1535 ridge_t: f64,
1536 mode: ArrowSolverMode,
1537 ) -> Result<(), ArrowSchurError> {
1538 if start > end || end > self.n_rows {
1539 return Err(ArrowSchurError::SchurFactorFailed {
1540 reason: format!(
1541 "streaming Arrow-Schur chunk [{start}, {end}) outside 0..{}",
1542 self.n_rows
1543 ),
1544 });
1545 }
1546 let backend = CpuBatchedBlockSolver;
1547 let k = self.k;
1548 // Per-row factor + two block solves + a `k×k` GEMM subtract is the whole
1549 // assembly cost at the SAE LLM shape (#1017); the rows are independent so
1550 // the chunk fans across cores. Stay sequential for the handful-of-rows
1551 // non-SAE callers, or when already inside a rayon worker (the topology
1552 // race fans candidates with `run_topology_race_parallel`) to avoid
1553 // nested-rayon oversubscription — the same gate `schur_matvec` uses.
1554 let parallel = (end - start) >= SCHUR_MATVEC_PARALLEL_ROW_MIN
1555 && rayon::current_thread_index().is_none();
1556 if parallel {
1557 use rayon::prelude::*;
1558 const CHUNK: usize = 64;
1559 // Bind `&self` so the per-row body borrows only the immutable
1560 // streaming state. Each row contributes `+H_βt^(i)(H_tt^(i))⁻¹ g_t^(i)`
1561 // (length `k`) to the reduced RHS and `−H_βt^(i)(H_tt^(i))⁻¹ H_tβ^(i)`
1562 // (`k×k`) to the reduced Schur complement; both are written INTO a
1563 // worker-private `(rhs_part, s_part)` pair so the chunk partials fold
1564 // back in chunk order — bit-identical run-to-run regardless of thread
1565 // scheduling (the #1017 verification gate). The chunk fold reassociates
1566 // the row sum relative to serial, so the criterion ranking is stable
1567 // only up to that f64 margin; a near-tie winner inside the margin can
1568 // flip — not an exact no-move guarantee (#1211).
1569 let this: &Self = self;
1570 let row_into = |row_idx: usize,
1571 rhs_part: &mut Array1<f64>,
1572 s_part: &mut Array2<f64>|
1573 -> Result<(), ArrowSchurError> {
1574 let row = (this.row_builder)(row_idx)?;
1575 let di = row.htt.nrows();
1576 this.validate_row(row_idx, &row)?;
1577 let htbeta = this.row_htbeta(row_idx, &row, di);
1578 let factor = this.factor_row(&row, ridge_t, di, row_idx)?;
1579 let v = backend.solve_block_vector(factor.view(), row.gt.view());
1580 for c in 0..di {
1581 let vc = v[c];
1582 if vc == 0.0 {
1583 continue;
1584 }
1585 for a in 0..k {
1586 rhs_part[a] += htbeta[[c, a]] * vc;
1587 }
1588 }
1589 match mode {
1590 // InexactPCG differs from Direct only in how the *reduced*
1591 // system is solved, not in how it is assembled, so it shares
1592 // the dense Schur subtraction here (see the serial branch).
1593 ArrowSolverMode::Direct | ArrowSolverMode::InexactPCG => {
1594 let solved = backend.solve_block_matrix(factor.view(), htbeta.view());
1595 backend.block_gemm_subtract(s_part, &htbeta, &solved);
1596 }
1597 ArrowSolverMode::SqrtBA => {
1598 let whitened =
1599 backend.sqrt_solve_block_matrix(factor.view(), htbeta.view());
1600 backend.block_gemm_subtract(s_part, &whitened, &whitened);
1601 }
1602 }
1603 Ok(())
1604 };
1605 let partials: Vec<(Array1<f64>, Array2<f64>)> = (start..end)
1606 .into_par_iter()
1607 .chunks(CHUNK)
1608 .map(|idxs| {
1609 let mut rhs_part = Array1::<f64>::zeros(k);
1610 let mut s_part = Array2::<f64>::zeros((k, k));
1611 for i in idxs {
1612 row_into(i, &mut rhs_part, &mut s_part)?;
1613 }
1614 Ok::<_, ArrowSchurError>((rhs_part, s_part))
1615 })
1616 .collect::<Result<Vec<_>, _>>()?;
1617 // Deterministic ordered reduction: fold chunk partials left-to-right.
1618 // `block_gemm_subtract` already subtracted into each `s_part`, so the
1619 // partials carry the negative Schur contribution; add them in.
1620 for (rhs_part, s_part) in &partials {
1621 for a in 0..k {
1622 self.rhs_acc[a] += rhs_part[a];
1623 }
1624 self.s_acc += s_part;
1625 }
1626 } else {
1627 // Serial path accumulates DIRECTLY into the running `self.{rhs,s}_acc`
1628 // (which carry the `reset_accumulator` seed `H_ββ + ridge·I`), exactly
1629 // as before — bit-for-bit unchanged for the handful-of-rows callers.
1630 for row_idx in start..end {
1631 let row = (self.row_builder)(row_idx)?;
1632 let di = row.htt.nrows();
1633 self.validate_row(row_idx, &row)?;
1634 let htbeta = self.row_htbeta(row_idx, &row, di);
1635 let factor = self.factor_row(&row, ridge_t, di, row_idx)?;
1636 let v = backend.solve_block_vector(factor.view(), row.gt.view());
1637 for c in 0..di {
1638 let vc = v[c];
1639 if vc == 0.0 {
1640 continue;
1641 }
1642 for a in 0..k {
1643 self.rhs_acc[a] += htbeta[[c, a]] * vc;
1644 }
1645 }
1646 match mode {
1647 ArrowSolverMode::Direct | ArrowSolverMode::InexactPCG => {
1648 let solved = backend.solve_block_matrix(factor.view(), htbeta.view());
1649 backend.block_gemm_subtract(&mut self.s_acc, &htbeta, &solved);
1650 }
1651 ArrowSolverMode::SqrtBA => {
1652 let whitened =
1653 backend.sqrt_solve_block_matrix(factor.view(), htbeta.view());
1654 backend.block_gemm_subtract(&mut self.s_acc, &whitened, &whitened);
1655 }
1656 }
1657 }
1658 }
1659 Ok(())
1660 }
1661
1662 /// Compute the exact arrow Hessian log-determinant by accumulating the
1663 /// reduced Schur complement in row chunks, without retaining the full set
1664 /// of per-row Cholesky factors.
1665 ///
1666 /// This is the streaming analogue of [`ArrowFactorCache::arrow_log_det`]:
1667 ///
1668 /// ```text
1669 /// log|H| = Σ_i log|H_tt^(i)| + log|H_ββ - Σ_i H_βt^(i) H_tt^(i)⁻¹ H_tβ^(i)|.
1670 /// ```
1671 ///
1672 /// The same row builder and procedural `H_tβ` callbacks used by the
1673 /// streaming Newton solve are consumed here, so callers can score REML
1674 /// evidence without materialising the full `(N × q × K)` cross block or
1675 /// the full list of row factors.
1676 pub fn reduced_schur_and_log_det_tt(
1677 &mut self,
1678 ridge_t: f64,
1679 ridge_beta: f64,
1680 options: &ArrowSolveOptions,
1681 ) -> Result<(f64, Array2<f64>), ArrowSchurError> {
1682 if self.ibp_cross_row_active {
1683 return Err(ArrowSchurError::SchurFactorFailed {
1684 reason: "streaming arrow log-det cannot carry the exact cross-row IBP \
1685 Woodbury correction (#1038): U's columns span all rows, so the \
1686 rank-R capacitance needs the per-row factors retained — the very \
1687 (N·K) residency the streaming path avoids. Route IBP-active fits \
1688 through the dense resident ArrowFactorCache::arrow_log_det instead."
1689 .to_string(),
1690 });
1691 }
1692 self.tolerate_ill_conditioning = options.tolerate_ill_conditioning;
1693 self.reset_accumulator(ridge_beta)?;
1694 let backend = CpuBatchedBlockSolver;
1695 let mut log_det_tt = 0.0_f64;
1696 for start in (0..self.n_rows).step_by(self.chunk_size) {
1697 let end = (start + self.chunk_size).min(self.n_rows);
1698 for row_idx in start..end {
1699 let row = (self.row_builder)(row_idx)?;
1700 let di = row.htt.nrows();
1701 self.validate_row(row_idx, &row)?;
1702 let htbeta = self.row_htbeta(row_idx, &row, di);
1703 let factor = self.factor_row(&row, ridge_t, di, row_idx)?;
1704 for axis in 0..di {
1705 log_det_tt += 2.0 * factor[[axis, axis]].ln();
1706 }
1707 match options.mode {
1708 ArrowSolverMode::Direct | ArrowSolverMode::InexactPCG => {
1709 let solved = backend.solve_block_matrix(factor.view(), htbeta.view());
1710 backend.block_gemm_subtract(&mut self.s_acc, &htbeta, &solved);
1711 }
1712 ArrowSolverMode::SqrtBA => {
1713 let whitened =
1714 backend.sqrt_solve_block_matrix(factor.view(), htbeta.view());
1715 backend.block_gemm_subtract(&mut self.s_acc, &whitened, &whitened);
1716 }
1717 }
1718 }
1719 }
1720 symmetrize_upper_from_lower(&mut self.s_acc);
1721 let schur = std::mem::replace(&mut self.s_acc, Array2::<f64>::zeros((self.k, self.k)));
1722 Ok((log_det_tt, schur))
1723 }
1724
1725 /// As [`Self::reduced_schur_and_log_det_tt`], but ALSO accumulates the exact
1726 /// cross-row IBP Woodbury building blocks (#1038) when this streaming system
1727 /// carried an [`IbpCrossRowSource`].
1728 ///
1729 /// Mirrors the dense `factor_blocks_for_system` + [`CrossRowWoodbury::build`]
1730 /// path exactly:
1731 ///
1732 /// * the per-row logit self term `Σ_k d_k·z'_ik²`
1733 /// ([`IbpCrossRowSource::self_term_downdate`]) is subtracted from each
1734 /// `H_tt^(i)` BEFORE factoring, so the factored base — and therefore the
1735 /// returned `log_det_tt`/Schur — is the NO-SELF `H₀'` (the dense path
1736 /// factors `H₀'` too, then layers the rank-`R` Woodbury);
1737 /// * for each row it forms `Aᵢ⁻¹ Uᵢ` (one block solve against the row's
1738 /// `H₀'` factor, `Uᵢ` the J-weighted atom-column indicator) and adds the
1739 /// row's contributions to `M0 = Σ Uᵢᵀ Aᵢ⁻¹ Uᵢ` (`R×R`) and
1740 /// `W = Σ Bᵢᵀ Aᵢ⁻¹ Uᵢ` (`k×R`, `Bᵢ = H_tβ^(i)`).
1741 ///
1742 /// The caller sums `M0`/`W`/`log_det_tt`/Schur over chunks and closes the
1743 /// capacitance `M = M0 + Wᵀ S⁻¹ W`, `log det(I_R + D·M)` against the GLOBAL
1744 /// reduced Schur `S` via [`streaming_cross_row_woodbury_log_det`]. This is
1745 /// the exact `log det H_full = log det H₀' + log det(I_R + D Uᵀ H₀'⁻¹ U)`
1746 /// the dense [`ArrowFactorCache::arrow_log_det`] returns — the streaming and
1747 /// dense evidence then optimize the SAME REML objective (#1225).
1748 ///
1749 /// When no IBP source is present this delegates to
1750 /// [`Self::reduced_schur_and_log_det_tt`] and returns `None` woodbury, so
1751 /// every non-IBP (softmax / JumpReLU) caller is bit-for-bit unchanged.
1752 pub fn reduced_schur_log_det_tt_woodbury(
1753 &mut self,
1754 ridge_t: f64,
1755 ridge_beta: f64,
1756 options: &ArrowSolveOptions,
1757 ) -> Result<(f64, Array2<f64>, Option<StreamingWoodburyChunk>), ArrowSchurError> {
1758 let Some(source) = self.ibp_cross_row.clone() else {
1759 // Temporarily clear the refusal flag so the shared bare path runs;
1760 // it is `false` whenever the source is absent, so this is a no-op.
1761 let (log_det_tt, schur) =
1762 self.reduced_schur_and_log_det_tt(ridge_t, ridge_beta, options)?;
1763 return Ok((log_det_tt, schur, None));
1764 };
1765 let r = source.r;
1766 let total_len = self.row_offsets[self.n_rows];
1767 let down = source.self_term_downdate(total_len);
1768 // Group the sparse `U` entries `(global_t_index, atom, z'_ik)` by row as
1769 // `(local_slot, atom, z)`. `row_offsets` is strictly ascending, so the
1770 // owning row of a global index is located by binary search.
1771 let mut row_entries: Vec<Vec<(usize, usize, f64)>> = vec![Vec::new(); self.n_rows];
1772 for &(g, atom, z) in &source.entries {
1773 let i = match self.row_offsets.binary_search(&g) {
1774 Ok(idx) => idx,
1775 Err(idx) => idx - 1,
1776 };
1777 let slot = g - self.row_offsets[i];
1778 row_entries[i].push((slot, atom, z));
1779 }
1780 self.tolerate_ill_conditioning = options.tolerate_ill_conditioning;
1781 self.reset_accumulator(ridge_beta)?;
1782 let backend = CpuBatchedBlockSolver;
1783 let mut log_det_tt = 0.0_f64;
1784 let mut m0 = Array2::<f64>::zeros((r, r));
1785 let mut w = Array2::<f64>::zeros((self.k, r));
1786 for start in (0..self.n_rows).step_by(self.chunk_size) {
1787 let end = (start + self.chunk_size).min(self.n_rows);
1788 for row_idx in start..end {
1789 let mut row = (self.row_builder)(row_idx)?;
1790 let di = row.htt.nrows();
1791 self.validate_row(row_idx, &row)?;
1792 // Downdate the per-row logit self term so the factored base is
1793 // `H₀'` (the dense path downdates the SAME `down[base + j]`).
1794 let base = self.row_offsets[row_idx];
1795 for j in 0..di {
1796 row.htt[[j, j]] -= down[base + j];
1797 }
1798 let htbeta = self.row_htbeta(row_idx, &row, di);
1799 let factor = self.factor_row(&row, ridge_t, di, row_idx)?;
1800 for axis in 0..di {
1801 log_det_tt += 2.0 * factor[[axis, axis]].ln();
1802 }
1803 match options.mode {
1804 ArrowSolverMode::Direct | ArrowSolverMode::InexactPCG => {
1805 let solved = backend.solve_block_matrix(factor.view(), htbeta.view());
1806 backend.block_gemm_subtract(&mut self.s_acc, &htbeta, &solved);
1807 }
1808 ArrowSolverMode::SqrtBA => {
1809 let whitened =
1810 backend.sqrt_solve_block_matrix(factor.view(), htbeta.view());
1811 backend.block_gemm_subtract(&mut self.s_acc, &whitened, &whitened);
1812 }
1813 }
1814 let entries = &row_entries[row_idx];
1815 if !entries.is_empty() {
1816 // `Uᵢ`: `di × R`, column `atom` carries `z'_ik` at its logit
1817 // slot. `Aᵢ⁻¹ Uᵢ` via the row's `H₀'` Cholesky.
1818 let mut u_local = Array2::<f64>::zeros((di, r));
1819 for &(slot, atom, z) in entries {
1820 u_local[[slot, atom]] += z;
1821 }
1822 let ainv_u = backend.solve_block_matrix(factor.view(), u_local.view());
1823 // `M0 += Uᵢᵀ Aᵢ⁻¹ Uᵢ` (`R×R`); `W += Bᵢᵀ Aᵢ⁻¹ Uᵢ` (`k×R`).
1824 m0 += &u_local.t().dot(&ainv_u);
1825 w += &htbeta.t().dot(&ainv_u);
1826 }
1827 }
1828 }
1829 symmetrize_upper_from_lower(&mut self.s_acc);
1830 let schur = std::mem::replace(&mut self.s_acc, Array2::<f64>::zeros((self.k, self.k)));
1831 Ok((
1832 log_det_tt,
1833 schur,
1834 Some(StreamingWoodburyChunk {
1835 m0,
1836 w,
1837 d: source.d.clone(),
1838 }),
1839 ))
1840 }
1841
1842 pub fn reduced_schur_log_det(
1843 schur: &Array2<f64>,
1844 options: &ArrowSolveOptions,
1845 ) -> Result<f64, ArrowSchurError> {
1846 let rhs = Array1::<f64>::zeros(schur.nrows());
1847 let trust_metric_weights = None;
1848 let (delta, schur_factor, diag) =
1849 solve_dense_reduced_system(schur, &rhs, options, trust_metric_weights)?;
1850 if delta.len() != schur.nrows() || diag.iterations != 0 {
1851 return Err(ArrowSchurError::SchurFactorFailed {
1852 reason: "streaming log-det reduced solve returned incoherent diagnostics"
1853 .to_string(),
1854 });
1855 }
1856 let schur_factor = schur_factor.ok_or_else(|| ArrowSchurError::SchurFactorFailed {
1857 reason: "streaming log-det requires a dense reduced Schur factor".to_string(),
1858 })?;
1859 let mut log_det_schur = 0.0_f64;
1860 for axis in 0..schur_factor.nrows() {
1861 log_det_schur += 2.0 * schur_factor[[axis, axis]].ln();
1862 }
1863 Ok(log_det_schur)
1864 }
1865
1866 pub fn exact_arrow_log_det(
1867 &mut self,
1868 ridge_t: f64,
1869 ridge_beta: f64,
1870 options: &ArrowSolveOptions,
1871 ) -> Result<f64, ArrowSchurError> {
1872 let (log_det_tt, schur) =
1873 self.reduced_schur_and_log_det_tt(ridge_t, ridge_beta, options)?;
1874 Ok(log_det_tt + Self::reduced_schur_log_det(&schur, options)?)
1875 }
1876
1877 pub fn solve(
1878 &mut self,
1879 ridge_t: f64,
1880 ridge_beta: f64,
1881 options: &ArrowSolveOptions,
1882 ) -> Result<(Array1<f64>, Array1<f64>, Option<Array2<f64>>), ArrowSchurError> {
1883 if self.ibp_cross_row_active {
1884 return Err(ArrowSchurError::SchurFactorFailed {
1885 reason: "streaming arrow solve cannot carry the exact cross-row IBP \
1886 Woodbury correction (#1038); route IBP-active fits through the \
1887 dense resident solve_arrow_newton_step_with_options instead."
1888 .to_string(),
1889 });
1890 }
1891 // Propagate the evidence/log-det ill-conditioning tolerance to the
1892 // per-row factor calls inside `accumulate_chunk` / `back_substitute`,
1893 // which take their stable public signatures. Direct callers of
1894 // `accumulate_chunk` keep the conservative default (`false`, full guard).
1895 self.tolerate_ill_conditioning = options.tolerate_ill_conditioning;
1896 self.reset_accumulator(ridge_beta)?;
1897 for start in (0..self.n_rows).step_by(self.chunk_size) {
1898 let end = (start + self.chunk_size).min(self.n_rows);
1899 self.accumulate_chunk(start, end, ridge_t, options.mode)?;
1900 }
1901 for j in 0..self.k {
1902 self.rhs_acc[j] -= self.gb[j];
1903 }
1904 symmetrize_upper_from_lower(&mut self.s_acc);
1905 let trust_metric_weights = None;
1906 let (delta_beta, schur_factor, _diag) =
1907 solve_dense_reduced_system(&self.s_acc, &self.rhs_acc, options, trust_metric_weights)?;
1908 let delta_t = self.back_substitute(ridge_t, delta_beta.view())?;
1909 Ok((delta_t, delta_beta, schur_factor))
1910 }
1911
1912 pub(crate) fn back_substitute(
1913 &self,
1914 ridge_t: f64,
1915 delta_beta: ArrayView1<'_, f64>,
1916 ) -> Result<Array1<f64>, ArrowSchurError> {
1917 let backend = CpuBatchedBlockSolver;
1918 // Total delta_t length = row_offsets[n_rows].
1919 let total_len = self.row_offsets[self.n_rows];
1920 let mut delta_t = Array1::<f64>::zeros(total_len);
1921 // Each row's back-solve `Δt_i = -(H_tt^(i))⁻¹(g_t^(i) + H_tβ^(i)Δβ)`
1922 // writes a DISJOINT segment `delta_t[row_base .. row_base+di]` — no
1923 // cross-row reduction, so this is embarrassingly parallel and the scatter
1924 // is bit-identical regardless of which thread produced each segment (the
1925 // #1017 verification gate). At the SAE LLM shape (`n` in the thousands)
1926 // the per-row factor + solve is the whole cost; below the threshold, or
1927 // when already inside a rayon worker (the topology race fans candidates
1928 // with `run_topology_race_parallel`), stay sequential to avoid
1929 // nested-rayon oversubscription — the same guard `schur_matvec` uses.
1930 let parallel =
1931 self.n_rows >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
1932 if parallel {
1933 use rayon::prelude::*;
1934 const CHUNK: usize = 64;
1935 // Per-row body: factor, form the RHS, solve, return `-(dt_i)`.
1936 let row_solve = |row_idx: usize| -> Result<(usize, Array1<f64>), ArrowSchurError> {
1937 let row = (self.row_builder)(row_idx)?;
1938 let di = row.htt.nrows();
1939 self.validate_row(row_idx, &row)?;
1940 let factor = self.factor_row(&row, ridge_t, di, row_idx)?;
1941 let mut htbeta_delta = Array1::<f64>::zeros(di);
1942 if let Some(op) = self.htbeta_matvec.as_ref() {
1943 op(row_idx, delta_beta, &mut htbeta_delta);
1944 } else {
1945 for c in 0..di {
1946 let mut acc = 0.0_f64;
1947 for a in 0..self.k {
1948 acc += row.htbeta[[c, a]] * delta_beta[a];
1949 }
1950 htbeta_delta[c] = acc;
1951 }
1952 }
1953 let mut rhs = Array1::<f64>::zeros(di);
1954 for c in 0..di {
1955 rhs[c] = row.gt[c] + htbeta_delta[c];
1956 }
1957 let dt_i = backend.solve_block_vector(factor.view(), rhs.view());
1958 let mut neg = Array1::<f64>::zeros(di);
1959 for c in 0..di {
1960 neg[c] = -dt_i[c];
1961 }
1962 Ok((self.row_offsets[row_idx], neg))
1963 };
1964 // Collect per-row segments under rayon, then scatter into the disjoint
1965 // slices. Errors are surfaced via `collect::<Result<…>>`.
1966 let segments: Vec<(usize, Array1<f64>)> = (0..self.n_rows)
1967 .into_par_iter()
1968 .chunks(CHUNK)
1969 .map(|idxs| {
1970 idxs.into_iter()
1971 .map(&row_solve)
1972 .collect::<Result<Vec<_>, _>>()
1973 })
1974 .collect::<Result<Vec<_>, _>>()?
1975 .into_iter()
1976 .flatten()
1977 .collect();
1978 for (base, seg) in &segments {
1979 for (c, &v) in seg.iter().enumerate() {
1980 delta_t[base + c] = v;
1981 }
1982 }
1983 } else {
1984 let mut rhs = Array1::<f64>::zeros(self.d);
1985 for start in (0..self.n_rows).step_by(self.chunk_size) {
1986 let end = (start + self.chunk_size).min(self.n_rows);
1987 for row_idx in start..end {
1988 let row = (self.row_builder)(row_idx)?;
1989 let di = row.htt.nrows();
1990 self.validate_row(row_idx, &row)?;
1991 let factor = self.factor_row(&row, ridge_t, di, row_idx)?;
1992 // `H_tβ^(i) Δβ`: route through the procedural operator when
1993 // present (no dense slab), else through the dense slab.
1994 let mut htbeta_delta = Array1::<f64>::zeros(di);
1995 if let Some(op) = self.htbeta_matvec.as_ref() {
1996 op(row_idx, delta_beta, &mut htbeta_delta);
1997 } else {
1998 for c in 0..di {
1999 let mut acc = 0.0_f64;
2000 for a in 0..self.k {
2001 acc += row.htbeta[[c, a]] * delta_beta[a];
2002 }
2003 htbeta_delta[c] = acc;
2004 }
2005 }
2006 for c in 0..di {
2007 rhs[c] = row.gt[c] + htbeta_delta[c];
2008 }
2009 let dt_i = backend.solve_block_vector(factor.view(), rhs.view());
2010 let row_base = self.row_offsets[row_idx];
2011 for c in 0..di {
2012 delta_t[row_base + c] = -dt_i[c];
2013 }
2014 }
2015 }
2016 }
2017 Ok(delta_t)
2018 }
2019
2020 pub(crate) fn validate_row(
2021 &self,
2022 row_idx: usize,
2023 row: &ArrowRowBlock,
2024 ) -> Result<(), ArrowSchurError> {
2025 let expected_di = if row_idx < self.row_dims.len() {
2026 self.row_dims[row_idx]
2027 } else {
2028 self.d
2029 };
2030 let actual_di = row.htt.nrows();
2031 if actual_di != expected_di || row.htt.ncols() != expected_di {
2032 return Err(ArrowSchurError::PerRowFactorFailed {
2033 row: row_idx,
2034 reason: format!(
2035 "streaming row H_tt shape {:?} != ({expected_di}, {expected_di})",
2036 row.htt.dim(),
2037 ),
2038 });
2039 }
2040 // The dense `H_tβ` slab is only validated when no procedural operator is
2041 // installed; with `htbeta_matvec` the slab is intentionally zero-sized
2042 // and the cross-block is probed in `row_htbeta`.
2043 if self.htbeta_matvec.is_none() && row.htbeta.dim() != (expected_di, self.k) {
2044 return Err(ArrowSchurError::SchurFactorFailed {
2045 reason: format!(
2046 "streaming row H_tβ shape {:?} != ({expected_di}, {})",
2047 row.htbeta.dim(),
2048 self.k
2049 ),
2050 });
2051 }
2052 if row.gt.len() != expected_di {
2053 return Err(ArrowSchurError::PerRowFactorFailed {
2054 row: row_idx,
2055 reason: format!("streaming row g_t length {} != {expected_di}", row.gt.len()),
2056 });
2057 }
2058 Ok::<(), _>(())
2059 }
2060}
2061
2062pub(crate) fn apply_analytic_penalty<S, G, D, P, H>(
2063 penalty: &AnalyticPenaltyKind,
2064 target: ArrayView1<'_, f64>,
2065 rho_local: ArrayView1<'_, f64>,
2066 expected_target_len: usize,
2067 hvp_columns: usize,
2068 scatter_target: &mut S,
2069 mut grad_scatter: G,
2070 mut diag_scatter: D,
2071 seed_hvp_probe: P,
2072 mut hvp_column_scatter: H,
2073) where
2074 G: FnMut(&mut S, usize, f64),
2075 D: FnMut(&mut S, usize, f64),
2076 P: Fn(usize, &mut Array1<f64>),
2077 H: for<'a> FnMut(&mut S, usize, ArrayView1<'a, f64>),
2078{
2079 assert_eq!(target.len(), expected_target_len);
2080
2081 let grad = penalty.grad_target(target, rho_local);
2082 for index in 0..expected_target_len {
2083 grad_scatter(scatter_target, index, grad[index]);
2084 }
2085
2086 // The scattered curvature lands in the arrow-Schur `H_tt` / `H_ββ` blocks,
2087 // which are Cholesky-factored (with LM ridge escalation) as the Newton /
2088 // PIRLS curvature operator and must therefore stay PSD. Nonconvex
2089 // sparsifiers (log sparsity, JumpReLU) have an *indefinite* exact Hessian
2090 // that would destroy that positive-definiteness, so we scatter the PSD
2091 // majorizer here — never the exact `hessian_diag` / `hvp`. For convex
2092 // penalties the majorizer equals the exact Hessian (the trait default
2093 // delegates), so this is exact for them. Exact-derivative consumers (the
2094 // outer objective Hessian) use `hessian_diag` / `hvp` directly elsewhere.
2095 if let Some(diag) = penalty.psd_majorizer_diag(target, rho_local) {
2096 assert_eq!(diag.len(), expected_target_len);
2097 for index in 0..expected_target_len {
2098 diag_scatter(scatter_target, index, diag[index]);
2099 }
2100 return;
2101 }
2102
2103 let mut probe = Array1::<f64>::zeros(expected_target_len);
2104 for column in 0..hvp_columns {
2105 probe.fill(0.0);
2106 seed_hvp_probe(column, &mut probe);
2107 let hv = penalty.psd_majorizer_hvp(target, rho_local, probe.view());
2108 hvp_column_scatter(scatter_target, column, hv.view());
2109 }
2110}
2111
2112pub(crate) fn analytic_penalty_is_row_block_diagonal(penalty: &AnalyticPenaltyKind) -> bool {
2113 penalty.is_row_block_diagonal()
2114}
2115
2116/// Per-row + Schur Cholesky factor cache produced by
2117/// [`solve_arrow_newton_step_with_options`]. Consumed downstream by the IFT warm-start
2118/// predictor in `crate::persistent_warm_start`: when the outer
2119/// loop perturbs `(β, ρ)` by a small amount, the new Newton step can be
2120/// predicted by re-using these factors against a refreshed RHS, saving
2121/// the dominant `O(N d³ + K³)` factorization cost.
2122#[derive(Clone)]
2123pub struct ArrowFactorSlab {
2124 pub(crate) data: Arc<[f64]>,
2125 pub(crate) offsets: Arc<[usize]>,
2126 pub(crate) dims: Arc<[usize]>,
2127}
2128
2129impl ArrowFactorSlab {
2130 pub fn from_blocks(blocks: Vec<Array2<f64>>) -> Self {
2131 let mut data = Vec::new();
2132 let mut offsets = Vec::with_capacity(blocks.len() + 1);
2133 let mut dims = Vec::with_capacity(blocks.len());
2134 offsets.push(0);
2135 for block in blocks {
2136 let (rows, cols) = block.dim();
2137 assert_eq!(rows, cols, "ArrowFactorSlab stores square row factors");
2138 dims.push(rows);
2139 data.extend(block.iter().copied());
2140 offsets.push(data.len());
2141 }
2142 Self {
2143 data: data.into(),
2144 offsets: offsets.into(),
2145 dims: dims.into(),
2146 }
2147 }
2148
2149 pub fn len(&self) -> usize {
2150 self.dims.len()
2151 }
2152
2153 pub fn is_empty(&self) -> bool {
2154 self.dims.is_empty()
2155 }
2156
2157 pub fn factor(&self, row: usize) -> ArrayView2<'_, f64> {
2158 let dim = self.dims[row];
2159 let range = self.offsets[row]..self.offsets[row + 1];
2160 ArrayView2::from_shape((dim, dim), &self.data[range])
2161 .expect("ArrowFactorSlab row offset/dim invariant violated")
2162 }
2163
2164 pub fn iter(&self) -> impl Iterator<Item = ArrayView2<'_, f64>> + '_ {
2165 (0..self.len()).map(|row| self.factor(row))
2166 }
2167}
2168
2169impl std::fmt::Debug for ArrowFactorSlab {
2170 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2171 f.debug_struct("ArrowFactorSlab")
2172 .field("rows", &self.len())
2173 .field("values", &self.data.len())
2174 .finish()
2175 }
2176}
2177
2178#[derive(Clone)]
2179pub enum ArrowUndampedFactors {
2180 SameAsDamped,
2181 Owned(ArrowFactorSlab),
2182}
2183
2184impl std::fmt::Debug for ArrowUndampedFactors {
2185 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2186 match self {
2187 Self::SameAsDamped => f.write_str("SameAsDamped"),
2188 Self::Owned(factors) => f.debug_tuple("Owned").field(&factors.len()).finish(),
2189 }
2190 }
2191}
2192
2193/// Apply `H_tβ^(row) · x` for one row, writing into `out` (length `d`).
2194///
2195/// Sums the installed matrix-free operator, when present, and any correctly
2196/// shaped dense `row.htbeta` slab. This lets structured data-fit rows coexist
2197/// with dense analytic-penalty cross blocks on the same row.
2198pub(crate) fn sys_htbeta_apply_row(
2199 sys: &ArrowSchurSystem,
2200 row_idx: usize,
2201 row: &ArrowRowBlock,
2202 x: ArrayView1<'_, f64>,
2203 out: &mut Array1<f64>,
2204) {
2205 out.fill(0.0);
2206 if let Some(op) = sys.htbeta_matvec.as_ref() {
2207 op(row_idx, x, out);
2208 }
2209 if (sys.htbeta_dense_supplement || sys.htbeta_matvec.is_none())
2210 && row.htbeta.dim() == (out.len(), sys.k)
2211 {
2212 let di = row.htbeta.nrows();
2213 for c in 0..di {
2214 let mut acc = 0.0_f64;
2215 for a in 0..sys.k {
2216 acc += row.htbeta[[c, a]] * x[a];
2217 }
2218 out[c] += acc;
2219 }
2220 }
2221}
2222
2223/// Accumulate `H_βt^(row) · v` into `out` (length `k`).
2224///
2225/// `out[a] += Σ_c H_tβ^(row)[c, a] · v[c]`
2226///
2227/// Sums the installed matrix-free operator, when present, and any correctly
2228/// shaped dense `row.htbeta` slab.
2229pub(crate) fn sys_htbeta_accumulate_transpose(
2230 sys: &ArrowSchurSystem,
2231 row_idx: usize,
2232 row: &ArrowRowBlock,
2233 v: ArrayView1<'_, f64>,
2234 out: &mut Array1<f64>,
2235) {
2236 if let Some(op) = sys.htbeta_matvec.as_ref() {
2237 htbeta_probe_transpose(row_idx, op, v, out, v.len(), sys.k);
2238 }
2239 if (sys.htbeta_dense_supplement || sys.htbeta_matvec.is_none())
2240 && row.htbeta.dim() == (v.len(), sys.k)
2241 {
2242 let di = row.htbeta.nrows();
2243 for c in 0..di {
2244 let vc = v[c];
2245 if vc == 0.0 {
2246 continue;
2247 }
2248 for a in 0..sys.k {
2249 out[a] += row.htbeta[[c, a]] * vc;
2250 }
2251 }
2252 }
2253}
2254
2255/// Materialize the dense `(di, k)` cross-block for one row.
2256///
2257/// Materializes the sum of the installed matrix-free operator and any correctly
2258/// shaped dense slab on the row.
2259pub(crate) fn sys_htbeta_materialize_row(
2260 sys: &ArrowSchurSystem,
2261 row_idx: usize,
2262 row: &ArrowRowBlock,
2263) -> Result<Array2<f64>, ArrowSchurError> {
2264 let di = sys.row_dims[row_idx];
2265 let k = sys.k;
2266 let use_dense = sys.htbeta_dense_supplement || sys.htbeta_matvec.is_none();
2267 let mut mat = if use_dense && row.htbeta.dim() == (di, k) {
2268 row.htbeta.clone()
2269 } else {
2270 Array2::<f64>::zeros((di, k))
2271 };
2272 if let Some(op) = sys.htbeta_matvec.as_ref() {
2273 let mut e_a = Array1::<f64>::zeros(k);
2274 let mut col = Array1::<f64>::zeros(di);
2275 for a in 0..k {
2276 e_a.fill(0.0);
2277 e_a[a] = 1.0;
2278 col.fill(0.0);
2279 op(row_idx, e_a.view(), &mut col);
2280 for c in 0..di {
2281 mat[[c, a]] += col[c];
2282 }
2283 }
2284 } else if use_dense && row.htbeta.dim() != (di, k) {
2285 return Err(ArrowSchurError::SchurFactorFailed {
2286 reason: format!(
2287 "row {row_idx}: htbeta shape {:?} != ({di}, {k}) and no htbeta_matvec installed",
2288 row.htbeta.dim()
2289 ),
2290 });
2291 }
2292 Ok(mat)
2293}
2294
2295/// Probe each column of `H_tβ^(row)` by applying the operator to `e_a` and
2296/// dotting the result with `v`. Accumulates into `out[a]` for all `a in 0..k`.
2297///
2298/// `out[a] += (H_tβ^(row) e_a) · v = H_βt^(row)[a, :] · v`
2299pub(crate) fn htbeta_probe_transpose(
2300 row: usize,
2301 op: &RowHtbetaMatvec,
2302 v: ArrayView1<'_, f64>,
2303 out: &mut Array1<f64>,
2304 d: usize,
2305 k: usize,
2306) {
2307 let mut e_a = Array1::<f64>::zeros(k);
2308 let mut col_a = Array1::<f64>::zeros(d);
2309 for a in 0..k {
2310 e_a.fill(0.0);
2311 e_a[a] = 1.0;
2312 col_a.fill(0.0);
2313 op(row, e_a.view(), &mut col_a);
2314 let mut acc = 0.0_f64;
2315 for c in 0..d {
2316 acc += col_a[c] * v[c];
2317 }
2318 out[a] += acc;
2319 }
2320}
2321
2322#[derive(Clone)]
2323pub enum ArrowHtbetaCache {
2324 Dense {
2325 blocks: Arc<[Array2<f64>]>,
2326 estimated_bytes: usize,
2327 },
2328 Matvec {
2329 op: RowHtbetaMatvec,
2330 estimated_bytes: usize,
2331 },
2332 Disabled {
2333 estimated_bytes: usize,
2334 },
2335}
2336
2337impl std::fmt::Debug for ArrowHtbetaCache {
2338 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2339 match self {
2340 Self::Dense {
2341 blocks,
2342 estimated_bytes,
2343 } => f
2344 .debug_struct("Dense")
2345 .field("blocks", &blocks.len())
2346 .field("estimated_bytes", estimated_bytes)
2347 .finish(),
2348 Self::Matvec {
2349 estimated_bytes, ..
2350 } => f
2351 .debug_struct("Matvec")
2352 .field("estimated_bytes", estimated_bytes)
2353 .finish(),
2354 Self::Disabled { estimated_bytes } => f
2355 .debug_struct("Disabled")
2356 .field("estimated_bytes", estimated_bytes)
2357 .finish(),
2358 }
2359 }
2360}
2361
2362impl ArrowHtbetaCache {
2363 pub(crate) fn is_available(&self) -> bool {
2364 !matches!(self, Self::Disabled { .. })
2365 }
2366
2367 pub(crate) fn apply_row(
2368 &self,
2369 row: usize,
2370 delta_beta: ArrayView1<'_, f64>,
2371 out: &mut Array1<f64>,
2372 ) -> bool {
2373 match self {
2374 Self::Dense { blocks, .. } => {
2375 let Some(block) = blocks.get(row) else {
2376 return false;
2377 };
2378 if block.ncols() != delta_beta.len() || block.nrows() != out.len() {
2379 return false;
2380 }
2381 for c in 0..block.nrows() {
2382 let mut acc = 0.0_f64;
2383 for a in 0..block.ncols() {
2384 acc += block[[c, a]] * delta_beta[a];
2385 }
2386 out[c] = acc;
2387 }
2388 true
2389 }
2390 Self::Matvec { op, .. } => {
2391 op(row, delta_beta, out);
2392 true
2393 }
2394 Self::Disabled { .. } => false,
2395 }
2396 }
2397
2398 /// Apply the transpose: `out[a] += H_βt^(row)[a, c] · v[c]` for all `a`.
2399 ///
2400 /// `v` has length `d`; `out` has length `k`. Accumulates (does NOT zero
2401 /// `out` first) so callers can sum contributions across rows into a shared
2402 /// accumulator. Returns `false` when the cache is `Disabled` and no
2403 /// `fallback_op` is provided.
2404 pub(crate) fn apply_row_transpose_accumulate(
2405 &self,
2406 row: usize,
2407 v: ArrayView1<'_, f64>,
2408 out: &mut Array1<f64>,
2409 d: usize,
2410 k: usize,
2411 fallback_op: Option<&RowHtbetaMatvec>,
2412 ) -> bool {
2413 match self {
2414 Self::Dense { blocks, .. } => {
2415 let Some(block) = blocks.get(row) else {
2416 return false;
2417 };
2418 if block.nrows() != v.len() || block.ncols() != out.len() {
2419 return false;
2420 }
2421 // H_βt^(i) · v: outer-loop c hoists v[c], inner-loop a is
2422 // contiguous in row-major (d, k) layout.
2423 for c in 0..block.nrows() {
2424 let vc = v[c];
2425 if vc == 0.0 {
2426 continue;
2427 }
2428 for a in 0..block.ncols() {
2429 out[a] += block[[c, a]] * vc;
2430 }
2431 }
2432 true
2433 }
2434 Self::Matvec { op, .. } => {
2435 // Probe column-by-column: H_tβ^(row) e_a is column a. dot(col_a, v)
2436 // is entry a of H_βt^(row) v.
2437 htbeta_probe_transpose(row, op, v, out, d, k);
2438 true
2439 }
2440 Self::Disabled { .. } => {
2441 // No cached block. Use the caller-supplied fallback op if present.
2442 if let Some(op) = fallback_op {
2443 htbeta_probe_transpose(row, op, v, out, d, k);
2444 true
2445 } else {
2446 false
2447 }
2448 }
2449 }
2450 }
2451}
2452
2453/// RAW per-row spectral data of a spectrally-deflated undamped evidence `H_tt`
2454/// block (see [`ArrowFactorCache::deflation_row_spectra`]).
2455///
2456/// `evecs` columns are the RAW symmetric eigenvectors `uₘ` of `H_tt`
2457/// (orthonormal; the deflated directions `vᵢ` are the subset whose eigenvalue
2458/// was pinned). `raw_evals[m]` is the RAW eigenvalue `λₘ` BEFORE the unit-pin /
2459/// floor-clamp. `cond_evals[m]` is the conditioned eigenvalue `λ̃ₘ` the factor
2460/// actually uses (`λ̃ = λ` for an unclamped kept direction, the positive `floor`
2461/// for a clamped kept direction, `1` for a deflated direction). Together they
2462/// give the Daleckii–Krein divided differences the outer-gradient deflation
2463/// correction needs.
2464#[derive(Debug, Clone)]
2465pub struct RowDeflationSpectrum {
2466 pub evecs: Array2<f64>,
2467 pub raw_evals: Array1<f64>,
2468 pub cond_evals: Array1<f64>,
2469}
2470
2471#[derive(Debug, Clone)]
2472pub struct ArrowFactorCache {
2473 /// Per-row lower-triangular Cholesky factors of `H_tt^(i) + ridge_t·I`.
2474 ///
2475 /// These are the *damped* factors used inside the Newton solve. The IFT
2476 /// predictor must NOT use them — see [`Self::htt_factors_undamped`].
2477 pub htt_factors: ArrowFactorSlab,
2478 /// Per-row lower-triangular Cholesky factors of the UNDAMPED
2479 /// `H_tt^(i)` (no `ridge_t` added).
2480 ///
2481 /// The IFT predictor formula
2482 /// `Δt_i = -(H_tt^(i))⁻¹ · (H_tβ^(i) Δβ + δg_t^(i))` is derived from
2483 /// `∂g_t/∂t = H_tt` at the stationary point, with no LM damping term.
2484 /// Reusing the damped factors would bias the predicted shift toward zero
2485 /// in proportion to `ridge_t`. We pay one extra `O(N d³)` Cholesky per
2486 /// Newton solve — the same complexity class as the Newton solve itself —
2487 /// to make the IFT exact.
2488 pub htt_factors_undamped: ArrowUndampedFactors,
2489 /// Lower-triangular Cholesky factor of the Schur complement when the
2490 /// selected BA mode formed/factored dense RCS. `None` for
2491 /// [`ArrowSolverMode::InexactPCG`], where Agarwal-style inexact LM avoids
2492 /// the dense `K × K` factor.
2493 pub schur_factor: Option<Array2<f64>>,
2494 /// Exact undamped joint-Hessian log-determinant produced by the dense
2495 /// factorization path. REML evidence consumes this directly so the Laplace
2496 /// normalizer cannot miss the log-det even when later cache consumers only
2497 /// need solves/traces.
2498 pub joint_hessian_log_det: Option<f64>,
2499 /// BA mode used to create this cache.
2500 pub solver_mode: ArrowSolverMode,
2501 /// Ridge values used to build the cached factors (recorded so the
2502 /// warm-start predictor knows whether the cache is still valid for a
2503 /// requested ridge level).
2504 pub ridge_t: f64,
2505 pub ridge_beta: f64,
2506 /// Per-row cross-block access for `H_tβ^(i) x`.
2507 ///
2508 /// Large caches retain a row matvec callback or disable β-coupled IFT
2509 /// prediction instead of cloning every dense `d × K` slab.
2510 pub htbeta: ArrowHtbetaCache,
2511 /// Maximum per-row latent dim (upper bound; matches `sys.d` at creation).
2512 pub d: usize,
2513 /// Per-row latent dims: `row_dims[i]` is the active dim for row `i`.
2514 pub row_dims: Arc<[usize]>,
2515 /// Flat-buffer row offsets for `delta_t` / IFT output vectors.
2516 /// `row_offsets[i]` is the start of row `i`; `row_offsets[n]` is the
2517 /// total length.
2518 pub row_offsets: Arc<[usize]>,
2519 /// β dimensionality `K`.
2520 pub k: usize,
2521 /// Geometry tag for the row-local factors and cross-blocks.
2522 pub manifold_mode_fingerprint: u64,
2523 /// Row-system tag for the cached per-row factors, cross-blocks, and
2524 /// shared-block diagonal used to build the Schur factor.
2525 pub row_hessian_fingerprint: u64,
2526 /// PCG instrumentation from the solve that produced this cache.
2527 ///
2528 /// Zero-valued (default) when the selected mode did not use PCG
2529 /// (i.e. `Direct` or `SqrtBA`).
2530 pub pcg_diagnostics: PcgDiagnostics,
2531 /// Number of row-local gauge directions stiffened in an undamped evidence
2532 /// factorization.
2533 ///
2534 /// Each direction is stiffened at UNIT stiffness `kappa = 1.0`, so it
2535 /// contributes `log(1) = 0` to the row-block logdet through the returned
2536 /// Cholesky factor: the gauge orbit is a criterion null direction and adds
2537 /// nothing to the Laplace normalizer (the quotient pseudo-determinant
2538 /// convention, cf. `PenaltyPseudologdet`). Zero theta/rho dependence.
2539 pub gauge_deflated_directions: usize,
2540 /// Per-row unit-norm directions `vᵢ` (in each row's `d`-dim latent block
2541 /// coordinates) that an undamped evidence factorization stiffened to UNIT
2542 /// stiffness `λ̃ = 1` (gauge or spectral deflation). Indexed by row; empty
2543 /// for every PD row factored without deflation, and empty overall on the
2544 /// non-deflating solver paths (streaming / cross-row-penalty CG / device).
2545 ///
2546 /// A deflated direction contributes `log(1) = 0` to the row-block log-det
2547 /// and is ρ/θ-INDEPENDENT, so its true contribution to `∂log|H|/∂ρ` is `0`.
2548 /// The analytic outer-gradient traces (`assignment_log_strength_hessian_trace`,
2549 /// `learnable_ibp_data_logdet_alpha_trace`, `logdet_theta_adjoint`) contract
2550 /// `∂H_raw/∂ρ` (the RAW, pre-deflation block derivative) against the DEFLATED
2551 /// inverse, which assigns `1/λ̃ = 1` to each `vᵢ` and therefore spuriously
2552 /// adds `½ vᵢᵀ (∂H_raw/∂ρ) vᵢ`. Those traces subtract this per-row term
2553 /// (kept-subspace restriction) using these directions; without them the
2554 /// REML outer ρ-gradient is biased by `+Σ_deflated ½ vᵢᵀ ∂H_raw/∂ρ vᵢ`.
2555 pub deflated_row_directions: Arc<[Vec<Array1<f64>>]>,
2556 /// Per-row RAW spectral decomposition of an undamped evidence `H_tt` block
2557 /// that underwent SPECTRAL deflation, surfaced so the outer ρ/θ-gradient
2558 /// traces can apply the EXACT deflation-map (Daleckii–Krein) derivative
2559 /// correction, not just the within-row kept-subspace term.
2560 ///
2561 /// The criterion VALUE re-deflates `H_tt` at every ρ, so its gradient is
2562 /// `tr(H_deflated⁻¹ DΦ[∂H_raw/∂ρ])`, where `Φ` is the spectral pin-to-unit
2563 /// map. By Daleckii–Krein `DΦ[Ȧ] = U (F ∘ UᵀȦU) Uᵀ` with the divided-
2564 /// difference matrix `F_{ml} = (λ̃ₘ − λ̃ₗ)/(λₘ − λₗ)` (raw `λ` in the
2565 /// denominator, conditioned `λ̃` in the numerator). The kept×kept block of
2566 /// `F` is `1` (the kept subspace contracts the raw derivative unchanged), the
2567 /// deflated×deflated block is `0`, and the kept(m)×deflated(i) block is
2568 /// `(λₘ − 1)/(λₘ − λᵢ)` — this last, ROTATION, term is what the per-row
2569 /// kept-subspace correction alone misses; it couples to the β-block through
2570 /// the Schur back-substitution carried in `(H⁻¹)_tt`.
2571 ///
2572 /// `Some(spectrum)` only for spectrally-deflated rows; `None` for PD rows,
2573 /// gauge-only deflation (ρ-independent structural null — within-row term
2574 /// suffices), and every non-SAE-evidence solver path (streaming / device /
2575 /// cross-row CG). Empty overall when no row deflated spectrally.
2576 pub deflation_row_spectra: Arc<[Option<RowDeflationSpectrum>]>,
2577 /// Exact cross-row IBP rank-`R` Woodbury correction (#1038), present iff the
2578 /// source system carried an [`IbpCrossRowSource`]. When set, the per-row
2579 /// factors above are of the NO-SELF base `H₀'` (self term `d_k·z'_ik²`
2580 /// downdated from each logit diagonal), and this carrier supplies the exact
2581 /// rank-`R` correction so the value/curvature solve
2582 /// ([`Self::full_inverse_apply`]), the evidence log-determinant
2583 /// ([`Self::arrow_log_det`]), and the θ/ρ-adjoint all describe the same
2584 /// `H_full = H₀' + U D Uᵀ`.
2585 pub cross_row_woodbury: Option<CrossRowWoodbury>,
2586}
2587
2588/// Materialized exact cross-row IBP Woodbury correction (#1038), built against
2589/// an [`ArrowFactorCache`] whose per-row factors are the NO-SELF base `H₀'`.
2590///
2591/// Holds `U` (the `delta_t_len × R` arrow-`t` factor, β-part implicitly zero),
2592/// `D = diag(d_k)`, the projected `M = UᵀH₀'⁻¹U`, the columns `H₀'⁻¹U`, and the
2593/// **LU factorization of the (generally non-symmetric, possibly indefinite)
2594/// capacitance** `C = I_R + D·M`. `d_k = w·s'_k` is not sign-definite, so the
2595/// capacitance is factored by a partial-pivot LU (exact for any sign); the same
2596/// factorization serves the log-determinant `log det C`, the inverse correction
2597/// `H_full⁻¹w = H₀'⁻¹w − H₀'⁻¹U·C⁻¹·(D Uᵀ H₀'⁻¹w)`, and the adjoint's
2598/// selected-inverse (`C⁻¹` and `M`). The full inverse, value/curvature solve,
2599/// log-determinant, and adjoint therefore all describe the SAME
2600/// `H_full = H₀' + U D Uᵀ`.
2601#[derive(Debug, Clone)]
2602pub struct CrossRowWoodbury {
2603 /// `U`: `delta_t_len × R`, column `k` supported on atom-`k` logit slots.
2604 pub u: Array2<f64>,
2605 /// `d_k`, length `R`.
2606 pub d: Array1<f64>,
2607 /// `H₀'⁻¹ U` (the `t`-block), `delta_t_len × R`.
2608 pub h0inv_u: Array2<f64>,
2609 /// `(H₀'⁻¹ U)` β-block, `K × R`. `U` has no β support, but the bordered
2610 /// solve couples the latent columns to `β` through the Schur complement, so
2611 /// this block is generally nonzero and the inverse correction must apply it
2612 /// to the `β` output too.
2613 pub h0inv_u_beta: Array2<f64>,
2614 /// `M = Uᵀ H₀'⁻¹ U`, `R × R` (symmetric). Retained for the θ/ρ-adjoint.
2615 pub m: Array2<f64>,
2616 /// Partial-pivot LU of the capacitance `C = I_R + D·M` (`lu` packs `L`/`U`,
2617 /// `piv` the row swaps), built by [`small_lu_factor`].
2618 pub capacitance_lu: SmallLu,
2619 /// The sparse `U` entries `(global_t_index, atom_k, z'_ik)` — retained so
2620 /// `Uᵀ·v` can be formed over the atom slots without re-deriving them.
2621 pub entries: Vec<(usize, usize, f64)>,
2622}
2623
2624/// Dense partial-pivot LU of a small square matrix. Used for the cross-row IBP
2625/// capacitance `C = I_R + D·M`, which is generally non-symmetric and possibly
2626/// indefinite (`d_k = w·s'_k` is not sign-definite), so a Cholesky/LDLᵀ is
2627/// unavailable. `R` is the atom count, so this is a cheap dense factorization.
2628#[derive(Debug, Clone)]
2629pub struct SmallLu {
2630 /// Packed `L` (unit lower, below diagonal) and `U` (upper, on/above
2631 /// diagonal), `R × R`, in the row-permuted order encoded by `piv`.
2632 pub(crate) lu: Array2<f64>,
2633 /// Row permutation: `piv[i]` is the original row now in position `i`.
2634 pub(crate) piv: Vec<usize>,
2635 /// Sign of the permutation (`±1`), folded into the determinant.
2636 pub(crate) perm_sign: f64,
2637}
2638
2639/// Partial-pivot LU factorization of a small dense square matrix `a` (`R × R`).
2640/// Returns `None` only when a pivot is exactly zero (singular `C`).
2641pub(crate) fn small_lu_factor(a: &Array2<f64>) -> Option<SmallLu> {
2642 let r = a.nrows();
2643 assert_eq!(a.ncols(), r, "small_lu_factor: non-square input");
2644 let mut lu = a.clone();
2645 let mut piv: Vec<usize> = (0..r).collect();
2646 let mut perm_sign = 1.0_f64;
2647 for col in 0..r {
2648 // Partial pivot: pick the largest-magnitude entry on/below the diagonal.
2649 let mut pivot_row = col;
2650 let mut pivot_mag = lu[[col, col]].abs();
2651 for row in (col + 1)..r {
2652 let mag = lu[[row, col]].abs();
2653 if mag > pivot_mag {
2654 pivot_mag = mag;
2655 pivot_row = row;
2656 }
2657 }
2658 // Reject not just an exactly-zero pivot, but any non-finite or
2659 // subnormal magnitude: dividing by a subnormal in the elimination /
2660 // back-solve produces `Inf`/`NaN` that would otherwise flow silently
2661 // into the Woodbury inverse and the evidence log-det (#1038). A
2662 // capacitance this degenerate is a desync the caller must surface
2663 // (→ `Ok(None)` cross-row-absent / `SchurFactorFailed`), not consume.
2664 if !pivot_mag.is_finite() || pivot_mag < f64::MIN_POSITIVE {
2665 return None;
2666 }
2667 if pivot_row != col {
2668 for c in 0..r {
2669 lu.swap((col, c), (pivot_row, c));
2670 }
2671 piv.swap(col, pivot_row);
2672 perm_sign = -perm_sign;
2673 }
2674 let pivot = lu[[col, col]];
2675 for row in (col + 1)..r {
2676 let factor = lu[[row, col]] / pivot;
2677 lu[[row, col]] = factor;
2678 for c in (col + 1)..r {
2679 let v = lu[[col, c]];
2680 lu[[row, c]] -= factor * v;
2681 }
2682 }
2683 }
2684 // Post-elimination invariant: every U diagonal is finite and not subnormal.
2685 // The per-column pivot guard above validates each diagonal as it is chosen,
2686 // but assert it explicitly so `SmallLu::solve` can divide by `lu[[i, i]]`
2687 // without a per-entry guard and so a `SmallLu` value can never carry a
2688 // factor that would silently emit `Inf`/`NaN` into the capacitance solve.
2689 for i in 0..r {
2690 let u = lu[[i, i]];
2691 if !u.is_finite() || u.abs() < f64::MIN_POSITIVE {
2692 return None;
2693 }
2694 }
2695 Some(SmallLu { lu, piv, perm_sign })
2696}
2697
2698impl SmallLu {
2699 pub(crate) fn dim(&self) -> usize {
2700 self.lu.nrows()
2701 }
2702
2703 /// `log|det|` and the determinant sign (`±1`).
2704 pub(crate) fn log_abs_det_and_sign(&self) -> (f64, f64) {
2705 let mut log_abs = 0.0_f64;
2706 let mut sign = self.perm_sign;
2707 for i in 0..self.dim() {
2708 let u = self.lu[[i, i]];
2709 log_abs += u.abs().ln();
2710 if u < 0.0 {
2711 sign = -sign;
2712 }
2713 }
2714 (log_abs, sign)
2715 }
2716
2717 /// Solve `C x = b` reusing the factorization (in place into a fresh vector).
2718 ///
2719 /// Returns `None` when the solve cannot produce a finite result — either a
2720 /// `U` diagonal is non-finite/subnormal (defensive: `small_lu_factor`
2721 /// already rejects such factors, but a future construction path might not)
2722 /// or the back-substitution overflows to `Inf`/`NaN` for an extreme RHS on
2723 /// an ill-conditioned (yet validly factored) capacitance. Surfacing `None`
2724 /// lets the Woodbury / evidence consumers fail loudly (#1038) instead of
2725 /// flowing a silent `NaN` into the log-det and outer gradient.
2726 pub(crate) fn solve(&self, b: &Array1<f64>) -> Option<Array1<f64>> {
2727 let r = self.dim();
2728 // Apply the row permutation: y = P b.
2729 let mut y = Array1::<f64>::zeros(r);
2730 for i in 0..r {
2731 y[i] = b[self.piv[i]];
2732 }
2733 // Forward solve L y' = P b (L unit-lower).
2734 for i in 0..r {
2735 let mut sum = y[i];
2736 for j in 0..i {
2737 sum -= self.lu[[i, j]] * y[j];
2738 }
2739 y[i] = sum;
2740 }
2741 // Back solve U x = y' (U upper, explicit diagonal).
2742 let mut x = Array1::<f64>::zeros(r);
2743 for i in (0..r).rev() {
2744 let mut sum = y[i];
2745 for j in (i + 1)..r {
2746 sum -= self.lu[[i, j]] * x[j];
2747 }
2748 let pivot = self.lu[[i, i]];
2749 if !pivot.is_finite() || pivot.abs() < f64::MIN_POSITIVE {
2750 return None;
2751 }
2752 x[i] = sum / pivot;
2753 }
2754 if x.iter().all(|v| v.is_finite()) {
2755 Some(x)
2756 } else {
2757 None
2758 }
2759 }
2760}
2761
2762/// Close the streaming cross-row IBP Woodbury (#1038) into its exact evidence
2763/// log-determinant correction.
2764///
2765/// Given the chunk-summed `M0 = Uᵀ A⁻¹ U` (`R×R`), `W = Bᵀ A⁻¹ U` (`k×R`), the
2766/// GLOBAL reduced Schur `schur` (`S`, `k×k`, PD), and the per-atom `D`, this
2767/// forms the EXACT projected `M = Uᵀ H₀'⁻¹ U = M0 + Wᵀ S⁻¹ W` (the bordered
2768/// inverse `t`-block `(H₀'⁻¹)_tt = A⁻¹ + A⁻¹ B S⁻¹ Bᵀ A⁻¹` contracted by `U`),
2769/// then the capacitance `C = I_R + D·M` and returns `log det C`.
2770///
2771/// This is byte-for-byte the same quantity, and the same sign convention, that
2772/// the dense [`CrossRowWoodbury::log_det`] returns (symmetrize `M`, scale rows
2773/// by `D`, partial-pivot LU, reject a non-positive determinant). Returns
2774/// `Ok(None)` when the capacitance is non-PD / singular — the recoverable
2775/// "cross-row IBP joint Hessian is non-PD at this ρ" infeasible-probe refusal,
2776/// NOT a silent wrong value.
2777pub fn streaming_cross_row_woodbury_log_det(
2778 schur: &Array2<f64>,
2779 m0: &Array2<f64>,
2780 w: &Array2<f64>,
2781 d: &Array1<f64>,
2782) -> Result<Option<f64>, ArrowSchurError> {
2783 let r = d.len();
2784 let factor =
2785 cholesky_lower(schur).map_err(|reason| ArrowSchurError::SchurFactorFailed { reason })?;
2786 // M = M0 + Wᵀ S⁻¹ W. With S⁻¹ symmetric, `(Wᵀ S⁻¹ W)[a, b] = (S⁻¹ w_a)·w_b`.
2787 let mut m = m0.clone();
2788 for a in 0..r {
2789 let w_a = w.column(a).to_owned();
2790 let sinv_w_a = cholesky_solve_vector(&factor, &w_a);
2791 for b in 0..r {
2792 m[[a, b]] += sinv_w_a.dot(&w.column(b));
2793 }
2794 }
2795 // Symmetrize to clear back-substitution rounding asymmetry (matches
2796 // `CrossRowWoodbury::build`).
2797 for a in 0..r {
2798 for b in (a + 1)..r {
2799 let avg = 0.5 * (m[[a, b]] + m[[b, a]]);
2800 m[[a, b]] = avg;
2801 m[[b, a]] = avg;
2802 }
2803 }
2804 // Capacitance C = I_R + D·M (row k scaled by d_k), factored by partial-pivot
2805 // LU (`d_k = w·s'_k` is not sign-definite, so C is generally non-symmetric /
2806 // indefinite — exactly the dense carrier's factorization).
2807 let mut c = Array2::<f64>::zeros((r, r));
2808 for a in 0..r {
2809 for b in 0..r {
2810 c[[a, b]] = d[a] * m[[a, b]];
2811 }
2812 c[[a, a]] += 1.0;
2813 }
2814 match small_lu_factor(&c) {
2815 Some(lu) => {
2816 let (log_abs, sign) = lu.log_abs_det_and_sign();
2817 Ok((sign > 0.0).then_some(log_abs))
2818 }
2819 // Exactly-singular capacitance: `H_full` det is 0, the Laplace log-det is
2820 // undefined — surface as the recoverable non-PD probe refusal.
2821 None => Ok(None),
2822 }
2823}
2824
2825impl CrossRowWoodbury {
2826 /// Build the exact rank-`R` cross-row Woodbury carrier from the IBP source
2827 /// and a cache whose per-row factors are the NO-SELF base `H₀'`.
2828 ///
2829 /// Computes `H₀'⁻¹U` (one [`ArrowFactorCache::full_inverse_apply`] back-solve
2830 /// per column, β-RHS zero — the `t`-block of the result is `H₀'⁻¹U`'s
2831 /// column), `M = UᵀH₀'⁻¹U`, and the LU of `C = I_R + D·M`. Returns `None`
2832 /// when the capacitance is exactly singular (the only un-representable case;
2833 /// the caller then proceeds with the bare `H₀'` cache and the cross-row term
2834 /// is absent — never silently inconsistent, since logdet/inverse/adjoint all
2835 /// key off the presence of this carrier).
2836 pub(crate) fn build(
2837 cache: &ArrowFactorCache,
2838 source: &IbpCrossRowSource,
2839 ) -> Result<Option<Self>, ArrowSchurError> {
2840 let r = source.r;
2841 let total_len = cache.delta_t_len();
2842 let u = source.dense_u(total_len);
2843 let d = source.d.clone();
2844 let zero_beta = Array1::<f64>::zeros(cache.k);
2845 // h0inv_u[:, k] = (H₀'⁻¹ U)_t for column k; h0inv_u_beta[:, k] its β-block.
2846 let mut h0inv_u = Array2::<f64>::zeros((total_len, r));
2847 let mut h0inv_u_beta = Array2::<f64>::zeros((cache.k, r));
2848 for k in 0..r {
2849 let col = u.column(k).to_owned();
2850 let (sol_t, sol_beta) = cache.full_inverse_apply(col.view(), zero_beta.view())?;
2851 for g in 0..total_len {
2852 h0inv_u[[g, k]] = sol_t[g];
2853 }
2854 for c in 0..cache.k {
2855 h0inv_u_beta[[c, k]] = sol_beta[c];
2856 }
2857 }
2858 // M = Uᵀ (H₀'⁻¹ U), symmetric R×R. U is sparse (atom-slot supported), so
2859 // contract over the listed entries.
2860 let mut m = Array2::<f64>::zeros((r, r));
2861 for a in 0..r {
2862 for b in 0..r {
2863 let mut acc = 0.0_f64;
2864 for &(g, k, z) in &source.entries {
2865 if k == a {
2866 acc += z * h0inv_u[[g, b]];
2867 }
2868 }
2869 m[[a, b]] = acc;
2870 }
2871 }
2872 // Symmetrize M to clear back-substitution rounding asymmetry.
2873 for a in 0..r {
2874 for b in (a + 1)..r {
2875 let avg = 0.5 * (m[[a, b]] + m[[b, a]]);
2876 m[[a, b]] = avg;
2877 m[[b, a]] = avg;
2878 }
2879 }
2880 // Capacitance C = I_R + D·M (row k scaled by d_k).
2881 let mut c = Array2::<f64>::zeros((r, r));
2882 for a in 0..r {
2883 for b in 0..r {
2884 c[[a, b]] = d[a] * m[[a, b]];
2885 }
2886 c[[a, a]] += 1.0;
2887 }
2888 let Some(capacitance_lu) = small_lu_factor(&c) else {
2889 return Ok(None);
2890 };
2891 Ok(Some(Self {
2892 u,
2893 d,
2894 h0inv_u,
2895 h0inv_u_beta,
2896 m,
2897 capacitance_lu,
2898 entries: source.entries.clone(),
2899 }))
2900 }
2901
2902 /// The sparse `U` entry list `(global_t_index, atom_k, z'_ik)`.
2903 pub(crate) fn source_entries(&self) -> &[(usize, usize, f64)] {
2904 &self.entries
2905 }
2906
2907 /// `C⁻¹ D` as a dense `R × R` matrix (`R` capacitance solves; column `l` is
2908 /// `d_l · C⁻¹ e_l`). Shared by the inverse-diagonal correction and any
2909 /// adjoint trace that needs the selected inverse of the capacitance.
2910 ///
2911 /// Returns `None` when any capacitance solve fails to produce a finite
2912 /// result (#1038); the consumer must surface this as a loud failure rather
2913 /// than propagate a `NaN` into the evidence/gradient.
2914 pub fn capacitance_inv_times_d(&self) -> Option<Array2<f64>> {
2915 let r = self.d.len();
2916 let mut out = Array2::<f64>::zeros((r, r));
2917 let mut e_l = Array1::<f64>::zeros(r);
2918 for l in 0..r {
2919 e_l.fill(0.0);
2920 e_l[l] = 1.0;
2921 let col = self.capacitance_lu.solve(&e_l)?;
2922 for k in 0..r {
2923 out[[k, l]] = col[k] * self.d[l];
2924 }
2925 }
2926 Some(out)
2927 }
2928
2929 /// Subtract the rank-`R` Woodbury term from the latent inverse diagonal:
2930 /// `diag ← diag − diag(H₀'⁻¹U C⁻¹ D Uᵀ H₀'⁻¹)`. With `G = h0inv_u` and
2931 /// `(C⁻¹D) = capacitance_inv_times_d()`, the entry at global index `g` is
2932 /// `Σ_{k,l} G[g,k] (C⁻¹D)[k,l] G[g,l]`.
2933 pub(crate) fn subtract_inverse_diagonal(
2934 &self,
2935 diag: &mut Array1<f64>,
2936 ) -> Result<(), ArrowSchurError> {
2937 let r = self.d.len();
2938 let cinv_d =
2939 self.capacitance_inv_times_d()
2940 .ok_or_else(|| ArrowSchurError::SchurFactorFailed {
2941 reason: "cross-row Woodbury capacitance solve produced a non-finite \
2942 C⁻¹D for the inverse-diagonal correction (#1038): \
2943 singular/ill-conditioned cross-row capacitance"
2944 .to_string(),
2945 })?;
2946 let total_len = self.h0inv_u.nrows();
2947 for g in 0..total_len {
2948 let mut acc = 0.0_f64;
2949 for k in 0..r {
2950 let gk = self.h0inv_u[[g, k]];
2951 if gk == 0.0 {
2952 continue;
2953 }
2954 for l in 0..r {
2955 acc += gk * cinv_d[[k, l]] * self.h0inv_u[[g, l]];
2956 }
2957 }
2958 diag[g] -= acc;
2959 }
2960 Ok(())
2961 }
2962
2963 /// `log det(I_R + D·M)` (the matrix-determinant-lemma correction). Returns
2964 /// `None` when the capacitance LU has a negative determinant — i.e. the
2965 /// implied `H_full` is non-PD, which is a desync the evidence must reject
2966 /// loudly rather than return a complex/`NaN` log-det.
2967 pub fn log_det(&self) -> Option<f64> {
2968 let (log_abs, sign) = self.log_det_correction();
2969 if sign > 0.0 { Some(log_abs) } else { None }
2970 }
2971
2972 /// `log det(I_R + D·M)`: the exact additive correction
2973 /// `log det H_full − log det H₀'` (matrix-determinant lemma). For a genuine
2974 /// PD `H_full` this is real; the LU sign is returned for the caller to
2975 /// surface a non-PD capacitance as an error rather than a silent `NaN`.
2976 pub(crate) fn log_det_correction(&self) -> (f64, f64) {
2977 self.capacitance_lu.log_abs_det_and_sign()
2978 }
2979
2980 /// Apply the rank-`R` inverse correction in place on BOTH arrow blocks:
2981 /// `u ← u − (H₀'⁻¹U) · C⁻¹ · (D Uᵀ (H₀'⁻¹ rhs)_t)`, where `h0inv_rhs_t` is
2982 /// the `t`-block of `H₀'⁻¹ rhs` already computed by the base
2983 /// [`ArrowFactorCache::full_inverse_apply`]. Implements the Woodbury
2984 /// identity `H_full⁻¹ = H₀'⁻¹ − H₀'⁻¹U C⁻¹ D Uᵀ H₀'⁻¹`. `U` has no `β`
2985 /// support so `Uᵀ·v` reads only the `t`-block, but `H₀'⁻¹U` couples to `β`
2986 /// through the Schur complement, so the correction touches `u_beta` too.
2987 ///
2988 /// `entries` lets `Uᵀ·v` be formed over the sparse atom slots.
2989 pub(crate) fn apply_inverse_correction(
2990 &self,
2991 h0inv_rhs_t: ArrayView1<'_, f64>,
2992 entries: &[(usize, usize, f64)],
2993 u_t: &mut Array1<f64>,
2994 u_beta: &mut Array1<f64>,
2995 ) -> Result<(), ArrowSchurError> {
2996 let r = self.d.len();
2997 // p = D Uᵀ (H₀'⁻¹ rhs)_t.
2998 let mut p = Array1::<f64>::zeros(r);
2999 for &(g, k, z) in entries {
3000 p[k] += z * h0inv_rhs_t[g];
3001 }
3002 for k in 0..r {
3003 p[k] *= self.d[k];
3004 }
3005 // q = C⁻¹ p. A non-finite solve is a singular/ill-conditioned cross-row
3006 // capacitance (#1038): fail loudly rather than write `NaN` into the
3007 // Newton step / adjoint solve.
3008 let q =
3009 self.capacitance_lu
3010 .solve(&p)
3011 .ok_or_else(|| ArrowSchurError::SchurFactorFailed {
3012 reason: "cross-row Woodbury capacitance solve produced a non-finite \
3013 C⁻¹p for the inverse correction (#1038): \
3014 singular/ill-conditioned cross-row capacitance"
3015 .to_string(),
3016 })?;
3017 // u_t -= (H₀'⁻¹U)_t · q.
3018 for g in 0..u_t.len() {
3019 let mut acc = 0.0_f64;
3020 for k in 0..r {
3021 acc += self.h0inv_u[[g, k]] * q[k];
3022 }
3023 u_t[g] -= acc;
3024 }
3025 // u_beta -= (H₀'⁻¹U)_β · q.
3026 for c in 0..u_beta.len() {
3027 let mut acc = 0.0_f64;
3028 for k in 0..r {
3029 acc += self.h0inv_u_beta[[c, k]] * q[k];
3030 }
3031 u_beta[c] -= acc;
3032 }
3033 Ok(())
3034 }
3035
3036 /// Forward apply of the rank-`R` cross-row curvature on the latent (`t`)
3037 /// block: `out_t += U D Uᵀ v_t`. This is the EXACT forward of the same
3038 /// `H_full = H₀' + U D Uᵀ` whose inverse [`Self::apply_inverse_correction`]
3039 /// applies and whose log-determinant [`Self::log_det_correction`] reports, so
3040 /// a forward Hessian apply that adds this term stays consistent (operator and
3041 /// preconditioner describe the SAME operator). `U` has no `β` support, so the
3042 /// forward term touches only the `t` block; the `β` coupling is purely an
3043 /// inverse-side Schur artifact (`h0inv_u_beta`) and must NOT appear here.
3044 ///
3045 /// Formed over the sparse `entries` `(global_t_index, atom_k, z'_ik)` so it
3046 /// matches `Uᵀ·v` in `apply_inverse_correction` bit-for-bit:
3047 /// `p_k = d_k · Σ_{(g,k,z)} z·v_t[g]`, then `out_t[g] += Σ_{(g,k,z)} z·p_k`.
3048 pub fn apply_forward_t(&self, v_t: ArrayView1<'_, f64>, out_t: &mut Array1<f64>) {
3049 let r = self.d.len();
3050 // p = D Uᵀ v_t.
3051 let mut p = Array1::<f64>::zeros(r);
3052 for &(g, k, z) in &self.entries {
3053 p[k] += z * v_t[g];
3054 }
3055 for k in 0..r {
3056 p[k] *= self.d[k];
3057 }
3058 // out_t += U p.
3059 for &(g, k, z) in &self.entries {
3060 out_t[g] += z * p[k];
3061 }
3062 }
3063}
3064
3065#[derive(Debug, Clone, Copy, PartialEq)]
3066pub struct ArrowFactorMinPivot {
3067 pub min_row_pivot: Option<f64>,
3068 pub min_schur_pivot: Option<f64>,
3069 pub min_pivot: Option<f64>,
3070}
3071
3072impl ArrowFactorMinPivot {
3073 pub(crate) fn combine(row: Option<f64>, schur: Option<f64>) -> Self {
3074 let min_pivot = match (row, schur) {
3075 (Some(a), Some(b)) => Some(a.min(b)),
3076 (Some(a), None) => Some(a),
3077 (None, Some(b)) => Some(b),
3078 (None, None) => None,
3079 };
3080 Self {
3081 min_row_pivot: row,
3082 min_schur_pivot: schur,
3083 min_pivot,
3084 }
3085 }
3086}
3087
3088pub(crate) fn lower_cholesky_min_pivot(factor: ArrayView2<'_, f64>) -> Option<f64> {
3089 let width = factor.nrows().min(factor.ncols());
3090 let mut out = None;
3091 for idx in 0..width {
3092 let pivot = factor[[idx, idx]] * factor[[idx, idx]];
3093 out = Some(match out {
3094 Some(current) => f64::min(current, pivot),
3095 None => pivot,
3096 });
3097 }
3098 out
3099}
3100
3101pub(crate) fn lower_cholesky_max_pivot(factor: ArrayView2<'_, f64>) -> Option<f64> {
3102 let width = factor.nrows().min(factor.ncols());
3103 let mut out = None;
3104 for idx in 0..width {
3105 let pivot = factor[[idx, idx]] * factor[[idx, idx]];
3106 out = Some(match out {
3107 Some(current) => f64::max(current, pivot),
3108 None => pivot,
3109 });
3110 }
3111 out
3112}
3113
3114/// Smallest cached Cholesky pivot for row blocks and the dense Schur factor.
3115///
3116/// Pivots are returned as squared lower-factor diagonals, matching the Hessian
3117/// scale rather than the Cholesky-factor scale. In inexact PCG mode the dense
3118/// Schur factor is absent, so `min_schur_pivot` is `None`.
3119pub fn arrow_factor_min_pivot(cache: &ArrowFactorCache) -> ArrowFactorMinPivot {
3120 let mut min_row_pivot = None;
3121 for factor in cache.htt_factors.iter() {
3122 if let Some(pivot) = lower_cholesky_min_pivot(factor) {
3123 min_row_pivot = Some(match min_row_pivot {
3124 Some(current) => f64::min(current, pivot),
3125 None => pivot,
3126 });
3127 }
3128 }
3129 let min_schur_pivot = cache
3130 .schur_factor
3131 .as_ref()
3132 .and_then(|factor| lower_cholesky_min_pivot(factor.view()));
3133 ArrowFactorMinPivot::combine(min_row_pivot, min_schur_pivot)
3134}
3135
3136/// Largest cached Cholesky pivot across the row blocks and the dense Schur
3137/// factor (Hessian scale, i.e. squared lower-factor diagonal). This is the
3138/// diagonal magnitude scale a safe-SPD pivot floor is measured against: the
3139/// curvature-homotopy tracker (#1007) compares the min pivot against
3140/// `√eps · max(this, 1)`, the same floor the inner solver's
3141/// [`safe_spd_pivot_min`] uses. `None` only for an empty cache.
3142pub fn arrow_factor_max_pivot(cache: &ArrowFactorCache) -> Option<f64> {
3143 let mut max_pivot: Option<f64> = None;
3144 for factor in cache.htt_factors.iter() {
3145 if let Some(pivot) = lower_cholesky_max_pivot(factor) {
3146 max_pivot = Some(match max_pivot {
3147 Some(current) => f64::max(current, pivot),
3148 None => pivot,
3149 });
3150 }
3151 }
3152 if let Some(factor) = cache.schur_factor.as_ref()
3153 && let Some(pivot) = lower_cholesky_max_pivot(factor.view())
3154 {
3155 max_pivot = Some(match max_pivot {
3156 Some(current) => f64::max(current, pivot),
3157 None => pivot,
3158 });
3159 }
3160 max_pivot
3161}
3162
3163impl ArrowFactorCache {
3164 pub fn n_rows(&self) -> usize {
3165 self.htt_factors.len()
3166 }
3167
3168 pub fn htbeta_available(&self) -> bool {
3169 self.htbeta.is_available()
3170 }
3171
3172 /// Whether the Newton solve that produced this cache actually executed on
3173 /// the device: the device-resident Direct dense solve or the device-resident
3174 /// matrix-free SAE PCG (whose matvec runs in CUDA kernels). This does NOT
3175 /// include the injected host-procedural reduced-Schur matvec, whose
3176 /// arithmetic runs on the CPU even when a CUDA context was opened to build
3177 /// per-row factors (#1209) — that path sets
3178 /// `PcgDiagnostics::injected_host_procedural_matvec` instead. Read-only
3179 /// routing provenance: lets a fit result record device-vs-CPU as ground
3180 /// truth instead of inferring it from the runtime probe. Mirrors
3181 /// `PcgDiagnostics::used_device_arrow`.
3182 #[must_use]
3183 pub fn used_device(&self) -> bool {
3184 self.pcg_diagnostics.used_device_arrow
3185 }
3186
3187 pub fn undamped_factor(&self, row: usize) -> ArrayView2<'_, f64> {
3188 match &self.htt_factors_undamped {
3189 ArrowUndampedFactors::SameAsDamped => self.htt_factors.factor(row),
3190 ArrowUndampedFactors::Owned(factors) => factors.factor(row),
3191 }
3192 }
3193
3194 pub fn undamped_factor_count(&self) -> usize {
3195 match &self.htt_factors_undamped {
3196 ArrowUndampedFactors::SameAsDamped => self.htt_factors.len(),
3197 ArrowUndampedFactors::Owned(factors) => factors.len(),
3198 }
3199 }
3200
3201 pub fn undamped_factors_iter(&self) -> impl Iterator<Item = ArrayView2<'_, f64>> + '_ {
3202 (0..self.undamped_factor_count()).map(|row| self.undamped_factor(row))
3203 }
3204
3205 pub fn compute_undamped_arrow_log_det(&self) -> Option<f64> {
3206 if self.ridge_t != 0.0 || self.ridge_beta != 0.0 {
3207 return None;
3208 }
3209 // When the shared β block is empty (`k == 0`) the joint Hessian is
3210 // exactly the block diagonal of the per-row latent blocks: there is no
3211 // reduced Schur complement to form, so the dense Direct path leaves
3212 // `schur_factor = None` legitimately (not the InexactPCG "never formed
3213 // the dense K×K factor" case, which has `k > 0`). The log-det is then
3214 // the per-row sum with a zero (empty `0×0`) Schur contribution. Without
3215 // this the `schur_factor.as_ref()?` below would return `None` for a
3216 // β-profiled atom (#1132 euclidean K=4) and starve the REML Laplace
3217 // normaliser of the joint Hessian log-det it requires.
3218 let schur = match self.schur_factor.as_ref() {
3219 Some(schur) => Some(schur),
3220 None if self.k == 0 => None,
3221 None => return None,
3222 };
3223
3224 let mut acc = 0.0_f64;
3225 for l in self.undamped_factors_iter() {
3226 for i in 0..l.nrows() {
3227 let d = l[[i, i]];
3228 if d <= 0.0 || !d.is_finite() {
3229 return None;
3230 }
3231 acc += 2.0 * d.ln();
3232 }
3233 }
3234 if let Some(schur) = schur {
3235 for i in 0..schur.nrows() {
3236 let d = schur[[i, i]];
3237 if d <= 0.0 || !d.is_finite() {
3238 return None;
3239 }
3240 acc += 2.0 * d.ln();
3241 }
3242 }
3243 let woodbury_correction = self.cross_row_woodbury_log_det();
3244 if !woodbury_correction.is_finite() {
3245 return None;
3246 }
3247 Some(acc + woodbury_correction)
3248 }
3249
3250 /// The total length of `delta_t` / IFT output vectors for this cache.
3251 pub fn delta_t_len(&self) -> usize {
3252 self.row_offsets[self.n_rows()]
3253 }
3254
3255 pub fn apply_htbeta_row(
3256 &self,
3257 row: usize,
3258 delta_beta: ArrayView1<'_, f64>,
3259 out: &mut Array1<f64>,
3260 ) -> bool {
3261 let di = if row < self.row_dims.len() {
3262 self.row_dims[row]
3263 } else {
3264 self.d
3265 };
3266 if out.len() != di || delta_beta.len() != self.k {
3267 return false;
3268 }
3269 self.htbeta.apply_row(row, delta_beta, out)
3270 }
3271
3272 /// Accumulate `out[a] += H_βt^(row)[a, :] · v` for all `a in 0..k`.
3273 ///
3274 /// `v` has length `row_dims[row]`; `out` has length `k`. The caller must
3275 /// zero `out` before the first call if it needs a fresh result. Returns
3276 /// `false` when the cache is `Disabled` and no `fallback_op` is provided;
3277 /// callers must treat the accumulator as invalid in that case.
3278 pub fn apply_htbeta_row_transpose(
3279 &self,
3280 row: usize,
3281 v: ArrayView1<'_, f64>,
3282 out: &mut Array1<f64>,
3283 fallback_op: Option<&RowHtbetaMatvec>,
3284 ) -> bool {
3285 let di = if row < self.row_dims.len() {
3286 self.row_dims[row]
3287 } else {
3288 self.d
3289 };
3290 if v.len() != di || out.len() != self.k {
3291 return false;
3292 }
3293 self.htbeta
3294 .apply_row_transpose_accumulate(row, v, out, di, self.k, fallback_op)
3295 }
3296
3297 /// Arrow log-determinant
3298 /// `log|H| = Σ_i log|H_{t_i t_i}| + log|Schur_β|`
3299 /// using the cached (damped) factors.
3300 ///
3301 /// Returns `(log_det_tt_sum, log_det_schur)` so the caller can decide
3302 /// what to do with the Schur piece (e.g. REML evidence wants both;
3303 /// some diagnostics want only the per-row sum). `None` for the Schur
3304 /// piece signals that the cache was produced by an InexactPCG solve
3305 /// and never formed/factored the dense `K × K` reduced system.
3306 ///
3307 /// The log-determinant of a Cholesky factor `L` of `M` is
3308 /// `2 Σ log L_ii`.
3309 pub fn arrow_log_det(&self) -> (f64, Option<f64>) {
3310 let mut log_det_tt = 0.0_f64;
3311 for l in self.htt_factors.iter() {
3312 for i in 0..l.nrows() {
3313 log_det_tt += l[[i, i]].ln();
3314 }
3315 }
3316 log_det_tt *= 2.0;
3317 let log_det_schur = self.schur_factor.as_ref().map(|l| {
3318 let mut s = 0.0_f64;
3319 for i in 0..l.nrows() {
3320 s += l[[i, i]].ln();
3321 }
3322 2.0 * s + self.cross_row_woodbury_log_det()
3323 });
3324 (log_det_tt, log_det_schur)
3325 }
3326
3327 /// The exact cross-row IBP correction `log det(I_R + D·M)` to add to the
3328 /// base `log det H₀'` (#1038). Zero when no [`CrossRowWoodbury`] is present,
3329 /// so non-IBP caches are unaffected. The determinant lemma gives
3330 /// `log det H_full = log det H₀' + log det(I_R + D Uᵀ H₀'⁻¹ U)`; this is the
3331 /// second term, the only piece beyond the bare arrow log-determinant.
3332 ///
3333 /// Panics-free: a negative capacitance determinant (non-PD `H_full`) yields
3334 /// `NaN` here so the evidence surfaces the desync rather than silently
3335 /// dropping the imaginary part. Callers that must reject it should check
3336 /// [`CrossRowWoodbury::log_det`] directly.
3337 pub fn cross_row_woodbury_log_det(&self) -> f64 {
3338 match self.cross_row_woodbury.as_ref() {
3339 Some(w) => w.log_det().unwrap_or(f64::NAN),
3340 None => 0.0,
3341 }
3342 }
3343
3344 /// Diagonal of the latent (`t`-block) of the *full* bordered-arrow
3345 /// inverse `(H⁻¹)_tt`, in `delta_t` layout (length [`Self::delta_t_len`]).
3346 ///
3347 /// For the bordered arrow Hessian
3348 /// `H = [[A, B], [Bᵀ, H_ββ]]` with `A = H_tt` (block-diagonal per row,
3349 /// `A_i = H_tt^(i)`) and `B = H_tβ`, the standard block-inverse identity
3350 /// gives the `t`-block
3351 /// `(H⁻¹)_tt = A⁻¹ + A⁻¹ B S⁻¹ Bᵀ A⁻¹`, where
3352 /// `S = H_ββ − Bᵀ A⁻¹ B` is the Schur complement on `β`. Because `A` is
3353 /// block-diagonal, the `(i, j)` diagonal entry of `(H⁻¹)_tt` is computed
3354 /// purely from row `i`'s factor and cross-block:
3355 ///
3356 /// ```text
3357 /// a = A_i⁻¹ e_j (chol_solve on the per-row factor)
3358 /// [A_i⁻¹]_{jj} = a[j]
3359 /// w = B_iᵀ a = H_βt^(i) a (a K-vector)
3360 /// z = S⁻¹ w (chol_solve on the Schur factor)
3361 /// diag = a[j] + w · z
3362 /// ```
3363 ///
3364 /// The UNDAMPED per-row factors ([`Self::undamped_factor`]) are used so
3365 /// the result is the inverse of the *true* `H_tt`, not the LM-damped
3366 /// `H_tt + ridge_t·I` — same rationale the IFT predictor docstring gives
3367 /// at the top of this struct.
3368 ///
3369 /// # Consuming the diagonal as a per-(atom, axis) trace
3370 ///
3371 /// `(H⁻¹)_tt` is the latent covariance block. The selected-inverse trace
3372 /// for a contiguous group of latent coordinates (e.g. one atom's rows, or
3373 /// one axis across rows) is simply the sum of the returned diagonal entries
3374 /// over those `row_offsets[i] + j` indices — no off-diagonal terms are
3375 /// needed for the trace `tr[(H⁻¹)_tt · D]` against a diagonal selector `D`.
3376 ///
3377 /// # Errors
3378 ///
3379 /// Returns [`ArrowSchurError::SchurFactorFailed`] when this cache has no
3380 /// dense Schur factor or no usable `H_βt` coupling — i.e. it was produced
3381 /// by an [`ArrowSolverMode::InexactPCG`] solve (no dense `K × K` factor) or
3382 /// by a `Disabled` `htbeta` cache. The selected-inverse block-trace is not
3383 /// yet supported for the matrix-free PCG mode; that branch needs a separate
3384 /// Lanczos/Hutchinson estimator.
3385 pub fn latent_block_inverse_diagonal(&self) -> Result<Array1<f64>, ArrowSchurError> {
3386 let Some(schur_factor) = self.schur_factor.as_ref() else {
3387 return Err(ArrowSchurError::SchurFactorFailed {
3388 reason: "latent_block_inverse_diagonal requires a dense Schur factor; \
3389 the InexactPCG mode does not form one"
3390 .to_string(),
3391 });
3392 };
3393 if !self.htbeta_available() {
3394 return Err(ArrowSchurError::SchurFactorFailed {
3395 reason: "latent_block_inverse_diagonal requires the H_tβ coupling, \
3396 but this cache's htbeta is Disabled"
3397 .to_string(),
3398 });
3399 }
3400 let n = self.undamped_factor_count();
3401 let total_len = self.delta_t_len();
3402 let mut out = Array1::<f64>::zeros(total_len);
3403 // Per-row scratch, sized to the max latent dim / K.
3404 let mut e_j = Array1::<f64>::zeros(self.d);
3405 let mut w = Array1::<f64>::zeros(self.k);
3406 for i in 0..n {
3407 let di = self.row_dims[i];
3408 let row_base = self.row_offsets[i];
3409 let factor = self.undamped_factor(i);
3410 for j in 0..di {
3411 // a = A_i⁻¹ e_j.
3412 for c in 0..di {
3413 e_j[c] = 0.0;
3414 }
3415 e_j[j] = 1.0;
3416 let e_j_slice = e_j.slice(ndarray::s![..di]).to_owned();
3417 let a = cholesky_solve_vector(factor, &e_j_slice);
3418 // w = H_βt^(i) a (a K-vector); accumulator must start zeroed.
3419 w.fill(0.0);
3420 if !self.apply_htbeta_row_transpose(i, a.view(), &mut w, None) {
3421 return Err(ArrowSchurError::SchurFactorFailed {
3422 reason: format!(
3423 "latent_block_inverse_diagonal: H_βt^({i}) apply failed \
3424 (htbeta cache could not supply row {i})"
3425 ),
3426 });
3427 }
3428 // z = S⁻¹ w; correction = w · z.
3429 let z = cholesky_solve_vector(schur_factor, &w);
3430 let mut corr = 0.0_f64;
3431 for c in 0..self.k {
3432 corr += w[c] * z[c];
3433 }
3434 out[row_base + j] = a[j] + corr;
3435 }
3436 }
3437 if let Some(woodbury) = self.cross_row_woodbury.as_ref() {
3438 // #1038: the factors above are `H₀'`, so `out` is diag((H₀'⁻¹)_tt).
3439 // The full inverse diagonal subtracts the rank-`R` Woodbury term
3440 // diag(H₀'⁻¹U C⁻¹ D Uᵀ H₀'⁻¹). With `G = h0inv_u = (H₀'⁻¹U)_t` and
3441 // (by symmetry of `H₀'⁻¹`) `(Uᵀ H₀'⁻¹)_t = Gᵀ`, the diagonal entry at
3442 // global index `g` is `Σ_{k,l} G[g,k] (C⁻¹D)[k,l] G[g,l]`. Form the
3443 // `R×R` matrix `C⁻¹D` once (R solves), then contract per row index.
3444 woodbury.subtract_inverse_diagonal(&mut out)?;
3445 }
3446 Ok(out)
3447 }
3448
3449 /// Solve the full bordered-arrow system `H·u = w` on the cached factor
3450 /// (#1006): `w` arrives in arrow layout — `w_t` flat per
3451 /// [`Self::delta_t_len`] / `row_offsets`, `w_beta` of length `K` — and the
3452 /// solution comes back in the same layout. Standard block elimination on
3453 /// the SAME factors whose log-determinant the evidence reports:
3454 ///
3455 /// ```text
3456 /// y_i = H_tt^(i)⁻¹ · w_t^(i)
3457 /// r_β = w_β − Σ_i H_βt^(i) · y_i
3458 /// u_β = Schur⁻¹ · r_β
3459 /// u_t^(i) = y_i − H_tt^(i)⁻¹ · (H_tβ^(i) · u_β)
3460 /// ```
3461 ///
3462 /// This is the IFT / adjoint back-solve the analytic outer ρ-gradient
3463 /// consumes: `u_j = H⁻¹ (∂g/∂ρ_j)` per outer coordinate and the
3464 /// `H⁻¹`-side of the third-order correction `−½·Γᵀ·H⁻¹·(∂g/∂ρ_j)`.
3465 /// Contract: the cache must be the ridge-0 Direct evidence factor
3466 /// (undamped per-row factors + dense Schur), so the solve is against the
3467 /// criterion's own `H` — never a damped surrogate (that would desync the
3468 /// gradient from the reported evidence).
3469 ///
3470 /// When the cache carries an exact cross-row IBP
3471 /// [`CrossRowWoodbury`] (#1038), the per-row factors are the NO-SELF base
3472 /// `H₀'` and this method layers the rank-`R` Woodbury correction so the
3473 /// returned solve is against the FULL `H_full = H₀' + U D Uᵀ` — the same
3474 /// operator whose log-determinant [`Self::arrow_log_det`] reports. The
3475 /// θ/ρ-adjoint that consumes this therefore sees the cross-row curvature.
3476 pub fn full_inverse_apply(
3477 &self,
3478 w_t: ArrayView1<'_, f64>,
3479 w_beta: ArrayView1<'_, f64>,
3480 ) -> Result<(Array1<f64>, Array1<f64>), ArrowSchurError> {
3481 let (mut u_t, mut u_beta) = self.full_inverse_apply_base(w_t, w_beta)?;
3482 if let Some(woodbury) = self.cross_row_woodbury.as_ref() {
3483 // u ← u − H₀'⁻¹U C⁻¹ D Uᵀ u. `u_t` is the `t`-block of `H₀'⁻¹ w`.
3484 let h0inv_w_t = u_t.clone();
3485 woodbury.apply_inverse_correction(
3486 h0inv_w_t.view(),
3487 woodbury.source_entries(),
3488 &mut u_t,
3489 &mut u_beta,
3490 )?;
3491 }
3492 Ok((u_t, u_beta))
3493 }
3494
3495 /// Bare bordered-arrow inverse solve against the cached per-row factors and
3496 /// Schur factor (the NO-SELF base `H₀'` when a cross-row Woodbury is
3497 /// present). [`Self::full_inverse_apply`] wraps this with the rank-`R`
3498 /// correction; [`CrossRowWoodbury::build`] calls this directly (before the
3499 /// carrier exists) to form `H₀'⁻¹U`.
3500 pub(crate) fn full_inverse_apply_base(
3501 &self,
3502 w_t: ArrayView1<'_, f64>,
3503 w_beta: ArrayView1<'_, f64>,
3504 ) -> Result<(Array1<f64>, Array1<f64>), ArrowSchurError> {
3505 let total_len = self.delta_t_len();
3506 if w_t.len() != total_len || w_beta.len() != self.k {
3507 return Err(ArrowSchurError::SchurFactorFailed {
3508 reason: format!(
3509 "full_inverse_apply: rhs shapes (w_t={}, w_beta={}) != (delta_t_len={}, K={})",
3510 w_t.len(),
3511 w_beta.len(),
3512 total_len,
3513 self.k
3514 ),
3515 });
3516 }
3517 let n = self.undamped_factor_count();
3518 // Forward pass: y_i = H_tt^(i)⁻¹ w_t^(i), accumulating the border RHS.
3519 let mut y = Array1::<f64>::zeros(total_len);
3520 let mut r_beta = w_beta.to_owned();
3521 for i in 0..n {
3522 let di = self.row_dims[i];
3523 let base = self.row_offsets[i];
3524 let factor = self.undamped_factor(i);
3525 let w_row = w_t.slice(ndarray::s![base..base + di]).to_owned();
3526 let y_row = cholesky_solve_vector(factor, &w_row);
3527 if self.k > 0 {
3528 // r_β −= H_βt^(i) y_i: accumulate into a scratch then subtract,
3529 // because the helper ACCUMULATES (+=) into its output.
3530 let mut acc = Array1::<f64>::zeros(self.k);
3531 if !self.apply_htbeta_row_transpose(i, y_row.view(), &mut acc, None) {
3532 return Err(ArrowSchurError::SchurFactorFailed {
3533 reason: format!(
3534 "full_inverse_apply: H_βt^({i}) apply failed (htbeta cache \
3535 could not supply row {i}; htbeta={:?}, di={}, k={})",
3536 self.htbeta,
3537 self.row_dims.get(i).copied().unwrap_or(self.d),
3538 self.k
3539 ),
3540 });
3541 }
3542 for c in 0..self.k {
3543 r_beta[c] -= acc[c];
3544 }
3545 }
3546 for j in 0..di {
3547 y[base + j] = y_row[j];
3548 }
3549 }
3550 // Border solve + back-substitution.
3551 let u_beta = if self.k > 0 {
3552 self.schur_inverse_apply(r_beta.view())?
3553 } else {
3554 Array1::<f64>::zeros(0)
3555 };
3556 let mut u_t = y;
3557 if self.k > 0 {
3558 let mut cross = Array1::<f64>::zeros(self.d);
3559 for i in 0..n {
3560 let di = self.row_dims[i];
3561 let base = self.row_offsets[i];
3562 let mut cross_row = cross.slice_mut(ndarray::s![..di]);
3563 cross_row.fill(0.0);
3564 let mut cross_owned = cross_row.to_owned();
3565 if !self.apply_htbeta_row(i, u_beta.view(), &mut cross_owned) {
3566 return Err(ArrowSchurError::SchurFactorFailed {
3567 reason: format!(
3568 "full_inverse_apply: H_tβ^({i}) apply failed (htbeta cache \
3569 could not supply row {i})"
3570 ),
3571 });
3572 }
3573 let factor = self.undamped_factor(i);
3574 let corr = cholesky_solve_vector(factor, &cross_owned);
3575 for j in 0..di {
3576 u_t[base + j] -= corr[j];
3577 }
3578 }
3579 }
3580 Ok((u_t, u_beta))
3581 }
3582
3583 /// Apply the β-block of the full inverse, `(H⁻¹)_ββ · rhs = S_β⁻¹ · rhs`,
3584 /// where `S_β` is the Schur complement on β whose Cholesky factor this
3585 /// cache holds in [`Self::schur_factor`].
3586 ///
3587 /// For the bordered arrow Hessian `H = [[A, B], [Bᵀ, H_ββ]]`, the
3588 /// β-block of `H⁻¹` is exactly the inverse of the Schur complement
3589 /// `S_β = H_ββ − Bᵀ A⁻¹ B`. One Cholesky back-substitution per call,
3590 /// reusing the cached factor; `rhs` and the returned vector both have
3591 /// length `K`.
3592 ///
3593 /// This is the general single-solve primitive for the β border. Callers
3594 /// that need a Schur-inverse trace `tr(S_β⁻¹ M)` against a structured
3595 /// penalty `M` (e.g. the SAE λ_smooth Fellner-Schall step, where
3596 /// `M = blockdiag_k(λ_k S_k ⊗ I_p)`) build it as
3597 /// `Σ_col e_colᵀ S_β⁻¹ M e_col` — apply this to each column of `M`
3598 /// (exploiting whatever sparsity `M` has) and read off `result[col]`.
3599 /// Keeping `M`'s layout on the caller side avoids coupling this solver
3600 /// to penalty-op types.
3601 ///
3602 /// # Errors
3603 ///
3604 /// Returns [`ArrowSchurError::SchurFactorFailed`] when this cache has no
3605 /// dense Schur factor (an [`ArrowSolverMode::InexactPCG`] solve) — the
3606 /// same not-yet-supported branch as [`Self::latent_block_inverse_diagonal`]
3607 /// — or when `rhs.len() != k`.
3608 ///
3609 /// Cross-row IBP (#1038) note: this is the β-block primitive of the
3610 /// factored base `S_β` (`H₀'` when a [`CrossRowWoodbury`] is present), used
3611 /// internally by [`Self::full_inverse_apply_base`]; it is deliberately NOT
3612 /// Woodbury-corrected so the base solve stays bare. The cross-row term has
3613 /// no `β` support, so `(H_full⁻¹)_ββ = S_β⁻¹` exactly on the directions any
3614 /// IBP ρ-trace contracts. A consumer needing the full `(H_full⁻¹)_ββ` for a
3615 /// β-supported direction should call [`Self::full_inverse_apply`] with a
3616 /// unit `β`-RHS (which applies the rank-`R` correction).
3617 pub fn schur_inverse_apply(
3618 &self,
3619 rhs: ArrayView1<'_, f64>,
3620 ) -> Result<Array1<f64>, ArrowSchurError> {
3621 let Some(schur_factor) = self.schur_factor.as_ref() else {
3622 return Err(ArrowSchurError::SchurFactorFailed {
3623 reason: "schur_inverse_apply requires a dense Schur factor; \
3624 the InexactPCG mode does not form one"
3625 .to_string(),
3626 });
3627 };
3628 if rhs.len() != self.k {
3629 return Err(ArrowSchurError::SchurFactorFailed {
3630 reason: format!(
3631 "schur_inverse_apply: rhs length {} != K {}",
3632 rhs.len(),
3633 self.k
3634 ),
3635 });
3636 }
3637 let rhs_owned = rhs.to_owned();
3638 Ok(cholesky_solve_vector(schur_factor, &rhs_owned))
3639 }
3640
3641 /// Dense principal sub-block of the β-block of the full inverse,
3642 /// `(H⁻¹)_ββ[block, block] = S_β⁻¹[block, block]`, shape `(W, W)` with
3643 /// `W = block.len()`.
3644 ///
3645 /// For the bordered arrow Hessian `H = [[A, B], [Bᵀ, H_ββ]]`, the β-block
3646 /// of `H⁻¹` is exactly `S_β⁻¹` (the inverse of the Schur complement whose
3647 /// Cholesky factor this cache holds). This returns the contiguous
3648 /// `block × block` sub-block — e.g. one SAE atom's decoder coefficients via
3649 /// [`gam_terms::sae::manifold::SaeManifoldTerm::beta_block_offsets`] — by
3650 /// solving `S_β x = e_j` for each `j ∈ block` (reusing the cached factor)
3651 /// and gathering the `block` rows of each solution column. `W`
3652 /// back-substitutions of size `K`; the result is symmetrized to clear
3653 /// back-substitution rounding asymmetry. Up to a dispersion scale `φ`, this
3654 /// block is the joint posterior covariance `Cov(β_block)` of those
3655 /// coefficients with the latent coordinates already marginalized out (that
3656 /// is precisely what Schur-eliminating the per-row `t`-blocks does).
3657 ///
3658 /// Same dense-Schur requirement / error contract as
3659 /// [`Self::schur_inverse_apply`]; additionally errors when `block` runs past
3660 /// `K`.
3661 pub fn schur_inverse_block(
3662 &self,
3663 block: std::ops::Range<usize>,
3664 ) -> Result<Array2<f64>, ArrowSchurError> {
3665 let Some(schur_factor) = self.schur_factor.as_ref() else {
3666 return Err(ArrowSchurError::SchurFactorFailed {
3667 reason: "schur_inverse_block requires a dense Schur factor; \
3668 the InexactPCG mode does not form one"
3669 .to_string(),
3670 });
3671 };
3672 if block.end > self.k {
3673 return Err(ArrowSchurError::SchurFactorFailed {
3674 reason: format!(
3675 "schur_inverse_block: block end {} exceeds K {}",
3676 block.end, self.k
3677 ),
3678 });
3679 }
3680 let w = block.len();
3681 let mut out = Array2::<f64>::zeros((w, w));
3682 let mut e_j = Array1::<f64>::zeros(self.k);
3683 for (jc, j) in block.clone().enumerate() {
3684 e_j.fill(0.0);
3685 e_j[j] = 1.0;
3686 let col = cholesky_solve_vector(schur_factor, &e_j);
3687 for (ic, i) in block.clone().enumerate() {
3688 out[[ic, jc]] = col[i];
3689 }
3690 }
3691 // S_β⁻¹ is symmetric; symmetrize to clear back-substitution rounding.
3692 for ic in 0..w {
3693 for jc in (ic + 1)..w {
3694 let avg = 0.5 * (out[[ic, jc]] + out[[jc, ic]]);
3695 out[[ic, jc]] = avg;
3696 out[[jc, ic]] = avg;
3697 }
3698 }
3699 Ok(out)
3700 }
3701}