1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
//! ARD (automatic relevance determination) coordinate-precision + latent-block
//! helpers for `SaeManifoldTerm`, moved verbatim out of construction.rs to keep it
//! under the 10k-line ban gate. Pure code move, no logic change.
use super::*;
impl SaeManifoldTerm {
/// Per-atom, per-axis coordinate sum-of-squares `‖t_kj‖² = Σ_i t_{i,k,j}²`.
///
/// This is the data-fit sufficient statistic for the ARD precision update
/// (the numerator-side `‖t‖²` of the deleted `α = n/‖t‖²` rule). Returned
/// per atom as an `Array1` of length `d_k`.
///
/// On a *periodic* (Circle) axis the relevant statistic is the von-Mises
/// energy-equivalent `Σ_i 2/α·V(t_i) = Σ_i (2/κ²)(1−cos κ t_i)` (independent
/// of α), so that `½·α·sumsq == Σ_i V(t_i)` matches `ard_value`. This keeps
/// the Mackay/Fellner–Schall fixed point `α ← n / (sumsq + tr H⁻¹)`
/// consistent with the actual periodic prior energy rather than the
/// origin-dependent raw `t²`.
pub(crate) fn ard_coord_sumsq(&self) -> Vec<Array1<f64>> {
let mut out = Vec::with_capacity(self.k_atoms());
for coord in &self.assignment.coords {
let d = coord.latent_dim();
let periods = coord.effective_axis_periods();
let mut sq = Array1::<f64>::zeros(d);
for row in 0..coord.n_obs() {
let t = coord.row(row);
for axis in 0..d {
// `sq_equiv` is independent of `alpha`; pass 1.0.
sq[axis] += ArdAxisPrior::eval(1.0, t[axis], periods[axis]).sq_equiv;
}
}
out.push(sq);
}
out
}
/// Per-atom, per-axis posterior-variance trace `tr_kj(H⁻¹) =
/// Σ_i [(H⁻¹)_tt]_{(i,k,j),(i,k,j)}` from the converged factor cache.
///
/// `cache.latent_block_inverse_diagonal()` returns the diagonal of the
/// latent block `(H⁻¹)_tt` in the cache's compact per-row `delta_t`
/// layout (length `row_offsets[N]`); each per-row block is laid out as
/// `[logit scalars…, then per-active-atom coord axes…]`. This routine
/// sums those diagonal entries over the coord positions belonging to each
/// `(atom k, axis j)` across all observation rows where atom `k` is active.
///
/// `self.last_row_layout` must be the layout from the *same* assemble that
/// produced `cache`:
/// - `Some(layout)`: compact active-set mode (JumpReLU / large-K
/// softmax-IBP truncation). For row `i`, atom `k`'s position in the
/// active list gives its compact coord-block start `coord_starts[i][pos]`;
/// inactive atoms contribute 0 (the prior dominates there anyway).
/// - `None`: dense full-support layout, uniform row dim
/// `q = assignment_dim + Σ d_k`; atom `k`'s coord block sits at the
/// fixed full-row offset `coord_offsets[k]` after the assignment chart.
///
/// This `tr_kj(H⁻¹)` is exactly the posterior-variance term the deleted
/// `α = n/‖t‖²` rule dropped; the corrected Mackay/Fellner-Schall fixed
/// point is `α_new = n / (‖t_kj‖² + tr_kj(H⁻¹))`.
///
/// At `K ≥ ARD_TRACE_HUTCHINSON_MIN_ATOMS` the exact selected-inverse diagonal
/// (one dense `K×K` Schur solve per latent coordinate — `O(total_t·K²) ≈
/// O(K³)` at massive `K`) is replaced by the matrix-free Hutchinson estimate
/// [`Self::latent_block_inverse_diagonal_hutchinson`]; below it the exact
/// diagonal is used unchanged (bit-for-bit tests preserved).
pub(crate) fn ard_inverse_traces(
&self,
cache: &ArrowFactorCache,
) -> Result<Vec<Array1<f64>>, ArrowSchurError> {
let inv_diag = if self.k_atoms() >= Self::ARD_TRACE_HUTCHINSON_MIN_ATOMS {
// Massive-K: `total_t` dense Schur solves is infeasible — estimate the
// whole latent inverse diagonal matrix-free with one full-arrow solve
// per Hutchinson probe (the grouped sums below tolerate the stochastic
// error, as this feeds a Fellner–Schall / dispersion denominator).
Self::latent_block_inverse_diagonal_hutchinson(
cache,
Self::ARD_TRACE_HUTCHINSON_PROBES,
Self::ARD_TRACE_HUTCHINSON_SEED,
)?
} else {
cache.latent_block_inverse_diagonal()?
};
let n = self.n_obs();
let coord_offsets = self.assignment.coord_offsets();
let mut traces: Vec<Array1<f64>> = self
.assignment
.coords
.iter()
.map(|c| Array1::<f64>::zeros(c.latent_dim()))
.collect();
for row in 0..n {
let row_base = cache.row_offsets[row];
match self.last_row_layout {
Some(ref layout) => {
let active = &layout.active_atoms[row];
let starts = &layout.coord_starts[row];
for (pos, &k) in active.iter().enumerate() {
let d = self.assignment.coords[k].latent_dim();
let block_start = starts[pos];
for axis in 0..d {
traces[k][axis] += inv_diag[row_base + block_start + axis];
}
}
}
None => {
for k in 0..self.k_atoms() {
let d = self.assignment.coords[k].latent_dim();
let block_start = coord_offsets[k];
for axis in 0..d {
traces[k][axis] += inv_diag[row_base + block_start + axis];
}
}
}
}
}
Ok(traces)
}
/// Atom-count threshold at/above which [`Self::ard_inverse_traces`] switches
/// from the exact selected-inverse latent diagonal (one dense `K×K` Schur
/// solve per latent coordinate — the `O(total_t·K²) ≈ O(K³)` massive-`K`
/// wall) to the matrix-free Hutchinson stochastic-diagonal estimator
/// [`Self::latent_block_inverse_diagonal_hutchinson`]. Set to match the
/// smoothness-dof Hutchinson gate ([`Self::SMOOTHNESS_DOF_HUTCHINSON_MIN_ATOMS`]),
/// well above every exact-path test fixture so ordinary-`K` behaviour — and
/// its bit-for-bit tests — is unchanged; the estimator engages only in the
/// massive dictionary regime (`K` up to 32k).
pub(crate) const ARD_TRACE_HUTCHINSON_MIN_ATOMS: usize = 2048;
/// Rademacher probe count for the Hutchinson latent-inverse-diagonal
/// estimator. One [`ArrowFactorCache::full_inverse_apply`] per probe yields
/// the WHOLE diagonal at once, so this is the total full-arrow solve count
/// that replaces the exact `total_t` per-coordinate Schur solves.
pub(crate) const ARD_TRACE_HUTCHINSON_PROBES: usize = 64;
/// Fixed base seed so the ARD-trace estimate is bit-reproducible across REML
/// outer iterations (cf. the SLQ log-det and smoothness-dof seeds).
pub(crate) const ARD_TRACE_HUTCHINSON_SEED: u64 = 0x5AED_A3D0_1ACE_9C01;
/// Matrix-free Hutchinson estimate of `diag((H⁻¹)_tt)` — the SAME quantity
/// [`ArrowFactorCache::latent_block_inverse_diagonal`] returns EXACTLY, but at
/// `O(num_probes · matvec)` instead of the exact `O(total_t · K²)`.
///
/// The exact selected-inverse builds the latent inverse diagonal one
/// coordinate at a time, each coordinate paying a dense `K×K` Schur solve;
/// over all `total_t = Σ_i d_i` latent coordinates that is `O(total_t·K²) ≈
/// O(K³)` at massive `K` (32k). This estimator replaces the per-coordinate
/// loop with `num_probes` full-arrow solves: for a Rademacher probe `z` over
/// the `t`-block (`E[z zᵀ] = I`), `u_t = (H⁻¹)_tt z` — the `t`-block of
/// `H⁻¹·[z; 0]`; the trailing `w_β = 0` drops the border coupling out of the
/// `t`-block — so the Hadamard product `z ⊙ u_t` has expectation exactly
/// `diag((H⁻¹)_tt)` (off-diagonal `i≠j` terms are mean-zero under
/// `E[z_i z_j] = 0`). Averaging over probes gives the unbiased diagonal. Each
/// probe is ONE [`ArrowFactorCache::full_inverse_apply`] (per-row solves + a
/// SINGLE Schur solve + the rank-`R` cross-row Woodbury correction — the same
/// `H_full` the exact path inverts), so the IBP curvature is included
/// identically.
///
/// Probes run serially and accumulate in a fixed order, so for a fixed
/// `(seed, num_probes)` the estimate is bit-reproducible (the REML determinism
/// contract, matching the SLQ log-det and smoothness-dof Hutchinson paths).
pub(crate) fn latent_block_inverse_diagonal_hutchinson(
cache: &ArrowFactorCache,
num_probes: usize,
seed: u64,
) -> Result<Array1<f64>, ArrowSchurError> {
let total_len = cache.delta_t_len();
let k = cache.k;
let probes = num_probes.max(1);
let mut out = Array1::<f64>::zeros(total_len);
let mut z = Array1::<f64>::zeros(total_len);
let w_beta_zero = Array1::<f64>::zeros(k);
for probe in 0..probes {
// Deterministic Rademacher probe (±1) over the t-block, seeded by
// `seed + probe` so the whole estimate is reproducible.
let mut state = seed.wrapping_add(probe as u64);
let mut bits = 0u64;
let mut remaining = 0u32;
for zi in z.iter_mut() {
if remaining == 0 {
bits = gam_linalg::utils::splitmix64(&mut state);
remaining = 64;
}
*zi = if bits & 1 == 1 { 1.0 } else { -1.0 };
bits >>= 1;
remaining -= 1;
}
// u_t = (H⁻¹)_tt z (w_β = 0 ⇒ the border coupling drops from the
// t-block); this is the FULL H_full inverse incl. cross-row Woodbury.
let (u_t, _u_beta) = cache.full_inverse_apply(z.view(), w_beta_zero.view())?;
for i in 0..total_len {
out[i] += z[i] * u_t[i];
}
}
let inv_p = 1.0 / (probes as f64);
for v in out.iter_mut() {
*v *= inv_p;
}
Ok(out)
}
pub(crate) fn ard_log_precision_explicit_derivatives(
&self,
rho: &SaeManifoldRho,
) -> Result<Vec<Array1<f64>>, String> {
if rho.log_ard.len() != self.k_atoms() {
return Err(format!(
"ARD rho has {} atoms but term has {}",
rho.log_ard.len(),
self.k_atoms()
));
}
let n = self.n_obs() as f64;
let mut out = Vec::with_capacity(self.k_atoms());
for (atom_idx, coord) in self.assignment.coords.iter().enumerate() {
let d = coord.latent_dim();
let mut atom_out = Array1::<f64>::zeros(rho.log_ard[atom_idx].len());
if rho.log_ard[atom_idx].is_empty() {
out.push(atom_out);
continue;
}
if rho.log_ard[atom_idx].len() != d {
return Err(format!(
"ARD rho atom {atom_idx} has len {} but atom dim is {d}",
rho.log_ard[atom_idx].len()
));
}
let periods = coord.effective_axis_periods();
for axis in 0..d {
let alpha = SaeManifoldRho::stable_exp_strength(rho.log_ard[atom_idx][axis]);
let period = periods[axis];
let mut energy_deriv = 0.0_f64;
for row in 0..coord.n_obs() {
let t = coord.row(row)[axis];
energy_deriv += ArdAxisPrior::eval(alpha, t, period).value;
}
let normalizer_deriv = match period {
None => -0.5 * n,
Some(p) => {
let kappa = std::f64::consts::TAU / p;
let eta = alpha / (kappa * kappa);
// d/d(log α) of `n[-η + log I0(η)]` = `n η (I1/I0 - 1)`.
// The ratio is computed without forming `e^{η}`, so it
// stays finite for large `η` instead of the `inf/inf =
// NaN` that `bessel_i1(η)/bessel_i0(η)` produces (#1113).
let ratio = bessel_i0_log_and_ratio(eta).1;
n * eta * (-1.0 + ratio)
}
};
atom_out[axis] = energy_deriv + normalizer_deriv;
}
out.push(atom_out);
}
Ok(out)
}
pub(crate) fn ard_log_precision_hessian_trace(
&self,
rho: &SaeManifoldRho,
cache: &ArrowFactorCache,
solver: &DeflatedArrowSolver<'_>,
) -> Result<Vec<Array1<f64>>, ArrowSchurError> {
// RAW selected-inverse diagonal: the per-axis diagonal contraction uses
// the DEFLATED inverse; the full kept-subspace + rotation deflation
// correction `tr(inv_vv·(D − DΦ[D]))` is subtracted per (row, axis)
// afterwards via the Daleckii–Krein helper. Each ARD ρ-component
// `(atom k, axis)` differentiates a SINGLE coordinate-slot diagonal entry,
// so its `D` is the rank-one `hess·e_s e_sᵀ` at that local slot `s`.
let inv_diag = solver
.latent_inverse_diagonal()
.map_err(|err| ArrowSchurError::SchurFactorFailed { reason: err })?;
let n = self.n_obs();
let total_t = cache.delta_t_len();
let coord_offsets = self.assignment.coord_offsets();
let ard_axis_periods: Vec<Vec<Option<f64>>> = self
.assignment
.coords
.iter()
.map(LatentCoordValues::effective_axis_periods)
.collect();
let mut traces: Vec<Array1<f64>> = self
.assignment
.coords
.iter()
.enumerate()
.map(|(k, c)| {
if rho.log_ard[k].is_empty() {
Array1::<f64>::zeros(0)
} else {
Array1::<f64>::zeros(c.latent_dim())
}
})
.collect();
// Hoisted RHS scratch reused across every (row, col) solve. Setting and
// clearing a SINGLE entry per column is O(1); a fresh
// `Array1::zeros(total_t)` memsets total_t≈n·q slots per inner iteration
// (O(n) per col ⇒ O(n²) redundant zeroing across the block build).
let mut rhs_t_scratch = Array1::<f64>::zeros(total_t);
let rhs_beta_zero = Array1::<f64>::zeros(cache.k);
for row in 0..n {
let row_base = cache.row_offsets[row];
let q = cache.row_dims[row];
let dirs = cache
.deflated_row_directions
.get(row)
.map(Vec::as_slice)
.unwrap_or(&[]);
let spectrum = cache
.deflation_row_spectra
.get(row)
.and_then(Option::as_ref);
// Per-row selected-inverse t-block, built once (only when deflated).
let inv_vv = if dirs.is_empty() {
None
} else {
let mut m = Array2::<f64>::zeros((q, q));
for col in 0..q {
rhs_t_scratch[row_base + col] = 1.0;
let solved = solver
.solve(rhs_t_scratch.view(), rhs_beta_zero.view())
.map_err(|err| ArrowSchurError::SchurFactorFailed { reason: err })?;
rhs_t_scratch[row_base + col] = 0.0;
for r in 0..q {
m[[r, col]] = solved.t[row_base + r];
}
}
Some(m)
};
// Correction for one local coordinate slot `s` with curvature `hess`.
let slot_correction = |s: usize, hess: f64| -> f64 {
let Some(iv) = inv_vv.as_ref() else {
return 0.0;
};
if s >= q || hess == 0.0 {
return 0.0;
}
let mut d = Array2::<f64>::zeros((q, q));
d[[s, s]] = hess;
Self::deflation_block_correction(iv, &d, dirs, spectrum)
};
match self.last_row_layout {
Some(ref layout) => {
let active = &layout.active_atoms[row];
let starts = &layout.coord_starts[row];
for (pos, &k) in active.iter().enumerate() {
if rho.log_ard[k].is_empty() {
continue;
}
let coord = &self.assignment.coords[k];
let d = coord.latent_dim();
let block_start = starts[pos];
for axis in 0..d {
let alpha = SaeManifoldRho::stable_exp_strength(rho.log_ard[k][axis]);
let t = coord.row(row)[axis];
let prior = ArdAxisPrior::eval(alpha, t, ard_axis_periods[k][axis]);
let hess = prior.hess.max(0.0);
let s = block_start + axis;
traces[k][axis] += 0.5 * inv_diag[row_base + s] * hess;
traces[k][axis] -= 0.5 * slot_correction(s, hess);
}
}
}
None => {
for k in 0..self.k_atoms() {
if rho.log_ard[k].is_empty() {
continue;
}
let coord = &self.assignment.coords[k];
let d = coord.latent_dim();
let block_start = coord_offsets[k];
for axis in 0..d {
let alpha = SaeManifoldRho::stable_exp_strength(rho.log_ard[k][axis]);
let t = coord.row(row)[axis];
let prior = ArdAxisPrior::eval(alpha, t, ard_axis_periods[k][axis]);
let hess = prior.hess.max(0.0);
let s = block_start + axis;
traces[k][axis] += 0.5 * inv_diag[row_base + s] * hess;
traces[k][axis] -= 0.5 * slot_correction(s, hess);
}
}
}
}
}
Ok(traces)
}
}