gam_sae/hybrid_split.rs
1//! #1026 — load-bearing curved-vs-linear hybrid split for the fitted SAE
2//! dictionary.
3//!
4//! The selection machinery ([`gam_solve::evidence::select_hybrid_atom`],
5//! [`gam_solve::evidence::select_hybrid_split`]) and the per-atom
6//! integration helper
7//! ([`crate::assignment::select_hybrid_atom_parameterization`]) are
8//! correct and tested, but until now were called nowhere in the fitter: the
9//! post-fit pass only *logged* each `d = 1` atom's fitted turning `Θ`. This
10//! module makes the split LOAD-BEARING by building, per fitted `d = 1` atom, the
11//! two already-realized candidates and adjudicating them by the common
12//! evidence criterion.
13//!
14//! ## The common-evidence comparison on the data (#1202)
15//!
16//! Both candidates are scored against the SAME data: the portion of the
17//! response matrix the atom is responsible for reconstructing, namely its
18//! **leave-this-atom-out residual**
19//!
20//! y_resp[i] = target[i] − ( Σ_j a[i,j]·γ_j(t_{ij}) − a[i,k]·γ_k(t_{ik}) )
21//! = target[i] − without_k[i],
22//!
23//! the response with every OTHER atom's contribution subtracted. Over the rows
24//! assigned to atom `k` (assignment mass `a[i,k] = a_k`), the two candidates
25//! predict that residual:
26//!
27//! * the CURVED candidate is scored at its CONSTRAINED MINIMUM over the atom's
28//! decoder coefficients — re-fit on `y_resp` as
29//! `min_B ‖y_resp − a_k·(Φ(t)·B)‖²` ([`curved_refit_rss`]) — so its data-fit
30//! deviance is the smallest weighted RSS the curved family can attain on this
31//! residual at the realized codes, not the possibly-collapsed realized curve.
32//! * the LINEAR candidate predicts `a_k · (b₀ + (t − t̄)·b₁)`, the best
33//! weighted least-squares straight line fit to `y_resp` (design column
34//! scaled by the same assignment mass `a_k`), so its data-fit deviance is the
35//! weighted RSS of the best line against the SAME residual.
36//!
37//! ## A genuine NESTED min-vs-min comparison (#1051)
38//!
39//! Both arms are now at their constrained minimum on the SAME leave-one-atom-out
40//! residual: the linear arm is the closed-form min-over-lines, and the curved arm
41//! is the min-over-decoders refit ([`curved_refit_rss`]). The linear special case
42//! `a_k·(b₀ + (t−t̄)·b₁)` is a MEMBER of the curved family whenever the straight
43//! lane `[1, (t−t̄)]` lies in the column span of the curved basis `Φ` — exactly for
44//! the interval / line-segment charts (whose basis carries the constant and linear
45//! terms) and to the basis's expressiveness for the periodic charts. So after the
46//! refit `curved_rss ≤ linear_rss` up to the least-squares solver tolerance: the
47//! curved family CANNOT do worse than its own `Θ = 0` sub-model. That is the
48//! nested-dominance property restored here — "curved match-or-beats linear" is now
49//! a floor established by re-optimizing the curved arm, not merely asserted.
50//!
51//! The broken (#1051) euclidean / multi-atom OUTER continuation is deliberately
52//! NOT re-entered: the direct per-atom `d = 1` decoder-only refit at the realized
53//! codes is sufficient to score the curved arm at its constrained minimum. When
54//! `Φ` is unavailable (no evaluator) or the refit solve is degenerate the arm
55//! falls back to the already-realized curve's RSS — an honest degradation, never a
56//! fabricated determinant. The argmin then trades the curved arm's (minimized)
57//! data fit against its larger parameter / complexity price, so a genuinely
58//! curved signal is preferred while a straight-line signal ties and collapses to
59//! the cheaper linear lane.
60//!
61//! This replaces an earlier diagnostic in which both candidates targeted the
62//! atom's already-fitted decoded image `γ_k(t)` (giving the curved arm a free
63//! zero residual against itself, #1202) and its successor which scored the
64//! already-REALIZED curved contribution (a post-fit heuristic that did not
65//! establish nested dominance); the curved arm is now re-fit to its minimum.
66
67use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
68
69use gam_linalg::faer_ndarray::FaerEigh;
70use gam_solve::evidence::{
71 HybridAtomCandidate, HybridAtomChoice, HybridSplitSelection, select_hybrid_split,
72};
73use gam_terms::latent::LatentManifold;
74use crate::chart_canonicalization::d1_atom_fitted_turning;
75use crate::manifold::{SaeManifoldAtom, solve_design_least_squares};
76
77/// The rank-aware Laplace negative-log-evidence of a reduced per-atom Gaussian
78/// reconstruction sub-model: `residual_objective + ½ log|H|` with no smoothing
79/// penalty logdet and a full-rank design (no null space), which is the form
80/// [`gam_solve::evidence::laplace_evidence`] reduces to on this comparison.
81/// Kept inline (rather than routed through `EvidenceLogDetSource`) because both
82/// candidates' Hessian logdets are already the closed-form scalar moments of
83/// their shared design — no factor cache or HVP callback to assemble.
84///
85/// SCALE CAVEAT: this is a FIXED-DISPERSION Laplace / penalized criterion, not a
86/// profiled REML marginal likelihood. It assumes unit dispersion (`σ² = 1`)
87/// after preprocessing — the residual objective is the bare `½ RSS` with no
88/// `RSS/(2σ²)` rescaling and no profiled-over-σ² log-determinant correction. It
89/// is therefore SENSITIVE to the response scale: rescaling `y → s·y` scales the
90/// `residual_objective` (`RSS`) by `s²` but leaves the `½ log|H|` complexity term
91/// unchanged, so the curved-vs-linear trade-off it expresses is not
92/// scale-invariant. Callers must keep the targets on a consistent (preprocessed,
93/// roughly unit-scale) footing for the comparison to mean what it says.
94fn reduced_laplace_nle(residual_objective: f64, log_det_h: f64) -> f64 {
95 residual_objective + 0.5 * log_det_h
96}
97
98/// Rank-aware `log|ΦᵀWΦ|_+` of the curved atom's weighted design Gram over its
99/// `M` decoder basis columns, with per-row weight `wᵢ = a_k²` (the same
100/// assignment-mass design weight the linear arm uses), summed over the
101/// eigenvalues above a relative spectral floor (#1223). This is the genuine
102/// weighted-design determinant the linear arm already reports — `log|XᵀWX|` —
103/// assembled for the curved basis so the two arms' Laplace complexity prices are
104/// computed on the SAME footing instead of pricing the curved arm with a
105/// parameter-count proxy `M·log(Σw)`.
106///
107/// Mirrors the linear arm exactly in what it does NOT include: no smoothing-
108/// penalty `λS` normalizer (the linear arm's Gram is the bare data Gram
109/// `diag(w_sum, s_tt)` too), so the comparison stays symmetric. The Gram is the
110/// design's outer Gram over its basis columns; it is identical across the `p`
111/// output channels (every channel shares the design `Φ`), so the per-channel
112/// `log|G|_+` is multiplied by `p` — matching the linear arm's `p·(…)` form.
113///
114/// `phi` is the curved design `Φ(t)` evaluated on the atom's assigned rows
115/// (`n × M`); `assign` is the per-row assignment mass `a_k` (NOT squared).
116/// Returns `None` when `Φ` is missing rows, the Gram is non-finite, or it has no
117/// positive eigenvalues (a fully rank-deficient design carries no determinant);
118/// the caller then falls back to the parameter-count proxy rather than fabricate
119/// a determinant.
120fn curved_design_gram_logdet(
121 phi: ArrayView2<'_, f64>,
122 assign: ArrayView1<'_, f64>,
123 p: usize,
124) -> Option<f64> {
125 let n = phi.nrows();
126 let m = phi.ncols();
127 if m == 0 || assign.len() != n || n == 0 {
128 return None;
129 }
130 // G = Φᵀ diag(a²) Φ (M×M, symmetric PSD).
131 let mut gram = Array2::<f64>::zeros((m, m));
132 for i in 0..n {
133 let w = assign[i] * assign[i];
134 if !(w.is_finite() && w >= 0.0) {
135 return None;
136 }
137 if w == 0.0 {
138 continue;
139 }
140 let row = phi.row(i);
141 for a in 0..m {
142 let wa = w * row[a];
143 for b in a..m {
144 gram[[a, b]] += wa * row[b];
145 }
146 }
147 }
148 // Symmetrize the lower triangle (we only filled the upper).
149 for a in 0..m {
150 for b in 0..a {
151 gram[[a, b]] = gram[[b, a]];
152 }
153 }
154 if gram.iter().any(|v| !v.is_finite()) {
155 return None;
156 }
157 let (vals, _vecs) = gram.eigh(faer::Side::Lower).ok()?;
158 // Rank-aware log-determinant: sum log of eigenvalues above a relative floor
159 // tied to the largest eigenvalue, dropping the numerically-null directions
160 // (the curved design's null space, analogous to the linear arm's full-rank
161 // 2-D Gram). A design with no positive eigenvalue carries no determinant.
162 let lambda_max = vals.iter().cloned().fold(0.0_f64, f64::max);
163 if !(lambda_max > 0.0 && lambda_max.is_finite()) {
164 return None;
165 }
166 let floor = lambda_max * 1e-12;
167 let mut log_det = 0.0_f64;
168 let mut rank = 0usize;
169 for &lambda in vals.iter() {
170 if lambda > floor {
171 log_det += lambda.ln();
172 rank += 1;
173 }
174 }
175 if rank == 0 || !log_det.is_finite() {
176 return None;
177 }
178 // The design Gram is shared across the p output channels.
179 Some((p as f64) * log_det)
180}
181
182/// The fitted straight sub-model `γ̃(t) = b₀ + (t − t̄)·b₁` of one `d = 1` atom:
183/// the exact assignment-mass-weighted least-squares line fit to the atom's
184/// leave-this-atom-out RESPONSE residual `y_resp` over its assigned rows (the
185/// curved family's nested `Θ = 0` sub-model on common data, #1202). Carried on a
186/// verdict that selects LINEAR so the collapsed reconstruction can replace the
187/// curved decoded row with this straight image at any coordinate WITHOUT
188/// re-entering the (broken, #1051) outer fit — the coefficients are already
189/// realized inside the adjudication.
190#[derive(Clone, Debug)]
191pub struct AtomLinearImage {
192 /// The atom's slot index in the dictionary (so the collapsed assembly knows
193 /// which atom's decoded row to substitute).
194 pub atom_idx: usize,
195 /// The mass-weighted coordinate mean `t̄` the line is centered on.
196 pub t_bar: f64,
197 /// Per-output-channel centered intercept `b₀ = γ̄` at `t̄` (length `p`).
198 pub b0: Array1<f64>,
199 /// Per-output-channel slope `b₁` (length `p`).
200 pub b1: Array1<f64>,
201 /// #1026 collapse-rescue per-row coordinates. `None` for the ordinary path:
202 /// the line is evaluated at the atom's OWN realized coordinate `t`. `Some(u)`
203 /// only when the atom's circle codes had collapsed to a single point
204 /// (`s_tt ≈ 0`) so its own coordinate carries no spread — then the line is fit
205 /// against, and reconstruct evaluates it at, these FRESH per-row codes `uᵢ`
206 /// (the projection of the leave-this-atom-out residual onto its top
207 /// mass-weighted output direction). This is what lets a circle atom that the
208 /// joint fit drove into the degenerate "chord-through-the-arc" fixed point
209 /// still reconstruct its residual's best linear direction — recovering the
210 /// linear-tail reach the hybrid-split was designed to deliver — instead of a
211 /// constant (its collapsed curve), which is the real-OLMo rank-1 co-collapse
212 /// (held-out EV ≈ 0.13 vs the linear ceiling ≈ 0.74). Length `n` (one per
213 /// reconstructed row); unassigned rows are gated to zero by `a_k` anyway.
214 ///
215 /// TRAIN-ONLY CAVEAT (#1777): these are the TRAIN rows' codes. They are only
216 /// meaningful for the exact rows the split was fit on; a held-out row has no
217 /// entry here and used to fall back to the atom's own (collapsed) coordinate
218 /// `own_t` — a DIFFERENT, degraded model out of sample. Prefer [`Self::v`]:
219 /// projecting a held-out row's leave-this-atom-out residual onto `v` recovers
220 /// that row's coordinate by the SAME math the train codes were built with, so
221 /// train and OOS use one model. `row_codes` is retained for back-compat and as
222 /// the exact cached train projection.
223 pub row_codes: Option<Array1<f64>>,
224 /// #1777 collapse-rescue projection DIRECTION `v` (length `p`, unit norm), the
225 /// top mass-weighted output direction of the atom's leave-this-atom-out
226 /// residual. `Some` exactly when this is a collapse-rescued image (paired with
227 /// `row_codes`); `None` for the ordinary straight-image path (which decodes at
228 /// the atom's own coordinate). This is the SERIALIZABLE quantity the FFI must
229 /// persist so an OOS term can recompute any row's coordinate as
230 /// `uᵢ = ⟨y_i − Σ_{j≠k} f_j(x_i), v⟩` — identical to the train code
231 /// `row_codes[i]` on a train row, and the correct held-out coordinate on an OOS
232 /// row (see [`Self::coordinate_from_residual`]). Length must equal `b0`/`b1`.
233 pub v: Option<Array1<f64>>,
234}
235
236impl AtomLinearImage {
237 /// Evaluate the straight sub-model `b₀ + (t − t̄)·b₁` into `out` (length `p`).
238 pub fn fill_row(&self, t: f64, out: &mut [f64]) {
239 let dt = t - self.t_bar;
240 for (j, slot) in out.iter_mut().enumerate() {
241 *slot = self.b0[j] + dt * self.b1[j];
242 }
243 }
244
245 /// The coordinate at which row `row` should evaluate this image: the
246 /// collapse-rescue fresh code `uᵢ` when present (#1026), else the atom's own
247 /// realized coordinate `own_t` passed by the caller.
248 ///
249 /// TRAIN-ONLY: `row_codes` is indexed by TRAIN row, so this is correct only
250 /// for the rows the split was fit on. Out of sample use
251 /// [`Self::coordinate_from_residual`] (target-aware, model-identical to train).
252 pub fn coordinate_for_row(&self, row: usize, own_t: f64) -> f64 {
253 match &self.row_codes {
254 Some(u) if row < u.len() => u[row],
255 _ => own_t,
256 }
257 }
258
259 /// #1777 — the collapse-rescue coordinate of a row from ITS OWN
260 /// leave-this-atom-out residual `resid = y_i − Σ_{j≠k} f_j(x_i)` (length `p`),
261 /// namely `uᵢ = ⟨resid, v⟩`. `Some(uᵢ)` exactly when this is a collapse-rescued
262 /// image (`v` is set); `None` for the ordinary straight-image path (which has
263 /// no projection direction and decodes at the atom's own coordinate).
264 ///
265 /// This is the SAME math [`build_collapse_rescue_linear_image`] used to build
266 /// the train `row_codes` (`row_codes[i] = ⟨target_resid[i], v⟩`), so on a TRAIN
267 /// row it reproduces `row_codes[i]` exactly, and on a HELD-OUT row it yields
268 /// that row's correct coordinate — train and OOS share one model. Returns
269 /// `None` if `resid`'s length disagrees with `v`.
270 pub fn coordinate_from_residual(&self, resid: &[f64]) -> Option<f64> {
271 let v = self.v.as_ref()?;
272 if resid.len() != v.len() {
273 return None;
274 }
275 Some(v.iter().zip(resid).map(|(&vj, &rj)| vj * rj).sum())
276 }
277
278 /// Whether this image is a #1777 collapse-rescued image (carries a projection
279 /// direction `v` and per-row train codes) rather than an ordinary straight
280 /// image evaluated at the atom's own coordinate.
281 pub fn is_collapse_rescued(&self) -> bool {
282 self.v.is_some()
283 }
284}
285
286/// One fitted `d = 1` atom's hybrid-split verdict, surfaced in the model output.
287#[derive(Clone, Debug)]
288pub struct AtomHybridVerdict {
289 /// The atom's name (slot identity in the dictionary).
290 pub atom_name: String,
291 /// The evidence-selected parameterization choice for this slot.
292 pub choice: HybridAtomChoice,
293 /// `true` iff the slot kept the CURVED parameterization (the fitted atom);
294 /// `false` iff it yielded to the LINEAR special case (the straight tail).
295 pub kept_curved: bool,
296 /// The atom's fitted turning `Θ = ∫|κ| ds` (radians), the novel geometric
297 /// quantity #1026 pairs against reconstruction EV: `Θ ≈ 0` is a linear-tail
298 /// direction wearing a curved basis, `Θ ≈ 2π` is a full curved loop. `None`
299 /// iff the evaluator has no analytic second jet or the curve is degenerate.
300 /// Captured here (not just logged) so the EV-vs-Θ frontier is queryable
301 /// structured data on the persisted report rather than a transient log line.
302 pub fitted_turning: Option<f64>,
303 /// The atom's training leave-one-atom-out explained-variance contribution
304 /// `ΔEV_k = EV(full) − EV(full∖{k})` — how much reconstruction EV this single
305 /// atom earns. Paired with [`Self::fitted_turning`] this is the `(Θ, ΔEV)`
306 /// point the #1026 frontier reports: a `Θ ≈ 0` atom with large `ΔEV` is a
307 /// genuine linear-tail direction; a high-`Θ` atom with large `ΔEV` is a
308 /// genuine curved family. `None` iff the caller did not supply LOAO EV.
309 pub train_loao_delta_ev: Option<f64>,
310 /// The fitted straight sub-model for this slot, present iff the verdict
311 /// selected LINEAR (`kept_curved == false`). The collapsed reconstruction
312 /// substitutes this for the atom's curved decoded image, making the verdict
313 /// load-bearing on the reconstruction rather than a passive diagnostic.
314 pub linear_image: Option<AtomLinearImage>,
315}
316
317/// The whole dictionary's hybrid-split report: one verdict per eligible `d = 1`
318/// atom, plus the dictionary-level aggregates the EV-vs-Θ frontier reports
319/// against.
320#[derive(Clone, Debug)]
321pub struct SaeHybridSplitReport {
322 /// One adjudicated verdict per eligible `d = 1` atom, in slot order. Atoms
323 /// that are not eligible (wrong dim, no evaluator, mid-homotopy) are absent
324 /// — they carry no curved/linear adjudication.
325 pub verdicts: Vec<AtomHybridVerdict>,
326 /// The dictionary-level rolled-up selection (summed NLE, total parameters,
327 /// curved/linear counts) over the eligible atoms.
328 pub selection: HybridSplitSelection,
329}
330
331/// Below this many assigned rows a `d = 1` atom cannot support a two-parameter
332/// straight-line fit with a residual estimate, so the linear candidate's
333/// deviance is undefined. Such atoms are skipped (absent from the report),
334/// never adjudicated on a fabricated deviance.
335const MIN_ROWS_FOR_LINEAR_FIT: usize = 3;
336
337/// #1610/#1026 — EV-PRESERVATION gate tolerance: a `d = 1` slot may collapse to
338/// its linear tail only if doing so costs at most this fraction of the target's
339/// total (centered) variance in full-reconstruction explained variance. The
340/// evidence argmin trades data-fit against the curved arm's parameter price in
341/// `NLE` units, but on a small / low-amplitude fixture that trade can prefer the
342/// cheaper line even when the curve carries real reconstruction signal — the
343/// collapse then DROPS EV (the observed 1.0 → 0.748 over-collapse). This gate is
344/// a direct guard on the quantity that actually regressed: the per-atom collapse
345/// EV impact equals `(linear_rss − curved_rss)/SST_full` exactly (collapsing
346/// atom `k` raises the full reconstruction SSR by `linear_rss − curved_rss` and
347/// the full target variance `SST_full` is fixed), so a collapse is vetoed when it
348/// would lose more than this fraction. EV is a dimensionless quantity in `[0,1]`,
349/// so an absolute EV-loss tolerance is itself scale-invariant — it does not
350/// reintroduce the scale-incommensurability that sank the evidence-reformulation
351/// attempt. A genuinely straight atom (the curved fit IS its own line) loses
352/// `≈ 0` EV and still collapses losslessly; only a curve doing real
353/// reconstruction work (loss `≫ 1e-3`) is kept. `1e-3` = 0.1% of total variance:
354/// comfortably above f64 round-off on an exact-line collapse yet far below any
355/// material EV loss, so it separates the lossless and load-bearing regimes
356/// without tuning.
357///
358/// WHY A FIXED DIMENSIONLESS TOLERANCE AND NOT A NOISE/DISPERSION-DERIVED ONE
359/// (#1610). This gate is a SAFETY BACKSTOP layered over the evidence/REML
360/// selection ([`select_hybrid_split`]), which is itself the nested-model
361/// statistical test: it already trades the curved arm's data-fit against its
362/// Laplace parameter price in NLE units and picks the line when the curve is not
363/// evidence-justified. The backstop exists only to catch the residual case where
364/// that argmin prefers the cheaper line yet doing so DROPS reconstruction EV (the
365/// observed 1.0 → 0.748). The correct instrument for a backstop over a
366/// statistical test is a conservative negligibility tolerance on the quantity it
367/// protects (full-reconstruction EV), NOT a second re-derived statistical
368/// threshold. A noise/dispersion-derived per-atom tolerance — e.g.
369/// `df_extra · σ̂² / SST_full`, the curve's expected spurious extra RSS under the
370/// null that the image is straight, with `σ̂²` the curved fit's residual
371/// dispersion — has, moreover, no SAFE calibration here: on an exact-line
372/// collapse the fit's dispersion `σ̂² → 0`, so a pure noise threshold falls BELOW
373/// the least-squares solver round-off of `linear_rss − curved_rss`
374/// (`≈ κ·εmach·SST_full`) and would spuriously VETO the lossless collapses the
375/// deterministic tests require; raising it to clear round-off, conversely, only
376/// relaxes the backstop TOWARD the over-collapse boundary it was added to hold
377/// (larger tolerance ⇒ fewer vetoes ⇒ more collapse). There is thus no safe
378/// direction to make it noise-adaptive. The safe window is the wide, well-
379/// separated band between solver round-off (`~1e-12` relative) and any material
380/// EV loss (`~1e-2`); `1e-3` is the standard "0.1% of variance is negligible"
381/// point inside it — dimensionless on an EV `∈ [0,1]` (hence scale-invariant, per
382/// this issue's scale-invariance contract) with a single explicit meaning, not a
383/// corpus-tuned magnitude. (A noise-adaptive backstop would also change
384/// reconstruction on real activations in a way only the real-OLMo behavioral
385/// battery can validate.)
386///
387/// PER-ATOM SCOPE: applied per slot, this tolerance bounds only ONE atom's
388/// individual EV loss. It does NOT by itself bound the dictionary-level EV loss
389/// when several atoms collapse at once: the true global RSS change is
390/// `Σ_k Δ_k + 2 Σ_{j<k} <Δrecon_j, Δrecon_k>` — both the accumulation of many
391/// individually-tolerable `Δ_k` AND the cross terms between co-active atoms are
392/// invisible to the per-atom gate. [`build_hybrid_split_report`] adds an
393/// aggregate global guard (interpreting this same fraction as a bound on
394/// `Σ_k max(Δ_k, 0)`) on top of the per-atom gate; see there for what that guard
395/// does and does not prove.
396const SAE_HYBRID_COLLAPSE_MAX_EV_LOSS: f64 = 1.0e-3;
397
398/// The full-reconstruction SSR INCREASE from collapsing one `d = 1` atom to its
399/// fitted straight sub-model: `linear_rss − curved_rss`, where both arms are
400/// scored against the atom's leave-this-atom-out response residual `y_resp`
401/// (`target_resid`) exactly as [`build_atom_candidates`] scores them. Because the
402/// full reconstruction differs from the collapsed one ONLY in this atom's
403/// contribution (`a_k·γ_k` → `a_k·line`) on its assigned rows, this scalar is the
404/// exact amount the full reconstruction's SSR rises when the slot is collapsed;
405/// dividing by the fixed total target variance gives the full-EV loss the
406/// EV-preservation gate keys on. Positive ⇒ the curve out-fits its straight
407/// projection (collapsing hurts); `≤ 0` ⇒ the line is at least as good (collapse
408/// is lossless or improving, never gated). Mirrors the `curved_rss` / `linear_rss`
409/// accumulation in [`build_atom_candidates`] bit-for-bit so the gate and the
410/// evidence comparison see the same residuals.
411fn collapse_ssr_increase(
412 coords: ArrayView1<'_, f64>,
413 assign: ArrayView1<'_, f64>,
414 decoded: ArrayView2<'_, f64>,
415 target_resid: ArrayView2<'_, f64>,
416 t_bar: f64,
417 b0: &Array1<f64>,
418 b1: &Array1<f64>,
419) -> f64 {
420 let n = assign.len();
421 let p = target_resid.ncols();
422 let mut curved_rss = 0.0_f64;
423 let mut linear_rss = 0.0_f64;
424 for i in 0..n {
425 let a = assign[i];
426 let dt = coords[i] - t_bar;
427 for j in 0..p {
428 let y = target_resid[[i, j]];
429 let r_curved = y - a * decoded[[i, j]];
430 curved_rss += r_curved * r_curved;
431 let r_linear = y - a * (b0[j] + dt * b1[j]);
432 linear_rss += r_linear * r_linear;
433 }
434 }
435 linear_rss - curved_rss
436}
437
438/// #1051/#1026 NESTED MIN — re-fit the curved atom's decoder on the SAME
439/// leave-this-atom-out residual `y_resp` and return its **minimum** weighted
440/// reconstruction RSS at the atom's realized codes.
441///
442/// This is the genuine constrained minimum of the curved family over its free
443/// decoder coefficients `B`:
444///
445/// ```text
446/// min_B Σᵢ ‖ y_resp[i] − a_k·( Φ(t_i)·B ) ‖² (design = diag(a)·Φ, rhs = y_resp).
447/// ```
448///
449/// It restores the nested-dominance floor `curved_rss ≤ linear_rss`. The linear
450/// special case `a_k·(b₀ + (t−t̄)·b₁)` is itself a member of this family whenever
451/// the straight lane `[1, (t−t̄)]` lies in the column span of the curved basis
452/// `Φ` — exactly (interval / line-segment charts, whose basis carries the
453/// constant and linear terms) or to the basis's expressiveness (the periodic
454/// charts). So `min_B` over `Φ` cannot do WORSE than the best straight line: the
455/// returned RSS is `≤` the linear arm's RSS up to the least-squares solver
456/// tolerance. This is the direct per-atom `d = 1` refit the module owes; the
457/// broken (#1051) euclidean / multi-atom OUTER continuation is deliberately NOT
458/// re-entered — the decoder-only refit at the realized codes is sufficient to
459/// score the curved arm at its constrained minimum.
460///
461/// `phi` is the curved design `Φ(t)` on the atom's rows (`n × M`); `assign` the
462/// per-row mass `a_k` (NOT squared — the design weight is the mass itself, so the
463/// residual is on the SAME footing the linear arm and the joint loss use).
464/// Returns `None` when the solve is degenerate or non-finite; the caller then
465/// falls back to the already-realized curve's RSS rather than fabricate a value.
466fn curved_refit_rss(
467 phi: ArrayView2<'_, f64>,
468 assign: ArrayView1<'_, f64>,
469 target_resid: ArrayView2<'_, f64>,
470) -> Option<f64> {
471 let n = phi.nrows();
472 let m = phi.ncols();
473 let p = target_resid.ncols();
474 if m == 0 || n == 0 || assign.len() != n || target_resid.nrows() != n || p == 0 {
475 return None;
476 }
477 // Weighted design `diag(a)·Φ` (n×M). The refit minimizes ‖diag(a)·Φ·B − y_resp‖²,
478 // so the fitted prediction is diag(a)·Φ·B — the curved contribution `a_k·γ_k`
479 // at its best decoder `B` on this residual.
480 let mut design = Array2::<f64>::zeros((n, m));
481 for i in 0..n {
482 let a = assign[i];
483 if !a.is_finite() {
484 return None;
485 }
486 for c in 0..m {
487 design[[i, c]] = a * phi[[i, c]];
488 }
489 }
490 let b = solve_design_least_squares(design.view(), target_resid).ok()?;
491 if b.iter().any(|v| !v.is_finite()) {
492 return None;
493 }
494 let pred = design.dot(&b);
495 let mut rss = 0.0_f64;
496 for i in 0..n {
497 for j in 0..p {
498 let r = target_resid[[i, j]] - pred[[i, j]];
499 rss += r * r;
500 }
501 }
502 rss.is_finite().then_some(rss)
503}
504
505/// Build the curved + linear candidates for ONE fitted `d = 1` atom and return
506/// them as `(linear, curved, (t̄, b₀, b₁))`, or `None` if the atom cannot present
507/// an honest pair (too few rows, degenerate coordinate span, or non-finite
508/// numbers). Both candidates are scored against the SAME data — the atom's
509/// leave-this-atom-out response residual `y_resp` — at their CONSTRAINED MINIMUM:
510/// the linear arm is the freshly-fit min-over-lines and the curved arm is the
511/// min-over-decoders refit ([`curved_refit_rss`]), so the comparison is a genuine
512/// nested min-vs-min one and `curved_rss ≤ linear_rss` holds up to solver
513/// tolerance for a basis whose span contains the straight lane (#1051).
514///
515/// Inputs over the atom's assigned rows:
516/// * `coords` — the fitted on-atom coordinate `t`.
517/// * `assign` — the per-row assignment mass `a_k` (NOT squared; this routine
518/// squares it where the design weight `a_k²` is needed).
519/// * `decoded` — the atom's fitted decoded image `γ_k(t) = Φ(t) B_k` (`p` cols),
520/// whose mass-scaled value `a_k·γ_k` is the curved candidate's PREDICTION.
521/// * `target_resid` — the atom's leave-this-atom-out response residual `y_resp`
522/// (`p` cols): the response with every OTHER atom's contribution removed.
523/// This is the data both candidates fit.
524///
525/// The curved candidate's data-fit deviance is `½·min_B Σ ‖y_resp − a_k·(Φ·B)‖²`
526/// (its constrained minimum over the decoder; the mass lives in the prediction);
527/// the linear candidate fits the best mass-weighted straight line to `y_resp` and
528/// pays `½ Σ ‖y_resp − a_k·(b₀ + (t − t̄)·b₁)‖²`. Because the linear prediction is
529/// itself a curved-family member (the straight lane lies in `span(Φ)` for the
530/// eligible charts), the curved arm's minimized RSS is `≤` the linear arm's up to
531/// solver tolerance, so the argmin is a genuine nested min-vs-min dominance
532/// comparison, not a post-fit compression heuristic.
533fn build_atom_candidates(
534 coords: ArrayView1<'_, f64>,
535 assign: ArrayView1<'_, f64>,
536 decoded: ArrayView2<'_, f64>,
537 target_resid: ArrayView2<'_, f64>,
538 curved_num_params: usize,
539 curved_phi: Option<ArrayView2<'_, f64>>,
540 fitted_turning: Option<f64>,
541) -> Option<(
542 HybridAtomCandidate,
543 HybridAtomCandidate,
544 (f64, Array1<f64>, Array1<f64>),
545)> {
546 let n = coords.len();
547 let p = decoded.ncols();
548 if n < MIN_ROWS_FOR_LINEAR_FIT
549 || decoded.nrows() != n
550 || assign.len() != n
551 || target_resid.nrows() != n
552 || target_resid.ncols() != p
553 || p == 0
554 {
555 return None;
556 }
557
558 // The LINEAR candidate fits `a_k·(b₀ + (t − t̄)·b₁)` to the residual `y_resp`,
559 // so the natural design column is `a_k·[1, (t − t̄)]` and the per-row Gram
560 // weight is `wᵢ = a_k²`. We accumulate the mass-weighted coordinate mean `t̄`
561 // and spread `s_tt` under that weight; a row that barely belongs to the atom
562 // (`a_k ≈ 0`) contributes ≈ nothing, exactly as in the joint loss.
563 let mut w_sum = 0.0_f64;
564 let mut t_bar = 0.0_f64;
565 for i in 0..n {
566 let a = assign[i];
567 if !(a.is_finite() && a >= 0.0) {
568 return None;
569 }
570 let w = a * a;
571 w_sum += w;
572 t_bar += w * coords[i];
573 }
574 if !(w_sum > 0.0) {
575 return None;
576 }
577 t_bar /= w_sum;
578
579 // Weighted Σ wᵢ·(t − t̄)² with `wᵢ = a_k²` — the coordinate spread under the
580 // line's design weight. A degenerate (single-point mass) coordinate has no
581 // slope direction; refuse rather than divide by ~0.
582 let mut s_tt = 0.0_f64;
583 for i in 0..n {
584 let dt = coords[i] - t_bar;
585 s_tt += assign[i] * assign[i] * dt * dt;
586 }
587 if !(s_tt > 1e-12 * (1.0 + t_bar * t_bar)) {
588 return None;
589 }
590
591 // Per-output-channel mass-weighted least squares for the line fit to the
592 // RESIDUAL `y_resp`. Minimizing `Σᵢ ‖y_resp[i] − a_k·(b₀ + (t − t̄)·b₁)‖²` in
593 // the centered basis has the diagonal normal equations
594 // b₀[j] = (Σ a_k·y_resp[i,j]) / w_sum, (recall the design intercept is a_k)
595 // b₁[j] = (Σ a_k·(t − t̄)·y_resp[i,j]) / s_tt.
596 let mut b0 = Array1::<f64>::zeros(p);
597 let mut b1 = Array1::<f64>::zeros(p);
598 for j in 0..p {
599 let mut s_1y = 0.0_f64;
600 let mut s_ty = 0.0_f64;
601 for i in 0..n {
602 let a = assign[i];
603 let dt = coords[i] - t_bar;
604 let y = target_resid[[i, j]];
605 s_1y += a * y;
606 s_ty += a * dt * y;
607 }
608 b0[j] = s_1y / w_sum;
609 b1[j] = s_ty / s_tt;
610 }
611
612 // Data-fit residual sums of squares of BOTH candidates against `y_resp`, the
613 // common data. The linear candidate predicts the best line
614 // `a_k·(b₀ + (t − t̄)·b₁)`; the curved candidate is scored at its CONSTRAINED
615 // MINIMUM over the decoder coefficients (nested min-vs-min, #1051), re-fit on
616 // this same residual — not the possibly-collapsed already-realized curve. We
617 // also carry the realized curve's RSS as the honest fallback when the basis
618 // `Φ` is unavailable or its refit solve is degenerate.
619 let mut linear_rss = 0.0_f64;
620 let mut realized_curved_rss = 0.0_f64;
621 for i in 0..n {
622 let a = assign[i];
623 let dt = coords[i] - t_bar;
624 for j in 0..p {
625 let y = target_resid[[i, j]];
626 let r_linear = y - a * (b0[j] + dt * b1[j]);
627 linear_rss += r_linear * r_linear;
628 let r_curved = y - a * decoded[[i, j]];
629 realized_curved_rss += r_curved * r_curved;
630 }
631 }
632 // #1051 NESTED MIN — the curved arm's data fit is `min_B ‖y_resp − diag(a)Φ B‖²`,
633 // its constrained minimum over the decoder. Because the linear lane is a member
634 // of the curved family (the straight columns lie in `span(Φ)` for the eligible
635 // charts), this min-curved RSS is `≤ linear_rss` up to solver tolerance — the
636 // "curved match-or-beats linear" floor. Fall back to the realized curve's RSS
637 // only when Φ is absent or the refit is degenerate.
638 let curved_rss = match curved_phi {
639 Some(phi) if phi.nrows() == n => {
640 curved_refit_rss(phi, assign, target_resid).unwrap_or(realized_curved_rss)
641 }
642 _ => realized_curved_rss,
643 };
644
645 // Gaussian-reconstruction deviance: the residual objective `½ RSS` the
646 // Laplace normalizer is added to. The curved arm pays `½·curved_rss` (how
647 // well its REALIZED curve explains the residual) plus its larger `M·p`
648 // parameter price; the linear arm pays `½·linear_rss` plus a `2·p` price.
649 // `curved_rss` is the realized (not re-optimized) curve's misfit, so it is NOT
650 // guaranteed `≤ linear_rss`: when the realized curve underperforms its own best
651 // straight projection the cheaper line simply wins. The argmin trades whatever
652 // data-fit the realized curve buys against the curvature parameter price — a
653 // post-fit compression decision, not a nested match-or-beat floor.
654 let curved_residual_objective = 0.5 * curved_rss;
655 let linear_residual_objective = 0.5 * linear_rss;
656
657 // Linear candidate parameter price: intercept + slope per output channel.
658 let linear_num_params = 2 * p;
659
660 // Laplace logdet of the (weighted) design Gram for the LINEAR candidate.
661 //
662 // For the centered weighted line fit `a_k·(b₀ + (t − t̄)·b₁)`, the per-output-
663 // channel design column is `a_k·[1, (t − t̄)]`, whose Gram is DIAGONAL in the
664 // centered basis: `diag(Σ a_k², Σ a_k²(t − t̄)²) = diag(w_sum, s_tt)`. Its log
665 // determinant is `log(w_sum) + log(s_tt)` PER output channel, i.e.
666 //
667 // log|H_linear| = p · ( log(w_sum) + log(s_tt) ).
668 //
669 // The `log(s_tt)` term is the slope direction's information: a line through a
670 // wide, heavily-massed coordinate spread is better-determined than one through
671 // a tiny spread, and the Laplace evidence must reflect that (#1203).
672 //
673 // The curved arm's Laplace determinant is now the genuine weighted-design
674 // Gram log-determinant `p · log|ΦᵀWΦ|_+` (#1223): the SAME quantity the
675 // linear arm reports (`p·(log w_sum + log s_tt) = p·log|XᵀWX|`), assembled
676 // from the curved basis `Φ` on the atom's assigned rows under the same
677 // assignment-mass design weight `wᵢ = a_k²`. Both arms omit the smoothing
678 // `λS` normalizer, so the complexity price is computed on a symmetric
679 // footing — no parameter-count proxy. Only when `Φ` is unavailable (the
680 // caller could not evaluate the basis) or its Gram is fully rank-deficient do
681 // we fall back to the historical `curved_num_params · log(w_sum)` proxy, so
682 // the comparison degrades gracefully rather than fabricating a determinant.
683 if !(w_sum > 0.0 && w_sum.is_finite() && s_tt.is_finite()) {
684 return None;
685 }
686 let linear_log_det_h = (p as f64) * (w_sum.ln() + s_tt.ln());
687 let curved_log_det_h = curved_phi
688 .and_then(|phi| {
689 if phi.nrows() == n {
690 curved_design_gram_logdet(phi, assign, p)
691 } else {
692 None
693 }
694 })
695 .unwrap_or_else(|| (curved_num_params as f64) * w_sum.ln());
696
697 // Reduced Laplace NLE `residual_objective + ½ log|H|`. Both omit an explicit
698 // smoothing-penalty logdet (the intrinsic smoothness penalty is
699 // reparameterization-invariant and identical in expectation across the two
700 // parameterizations of the same image).
701 let linear_nle = reduced_laplace_nle(linear_residual_objective, linear_log_det_h);
702 let curved_nle = reduced_laplace_nle(curved_residual_objective, curved_log_det_h);
703 if !(linear_nle.is_finite() && curved_nle.is_finite()) {
704 return None;
705 }
706
707 let linear = HybridAtomCandidate::linear(linear_nle, linear_num_params);
708 let curved = HybridAtomCandidate::curved(1, curved_nle, curved_num_params, fitted_turning);
709 Some((linear, curved, (t_bar, b0, b1)))
710}
711
712/// #1026 collapse rescue. When a `d = 1` atom's own coordinate has collapsed to a
713/// single point (`build_atom_candidates` refuses because `s_tt ≈ 0`), the atom is
714/// stuck in the degenerate "chord-through-the-arc" fixed point and its curved
715/// decode is a constant — the rank-1 dictionary co-collapse (real-OLMo held-out EV
716/// ≈ 0.13 vs the rank-K linear ceiling ≈ 0.74). The hybrid-split was DESIGNED to
717/// let such a linear-tail atom decode as a straight line; the only reason it can't
718/// here is that its own codes carry no spread to fit a slope against.
719///
720/// Recover FRESH per-row codes from the data instead: `uᵢ = yᵢ·v`, the projection
721/// of the leave-this-atom-out residual onto its top mass-weighted output direction
722/// `v` (the rank-1 of `Σᵢ wᵢ yᵢyᵢᵀ`, `wᵢ = a_k²` — the SAME design weight the line
723/// fit uses). These codes span the residual's strongest linear axis by
724/// construction, so the straight image `b₀ + (uᵢ − ū)·b₁` fit against them
725/// reconstructs that axis at LINEAR quality — exactly the linear-tail reach the
726/// split owes. Returns the forced-LINEAR candidate plus the image carrying `uᵢ`,
727/// or `None` when the residual itself carries no usable direction (a genuine zero
728/// atom the mass/decoder guards own).
729fn build_collapse_rescue_linear_image(
730 atom_idx: usize,
731 assign: ArrayView1<'_, f64>,
732 target_resid: ArrayView2<'_, f64>,
733) -> Option<(HybridAtomCandidate, AtomLinearImage)> {
734 let n = assign.len();
735 let p = target_resid.ncols();
736 if n < MIN_ROWS_FOR_LINEAR_FIT || target_resid.nrows() != n || p == 0 {
737 return None;
738 }
739 let mut w_sum = 0.0_f64;
740 for i in 0..n {
741 let a = assign[i];
742 if !(a.is_finite() && a >= 0.0) {
743 return None;
744 }
745 w_sum += a * a;
746 }
747 if !(w_sum > 0.0) {
748 return None;
749 }
750 // Top mass-weighted output direction `v` of the residual via power iteration on
751 // `M = Σᵢ wᵢ yᵢyᵢᵀ` (p×p, never materialized): `v ← normalize(Σᵢ wᵢ yᵢ (yᵢ·v))`.
752 // Seed from the per-channel weighted energy so a rank-1 residual converges in
753 // one step and the seed is deterministic (no RNG).
754 let mut v = Array1::<f64>::zeros(p);
755 for j in 0..p {
756 let mut e = 0.0_f64;
757 for i in 0..n {
758 let a = assign[i];
759 let y = target_resid[[i, j]];
760 e += a * a * y * y;
761 }
762 v[j] = e;
763 }
764 let mut vnorm = v.dot(&v).sqrt();
765 if !(vnorm > 0.0) {
766 return None;
767 }
768 v.mapv_inplace(|x| x / vnorm);
769 for _ in 0..32 {
770 let mut mv = Array1::<f64>::zeros(p);
771 for i in 0..n {
772 let a = assign[i];
773 let w = a * a;
774 let mut proj = 0.0_f64;
775 for j in 0..p {
776 proj += target_resid[[i, j]] * v[j];
777 }
778 let wp = w * proj;
779 for j in 0..p {
780 mv[j] += wp * target_resid[[i, j]];
781 }
782 }
783 vnorm = mv.dot(&mv).sqrt();
784 if !(vnorm > 0.0) {
785 return None;
786 }
787 mv.mapv_inplace(|x| x / vnorm);
788 let cos = mv.dot(&v).abs();
789 v = mv;
790 if cos > 1.0 - 1e-12 {
791 break;
792 }
793 }
794 // Fresh per-row codes `uᵢ = yᵢ·v` and the weighted line fit against them.
795 let mut u = Array1::<f64>::zeros(n);
796 let mut t_bar = 0.0_f64;
797 for i in 0..n {
798 let mut proj = 0.0_f64;
799 for j in 0..p {
800 proj += target_resid[[i, j]] * v[j];
801 }
802 u[i] = proj;
803 t_bar += assign[i] * assign[i] * proj;
804 }
805 t_bar /= w_sum;
806 let mut s_tt = 0.0_f64;
807 for i in 0..n {
808 let dt = u[i] - t_bar;
809 s_tt += assign[i] * assign[i] * dt * dt;
810 }
811 if !(s_tt > 1e-12 * (1.0 + t_bar * t_bar)) {
812 return None;
813 }
814 let mut b0 = Array1::<f64>::zeros(p);
815 let mut b1 = Array1::<f64>::zeros(p);
816 let mut linear_rss = 0.0_f64;
817 for j in 0..p {
818 let mut s_1y = 0.0_f64;
819 let mut s_ty = 0.0_f64;
820 for i in 0..n {
821 let a = assign[i];
822 let dt = u[i] - t_bar;
823 let y = target_resid[[i, j]];
824 s_1y += a * y;
825 s_ty += a * dt * y;
826 }
827 b0[j] = s_1y / w_sum;
828 b1[j] = s_ty / s_tt;
829 }
830 for i in 0..n {
831 let a = assign[i];
832 let dt = u[i] - t_bar;
833 for j in 0..p {
834 let r = target_resid[[i, j]] - a * (b0[j] + dt * b1[j]);
835 linear_rss += r * r;
836 }
837 }
838 let linear_log_det_h = (p as f64) * (w_sum.ln() + s_tt.ln());
839 let linear_nle = reduced_laplace_nle(0.5 * linear_rss, linear_log_det_h);
840 if !linear_nle.is_finite() {
841 return None;
842 }
843 let linear = HybridAtomCandidate::linear(linear_nle, 2 * p);
844 let image = AtomLinearImage {
845 atom_idx,
846 t_bar,
847 b0,
848 b1,
849 row_codes: Some(u),
850 // #1777 — persist the projection direction so an OOS row's coordinate can
851 // be recomputed as ⟨residual, v⟩ (identical to the train `row_codes`),
852 // rather than falling back to the atom's collapsed own coordinate.
853 v: Some(v),
854 };
855 Some((linear, image))
856}
857
858/// #1026 item-2 — one collapsed slot's TRUE dictionary-level reconstruction
859/// change `δ_k[i,j] = a_k·(γ_k(t_i) − line_k(t_i))` over ALL globally-aligned rows
860/// (the caller presents `coords`/`assign`/`decoded`/`image` on the same `n` rows,
861/// so these δ vectors ARE cross-atom aligned and their inner products are the
862/// genuine cross terms). `line_k` is evaluated at the image's own coordinate
863/// (collapse-rescue slots evaluate at their fresh per-row codes via
864/// [`AtomLinearImage::coordinate_for_row`], exactly as the collapsed reconstruction
865/// does). Collapsing atom `k` shifts the full reconstruction residual by `+δ_k`.
866fn slot_delta(
867 coords: &Array1<f64>,
868 assign: &Array1<f64>,
869 decoded: &Array2<f64>,
870 image: &AtomLinearImage,
871) -> Array2<f64> {
872 let n = decoded.nrows();
873 let p = decoded.ncols();
874 let mut d = Array2::<f64>::zeros((n, p));
875 for i in 0..n {
876 let a = assign[i];
877 let coord = image.coordinate_for_row(i, coords[i]);
878 let dt = coord - image.t_bar;
879 for j in 0..p {
880 let line = image.b0[j] + dt * image.b1[j];
881 d[[i, j]] = a * (decoded[[i, j]] - line);
882 }
883 }
884 d
885}
886
887/// #1026 item-2 — the ALL-CURVED global reconstruction residual
888/// `R0 = target − Σ_all a·γ` recovered from any single atom's leave-this-atom-out
889/// residual: `R0 = y_resp_k − a_k·γ_k = target_resid − a_k·decoded`. Identical for
890/// every atom (each `target_resid` adds back exactly that atom's own contribution),
891/// so the caller computes it once from the first slot.
892fn slot_r0(assign: &Array1<f64>, decoded: &Array2<f64>, target_resid: &Array2<f64>) -> Array2<f64> {
893 let n = decoded.nrows();
894 let p = decoded.ncols();
895 let mut r0 = Array2::<f64>::zeros((n, p));
896 for i in 0..n {
897 let a = assign[i];
898 for j in 0..p {
899 r0[[i, j]] = target_resid[[i, j]] - a * decoded[[i, j]];
900 }
901 }
902 r0
903}
904
905/// #1026 item-2 — GLOBAL cross-term collapse guard. Given the collapsed slots'
906/// dictionary-level reconstruction-change vectors `δ_k` (globally row-aligned,
907/// `n × p`) and the all-curved global residual `R0 = target − Σ_all a·γ`, decide
908/// which collapses must revert to curved so the TRUE global reconstruction SSR
909/// increase
910///
911/// ```text
912/// ΔRSS(S) = ‖R0 + Σ_{k∈S} δ_k‖² − ‖R0‖²
913/// = 2⟨R0, Σ_{k∈S} δ_k⟩ + ‖Σ_{k∈S} δ_k‖²
914/// = Σ_{k∈S} Δ_k + 2 Σ_{j<k∈S} ⟨δ_j, δ_k⟩
915/// ```
916///
917/// stays within `global_tol`. Unlike the per-atom / summed-loss guard this
918/// INCLUDES the cross terms `2 Σ_{j<k} ⟨δ_j, δ_k⟩` between simultaneously-collapsed
919/// atoms (correlated collapse errors), which the aggregate `Σ max(Δ_k, 0)` bound is
920/// blind to. `forced` are the always-collapsed δ (euclidean / collapse-rescue slots
921/// with no curved alternative): they stay in the reconstruction but cannot be
922/// rolled back. `eligible` are `(slot, curved_evidence_margin, δ)` for
923/// rollback-eligible collapses. When `ΔRSS` over ALL collapses exceeds tolerance,
924/// the least-justified eligible collapses (largest margin — the most marginal
925/// linear win) are reverted one at a time, RECOMPUTING the true global increase
926/// (cross terms and all) after each revert, until within tolerance or none remain.
927/// Returns the slot indices to revert to curved.
928fn global_collapse_rollback(
929 r0: ArrayView2<'_, f64>,
930 eligible: &[(usize, f64, &Array2<f64>)],
931 forced: &[&Array2<f64>],
932 global_tol: f64,
933) -> Vec<usize> {
934 let (n, p) = r0.dim();
935 // True global SSR increase for the active δ set: 2⟨R0, Σδ⟩ + ‖Σδ‖².
936 let increase = |active: &[&Array2<f64>]| -> f64 {
937 let mut cross = 0.0_f64;
938 let mut self_sq = 0.0_f64;
939 for i in 0..n {
940 for j in 0..p {
941 let mut s = 0.0_f64;
942 for d in active {
943 s += d[[i, j]];
944 }
945 cross += r0[[i, j]] * s;
946 self_sq += s * s;
947 }
948 }
949 2.0 * cross + self_sq
950 };
951 let mut kept = vec![true; eligible.len()];
952 let build_active = |kept: &[bool]| -> Vec<&Array2<f64>> {
953 let mut v: Vec<&Array2<f64>> = forced.to_vec();
954 for (idx, &(_, _, d)) in eligible.iter().enumerate() {
955 if kept[idx] {
956 v.push(d);
957 }
958 }
959 v
960 };
961 if increase(&build_active(&kept)) <= global_tol {
962 return Vec::new();
963 }
964 // Revert least-justified first: largest curved_evidence_margin (the most
965 // marginal linear win is the cheapest to give back to curved).
966 let mut order: Vec<usize> = (0..eligible.len()).collect();
967 order.sort_by(|&a, &b| {
968 eligible[b]
969 .1
970 .partial_cmp(&eligible[a].1)
971 .unwrap_or(std::cmp::Ordering::Equal)
972 });
973 let mut reverted = Vec::new();
974 for idx in order {
975 kept[idx] = false;
976 reverted.push(eligible[idx].0);
977 if increase(&build_active(&kept)) <= global_tol {
978 break;
979 }
980 }
981 reverted
982}
983
984/// Assemble the per-atom candidate slots for [`select_hybrid_split`] from the
985/// fitted `d = 1` atoms, run the adjudication, and return the report.
986///
987/// `atoms` are the fitted dictionary atoms; `coords_for` yields the on-atom
988/// coordinate column for a slot, `assign_for` the per-row assignment mass `a_k`,
989/// `decoded_for` the fitted decoded image rows `γ_k`, and `target_resid_for` the
990/// atom's leave-this-atom-out response residual `y_resp` (the data both
991/// candidates are scored against, #1202). `manifold_for` yields the atom's chart
992/// manifold (a flat / Euclidean chart can present only the linear candidate,
993/// enforced inside the selector).
994///
995/// Returns `None` (no report) when no atom is eligible — there is nothing to
996/// adjudicate.
997pub fn build_hybrid_split_report<'a, C, W, D, R, M, E>(
998 atoms: &'a [SaeManifoldAtom],
999 eligible_d1: impl Iterator<Item = usize>,
1000 mut coords_for: C,
1001 mut assign_for: W,
1002 mut decoded_for: D,
1003 mut target_resid_for: R,
1004 mut manifold_for: M,
1005 mut delta_ev_for: E,
1006 // #1026 — the full target's total (column-centered) variance `SST_full`, the
1007 // fixed denominator of the EV-preservation gate. `≤ 0` / non-finite disables
1008 // the gate (a degenerate, varianceless target has no EV to preserve).
1009 total_centered_variance: f64,
1010) -> Result<Option<SaeHybridSplitReport>, String>
1011where
1012 C: FnMut(usize) -> Array1<f64>,
1013 W: FnMut(usize) -> Array1<f64>,
1014 D: FnMut(usize) -> Array2<f64>,
1015 R: FnMut(usize) -> Array2<f64>,
1016 M: FnMut(usize) -> LatentManifold,
1017 // The atom's held-out LOAO `ΔEV_k`, keyed by atom index. `None` when LOAO EV
1018 // is unavailable (e.g. the caller has no target to measure against).
1019 E: FnMut(usize) -> Option<f64>,
1020{
1021 let mut slots: Vec<Vec<HybridAtomCandidate>> = Vec::new();
1022 let mut names: Vec<String> = Vec::new();
1023 let mut manifolds: Vec<LatentManifold> = Vec::new();
1024 // Per-slot fitted straight sub-model `(atom_idx, t̄, b₀, b₁)`, surfaced onto
1025 // the verdict iff the slot selects LINEAR so the collapsed reconstruction can
1026 // substitute it for the curved decoded image.
1027 let mut linear_images: Vec<AtomLinearImage> = Vec::new();
1028 // Per-slot `(Θ, ΔEV)` — the #1026 frontier point — carried onto each verdict
1029 // so the geometry/EV pairing is structured report data, not a log line.
1030 let mut turnings: Vec<Option<f64>> = Vec::new();
1031 let mut delta_evs: Vec<Option<f64>> = Vec::new();
1032 // #1026 item-2 — per-slot collapse loss `Δ_k = linear_rss − curved_rss` for the
1033 // GLOBAL EV-preservation guard below. `Some(Δ_k)` for a curveable slot that
1034 // retained a curved alternative (so a chosen collapse there can be rolled back);
1035 // `None` for euclidean slots (no curved option) and collapse-rescue slots (the
1036 // curve was already degenerate — collapsing recovers EV rather than losing it).
1037 let mut collapse_loss: Vec<Option<f64>> = Vec::new();
1038 // #1026 item-2 — per-slot dictionary-level reconstruction-change vector δ_k
1039 // (globally row-aligned n×p), and the all-curved global residual R0 (computed
1040 // once from the first slot). These feed the GLOBAL cross-term collapse guard,
1041 // which reconstructs the full dictionary with the selected collapses applied
1042 // and measures the TRUE global EV degradation (cross terms and all).
1043 let mut deltas: Vec<Array2<f64>> = Vec::new();
1044 let mut r0: Option<Array2<f64>> = None;
1045
1046 for atom_idx in eligible_d1 {
1047 let atom = &atoms[atom_idx];
1048 let coords = coords_for(atom_idx);
1049 let assign = assign_for(atom_idx);
1050 let decoded = decoded_for(atom_idx);
1051 let target_resid = target_resid_for(atom_idx);
1052 // Curved parameter price = the decoder's `M · p` coefficients.
1053 let curved_num_params = atom.decoder_coefficients.len();
1054 let fitted_turning = atom.basis_evaluator.as_ref().and_then(|evaluator| {
1055 d1_atom_fitted_turning(
1056 evaluator.as_ref(),
1057 atom.decoder_coefficients.view(),
1058 coords.view(),
1059 )
1060 .ok()
1061 .flatten()
1062 });
1063 // Evaluate the curved design `Φ(t)` on this atom's assigned rows so the
1064 // curved arm's Laplace complexity is the real weighted-design Gram
1065 // log-determinant rather than a parameter-count proxy (#1223). A `d = 1`
1066 // atom's coordinate column is presented as an `n × 1` design input. If
1067 // the evaluator is absent or refuses, `curved_phi` stays `None` and
1068 // `build_atom_candidates` falls back to the proxy.
1069 let coords_col = coords
1070 .view()
1071 .into_shape_with_order((coords.len(), 1))
1072 .ok()
1073 .map(|v| v.to_owned());
1074 let curved_phi = match (atom.basis_evaluator.as_ref(), coords_col.as_ref()) {
1075 (Some(evaluator), Some(col)) => {
1076 evaluator.evaluate(col.view()).ok().map(|(phi, _jet)| phi)
1077 }
1078 _ => None,
1079 };
1080 // A flat (Euclidean) chart cannot honestly present a curved candidate;
1081 // the selector drops it. Present both for curveable charts.
1082 let manifold = manifold_for(atom_idx);
1083 match build_atom_candidates(
1084 coords.view(),
1085 assign.view(),
1086 decoded.view(),
1087 target_resid.view(),
1088 curved_num_params,
1089 curved_phi.as_ref().map(|phi| phi.view()),
1090 fitted_turning,
1091 ) {
1092 Some((linear, curved, (t_bar, b0, b1))) => {
1093 // #1026 PER-ATOM EV-PRESERVATION gate. Collapsing this slot raises
1094 // the full reconstruction SSR by `linear_rss − curved_rss`; if that
1095 // is more than `SAE_HYBRID_COLLAPSE_MAX_EV_LOSS` of the fixed total
1096 // target variance the collapse would DROP this ONE atom's EV
1097 // materially, so veto it by presenting only the curved candidate (the
1098 // selector must keep curved). A lossless / improving collapse (`≤ 0`)
1099 // and a negligible one stay free to collapse — EV-neutral cases (the
1100 // top-k / birth-topology lines) are untouched. Only curveable charts
1101 // are gated; a euclidean chart never had a curved option. NOTE: this
1102 // gate is PER-ATOM only — the accumulation of many small collapses
1103 // and the dictionary-level cross terms are handled by the aggregate
1104 // guard after selection.
1105 let loss = collapse_ssr_increase(
1106 coords.view(),
1107 assign.view(),
1108 decoded.view(),
1109 target_resid.view(),
1110 t_bar,
1111 &b0,
1112 &b1,
1113 );
1114 let collapse_loses_ev = total_centered_variance.is_finite()
1115 && total_centered_variance > 0.0
1116 && loss > SAE_HYBRID_COLLAPSE_MAX_EV_LOSS * total_centered_variance;
1117 let euclidean = manifold.is_euclidean();
1118 let slot = if euclidean {
1119 vec![linear]
1120 } else if collapse_loses_ev {
1121 vec![curved]
1122 } else {
1123 vec![linear, curved]
1124 };
1125 // Build the straight image, then its globally-aligned δ_k for the
1126 // GLOBAL cross-term guard, before moving it into the report.
1127 let image = AtomLinearImage {
1128 atom_idx,
1129 t_bar,
1130 b0,
1131 b1,
1132 row_codes: None,
1133 // Ordinary straight image: decoded at the atom's own
1134 // coordinate, so it carries no residual-projection direction.
1135 v: None,
1136 };
1137 let delta = slot_delta(&coords, &assign, &decoded, &image);
1138 if r0.is_none() {
1139 r0 = Some(slot_r0(&assign, &decoded, &target_resid));
1140 }
1141 slots.push(slot);
1142 // A euclidean slot never had a curved alternative, so its collapse
1143 // carries no recoverable EV loss for the global guard; record `None`.
1144 collapse_loss.push(if euclidean { None } else { Some(loss) });
1145 names.push(atom.name.clone());
1146 manifolds.push(manifold);
1147 turnings.push(fitted_turning);
1148 delta_evs.push(delta_ev_for(atom_idx));
1149 deltas.push(delta);
1150 linear_images.push(image);
1151 }
1152 // #1026 collapse rescue: `build_atom_candidates` refused because the
1153 // atom's own coordinate collapsed (`s_tt ≈ 0`) — the rank-1 co-collapse
1154 // fixed point. Recover a FRESH linear image from the residual's top
1155 // direction (fresh per-row codes) and force the LINEAR verdict (a
1156 // single-option slot the selector must take) so the slot reconstructs
1157 // its residual's best linear axis at linear quality instead of the
1158 // collapsed-curve constant. `None` only when the residual itself is
1159 // degenerate — then there is genuinely nothing to recover and we skip.
1160 None => match build_collapse_rescue_linear_image(
1161 atom_idx,
1162 assign.view(),
1163 target_resid.view(),
1164 ) {
1165 Some((linear, image)) => {
1166 let delta = slot_delta(&coords, &assign, &decoded, &image);
1167 if r0.is_none() {
1168 r0 = Some(slot_r0(&assign, &decoded, &target_resid));
1169 }
1170 slots.push(vec![linear]);
1171 // Forced-linear rescue: the curve was degenerate, so there is no
1172 // curved alternative to roll back to and no recoverable EV loss.
1173 collapse_loss.push(None);
1174 names.push(atom.name.clone());
1175 manifolds.push(manifold);
1176 turnings.push(fitted_turning);
1177 delta_evs.push(delta_ev_for(atom_idx));
1178 deltas.push(delta);
1179 linear_images.push(image);
1180 }
1181 None => continue,
1182 },
1183 }
1184 }
1185
1186 if slots.is_empty() {
1187 return Ok(None);
1188 }
1189
1190 let mut selection = select_hybrid_split(&slots)?;
1191
1192 // #1026 item-2 — GLOBAL CROSS-TERM EV-preservation guard over the SELECTED
1193 // collapses.
1194 //
1195 // The per-atom gate above bounds each atom's individual EV loss, but the TRUE
1196 // dictionary-level RSS increase from collapsing a SET of atoms is
1197 // ΔRSS = Σ_k Δ_k + 2 Σ_{j<k} ⟨δ_j, δ_k⟩,
1198 // so two effects escape any per-atom or summed-loss bound: (1) the accumulation
1199 // of many individually-tolerable `Δ_k`, and (2) the CROSS TERMS between
1200 // simultaneously-collapsed atoms. The old aggregate `Σ max(Δ_k, 0)` guard bounded
1201 // (1) but was blind to (2): correlated collapse errors whose per-atom losses each
1202 // sit under tolerance can still push the true global loss over it.
1203 //
1204 // Every slot now carries its exact dictionary-level reconstruction-change vector
1205 // δ_k (globally row-aligned — the caller presents all per-atom arrays on the same
1206 // n rows), so we reconstruct the full dictionary WITH the selected collapse set
1207 // applied and measure the real global increase `ΔRSS` DIRECTLY (cross terms
1208 // captured), reverting the least-justified collapses until the degradation is
1209 // within the same EV tolerance and re-adjudicating so `selection` stays consistent.
1210 if total_centered_variance.is_finite() && total_centered_variance > 0.0 {
1211 if let Some(r0) = r0.as_ref() {
1212 let global_tol = SAE_HYBRID_COLLAPSE_MAX_EV_LOSS * total_centered_variance;
1213 // Partition the SELECTED-collapsed slots: rollback-eligible ones (a curved
1214 // alternative still present and a finite per-atom loss) vs forced ones
1215 // (euclidean / collapse-rescue — no curved fallback, stay collapsed but
1216 // still enter the global reconstruction so their cross terms are counted).
1217 let mut eligible: Vec<(usize, f64, &Array2<f64>)> = Vec::new();
1218 let mut forced: Vec<&Array2<f64>> = Vec::new();
1219 for (i, choice) in selection.atoms.iter().enumerate() {
1220 if !choice.param.is_linear() {
1221 continue;
1222 }
1223 let has_curved_alt = slots[i].iter().any(|c| !c.param.is_linear());
1224 let loss_finite = collapse_loss[i].map(|l| l.is_finite()).unwrap_or(false);
1225 if has_curved_alt && loss_finite {
1226 eligible.push((i, choice.curved_evidence_margin, &deltas[i]));
1227 } else {
1228 forced.push(&deltas[i]);
1229 }
1230 }
1231 let reverted = global_collapse_rollback(r0.view(), &eligible, &forced, global_tol);
1232 if !reverted.is_empty() {
1233 for slot in reverted {
1234 if let Some(curved) =
1235 slots[slot].iter().find(|c| !c.param.is_linear()).copied()
1236 {
1237 slots[slot] = vec![curved];
1238 }
1239 }
1240 selection = select_hybrid_split(&slots)?;
1241 }
1242 }
1243 }
1244
1245 let verdicts: Vec<AtomHybridVerdict> = names
1246 .into_iter()
1247 .zip(selection.atoms.iter().copied())
1248 .zip(linear_images.into_iter())
1249 .zip(turnings.into_iter())
1250 .zip(delta_evs.into_iter())
1251 .map(
1252 |((((atom_name, choice), linear_image), fitted_turning), train_loao_delta_ev)| {
1253 let kept_curved = !choice.param.is_linear();
1254 AtomHybridVerdict {
1255 atom_name,
1256 choice,
1257 kept_curved,
1258 fitted_turning,
1259 train_loao_delta_ev,
1260 // Carry the straight sub-model only when the verdict collapses
1261 // this slot to linear — the curved slots keep their fitted image.
1262 linear_image: if kept_curved {
1263 None
1264 } else {
1265 Some(linear_image)
1266 },
1267 }
1268 },
1269 )
1270 .collect();
1271
1272 Ok(Some(SaeHybridSplitReport {
1273 verdicts,
1274 selection,
1275 }))
1276}
1277
1278#[cfg(test)]
1279mod tests {
1280 use super::*;
1281 use std::f64::consts::PI;
1282
1283 /// A straight RESPONSE residual (the atom's data is a line) is explained
1284 /// equally well by both candidates, so the cheaper linear special case wins.
1285 /// With `a_k = 1` the curved decoded image is straight too (Θ = 0), so both
1286 /// the dominance floor and the evidence argmin select linear. This is the
1287 /// common-data nested comparison (#1202): linear is the curved family's
1288 /// `Θ = 0` member, so it cannot lose when a line already explains the data.
1289 #[test]
1290 fn straight_residual_selects_linear() {
1291 let n = 40;
1292 let coords = Array1::from_iter((0..n).map(|i| -1.0 + 2.0 * (i as f64) / ((n - 1) as f64)));
1293 let assign = Array1::<f64>::ones(n);
1294 // The data the atom must explain is a straight line in ℝ²; the curved
1295 // decoded image equals that same line (a Θ = 0 curved fit).
1296 let mut data = Array2::<f64>::zeros((n, 2));
1297 let mut decoded = Array2::<f64>::zeros((n, 2));
1298 for i in 0..n {
1299 data[[i, 0]] = coords[i];
1300 data[[i, 1]] = 0.6 * coords[i];
1301 decoded[[i, 0]] = coords[i];
1302 decoded[[i, 1]] = 0.6 * coords[i];
1303 }
1304 let (linear, curved, _) = build_atom_candidates(
1305 coords.view(),
1306 assign.view(),
1307 decoded.view(),
1308 data.view(),
1309 // a generous curved parameter price (M·p)
1310 10,
1311 None,
1312 Some(0.0),
1313 )
1314 .expect("straight residual yields a candidate pair");
1315 let choice =
1316 gam_solve::evidence::select_hybrid_atom(&[linear, curved]).expect("non-empty slot");
1317 assert!(
1318 choice.param.is_linear(),
1319 "a straight response residual must keep the linear special case"
1320 );
1321 }
1322
1323 /// A turning RESPONSE residual (the atom's data traces a full circle) is fit
1324 /// well by the curved decoded image (curved_rss ≈ 0) but poorly by any
1325 /// straight line (large linear_rss), so the curved candidate wins the common
1326 /// evidence comparison once its data-fit gain exceeds its extra parameter
1327 /// price (#1202).
1328 #[test]
1329 fn turning_residual_selects_curved_on_evidence() {
1330 let n = 60;
1331 let coords = Array1::from_iter((0..n).map(|i| (i as f64) / ((n - 1) as f64)));
1332 let assign = Array1::<f64>::ones(n);
1333 // The data is a full circle; the curved decoded image is that same
1334 // circle (the curved atom reconstructs its assigned residual), so the
1335 // curved candidate has ≈ zero data-fit residual while a straight line
1336 // cannot follow the loop.
1337 let mut data = Array2::<f64>::zeros((n, 2));
1338 let mut decoded = Array2::<f64>::zeros((n, 2));
1339 for i in 0..n {
1340 let theta = 2.0 * PI * coords[i];
1341 data[[i, 0]] = theta.cos();
1342 data[[i, 1]] = theta.sin();
1343 decoded[[i, 0]] = theta.cos();
1344 decoded[[i, 1]] = theta.sin();
1345 }
1346 // The curved atom has 5 parameters (just above the 4 = 2·p linear budget);
1347 // the full-circle linear residual exceeds the extra-parameter overhead, so
1348 // curved wins on evidence.
1349 let (linear, curved, _) = build_atom_candidates(
1350 coords.view(),
1351 assign.view(),
1352 decoded.view(),
1353 data.view(),
1354 5,
1355 None,
1356 Some(2.0 * PI),
1357 )
1358 .expect("turning residual yields a candidate pair");
1359 assert!(
1360 linear.negative_log_evidence > curved.negative_log_evidence,
1361 "the line must misfit the circular residual worse than the curve does \
1362 (linear NLE {} should exceed curved NLE {})",
1363 linear.negative_log_evidence,
1364 curved.negative_log_evidence
1365 );
1366 let choice =
1367 gam_solve::evidence::select_hybrid_atom(&[linear, curved]).expect("non-empty slot");
1368 assert_eq!(
1369 choice.param,
1370 gam_solve::evidence::HybridAtomParam::Curved { latent_dim: 1 },
1371 "a full-circle response residual must keep the curved parameterization"
1372 );
1373 assert!(
1374 choice.curved_evidence_margin > 0.0,
1375 "curved must win a positive evidence margin over the linear secant"
1376 );
1377 }
1378
1379 /// The nested-dominance floor on common data (#1202): when the curved decoded
1380 /// image is a WORSE fit to the response residual than its own best straight
1381 /// projection, linear must win — the curved family cannot be charged extra
1382 /// parameters to fit the residual no better than its `Θ = 0` member. Here the
1383 /// data is a line but the curved image bends away from it, so curved_rss >
1384 /// linear_rss and the cheaper, better-fitting line is selected.
1385 #[test]
1386 fn linear_beats_curved_when_curve_misfits_residual() {
1387 let n = 50;
1388 let coords = Array1::from_iter((0..n).map(|i| (i as f64) / ((n - 1) as f64)));
1389 let assign = Array1::<f64>::ones(n);
1390 // Data is a straight line; the curved decoded image is a parabola that
1391 // departs from it, so a straight line fits the data strictly better.
1392 let mut data = Array2::<f64>::zeros((n, 2));
1393 let mut decoded = Array2::<f64>::zeros((n, 2));
1394 for i in 0..n {
1395 let t = coords[i];
1396 data[[i, 0]] = t;
1397 data[[i, 1]] = 0.5 * t;
1398 decoded[[i, 0]] = t;
1399 decoded[[i, 1]] = t * t; // bends away from the linear data
1400 }
1401 let (linear, curved, _) = build_atom_candidates(
1402 coords.view(),
1403 assign.view(),
1404 decoded.view(),
1405 data.view(),
1406 // a real curved Θ above the floor so the dominance floor does not fire
1407 6,
1408 None,
1409 Some(1.0),
1410 )
1411 .expect("candidate pair");
1412 let choice =
1413 gam_solve::evidence::select_hybrid_atom(&[linear, curved]).expect("non-empty slot");
1414 assert!(
1415 choice.param.is_linear(),
1416 "a curved image that fits the data worse than its own line must yield \
1417 to the linear special case on common-data evidence (#1202)"
1418 );
1419 }
1420
1421 /// The LINEAR candidate's Laplace logdet is the genuine weighted-design Gram
1422 /// determinant `p·(log w_sum + log s_tt)` with `w_sum = Σ a_k²`, `s_tt =
1423 /// Σ a_k²(t − t̄)²` — it INCLUDES the coordinate-spread term `log(s_tt)`
1424 /// (#1203). Verify both contributions are present by reading the logdet off a
1425 /// candidate whose linear residual is exactly zero (response residual = the
1426 /// fitted line), so `NLE_linear = ½·logdet`. Doubling the coordinate spread
1427 /// (at fixed assignment mass) scales `s_tt` by 4 → logdet += `p·log(4)`;
1428 /// doubling all assignment masses scales BOTH `w_sum` and `s_tt` by 4 (they
1429 /// are quadratic in `a_k`) → logdet += `2p·log(4)`.
1430 #[test]
1431 fn linear_logdet_includes_weighted_coordinate_spread() {
1432 let n = 40;
1433 let p = 2usize;
1434 // Read the logdet back off a candidate with zero linear residual: the
1435 // response residual is exactly `a_k·(line)`, so the WLS line recovers it
1436 // with RSS == 0 and `NLE_linear = ½·logdet`.
1437 let logdet = |coords: &Array1<f64>, assign: &Array1<f64>| -> f64 {
1438 // A straight image; the response residual is the same line scaled by
1439 // the per-row assignment mass `a_k`, so the prediction `a_k·(b₀+dt·b₁)`
1440 // matches it exactly and linear_rss == 0.
1441 let line = |t: f64| -> [f64; 2] { [t, 0.6 * t] };
1442 let mut decoded = Array2::<f64>::zeros((n, p));
1443 let mut data = Array2::<f64>::zeros((n, p));
1444 for i in 0..n {
1445 let l = line(coords[i]);
1446 decoded[[i, 0]] = l[0];
1447 decoded[[i, 1]] = l[1];
1448 data[[i, 0]] = assign[i] * l[0];
1449 data[[i, 1]] = assign[i] * l[1];
1450 }
1451 let (linear, _curved, _) = build_atom_candidates(
1452 coords.view(),
1453 assign.view(),
1454 decoded.view(),
1455 data.view(),
1456 10,
1457 None,
1458 Some(0.0),
1459 )
1460 .expect("straight residual yields a pair");
1461 2.0 * linear.negative_log_evidence // = logdet (linear_rss == 0)
1462 };
1463
1464 let base_coords =
1465 Array1::from_iter((0..n).map(|i| -1.0 + 2.0 * (i as f64) / ((n - 1) as f64)));
1466 let ones = Array1::<f64>::ones(n);
1467
1468 // Doubling the coordinate spread → s_tt ×4, w_sum fixed → logdet += p·log(4).
1469 let wide_coords = base_coords.mapv(|t| 2.0 * t);
1470 let d_spread = logdet(&wide_coords, &ones) - logdet(&base_coords, &ones);
1471 assert!(
1472 (d_spread - (p as f64) * 4.0_f64.ln()).abs() < 1e-9,
1473 "linear logdet must move by p·log(4) when coordinate spread doubles \
1474 (got {d_spread}); the spread term log(s_tt) must be present"
1475 );
1476
1477 // Doubling all assignment masses → w_sum ×4 AND s_tt ×4 (quadratic in a_k)
1478 // → logdet += 2p·log(4).
1479 let twos = Array1::<f64>::from_elem(n, 2.0);
1480 let d_weight = logdet(&base_coords, &twos) - logdet(&base_coords, &ones);
1481 assert!(
1482 (d_weight - 2.0 * (p as f64) * 4.0_f64.ln()).abs() < 1e-9,
1483 "linear logdet must move by 2p·log(4) when all assignment masses double \
1484 (got {d_weight})"
1485 );
1486 }
1487
1488 /// #1223 — the curved arm's Laplace complexity is the REAL weighted-design
1489 /// Gram log-determinant `p·log|ΦᵀWΦ|_+`, not a parameter-count proxy. Build a
1490 /// curved design whose columns are the constant and the centered coordinate
1491 /// (a 2-column basis), so `ΦᵀWΦ = diag(w_sum, s_tt)` exactly matches the
1492 /// linear arm's data Gram, and assert `curved_design_gram_logdet` returns
1493 /// `p·(log w_sum + log s_tt)` — the same determinant the linear arm reports
1494 /// on the same design weight. A proxy `M·log(w_sum)` would instead omit the
1495 /// `log(s_tt)` spread term, so this pins the genuine determinant.
1496 #[test]
1497 fn curved_gram_logdet_is_real_weighted_design_determinant() {
1498 let n = 40;
1499 let p = 3usize;
1500 let coords = Array1::from_iter((0..n).map(|i| -1.0 + 2.0 * (i as f64) / ((n - 1) as f64)));
1501 let assign = Array1::<f64>::from_iter((0..n).map(|i| 0.5 + 0.01 * (i as f64)));
1502
1503 // Mass-weighted coordinate mean and spread under wᵢ = a_k².
1504 let mut w_sum = 0.0;
1505 let mut t_bar = 0.0;
1506 for i in 0..n {
1507 let w = assign[i] * assign[i];
1508 w_sum += w;
1509 t_bar += w * coords[i];
1510 }
1511 t_bar /= w_sum;
1512 let mut s_tt = 0.0;
1513 for i in 0..n {
1514 let dt = coords[i] - t_bar;
1515 s_tt += assign[i] * assign[i] * dt * dt;
1516 }
1517
1518 // Curved design columns: [1, (t − t̄)]. Its weighted Gram is exactly
1519 // diag(w_sum, s_tt) (the cross term Σ w·(t−t̄) vanishes by construction),
1520 // so log|ΦᵀWΦ| = log(w_sum) + log(s_tt).
1521 let mut phi = Array2::<f64>::zeros((n, 2));
1522 for i in 0..n {
1523 phi[[i, 0]] = 1.0;
1524 phi[[i, 1]] = coords[i] - t_bar;
1525 }
1526 let got = curved_design_gram_logdet(phi.view(), assign.view(), p)
1527 .expect("non-degenerate curved design has a determinant");
1528 let want = (p as f64) * (w_sum.ln() + s_tt.ln());
1529 assert!(
1530 (got - want).abs() < 1e-9,
1531 "curved Gram logdet must be the real p·log|ΦᵀWΦ| = {want}, got {got}"
1532 );
1533
1534 // A rank-deficient design (a duplicated column) drops the null direction:
1535 // its determinant equals that of the single retained constant column,
1536 // p·log(w_sum), NOT a 2-column proxy.
1537 let mut phi_dup = Array2::<f64>::zeros((n, 2));
1538 for i in 0..n {
1539 phi_dup[[i, 0]] = 1.0;
1540 phi_dup[[i, 1]] = 1.0;
1541 }
1542 let got_dup = curved_design_gram_logdet(phi_dup.view(), assign.view(), p)
1543 .expect("rank-1 design still has a positive determinant");
1544 let want_dup = (p as f64) * (2.0 * w_sum).ln();
1545 assert!(
1546 (got_dup - want_dup).abs() < 1e-9,
1547 "rank-deficient curved Gram must report only its positive direction \
1548 (p·log(2·w_sum) = {want_dup}), got {got_dup}"
1549 );
1550 }
1551
1552 /// #1051 NESTED MIN — the curved arm re-fit on the residual match-or-beats the
1553 /// best straight line: `curved_refit_rss ≤ best_line_rss` up to solver tolerance
1554 /// on a basis whose span contains the straight lane (here `Φ = [1, t, t²]`). A
1555 /// genuinely curved (quadratic) signal is STRICTLY preferred by the curved arm,
1556 /// while an exactly-straight signal ties near zero (so the cheaper linear lane
1557 /// wins downstream). This is the property the realized-curve heuristic could not
1558 /// establish: comparing a possibly-collapsed realized curve against min-over-lines
1559 /// did NOT guarantee `curved ≤ linear`; re-fitting the curved decoder does.
1560 #[test]
1561 fn refit_curved_rss_matches_or_beats_best_line_nested() {
1562 let n = 40usize;
1563 let p = 2usize;
1564 let coords =
1565 Array1::from_iter((0..n).map(|i| -1.0 + 2.0 * (i as f64) / ((n - 1) as f64)));
1566 let assign = Array1::<f64>::ones(n);
1567 // Φ = [1, t, t²]: its column span contains the straight lane [1, t], so the
1568 // decoder-only curved refit is a proper superset of the line fit.
1569 let mut phi = Array2::<f64>::zeros((n, 3));
1570 for i in 0..n {
1571 phi[[i, 0]] = 1.0;
1572 phi[[i, 1]] = coords[i];
1573 phi[[i, 2]] = coords[i] * coords[i];
1574 }
1575 // Best mass-weighted line RSS on `y` (assign = 1): fit design [1, (t − t̄)].
1576 let t_bar = coords.iter().sum::<f64>() / n as f64;
1577 let best_line_rss = |y: &Array2<f64>| -> f64 {
1578 let mut design = Array2::<f64>::zeros((n, 2));
1579 for i in 0..n {
1580 design[[i, 0]] = 1.0;
1581 design[[i, 1]] = coords[i] - t_bar;
1582 }
1583 let b = solve_design_least_squares(design.view(), y.view()).unwrap();
1584 let pred = design.dot(&b);
1585 let mut rss = 0.0_f64;
1586 for i in 0..n {
1587 for j in 0..p {
1588 let r = y[[i, j]] - pred[[i, j]];
1589 rss += r * r;
1590 }
1591 }
1592 rss
1593 };
1594 // (a) exactly straight, (b) quadratic curve, (c) noisy line.
1595 let mut y_line = Array2::<f64>::zeros((n, p));
1596 let mut y_curve = Array2::<f64>::zeros((n, p));
1597 let mut y_noisy = Array2::<f64>::zeros((n, p));
1598 for i in 0..n {
1599 let t = coords[i];
1600 y_line[[i, 0]] = 0.4 + 0.6 * t;
1601 y_line[[i, 1]] = -0.2 + 1.1 * t;
1602 y_curve[[i, 0]] = t * t;
1603 y_curve[[i, 1]] = 0.5 - t * t;
1604 y_noisy[[i, 0]] = 0.3 + 0.7 * t + 0.05 * (3.0 * t).sin();
1605 y_noisy[[i, 1]] = 0.9 * t;
1606 }
1607 for y in [&y_line, &y_curve, &y_noisy] {
1608 let curved = curved_refit_rss(phi.view(), assign.view(), y.view())
1609 .expect("non-degenerate refit");
1610 let line = best_line_rss(y);
1611 assert!(
1612 curved <= line + 1e-9 * (1.0 + line),
1613 "nested dominance: refit-curved RSS {curved} must be ≤ best-line RSS {line}"
1614 );
1615 }
1616 // A genuinely curved (quadratic) signal is STRICTLY preferred by the curve.
1617 let curved_c = curved_refit_rss(phi.view(), assign.view(), y_curve.view()).unwrap();
1618 let line_c = best_line_rss(&y_curve);
1619 assert!(
1620 curved_c < 0.5 * line_c,
1621 "a quadratic signal must be far better fit by the curve ({curved_c}) than \
1622 by the best line ({line_c})"
1623 );
1624 // A straight signal ties near zero — collapses to the cheaper linear lane.
1625 let curved_l = curved_refit_rss(phi.view(), assign.view(), y_line.view()).unwrap();
1626 assert!(
1627 curved_l < 1e-18 && best_line_rss(&y_line) < 1e-18,
1628 "an exactly-straight signal ties the two arms near zero (curved {curved_l})"
1629 );
1630 }
1631
1632 /// #1026 item-2 GLOBAL CROSS-TERM guard: two collapses with CORRELATED
1633 /// (parallel) reconstruction-change errors whose per-atom losses each sit under
1634 /// tolerance — and whose SUM `Σ Δ_k` is also under tolerance (so the OLD
1635 /// aggregate `Σ max(Δ_k,0)` guard would ACCEPT) — but whose cross term
1636 /// `2⟨δ_1,δ_2⟩` pushes the TRUE global loss over tolerance. The global guard must
1637 /// roll back the least-justified collapse. An ORTHOGONAL control (disjoint
1638 /// support) with the same per-atom losses is accepted, isolating the cross term.
1639 #[test]
1640 fn global_guard_rejects_correlated_collapses_the_aggregate_would_accept() {
1641 let n = 4usize;
1642 let p = 1usize;
1643 // R0 = 0 ⇒ Δ_k = ‖δ_k‖² and ΔRSS(S) = ‖Σ_S δ_k‖² exactly.
1644 let r0 = Array2::<f64>::zeros((n, p));
1645 // Parallel δ_1 = δ_2 = 1 on all rows: ‖δ_k‖² = 4 each, Σ Δ_k = 8,
1646 // ΔRSS_global = ‖δ_1+δ_2‖² = 16.
1647 let mut d1 = Array2::<f64>::zeros((n, p));
1648 let mut d2 = Array2::<f64>::zeros((n, p));
1649 for i in 0..n {
1650 d1[[i, 0]] = 1.0;
1651 d2[[i, 0]] = 1.0;
1652 }
1653 // tol = 10: each per-atom loss 4 ≤ 10, Σ Δ_k = 8 ≤ 10 (aggregate accepts),
1654 // but global 16 > 10 (cross term rejects).
1655 let global_tol = 10.0_f64;
1656 let eligible = vec![(0usize, 0.1_f64, &d1), (1usize, 0.2_f64, &d2)];
1657 let forced: Vec<&Array2<f64>> = Vec::new();
1658 let reverted = global_collapse_rollback(r0.view(), &eligible, &forced, global_tol);
1659 assert_eq!(
1660 reverted,
1661 vec![1usize],
1662 "the global cross-term guard must roll back the least-justified collapse \
1663 (largest margin = slot 1) that the summed-loss aggregate (8 ≤ 10) accepts"
1664 );
1665
1666 // ORTHOGONAL control: same per-atom losses (δ on disjoint rows) ⇒ cross term
1667 // 0 ⇒ ΔRSS_global = 8 ≤ 10 ⇒ no rollback. Isolates the cross term as the cause.
1668 let mut o1 = Array2::<f64>::zeros((n, p));
1669 let mut o2 = Array2::<f64>::zeros((n, p));
1670 o1[[0, 0]] = 2.0; // ‖o1‖² = 4
1671 o2[[2, 0]] = 2.0; // ‖o2‖² = 4, disjoint support
1672 let eligible_o = vec![(0usize, 0.1_f64, &o1), (1usize, 0.2_f64, &o2)];
1673 let reverted_o = global_collapse_rollback(r0.view(), &eligible_o, &forced, global_tol);
1674 assert!(
1675 reverted_o.is_empty(),
1676 "uncorrelated collapses (cross term 0, global loss 8 ≤ 10) must be accepted"
1677 );
1678 }
1679
1680 /// A degenerate (single-point-mass) coordinate has no slope direction and is
1681 /// refused rather than adjudicated on a fabricated deviance.
1682 #[test]
1683 fn degenerate_coordinate_is_refused() {
1684 let n = 5;
1685 let coords = Array1::<f64>::from_elem(n, 0.5); // no spread
1686 let assign = Array1::<f64>::ones(n);
1687 let decoded = Array2::<f64>::zeros((n, 2));
1688 let data = Array2::<f64>::zeros((n, 2));
1689 assert!(
1690 build_atom_candidates(
1691 coords.view(),
1692 assign.view(),
1693 decoded.view(),
1694 data.view(),
1695 6,
1696 None,
1697 Some(0.0)
1698 )
1699 .is_none(),
1700 "a degenerate coordinate span must be refused"
1701 );
1702 }
1703}