1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
#[cfg(test)]
mod amortized_encoder_tests {
use crate::terms::sae::manifold::tests::small_two_atom_periodic_term;
/// #1026 ladder item 2/3 — the amortized encoder is reachable end-to-end
/// from a fitted term and is certificate-honest: it encodes the dictionary's
/// own fit-time target, returns one result per atom with the right shape, and
/// every row is either certified or counted in
/// `encode_uncertified_count` (never silently miscounted), with the exact
/// fallback strictly reducing the uncertified count it inherits.
#[test]
fn amortized_encode_fitted_is_reachable_and_certificate_honest() {
let (term, target, rho) = small_two_atom_periodic_term();
let n = term.n_obs();
let k = term.k_atoms();
let results = term
.amortized_encode_fitted(target.view(), &rho)
.expect("amortized encode of the fit-time target runs end-to-end");
assert_eq!(
results.len(),
k,
"one encode result per atom in dictionary order"
);
for (atom_idx, result) in results.iter().enumerate() {
assert_eq!(
result.coords.nrows(),
n,
"atom {atom_idx} encode must produce one coordinate per row"
);
assert_eq!(
result.coords.ncols(),
term.atoms[atom_idx].latent_dim,
"atom {atom_idx} encode coords must match its latent dim"
);
// The uncertified count is the honest tally of rows the certificate
// could not gate — it must equal the false entries of the mask.
let uncertified = result.certified.iter().filter(|c| !**c).count();
assert_eq!(
result.encode_uncertified_count, uncertified,
"atom {atom_idx} uncertified count must match the certificate mask"
);
assert_eq!(
result.certified.len(),
n,
"atom {atom_idx} certificate mask must cover every row"
);
}
}
/// The fitted amplitudes the encoder derives are exactly the assignment
/// masses the reconstruction is assembled from — feeding them back is the
/// self-consistency the distilled map is supervised against.
#[test]
fn fitted_assignment_amplitudes_match_the_assignment_masses() {
let (term, _target, rho) = small_two_atom_periodic_term();
let n = term.n_obs();
let k = term.k_atoms();
let amplitudes = term
.fitted_assignment_amplitudes(&rho)
.expect("fitted amplitudes derive from the assignment");
assert_eq!(amplitudes.dim(), (n, k));
for row in 0..n {
let a = term
.assignment
.try_assignments_row_for_rho(row, &rho)
.expect("assignment row resolves");
for atom_idx in 0..k {
assert_eq!(
amplitudes[[row, atom_idx]],
a[atom_idx],
"amplitude[{row},{atom_idx}] must equal the assignment mass"
);
}
}
}
}
#[cfg(test)]
mod outer_gradient_error_classification_1451_tests {
use super::OuterGradientError;
/// #1451 — the three numerical/linear-algebra failure sites inside the
/// deflation path (`apply_cached_arrow_hessian`, the projected `h_span.eigh`,
/// and `DeflatedArrowSolver::from_orthonormal_gauges`) must distinguish a
/// genuine near-singular conditioning trip (FD-eligible `IllConditioned`)
/// from an internal-invariant defect — a shape/dimension mismatch or a
/// non-finite intermediate — which MUST propagate (`InternalInvariant`, NOT
/// FD-eligible).
///
/// `OuterGradientError::classify_arrow_solver_error` is the helper all three
/// sites route through. Before the #1451 fix every failure there was
/// re-labelled `IllConditioned` (the original `conditioning_err`), so the
/// shape/non-finite cases below would have been FD-eligible — masking an
/// internal defect behind a plausible-but-wrong FD descent direction, exactly
/// the regression #1436 set out to eliminate. This test pins that a
/// shape/non-finite error classifies to `InternalInvariant` (so it
/// propagates) while a genuine finite, correctly-shaped near-singular failure
/// stays `IllConditioned` (so it keeps the #1273 FD fallback).
#[test]
fn classify_arrow_solver_error_routes_shape_and_nonfinite_to_internal_1451() {
let conditioning = || OuterGradientError::IllConditioned {
reason: "near-singular joint Hessian (min/max pivot ratio 5.3e-16)".to_string(),
};
// Shape/dimension-mismatch markers emitted by the deflation helpers must
// classify as InternalInvariant and therefore be NOT FD-eligible.
let shape_messages = [
"apply_cached_arrow_hessian: vector shapes (t=3, beta=2) != cache shapes (t=4, beta=2)",
"DeflatedArrowSolver: gauge length 5 != cache full length 6",
"DeflatedArrowSolver: solution length 5 != cache full length 6",
];
for msg in shape_messages {
let classified =
OuterGradientError::classify_arrow_solver_error(msg, conditioning());
assert!(
matches!(classified, OuterGradientError::InternalInvariant { .. }),
"shape mismatch must classify to InternalInvariant (#1451); got {classified}"
);
assert!(
!classified.is_conditioning_recoverable(),
"a shape mismatch must NOT be conditioning-recoverable (#1451); got {classified}"
);
assert!(
!classified.admits_plain_solver_fallback(1.0),
"a shape mismatch must NOT admit the plain-solver fallback even at finite cost (#1451)"
);
}
// Non-finite-intermediate markers must likewise propagate as internal.
let nonfinite_messages = [
"DeflatedArrowSolver: gauge stiffness must be finite and positive; got NaN",
"outer_gradient_arrow_solver: non-finite entry in projected gauge Hessian",
];
for msg in nonfinite_messages {
let classified =
OuterGradientError::classify_arrow_solver_error(msg, conditioning());
assert!(
matches!(classified, OuterGradientError::InternalInvariant { .. }),
"non-finite intermediate must classify to InternalInvariant (#1451); \
got {classified}"
);
assert!(
!classified.is_conditioning_recoverable(),
"a non-finite intermediate must NOT be conditioning-recoverable (#1451); got {classified}"
);
}
// A genuine near-singular linear-algebra failure on a finite, correctly
// shaped input (back-solve / Cholesky/Woodbury factor that tripped on
// rank-deficiency) is the legitimate #1273 conditioning case: it must
// KEEP IllConditioned and stay conditioning-recoverable.
let conditioning_messages = [
"DeflatedArrowSolver: gauge Woodbury factor failed: matrix is not positive definite",
"DeflatedArrowSolver: gauge back-solve: singular factor",
];
for msg in conditioning_messages {
let classified =
OuterGradientError::classify_arrow_solver_error(msg, conditioning());
assert!(
matches!(classified, OuterGradientError::IllConditioned { .. }),
"a finite, correctly-shaped near-singular failure must KEEP \
IllConditioned (#1451 / #1273); got {classified}"
);
assert!(
classified.is_conditioning_recoverable(),
"a genuine conditioning failure must remain conditioning-recoverable (#1273); got {classified}"
);
assert!(
classified.admits_plain_solver_fallback(1.0),
"a genuine conditioning failure at finite cost must admit the plain-solver fallback (#1273)"
);
}
}
}
#[cfg(test)]
mod softmax_majorizer_active_entry_1410_tests {
//! #1410 — the active-only softmax-entropy curvature helpers
//! ([`super::active_softmax_gershgorin_majorizer_entry`],
//! [`super::softmax_dense_entropy_hessian_entry`],
//! [`super::softmax_majorizer_log_mean`]) let the compact assembly /
//! θ-adjoint / exact-Hessian-correction paths read one `(k)` diagonal or
//! `(k,j)` matrix entry WITHOUT materialising the full-`K` `d` vector or the
//! `K×K` dense entropy/majorizer blocks per row — the residual per-worker
//! `O(K)`/`O(K²)` scratch that defeated the compact `O(top_k·d)`-per-token
//! contract.
//!
//! Correctness is single-sourced: these helpers MUST reproduce the
//! `SoftmaxAssignmentSparsityPenalty` dense library routines
//! (`psd_majorizer_abs_row_sums`, `row_psd_majorizer`, `row_dense_hessian`)
//! BIT-FOR-BIT, because the assembled `B`, the criterion's `log|H|`, and the
//! #1006 θ-adjoint all differentiate ONE operator. If the dense library
//! formula ever changes, this oracle fails and forces the helpers back into
//! sync (preventing the value↔adjoint desync the compact rewrite must not
//! introduce).
use crate::terms::analytic_penalties::SoftmaxAssignmentSparsityPenalty;
/// Deterministic, well-spread softmax logit rows (a long tail plus a few
/// peaks) so the abs-row-sum / dense-Hessian algebra is exercised across
/// near-zero and near-one assignment masses.
fn logit_rows(k: usize) -> Vec<Vec<f64>> {
let mut rows = Vec::new();
// Row a: a few sharp peaks spread across K, deep floor elsewhere.
let mut a = vec![-7.0_f64; k];
for &peak in &[0usize, k / 3, 2 * k / 3, k - 1] {
a[peak] = 5.0 + (peak as f64) * 0.001;
}
rows.push(a);
// Row b: smoothly varying logits (no degenerate ties).
let b: Vec<f64> = (0..k)
.map(|i| ((i as f64) * 0.37).sin() * 2.0 - (i as f64) / (k as f64))
.collect();
rows.push(b);
// Row c: near-uniform (entropy Hessian indefinite here — the regime the
// Gershgorin majorizer exists for).
rows.push(vec![0.01; k]);
rows
}
#[test]
fn active_softmax_gershgorin_matches_dense_majorizer_1410() {
let k = 64usize;
let temperature = 0.8_f64;
let scale = 1.7_f64;
let penalty = SoftmaxAssignmentSparsityPenalty::new(k, temperature);
for row in logit_rows(k) {
// Dense reference: full-K abs-row-sum diagonal `d`.
let d_dense = penalty.psd_majorizer_abs_row_sums(&row, scale);
// The helper consumes the softmax row `a`, not raw logits, exactly as
// the assembly/adjoint feed it `assignments`. Build `a` the same way
// the penalty does internally.
let a = crate::terms::sae::assignment::softmax_row(
ndarray::ArrayView1::from(row.as_slice()),
temperature,
);
let a = a.as_slice().expect("softmax row contiguous");
let m = super::softmax_majorizer_log_mean(a);
for kk in 0..k {
let got = super::active_softmax_gershgorin_majorizer_entry(a, kk, m, scale);
assert_eq!(
got, d_dense[kk],
"active Gershgorin majorizer entry must equal the dense \
psd_majorizer_abs_row_sums[{kk}] BIT-FOR-BIT (single-source #1410/#1419)"
);
}
}
}
#[test]
fn active_softmax_dense_entropy_hessian_entry_matches_dense_block_1410() {
let k = 48usize;
let temperature = 1.3_f64;
let scale = 0.9_f64;
let penalty = SoftmaxAssignmentSparsityPenalty::new(k, temperature);
for row in logit_rows(k) {
let h_dense = penalty.row_dense_hessian(&row, scale);
let a = crate::terms::sae::assignment::softmax_row(
ndarray::ArrayView1::from(row.as_slice()),
temperature,
);
let a = a.as_slice().expect("softmax row contiguous");
let m = super::softmax_majorizer_log_mean(a);
for kk in 0..k {
for jj in 0..k {
let got = super::softmax_dense_entropy_hessian_entry(a, kk, jj, m, scale);
assert_eq!(
got, h_dense[[kk, jj]],
"active dense entropy-Hessian entry ({kk},{jj}) must equal \
row_dense_hessian BIT-FOR-BIT (single-source #1410/#1418)"
);
}
}
}
}
#[test]
fn active_softmax_majorizer_logit_derivative_matches_dense_1410() {
let k = 40usize;
let temperature = 0.7_f64;
let scale = 1.1_f64;
let inv_tau = 1.0 / temperature;
let penalty = SoftmaxAssignmentSparsityPenalty::new(k, temperature);
for row in logit_rows(k) {
let a = crate::terms::sae::assignment::softmax_row(
ndarray::ArrayView1::from(row.as_slice()),
temperature,
);
let a = a.as_slice().expect("softmax row contiguous");
let m = super::softmax_majorizer_log_mean(a);
// Pin the active diagonal entry against the dense library derivative
// matrix (which is diagonal: `out[[kk, kk]]`) for several `w`.
for w in [0usize, k / 2, k - 1] {
let dense = penalty.row_psd_majorizer_logit_derivative(&row, scale, w);
for kk in 0..k {
let got = super::active_softmax_majorizer_logit_derivative_entry(
a, kk, w, m, scale, inv_tau,
);
assert_eq!(
got, dense[[kk, kk]],
"active majorizer logit-derivative ∂D_({kk},{kk})/∂z_{w} must equal \
row_psd_majorizer_logit_derivative diagonal BIT-FOR-BIT \
(single-source #1410/#1419/#1006)"
);
}
}
}
}
}
/// #1418: the implicit-function (IFT) back-substitution must invert the EXACT
/// stationarity Jacobian `A = ∇²_θθ L`, not the assembled surrogate `B`.
#[cfg(test)]
mod exact_stationarity_solve_1418_tests {
use super::*;
use crate::terms::sae::manifold::tests::gamma_fd_tiny_fixture;
use ndarray::Array1;
/// Build a converged tiny SAE state whose inner residual is genuinely
/// nonzero (an unmodellable target perturbation on a curved periodic basis),
/// so the dropped curvature `ΔC = A − B = ⟨r, ∂²f⟩ + (H_entropy − D) + min(V'',0)`
/// is materially nonzero and `A ≠ B`. Returns the term, the perturbed target,
/// the rho, and the converged cache.
fn converged_state_with_residual() -> (
SaeManifoldTerm,
Array2<f64>,
SaeManifoldRho,
ArrowFactorCache,
) {
let (mut term, mut target, mut rho) = gamma_fd_tiny_fixture();
// Perturb the target off the model manifold so the inner optimum has a
// real residual `r`, hence a real `⟨r, ∂²f⟩` curvature delta.
let (n, p) = (target.nrows(), target.ncols());
for row in 0..n {
for col in 0..p {
let phase = (row as f64 + 0.35) / n as f64;
let theta = std::f64::consts::TAU * phase;
target[[row, col]] += 0.6 * (3.0 * theta + 0.5 * col as f64).sin();
}
}
// Activate the sparsity / smoothness / ARD prior strengths so the softmax
// entropy delta and the periodic-ARD `min(V'',0)` delta are live too.
rho.log_lambda_sparse = -0.5;
rho.log_lambda_smooth = -1.0;
for axis in rho.log_ard.iter_mut() {
for v in axis.iter_mut() {
*v = -0.5;
}
}
let (_value, _loss, cache) = term
.reml_criterion_with_cache(target.view(), &rho, None, 40, 0.4, 1.0e-6, 1.0e-6)
.expect("converged cache with residual");
(term, target, rho, cache)
}
/// `‖A x − rhs‖` for the exact stationarity Jacobian `A` (the matrix-free
/// `B v + ΔC v` apply).
fn a_residual_norm(
term: &SaeManifoldTerm,
rho: &SaeManifoldRho,
target: ArrayView2<'_, f64>,
cache: &ArrowFactorCache,
x: &SaeArrowVector,
rhs: &SaeArrowVector,
) -> f64 {
let ax = term
.apply_exact_hessian(rho, target, cache, x)
.expect("A matvec");
let resid = SaeArrowVector {
t: &ax.t - &rhs.t,
beta: &ax.beta - &rhs.beta,
};
sae_norm(&resid)
}
/// `solve_exact_stationarity` returns the EXACT solve of `A x = rhs` (small
/// `A`-residual), AND the surrogate solve `x_B = B⁻¹ rhs` leaves a LARGE
/// `A`-residual — so the certificate is non-vacuous (`A ≠ B`) and the IFT
/// step genuinely inverts `A`. Before #1418 the implicit step used `x_B`
/// (the truncated `B⁻¹`-Neumann iterate), whose `A`-residual is the large
/// value asserted below: that code leaves `‖A x_B − rhs‖` far from zero (and
/// the Neumann variant diverges outright once `ρ(B⁻¹ΔC) ≥ 1`), so this test
/// fails before the fix and passes only when the solve targets the exact `A`.
#[test]
fn solve_exact_stationarity_inverts_a_not_b_1418() {
let (term, target, rho, cache) = converged_state_with_residual();
let solver = DeflatedArrowSolver::plain(&cache);
// A deterministic, nonzero rhs spanning both the latent (t) and decoder
// (β) blocks.
let total_t = cache.delta_t_len();
let rhs = SaeArrowVector {
t: Array1::from_shape_fn(total_t, |i| 0.3 + 0.1 * ((i % 5) as f64) - 0.02 * i as f64),
beta: Array1::from_shape_fn(cache.k, |j| 0.2 - 0.05 * ((j % 3) as f64)),
};
let rhs_norm = sae_norm(&rhs).max(1.0);
// Exact A-solve via the #1418 path.
let x = term
.solve_exact_stationarity(&rho, target.view(), &cache, &solver, &rhs)
.expect("exact stationarity solve");
let exact_resid = a_residual_norm(&term, &rho, target.view(), &cache, &x, &rhs);
// Surrogate solve x_B = B⁻¹ rhs (the pre-#1418 implicit step).
let x_b = solver.solve(rhs.t.view(), rhs.beta.view()).expect("B inverse");
let surrogate_resid = a_residual_norm(&term, &rho, target.view(), &cache, &x_b, &rhs);
// 1) The exact solve drives the A-residual to ~0.
assert!(
exact_resid <= 1.0e-6 * rhs_norm,
"solve_exact_stationarity must invert the EXACT A: ‖A x − rhs‖/‖rhs‖ = {:.3e} \
(rhs_norm={rhs_norm:.3e}) — the IFT step is not solving A x = rhs (#1418)",
exact_resid / rhs_norm
);
// 2) Non-vacuity: the surrogate B-solve leaves a materially large
// A-residual, so A ≠ B is genuinely exercised. The pre-#1418 code used
// x_B for the implicit step, so this is exactly the error #1418 removed.
assert!(
surrogate_resid >= 1.0e-2 * rhs_norm,
"the surrogate B-solve must leave a large A-residual so the A≠B fix is \
non-vacuous: ‖A x_B − rhs‖/‖rhs‖ = {:.3e} — ΔC = A − B is too small to \
distinguish the exact stationarity Jacobian from the surrogate",
surrogate_resid / rhs_norm
);
// 3) The exact solve is a strict, large improvement over the surrogate.
assert!(
exact_resid < 1.0e-3 * surrogate_resid,
"exact A-solve residual {exact_resid:.3e} must be far below surrogate {surrogate_resid:.3e}"
);
}
}