Skip to main content

ArrowFactorCache

Struct ArrowFactorCache 

Source
pub struct ArrowFactorCache {
Show 19 fields pub htt_factors: ArrowFactorSlab, pub htt_factors_undamped: ArrowUndampedFactors, pub schur_factor: Option<Array2<f64>>, pub joint_hessian_log_det: Option<f64>, pub solver_mode: ArrowSolverMode, pub ridge_t: f64, pub ridge_beta: f64, pub htbeta: ArrowHtbetaCache, pub d: usize, pub row_dims: Arc<[usize]>, pub row_offsets: Arc<[usize]>, pub k: usize, pub manifold_mode_fingerprint: u64, pub row_hessian_fingerprint: u64, pub pcg_diagnostics: PcgDiagnostics, pub gauge_deflated_directions: usize, pub deflated_row_directions: Arc<[Vec<Array1<f64>>]>, pub deflation_row_spectra: Arc<[Option<RowDeflationSpectrum>]>, pub cross_row_woodbury: Option<CrossRowWoodbury>,
}

Fields§

§htt_factors: ArrowFactorSlab

Per-row lower-triangular Cholesky factors of H_tt^(i) + ridge_t·I.

These are the damped factors used inside the Newton solve. The IFT predictor must NOT use them — see Self::htt_factors_undamped.

§htt_factors_undamped: ArrowUndampedFactors

Per-row lower-triangular Cholesky factors of the UNDAMPED H_tt^(i) (no ridge_t added).

The IFT predictor formula Δt_i = -(H_tt^(i))⁻¹ · (H_tβ^(i) Δβ + δg_t^(i)) is derived from ∂g_t/∂t = H_tt at the stationary point, with no LM damping term. Reusing the damped factors would bias the predicted shift toward zero in proportion to ridge_t. We pay one extra O(N d³) Cholesky per Newton solve — the same complexity class as the Newton solve itself — to make the IFT exact.

§schur_factor: Option<Array2<f64>>

Lower-triangular Cholesky factor of the Schur complement when the selected BA mode formed/factored dense RCS. None for ArrowSolverMode::InexactPCG, where Agarwal-style inexact LM avoids the dense K × K factor.

§joint_hessian_log_det: Option<f64>

Exact undamped joint-Hessian log-determinant produced by the dense factorization path. REML evidence consumes this directly so the Laplace normalizer cannot miss the log-det even when later cache consumers only need solves/traces.

§solver_mode: ArrowSolverMode

BA mode used to create this cache.

§ridge_t: f64

Ridge values used to build the cached factors (recorded so the warm-start predictor knows whether the cache is still valid for a requested ridge level).

§ridge_beta: f64§htbeta: ArrowHtbetaCache

Per-row cross-block access for H_tβ^(i) x.

Large caches retain a row matvec callback or disable β-coupled IFT prediction instead of cloning every dense d × K slab.

§d: usize

Maximum per-row latent dim (upper bound; matches sys.d at creation).

§row_dims: Arc<[usize]>

Per-row latent dims: row_dims[i] is the active dim for row i.

§row_offsets: Arc<[usize]>

Flat-buffer row offsets for delta_t / IFT output vectors. row_offsets[i] is the start of row i; row_offsets[n] is the total length.

§k: usize

β dimensionality K.

§manifold_mode_fingerprint: u64

Geometry tag for the row-local factors and cross-blocks.

§row_hessian_fingerprint: u64

Row-system tag for the cached per-row factors, cross-blocks, and shared-block diagonal used to build the Schur factor.

§pcg_diagnostics: PcgDiagnostics

PCG instrumentation from the solve that produced this cache.

Zero-valued (default) when the selected mode did not use PCG (i.e. Direct or SqrtBA).

§gauge_deflated_directions: usize

Number of row-local gauge directions stiffened in an undamped evidence factorization.

Each direction is stiffened at UNIT stiffness kappa = 1.0, so it contributes log(1) = 0 to the row-block logdet through the returned Cholesky factor: the gauge orbit is a criterion null direction and adds nothing to the Laplace normalizer (the quotient pseudo-determinant convention, cf. PenaltyPseudologdet). Zero theta/rho dependence.

§deflated_row_directions: Arc<[Vec<Array1<f64>>]>

Per-row unit-norm directions vᵢ (in each row’s d-dim latent block coordinates) that an undamped evidence factorization stiffened to UNIT stiffness λ̃ = 1 (gauge or spectral deflation). Indexed by row; empty for every PD row factored without deflation, and empty overall on the non-deflating solver paths (streaming / cross-row-penalty CG / device).

A deflated direction contributes log(1) = 0 to the row-block log-det and is ρ/θ-INDEPENDENT, so its true contribution to ∂log|H|/∂ρ is 0. The analytic outer-gradient traces (assignment_log_strength_hessian_trace, learnable_ibp_data_logdet_alpha_trace, logdet_theta_adjoint) contract ∂H_raw/∂ρ (the RAW, pre-deflation block derivative) against the DEFLATED inverse, which assigns 1/λ̃ = 1 to each vᵢ and therefore spuriously adds ½ vᵢᵀ (∂H_raw/∂ρ) vᵢ. Those traces subtract this per-row term (kept-subspace restriction) using these directions; without them the REML outer ρ-gradient is biased by +Σ_deflated ½ vᵢᵀ ∂H_raw/∂ρ vᵢ.

§deflation_row_spectra: Arc<[Option<RowDeflationSpectrum>]>

Per-row RAW spectral decomposition of an undamped evidence H_tt block that underwent SPECTRAL deflation, surfaced so the outer ρ/θ-gradient traces can apply the EXACT deflation-map (Daleckii–Krein) derivative correction, not just the within-row kept-subspace term.

The criterion VALUE re-deflates H_tt at every ρ, so its gradient is tr(H_deflated⁻¹ DΦ[∂H_raw/∂ρ]), where Φ is the spectral pin-to-unit map. By Daleckii–Krein DΦ[Ȧ] = U (F ∘ UᵀȦU) Uᵀ with the divided- difference matrix F_{ml} = (λ̃ₘ − λ̃ₗ)/(λₘ − λₗ) (raw λ in the denominator, conditioned λ̃ in the numerator). The kept×kept block of F is 1 (the kept subspace contracts the raw derivative unchanged), the deflated×deflated block is 0, and the kept(m)×deflated(i) block is (λₘ − 1)/(λₘ − λᵢ) — this last, ROTATION, term is what the per-row kept-subspace correction alone misses; it couples to the β-block through the Schur back-substitution carried in (H⁻¹)_tt.

Some(spectrum) only for spectrally-deflated rows; None for PD rows, gauge-only deflation (ρ-independent structural null — within-row term suffices), and every non-SAE-evidence solver path (streaming / device / cross-row CG). Empty overall when no row deflated spectrally.

§cross_row_woodbury: Option<CrossRowWoodbury>

Exact cross-row IBP rank-R Woodbury correction (#1038), present iff the source system carried an IbpCrossRowSource. When set, the per-row factors above are of the NO-SELF base H₀' (self term d_k·z'_ik² downdated from each logit diagonal), and this carrier supplies the exact rank-R correction so the value/curvature solve (Self::full_inverse_apply), the evidence log-determinant (Self::arrow_log_det), and the θ/ρ-adjoint all describe the same H_full = H₀' + U D Uᵀ.

Implementations§

Source§

impl ArrowFactorCache

Source

pub fn n_rows(&self) -> usize

Source

pub fn htbeta_available(&self) -> bool

Source

pub fn used_device(&self) -> bool

Whether the Newton solve that produced this cache actually executed on the device: the device-resident Direct dense solve or the device-resident matrix-free SAE PCG (whose matvec runs in CUDA kernels). This does NOT include the injected host-procedural reduced-Schur matvec, whose arithmetic runs on the CPU even when a CUDA context was opened to build per-row factors (#1209) — that path sets PcgDiagnostics::injected_host_procedural_matvec instead. Read-only routing provenance: lets a fit result record device-vs-CPU as ground truth instead of inferring it from the runtime probe. Mirrors PcgDiagnostics::used_device_arrow.

Source

pub fn undamped_factor(&self, row: usize) -> ArrayView2<'_, f64>

Source

pub fn undamped_factor_count(&self) -> usize

Source

pub fn undamped_factors_iter( &self, ) -> impl Iterator<Item = ArrayView2<'_, f64>> + '_

Source

pub fn compute_undamped_arrow_log_det(&self) -> Option<f64>

Source

pub fn delta_t_len(&self) -> usize

The total length of delta_t / IFT output vectors for this cache.

Source

pub fn apply_htbeta_row( &self, row: usize, delta_beta: ArrayView1<'_, f64>, out: &mut Array1<f64>, ) -> bool

Source

pub fn apply_htbeta_row_transpose( &self, row: usize, v: ArrayView1<'_, f64>, out: &mut Array1<f64>, fallback_op: Option<&RowHtbetaMatvec>, ) -> bool

Accumulate out[a] += H_βt^(row)[a, :] · v for all a in 0..k.

v has length row_dims[row]; out has length k. The caller must zero out before the first call if it needs a fresh result. Returns false when the cache is Disabled and no fallback_op is provided; callers must treat the accumulator as invalid in that case.

Source

pub fn arrow_log_det(&self) -> (f64, Option<f64>)

Arrow log-determinant log|H| = Σ_i log|H_{t_i t_i}| + log|Schur_β| using the cached (damped) factors.

Returns (log_det_tt_sum, log_det_schur) so the caller can decide what to do with the Schur piece (e.g. REML evidence wants both; some diagnostics want only the per-row sum). None for the Schur piece signals that the cache was produced by an InexactPCG solve and never formed/factored the dense K × K reduced system.

The log-determinant of a Cholesky factor L of M is 2 Σ log L_ii.

Source

pub fn cross_row_woodbury_log_det(&self) -> f64

The exact cross-row IBP correction log det(I_R + D·M) to add to the base log det H₀' (#1038). Zero when no CrossRowWoodbury is present, so non-IBP caches are unaffected. The determinant lemma gives log det H_full = log det H₀' + log det(I_R + D Uᵀ H₀'⁻¹ U); this is the second term, the only piece beyond the bare arrow log-determinant.

Panics-free: a negative capacitance determinant (non-PD H_full) yields NaN here so the evidence surfaces the desync rather than silently dropping the imaginary part. Callers that must reject it should check CrossRowWoodbury::log_det directly.

Source

pub fn latent_block_inverse_diagonal( &self, ) -> Result<Array1<f64>, ArrowSchurError>

Diagonal of the latent (t-block) of the full bordered-arrow inverse (H⁻¹)_tt, in delta_t layout (length Self::delta_t_len).

For the bordered arrow Hessian H = [[A, B], [Bᵀ, H_ββ]] with A = H_tt (block-diagonal per row, A_i = H_tt^(i)) and B = H_tβ, the standard block-inverse identity gives the t-block (H⁻¹)_tt = A⁻¹ + A⁻¹ B S⁻¹ Bᵀ A⁻¹, where S = H_ββ − Bᵀ A⁻¹ B is the Schur complement on β. Because A is block-diagonal, the (i, j) diagonal entry of (H⁻¹)_tt is computed purely from row i’s factor and cross-block:

a    = A_i⁻¹ e_j                       (chol_solve on the per-row factor)
[A_i⁻¹]_{jj} = a[j]
w    = B_iᵀ a = H_βt^(i) a             (a K-vector)
z    = S⁻¹ w                           (chol_solve on the Schur factor)
diag = a[j] + w · z

The UNDAMPED per-row factors (Self::undamped_factor) are used so the result is the inverse of the true H_tt, not the LM-damped H_tt + ridge_t·I — same rationale the IFT predictor docstring gives at the top of this struct.

§Consuming the diagonal as a per-(atom, axis) trace

(H⁻¹)_tt is the latent covariance block. The selected-inverse trace for a contiguous group of latent coordinates (e.g. one atom’s rows, or one axis across rows) is simply the sum of the returned diagonal entries over those row_offsets[i] + j indices — no off-diagonal terms are needed for the trace tr[(H⁻¹)_tt · D] against a diagonal selector D.

§Errors

Returns [ArrowSchurError::SchurFactorFailed] when this cache has no dense Schur factor or no usable H_βt coupling — i.e. it was produced by an ArrowSolverMode::InexactPCG solve (no dense K × K factor) or by a Disabled htbeta cache. The selected-inverse block-trace is not yet supported for the matrix-free PCG mode; that branch needs a separate Lanczos/Hutchinson estimator.

Source

pub fn full_inverse_apply( &self, w_t: ArrayView1<'_, f64>, w_beta: ArrayView1<'_, f64>, ) -> Result<(Array1<f64>, Array1<f64>), ArrowSchurError>

Solve the full bordered-arrow system H·u = w on the cached factor (#1006): w arrives in arrow layout — w_t flat per Self::delta_t_len / row_offsets, w_beta of length K — and the solution comes back in the same layout. Standard block elimination on the SAME factors whose log-determinant the evidence reports:

  y_i      = H_tt^(i)⁻¹ · w_t^(i)
  r_β      = w_β − Σ_i H_βt^(i) · y_i
  u_β      = Schur⁻¹ · r_β
  u_t^(i)  = y_i − H_tt^(i)⁻¹ · (H_tβ^(i) · u_β)

This is the IFT / adjoint back-solve the analytic outer ρ-gradient consumes: u_j = H⁻¹ (∂g/∂ρ_j) per outer coordinate and the H⁻¹-side of the third-order correction −½·Γᵀ·H⁻¹·(∂g/∂ρ_j). Contract: the cache must be the ridge-0 Direct evidence factor (undamped per-row factors + dense Schur), so the solve is against the criterion’s own H — never a damped surrogate (that would desync the gradient from the reported evidence).

When the cache carries an exact cross-row IBP CrossRowWoodbury (#1038), the per-row factors are the NO-SELF base H₀' and this method layers the rank-R Woodbury correction so the returned solve is against the FULL H_full = H₀' + U D Uᵀ — the same operator whose log-determinant Self::arrow_log_det reports. The θ/ρ-adjoint that consumes this therefore sees the cross-row curvature.

Source

pub fn schur_inverse_apply( &self, rhs: ArrayView1<'_, f64>, ) -> Result<Array1<f64>, ArrowSchurError>

Apply the β-block of the full inverse, (H⁻¹)_ββ · rhs = S_β⁻¹ · rhs, where S_β is the Schur complement on β whose Cholesky factor this cache holds in Self::schur_factor.

For the bordered arrow Hessian H = [[A, B], [Bᵀ, H_ββ]], the β-block of H⁻¹ is exactly the inverse of the Schur complement S_β = H_ββ − Bᵀ A⁻¹ B. One Cholesky back-substitution per call, reusing the cached factor; rhs and the returned vector both have length K.

This is the general single-solve primitive for the β border. Callers that need a Schur-inverse trace tr(S_β⁻¹ M) against a structured penalty M (e.g. the SAE λ_smooth Fellner-Schall step, where M = blockdiag_k(λ_k S_k ⊗ I_p)) build it as Σ_col e_colᵀ S_β⁻¹ M e_col — apply this to each column of M (exploiting whatever sparsity M has) and read off result[col]. Keeping M’s layout on the caller side avoids coupling this solver to penalty-op types.

§Errors

Returns [ArrowSchurError::SchurFactorFailed] when this cache has no dense Schur factor (an ArrowSolverMode::InexactPCG solve) — the same not-yet-supported branch as Self::latent_block_inverse_diagonal — or when rhs.len() != k.

Cross-row IBP (#1038) note: this is the β-block primitive of the factored base S_β (H₀' when a CrossRowWoodbury is present), used internally by Self::full_inverse_apply_base; it is deliberately NOT Woodbury-corrected so the base solve stays bare. The cross-row term has no β support, so (H_full⁻¹)_ββ = S_β⁻¹ exactly on the directions any IBP ρ-trace contracts. A consumer needing the full (H_full⁻¹)_ββ for a β-supported direction should call Self::full_inverse_apply with a unit β-RHS (which applies the rank-R correction).

Source

pub fn schur_inverse_block( &self, block: Range<usize>, ) -> Result<Array2<f64>, ArrowSchurError>

Dense principal sub-block of the β-block of the full inverse, (H⁻¹)_ββ[block, block] = S_β⁻¹[block, block], shape (W, W) with W = block.len().

For the bordered arrow Hessian H = [[A, B], [Bᵀ, H_ββ]], the β-block of H⁻¹ is exactly S_β⁻¹ (the inverse of the Schur complement whose Cholesky factor this cache holds). This returns the contiguous block × block sub-block — e.g. one SAE atom’s decoder coefficients via [gam_terms::sae::manifold::SaeManifoldTerm::beta_block_offsets] — by solving S_β x = e_j for each j ∈ block (reusing the cached factor) and gathering the block rows of each solution column. W back-substitutions of size K; the result is symmetrized to clear back-substitution rounding asymmetry. Up to a dispersion scale φ, this block is the joint posterior covariance Cov(β_block) of those coefficients with the latent coordinates already marginalized out (that is precisely what Schur-eliminating the per-row t-blocks does).

Same dense-Schur requirement / error contract as Self::schur_inverse_apply; additionally errors when block runs past K.

Trait Implementations§

Source§

impl Clone for ArrowFactorCache

Source§

fn clone(&self) -> ArrowFactorCache

Returns a duplicate of the value. Read more
1.0.0 (const: unstable) · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl Debug for ArrowFactorCache

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> ByRef<T> for T

Source§

fn by_ref(&self) -> &T

Source§

impl<ST, DT> CastableFrom<ST, Initialized, Initialized> for DT
where ST: ?Sized, DT: ?Sized,

Source§

impl<ST, DT> CastableFrom<ST, Uninit, Uninit> for DT
where ST: ?Sized, DT: ?Sized,

Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> DistributionExt for T
where T: ?Sized,

Source§

fn rand<T>(&self, rng: &mut (impl Rng + ?Sized)) -> T
where Self: Distribution<T>,

Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Imply<T> for U
where T: ?Sized, U: ?Sized,

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> Read<Exclusive, BecauseExclusive> for T
where T: ?Sized,

Source§

impl<T> Same for T

Source§

type Output = T

Should always be Self
Source§

impl<SS, SP> SupersetOf<SS> for SP
where SS: SubsetOf<SP>,

Source§

fn to_subset(&self) -> Option<SS>

The inverse inclusion map: attempts to construct self from the equivalent element of its superset. Read more
Source§

fn is_in_subset(&self) -> bool

Checks if self is actually part of its subset T (and can be converted to it).
Source§

fn to_subset_unchecked(&self) -> SS

Use with care! Same as self.to_subset but without any property checks. Always succeeds.
Source§

fn from_subset(element: &SS) -> SP

The inclusion map: converts self to the equivalent element of its superset.
Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V