gam_sae/hybrid_split.rs
1//! #1026 — load-bearing curved-vs-linear hybrid split for the fitted SAE
2//! dictionary.
3//!
4//! The selection machinery ([`gam_solve::evidence::select_hybrid_atom`],
5//! [`gam_solve::evidence::select_hybrid_split`]) and the per-atom
6//! integration helper
7//! ([`crate::assignment::select_hybrid_atom_parameterization`]) are
8//! correct and tested, but until now were called nowhere in the fitter: the
9//! post-fit pass only *logged* each `d = 1` atom's fitted turning `Θ`. This
10//! module makes the split LOAD-BEARING by building, per fitted `d = 1` atom, the
11//! two already-realized candidates and adjudicating them by the common
12//! evidence criterion.
13//!
14//! ## The common-evidence comparison on the data (#1202)
15//!
16//! Both candidates are scored against the SAME data: the portion of the
17//! response matrix the atom is responsible for reconstructing, namely its
18//! **leave-this-atom-out residual**
19//!
20//! y_resp[i] = target[i] − ( Σ_j a[i,j]·γ_j(t_{ij}) − a[i,k]·γ_k(t_{ik}) )
21//! = target[i] − without_k[i],
22//!
23//! the response with every OTHER atom's contribution subtracted. Over the rows
24//! assigned to atom `k` (assignment mass `a[i,k] = a_k`), the two candidates
25//! predict that residual:
26//!
27//! * the CURVED candidate is scored at its CONSTRAINED MINIMUM over the atom's
28//! decoder coefficients — re-fit on `y_resp` as
29//! `min_B ‖y_resp − a_k·(Φ(t)·B)‖²` ([`curved_refit_rss`]) — so its data-fit
30//! deviance is the smallest weighted RSS the curved family can attain on this
31//! residual at the realized codes, not the possibly-collapsed realized curve.
32//! * the LINEAR candidate predicts `a_k · (b₀ + (t − t̄)·b₁)`, the best
33//! weighted least-squares straight line fit to `y_resp` (design column
34//! scaled by the same assignment mass `a_k`), so its data-fit deviance is the
35//! weighted RSS of the best line against the SAME residual.
36//!
37//! ## A genuine NESTED min-vs-min comparison (#1051)
38//!
39//! Both arms are now at their constrained minimum on the SAME leave-one-atom-out
40//! residual: the linear arm is the closed-form min-over-lines, and the curved arm
41//! is the min-over-decoders refit ([`curved_refit_rss`]). The linear special case
42//! `a_k·(b₀ + (t−t̄)·b₁)` is a MEMBER of the curved family whenever the straight
43//! lane `[1, (t−t̄)]` lies in the column span of the curved basis `Φ` — exactly for
44//! the interval / line-segment charts (whose basis carries the constant and linear
45//! terms) and to the basis's expressiveness for the periodic charts. So after the
46//! refit `curved_rss ≤ linear_rss` up to the least-squares solver tolerance: the
47//! curved family CANNOT do worse than its own `Θ = 0` sub-model. That is the
48//! nested-dominance property restored here — "curved match-or-beats linear" is now
49//! a floor established by re-optimizing the curved arm, not merely asserted.
50//!
51//! The broken (#1051) euclidean / multi-atom OUTER continuation is deliberately
52//! NOT re-entered: the direct per-atom `d = 1` decoder-only refit at the realized
53//! codes is sufficient to score the curved arm at its constrained minimum. When
54//! `Φ` is unavailable (no evaluator) or the refit solve is degenerate the arm
55//! falls back to the already-realized curve's RSS — an honest degradation, never a
56//! fabricated determinant. The argmin then trades the curved arm's (minimized)
57//! data fit against its larger parameter / complexity price, so a genuinely
58//! curved signal is preferred while a straight-line signal ties and collapses to
59//! the cheaper linear lane.
60//!
61//! This replaces an earlier diagnostic in which both candidates targeted the
62//! atom's already-fitted decoded image `γ_k(t)` (giving the curved arm a free
63//! zero residual against itself, #1202) and its successor which scored the
64//! already-REALIZED curved contribution (a post-fit heuristic that did not
65//! establish nested dominance); the curved arm is now re-fit to its minimum.
66
67use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
68
69use gam_linalg::faer_ndarray::FaerEigh;
70use gam_solve::evidence::{
71 HybridAtomCandidate, HybridAtomChoice, HybridSplitSelection, select_hybrid_split,
72};
73use gam_terms::latent::LatentManifold;
74use crate::chart_canonicalization::d1_atom_fitted_turning;
75use crate::manifold::{SaeManifoldAtom, solve_design_least_squares};
76
77/// The rank-aware Laplace negative-log-evidence of a reduced per-atom Gaussian
78/// reconstruction sub-model: `residual_objective + ½ log|H|` with no smoothing
79/// penalty logdet and a full-rank design (no null space), which is the form
80/// [`gam_solve::evidence::laplace_evidence`] reduces to on this comparison.
81/// Kept inline (rather than routed through `EvidenceLogDetSource`) because both
82/// candidates' Hessian logdets are already the closed-form scalar moments of
83/// their shared design — no factor cache or HVP callback to assemble.
84///
85/// SCALE CAVEAT: this is a FIXED-DISPERSION Laplace / penalized criterion, not a
86/// profiled REML marginal likelihood. It assumes unit dispersion (`σ² = 1`)
87/// after preprocessing — the residual objective is the bare `½ RSS` with no
88/// `RSS/(2σ²)` rescaling and no profiled-over-σ² log-determinant correction. It
89/// is therefore SENSITIVE to the response scale: rescaling `y → s·y` scales the
90/// `residual_objective` (`RSS`) by `s²` but leaves the `½ log|H|` complexity term
91/// unchanged, so the curved-vs-linear trade-off it expresses is not
92/// scale-invariant. Callers must keep the targets on a consistent (preprocessed,
93/// roughly unit-scale) footing for the comparison to mean what it says.
94fn reduced_laplace_nle(residual_objective: f64, log_det_h: f64) -> f64 {
95 residual_objective + 0.5 * log_det_h
96}
97
98/// Rank-aware `log|ΦᵀWΦ|_+` of the curved atom's weighted design Gram over its
99/// `M` decoder basis columns, with per-row weight `wᵢ = a_k²` (the same
100/// assignment-mass design weight the linear arm uses), summed over the
101/// eigenvalues above a relative spectral floor (#1223). This is the genuine
102/// weighted-design determinant the linear arm already reports — `log|XᵀWX|` —
103/// assembled for the curved basis so the two arms' Laplace complexity prices are
104/// computed on the SAME footing instead of pricing the curved arm with a
105/// parameter-count proxy `M·log(Σw)`.
106///
107/// Mirrors the linear arm exactly in what it does NOT include: no smoothing-
108/// penalty `λS` normalizer (the linear arm's Gram is the bare data Gram
109/// `diag(w_sum, s_tt)` too), so the comparison stays symmetric. The Gram is the
110/// design's outer Gram over its basis columns; it is identical across the `p`
111/// output channels (every channel shares the design `Φ`), so the per-channel
112/// `log|G|_+` is multiplied by `p` — matching the linear arm's `p·(…)` form.
113///
114/// `phi` is the curved design `Φ(t)` evaluated on the atom's assigned rows
115/// (`n × M`); `assign` is the per-row assignment mass `a_k` (NOT squared).
116/// Returns `None` when `Φ` is missing rows, the Gram is non-finite, or it has no
117/// positive eigenvalues (a fully rank-deficient design carries no determinant);
118/// the caller then falls back to the parameter-count proxy rather than fabricate
119/// a determinant.
120fn curved_design_gram_logdet(
121 phi: ArrayView2<'_, f64>,
122 assign: ArrayView1<'_, f64>,
123 p: usize,
124) -> Option<f64> {
125 let n = phi.nrows();
126 let m = phi.ncols();
127 if m == 0 || assign.len() != n || n == 0 {
128 return None;
129 }
130 // G = Φᵀ diag(a²) Φ (M×M, symmetric PSD).
131 let mut gram = Array2::<f64>::zeros((m, m));
132 for i in 0..n {
133 let w = assign[i] * assign[i];
134 if !(w.is_finite() && w >= 0.0) {
135 return None;
136 }
137 if w == 0.0 {
138 continue;
139 }
140 let row = phi.row(i);
141 for a in 0..m {
142 let wa = w * row[a];
143 for b in a..m {
144 gram[[a, b]] += wa * row[b];
145 }
146 }
147 }
148 // Symmetrize the lower triangle (we only filled the upper).
149 for a in 0..m {
150 for b in 0..a {
151 gram[[a, b]] = gram[[b, a]];
152 }
153 }
154 if gram.iter().any(|v| !v.is_finite()) {
155 return None;
156 }
157 let (vals, _vecs) = gram.eigh(faer::Side::Lower).ok()?;
158 // Rank-aware log-determinant: sum log of eigenvalues above a relative floor
159 // tied to the largest eigenvalue, dropping the numerically-null directions
160 // (the curved design's null space, analogous to the linear arm's full-rank
161 // 2-D Gram). A design with no positive eigenvalue carries no determinant.
162 let lambda_max = vals.iter().cloned().fold(0.0_f64, f64::max);
163 if !(lambda_max > 0.0 && lambda_max.is_finite()) {
164 return None;
165 }
166 let floor = lambda_max * 1e-12;
167 let mut log_det = 0.0_f64;
168 let mut rank = 0usize;
169 for &lambda in vals.iter() {
170 if lambda > floor {
171 log_det += lambda.ln();
172 rank += 1;
173 }
174 }
175 if rank == 0 || !log_det.is_finite() {
176 return None;
177 }
178 // The design Gram is shared across the p output channels.
179 Some((p as f64) * log_det)
180}
181
182/// The fitted straight sub-model `γ̃(t) = b₀ + (t − t̄)·b₁` of one `d = 1` atom:
183/// the exact assignment-mass-weighted least-squares line fit to the atom's
184/// leave-this-atom-out RESPONSE residual `y_resp` over its assigned rows (the
185/// curved family's nested `Θ = 0` sub-model on common data, #1202). Carried on a
186/// verdict that selects LINEAR so the collapsed reconstruction can replace the
187/// curved decoded row with this straight image at any coordinate WITHOUT
188/// re-entering the (broken, #1051) outer fit — the coefficients are already
189/// realized inside the adjudication.
190#[derive(Clone, Debug)]
191pub struct AtomLinearImage {
192 /// The atom's slot index in the dictionary (so the collapsed assembly knows
193 /// which atom's decoded row to substitute).
194 pub atom_idx: usize,
195 /// The mass-weighted coordinate mean `t̄` the line is centered on.
196 pub t_bar: f64,
197 /// Per-output-channel centered intercept `b₀ = γ̄` at `t̄` (length `p`).
198 pub b0: Array1<f64>,
199 /// Per-output-channel slope `b₁` (length `p`).
200 pub b1: Array1<f64>,
201 /// #1026 collapse-rescue per-row coordinates. `None` for the ordinary path:
202 /// the line is evaluated at the atom's OWN realized coordinate `t`. `Some(u)`
203 /// only when the atom's circle codes had collapsed to a single point
204 /// (`s_tt ≈ 0`) so its own coordinate carries no spread — then the line is fit
205 /// against, and reconstruct evaluates it at, these FRESH per-row codes `uᵢ`
206 /// (the projection of the leave-this-atom-out residual onto its top
207 /// mass-weighted output direction). This is what lets a circle atom that the
208 /// joint fit drove into the degenerate "chord-through-the-arc" fixed point
209 /// still reconstruct its residual's best linear direction — recovering the
210 /// linear-tail reach the hybrid-split was designed to deliver — instead of a
211 /// constant (its collapsed curve), which is the real-OLMo rank-1 co-collapse
212 /// (held-out EV ≈ 0.13 vs the linear ceiling ≈ 0.74). Length `n` (one per
213 /// reconstructed row); unassigned rows are gated to zero by `a_k` anyway.
214 ///
215 /// TRAIN-ONLY CAVEAT (#1777): these are the TRAIN rows' codes. They are only
216 /// meaningful for the exact rows the split was fit on; a held-out row has no
217 /// entry here and used to fall back to the atom's own (collapsed) coordinate
218 /// `own_t` — a DIFFERENT, degraded model out of sample. Prefer [`Self::v`]:
219 /// projecting a held-out row's leave-this-atom-out residual onto `v` recovers
220 /// that row's coordinate by the SAME math the train codes were built with, so
221 /// train and OOS use one model. `row_codes` is retained for back-compat and as
222 /// the exact cached train projection.
223 pub row_codes: Option<Array1<f64>>,
224 /// #1777 collapse-rescue projection DIRECTION `v` (length `p`, unit norm), the
225 /// top mass-weighted output direction of the atom's leave-this-atom-out
226 /// residual. `Some` exactly when this is a collapse-rescued image (paired with
227 /// `row_codes`); `None` for the ordinary straight-image path (which decodes at
228 /// the atom's own coordinate). This is the SERIALIZABLE quantity the FFI must
229 /// persist so an OOS term can recompute any row's coordinate as
230 /// `uᵢ = ⟨y_i − Σ_{j≠k} f_j(x_i), v⟩` — identical to the train code
231 /// `row_codes[i]` on a train row, and the correct held-out coordinate on an OOS
232 /// row (see [`Self::coordinate_from_residual`]). Length must equal `b0`/`b1`.
233 pub v: Option<Array1<f64>>,
234}
235
236impl AtomLinearImage {
237 /// Evaluate the straight sub-model `b₀ + (t − t̄)·b₁` into `out` (length `p`).
238 pub fn fill_row(&self, t: f64, out: &mut [f64]) {
239 let dt = t - self.t_bar;
240 for (j, slot) in out.iter_mut().enumerate() {
241 *slot = self.b0[j] + dt * self.b1[j];
242 }
243 }
244
245 /// The coordinate at which row `row` should evaluate this image: the
246 /// collapse-rescue fresh code `uᵢ` when present (#1026), else the atom's own
247 /// realized coordinate `own_t` passed by the caller.
248 ///
249 /// TRAIN-ONLY: `row_codes` is indexed by TRAIN row, so this is correct only
250 /// for the rows the split was fit on. Out of sample use
251 /// [`Self::coordinate_from_residual`] (target-aware, model-identical to train).
252 pub fn coordinate_for_row(&self, row: usize, own_t: f64) -> f64 {
253 match &self.row_codes {
254 Some(u) if row < u.len() => u[row],
255 _ => own_t,
256 }
257 }
258
259 /// #1777 — the collapse-rescue coordinate of a row from ITS OWN
260 /// leave-this-atom-out residual `resid = y_i − Σ_{j≠k} f_j(x_i)` (length `p`),
261 /// namely `uᵢ = ⟨resid, v⟩`. `Some(uᵢ)` exactly when this is a collapse-rescued
262 /// image (`v` is set); `None` for the ordinary straight-image path (which has
263 /// no projection direction and decodes at the atom's own coordinate).
264 ///
265 /// This is the SAME math [`build_collapse_rescue_linear_image`] used to build
266 /// the train `row_codes` (`row_codes[i] = ⟨target_resid[i], v⟩`), so on a TRAIN
267 /// row it reproduces `row_codes[i]` exactly, and on a HELD-OUT row it yields
268 /// that row's correct coordinate — train and OOS share one model. Returns
269 /// `None` if `resid`'s length disagrees with `v`.
270 pub fn coordinate_from_residual(&self, resid: &[f64]) -> Option<f64> {
271 let v = self.v.as_ref()?;
272 if resid.len() != v.len() {
273 return None;
274 }
275 Some(v.iter().zip(resid).map(|(&vj, &rj)| vj * rj).sum())
276 }
277
278 /// Whether this image is a #1777 collapse-rescued image (carries a projection
279 /// direction `v` and per-row train codes) rather than an ordinary straight
280 /// image evaluated at the atom's own coordinate.
281 pub fn is_collapse_rescued(&self) -> bool {
282 self.v.is_some()
283 }
284}
285
286/// One fitted `d = 1` atom's hybrid-split verdict, surfaced in the model output.
287#[derive(Clone, Debug)]
288pub struct AtomHybridVerdict {
289 /// The atom's name (slot identity in the dictionary).
290 pub atom_name: String,
291 /// The evidence-selected parameterization choice for this slot.
292 pub choice: HybridAtomChoice,
293 /// `true` iff the slot kept the CURVED parameterization (the fitted atom);
294 /// `false` iff it yielded to the LINEAR special case (the straight tail).
295 pub kept_curved: bool,
296 /// The atom's fitted turning `Θ = ∫|κ| ds` (radians), the novel geometric
297 /// quantity #1026 pairs against reconstruction EV: `Θ ≈ 0` is a linear-tail
298 /// direction wearing a curved basis, `Θ ≈ 2π` is a full curved loop. `None`
299 /// iff the evaluator has no analytic second jet or the curve is degenerate.
300 /// Captured here (not just logged) so the EV-vs-Θ frontier is queryable
301 /// structured data on the persisted report rather than a transient log line.
302 pub fitted_turning: Option<f64>,
303 /// The atom's training leave-one-atom-out explained-variance contribution
304 /// `ΔEV_k = EV(full) − EV(full∖{k})` — how much reconstruction EV this single
305 /// atom earns. Paired with [`Self::fitted_turning`] this is the `(Θ, ΔEV)`
306 /// point the #1026 frontier reports: a `Θ ≈ 0` atom with large `ΔEV` is a
307 /// genuine linear-tail direction; a high-`Θ` atom with large `ΔEV` is a
308 /// genuine curved family. `None` iff the caller did not supply LOAO EV.
309 pub train_loao_delta_ev: Option<f64>,
310 /// The fitted straight sub-model for this slot, present iff the verdict
311 /// selected LINEAR (`kept_curved == false`). The collapsed reconstruction
312 /// substitutes this for the atom's curved decoded image, making the verdict
313 /// load-bearing on the reconstruction rather than a passive diagnostic.
314 pub linear_image: Option<AtomLinearImage>,
315}
316
317/// The whole dictionary's hybrid-split report: one verdict per eligible `d = 1`
318/// atom, plus the dictionary-level aggregates the EV-vs-Θ frontier reports
319/// against.
320#[derive(Clone, Debug)]
321pub struct SaeHybridSplitReport {
322 /// One adjudicated verdict per eligible `d = 1` atom, in slot order. Atoms
323 /// that are not eligible (wrong dim, no evaluator, mid-homotopy) are absent
324 /// — they carry no curved/linear adjudication.
325 pub verdicts: Vec<AtomHybridVerdict>,
326 /// The dictionary-level rolled-up selection (summed NLE, total parameters,
327 /// curved/linear counts) over the eligible atoms.
328 pub selection: HybridSplitSelection,
329}
330
331/// Below this many assigned rows a `d = 1` atom cannot support a two-parameter
332/// straight-line fit with a residual estimate, so the linear candidate's
333/// deviance is undefined. Such atoms are skipped (absent from the report),
334/// never adjudicated on a fabricated deviance.
335const MIN_ROWS_FOR_LINEAR_FIT: usize = 3;
336
337/// #1610/#1026 — EV-PRESERVATION gate tolerance: a `d = 1` slot may collapse to
338/// its linear tail only if doing so costs at most this fraction of the target's
339/// total (centered) variance in full-reconstruction explained variance. The
340/// evidence argmin trades data-fit against the curved arm's parameter price in
341/// `NLE` units, but on a small / low-amplitude fixture that trade can prefer the
342/// cheaper line even when the curve carries real reconstruction signal — the
343/// collapse then DROPS EV (the observed 1.0 → 0.748 over-collapse). This gate is
344/// a direct guard on the quantity that actually regressed: the per-atom collapse
345/// EV impact equals `(linear_rss − curved_rss)/SST_full` exactly (collapsing
346/// atom `k` raises the full reconstruction SSR by `linear_rss − curved_rss` and
347/// the full target variance `SST_full` is fixed), so a collapse is vetoed when it
348/// would lose more than this fraction. EV is a dimensionless quantity in `[0,1]`,
349/// so an absolute EV-loss tolerance is itself scale-invariant — it does not
350/// reintroduce the scale-incommensurability that sank the evidence-reformulation
351/// attempt. A genuinely straight atom (the curved fit IS its own line) loses
352/// `≈ 0` EV and still collapses losslessly; only a curve doing real
353/// reconstruction work (loss `≫ 1e-3`) is kept. `1e-3` = 0.1% of total variance:
354/// comfortably above f64 round-off on an exact-line collapse yet far below any
355/// material EV loss, so it separates the lossless and load-bearing regimes
356/// without tuning.
357///
358/// PER-ATOM SCOPE: applied per slot, this tolerance bounds only ONE atom's
359/// individual EV loss. It does NOT by itself bound the dictionary-level EV loss
360/// when several atoms collapse at once: the true global RSS change is
361/// `Σ_k Δ_k + 2 Σ_{j<k} <Δrecon_j, Δrecon_k>` — both the accumulation of many
362/// individually-tolerable `Δ_k` AND the cross terms between co-active atoms are
363/// invisible to the per-atom gate. [`build_hybrid_split_report`] adds an
364/// aggregate global guard (interpreting this same fraction as a bound on
365/// `Σ_k max(Δ_k, 0)`) on top of the per-atom gate; see there for what that guard
366/// does and does not prove.
367const SAE_HYBRID_COLLAPSE_MAX_EV_LOSS: f64 = 1.0e-3;
368
369/// The full-reconstruction SSR INCREASE from collapsing one `d = 1` atom to its
370/// fitted straight sub-model: `linear_rss − curved_rss`, where both arms are
371/// scored against the atom's leave-this-atom-out response residual `y_resp`
372/// (`target_resid`) exactly as [`build_atom_candidates`] scores them. Because the
373/// full reconstruction differs from the collapsed one ONLY in this atom's
374/// contribution (`a_k·γ_k` → `a_k·line`) on its assigned rows, this scalar is the
375/// exact amount the full reconstruction's SSR rises when the slot is collapsed;
376/// dividing by the fixed total target variance gives the full-EV loss the
377/// EV-preservation gate keys on. Positive ⇒ the curve out-fits its straight
378/// projection (collapsing hurts); `≤ 0` ⇒ the line is at least as good (collapse
379/// is lossless or improving, never gated). Mirrors the `curved_rss` / `linear_rss`
380/// accumulation in [`build_atom_candidates`] bit-for-bit so the gate and the
381/// evidence comparison see the same residuals.
382fn collapse_ssr_increase(
383 coords: ArrayView1<'_, f64>,
384 assign: ArrayView1<'_, f64>,
385 decoded: ArrayView2<'_, f64>,
386 target_resid: ArrayView2<'_, f64>,
387 t_bar: f64,
388 b0: &Array1<f64>,
389 b1: &Array1<f64>,
390) -> f64 {
391 let n = assign.len();
392 let p = target_resid.ncols();
393 let mut curved_rss = 0.0_f64;
394 let mut linear_rss = 0.0_f64;
395 for i in 0..n {
396 let a = assign[i];
397 let dt = coords[i] - t_bar;
398 for j in 0..p {
399 let y = target_resid[[i, j]];
400 let r_curved = y - a * decoded[[i, j]];
401 curved_rss += r_curved * r_curved;
402 let r_linear = y - a * (b0[j] + dt * b1[j]);
403 linear_rss += r_linear * r_linear;
404 }
405 }
406 linear_rss - curved_rss
407}
408
409/// #1051/#1026 NESTED MIN — re-fit the curved atom's decoder on the SAME
410/// leave-this-atom-out residual `y_resp` and return its **minimum** weighted
411/// reconstruction RSS at the atom's realized codes.
412///
413/// This is the genuine constrained minimum of the curved family over its free
414/// decoder coefficients `B`:
415///
416/// ```text
417/// min_B Σᵢ ‖ y_resp[i] − a_k·( Φ(t_i)·B ) ‖² (design = diag(a)·Φ, rhs = y_resp).
418/// ```
419///
420/// It restores the nested-dominance floor `curved_rss ≤ linear_rss`. The linear
421/// special case `a_k·(b₀ + (t−t̄)·b₁)` is itself a member of this family whenever
422/// the straight lane `[1, (t−t̄)]` lies in the column span of the curved basis
423/// `Φ` — exactly (interval / line-segment charts, whose basis carries the
424/// constant and linear terms) or to the basis's expressiveness (the periodic
425/// charts). So `min_B` over `Φ` cannot do WORSE than the best straight line: the
426/// returned RSS is `≤` the linear arm's RSS up to the least-squares solver
427/// tolerance. This is the direct per-atom `d = 1` refit the module owes; the
428/// broken (#1051) euclidean / multi-atom OUTER continuation is deliberately NOT
429/// re-entered — the decoder-only refit at the realized codes is sufficient to
430/// score the curved arm at its constrained minimum.
431///
432/// `phi` is the curved design `Φ(t)` on the atom's rows (`n × M`); `assign` the
433/// per-row mass `a_k` (NOT squared — the design weight is the mass itself, so the
434/// residual is on the SAME footing the linear arm and the joint loss use).
435/// Returns `None` when the solve is degenerate or non-finite; the caller then
436/// falls back to the already-realized curve's RSS rather than fabricate a value.
437fn curved_refit_rss(
438 phi: ArrayView2<'_, f64>,
439 assign: ArrayView1<'_, f64>,
440 target_resid: ArrayView2<'_, f64>,
441) -> Option<f64> {
442 let n = phi.nrows();
443 let m = phi.ncols();
444 let p = target_resid.ncols();
445 if m == 0 || n == 0 || assign.len() != n || target_resid.nrows() != n || p == 0 {
446 return None;
447 }
448 // Weighted design `diag(a)·Φ` (n×M). The refit minimizes ‖diag(a)·Φ·B − y_resp‖²,
449 // so the fitted prediction is diag(a)·Φ·B — the curved contribution `a_k·γ_k`
450 // at its best decoder `B` on this residual.
451 let mut design = Array2::<f64>::zeros((n, m));
452 for i in 0..n {
453 let a = assign[i];
454 if !a.is_finite() {
455 return None;
456 }
457 for c in 0..m {
458 design[[i, c]] = a * phi[[i, c]];
459 }
460 }
461 let b = solve_design_least_squares(design.view(), target_resid).ok()?;
462 if b.iter().any(|v| !v.is_finite()) {
463 return None;
464 }
465 let pred = design.dot(&b);
466 let mut rss = 0.0_f64;
467 for i in 0..n {
468 for j in 0..p {
469 let r = target_resid[[i, j]] - pred[[i, j]];
470 rss += r * r;
471 }
472 }
473 rss.is_finite().then_some(rss)
474}
475
476/// Build the curved + linear candidates for ONE fitted `d = 1` atom and return
477/// them as `(linear, curved, (t̄, b₀, b₁))`, or `None` if the atom cannot present
478/// an honest pair (too few rows, degenerate coordinate span, or non-finite
479/// numbers). Both candidates are scored against the SAME data — the atom's
480/// leave-this-atom-out response residual `y_resp` — at their CONSTRAINED MINIMUM:
481/// the linear arm is the freshly-fit min-over-lines and the curved arm is the
482/// min-over-decoders refit ([`curved_refit_rss`]), so the comparison is a genuine
483/// nested min-vs-min one and `curved_rss ≤ linear_rss` holds up to solver
484/// tolerance for a basis whose span contains the straight lane (#1051).
485///
486/// Inputs over the atom's assigned rows:
487/// * `coords` — the fitted on-atom coordinate `t`.
488/// * `assign` — the per-row assignment mass `a_k` (NOT squared; this routine
489/// squares it where the design weight `a_k²` is needed).
490/// * `decoded` — the atom's fitted decoded image `γ_k(t) = Φ(t) B_k` (`p` cols),
491/// whose mass-scaled value `a_k·γ_k` is the curved candidate's PREDICTION.
492/// * `target_resid` — the atom's leave-this-atom-out response residual `y_resp`
493/// (`p` cols): the response with every OTHER atom's contribution removed.
494/// This is the data both candidates fit.
495///
496/// The curved candidate's data-fit deviance is `½·min_B Σ ‖y_resp − a_k·(Φ·B)‖²`
497/// (its constrained minimum over the decoder; the mass lives in the prediction);
498/// the linear candidate fits the best mass-weighted straight line to `y_resp` and
499/// pays `½ Σ ‖y_resp − a_k·(b₀ + (t − t̄)·b₁)‖²`. Because the linear prediction is
500/// itself a curved-family member (the straight lane lies in `span(Φ)` for the
501/// eligible charts), the curved arm's minimized RSS is `≤` the linear arm's up to
502/// solver tolerance, so the argmin is a genuine nested min-vs-min dominance
503/// comparison, not a post-fit compression heuristic.
504fn build_atom_candidates(
505 coords: ArrayView1<'_, f64>,
506 assign: ArrayView1<'_, f64>,
507 decoded: ArrayView2<'_, f64>,
508 target_resid: ArrayView2<'_, f64>,
509 curved_num_params: usize,
510 curved_phi: Option<ArrayView2<'_, f64>>,
511 fitted_turning: Option<f64>,
512) -> Option<(
513 HybridAtomCandidate,
514 HybridAtomCandidate,
515 (f64, Array1<f64>, Array1<f64>),
516)> {
517 let n = coords.len();
518 let p = decoded.ncols();
519 if n < MIN_ROWS_FOR_LINEAR_FIT
520 || decoded.nrows() != n
521 || assign.len() != n
522 || target_resid.nrows() != n
523 || target_resid.ncols() != p
524 || p == 0
525 {
526 return None;
527 }
528
529 // The LINEAR candidate fits `a_k·(b₀ + (t − t̄)·b₁)` to the residual `y_resp`,
530 // so the natural design column is `a_k·[1, (t − t̄)]` and the per-row Gram
531 // weight is `wᵢ = a_k²`. We accumulate the mass-weighted coordinate mean `t̄`
532 // and spread `s_tt` under that weight; a row that barely belongs to the atom
533 // (`a_k ≈ 0`) contributes ≈ nothing, exactly as in the joint loss.
534 let mut w_sum = 0.0_f64;
535 let mut t_bar = 0.0_f64;
536 for i in 0..n {
537 let a = assign[i];
538 if !(a.is_finite() && a >= 0.0) {
539 return None;
540 }
541 let w = a * a;
542 w_sum += w;
543 t_bar += w * coords[i];
544 }
545 if !(w_sum > 0.0) {
546 return None;
547 }
548 t_bar /= w_sum;
549
550 // Weighted Σ wᵢ·(t − t̄)² with `wᵢ = a_k²` — the coordinate spread under the
551 // line's design weight. A degenerate (single-point mass) coordinate has no
552 // slope direction; refuse rather than divide by ~0.
553 let mut s_tt = 0.0_f64;
554 for i in 0..n {
555 let dt = coords[i] - t_bar;
556 s_tt += assign[i] * assign[i] * dt * dt;
557 }
558 if !(s_tt > 1e-12 * (1.0 + t_bar * t_bar)) {
559 return None;
560 }
561
562 // Per-output-channel mass-weighted least squares for the line fit to the
563 // RESIDUAL `y_resp`. Minimizing `Σᵢ ‖y_resp[i] − a_k·(b₀ + (t − t̄)·b₁)‖²` in
564 // the centered basis has the diagonal normal equations
565 // b₀[j] = (Σ a_k·y_resp[i,j]) / w_sum, (recall the design intercept is a_k)
566 // b₁[j] = (Σ a_k·(t − t̄)·y_resp[i,j]) / s_tt.
567 let mut b0 = Array1::<f64>::zeros(p);
568 let mut b1 = Array1::<f64>::zeros(p);
569 for j in 0..p {
570 let mut s_1y = 0.0_f64;
571 let mut s_ty = 0.0_f64;
572 for i in 0..n {
573 let a = assign[i];
574 let dt = coords[i] - t_bar;
575 let y = target_resid[[i, j]];
576 s_1y += a * y;
577 s_ty += a * dt * y;
578 }
579 b0[j] = s_1y / w_sum;
580 b1[j] = s_ty / s_tt;
581 }
582
583 // Data-fit residual sums of squares of BOTH candidates against `y_resp`, the
584 // common data. The linear candidate predicts the best line
585 // `a_k·(b₀ + (t − t̄)·b₁)`; the curved candidate is scored at its CONSTRAINED
586 // MINIMUM over the decoder coefficients (nested min-vs-min, #1051), re-fit on
587 // this same residual — not the possibly-collapsed already-realized curve. We
588 // also carry the realized curve's RSS as the honest fallback when the basis
589 // `Φ` is unavailable or its refit solve is degenerate.
590 let mut linear_rss = 0.0_f64;
591 let mut realized_curved_rss = 0.0_f64;
592 for i in 0..n {
593 let a = assign[i];
594 let dt = coords[i] - t_bar;
595 for j in 0..p {
596 let y = target_resid[[i, j]];
597 let r_linear = y - a * (b0[j] + dt * b1[j]);
598 linear_rss += r_linear * r_linear;
599 let r_curved = y - a * decoded[[i, j]];
600 realized_curved_rss += r_curved * r_curved;
601 }
602 }
603 // #1051 NESTED MIN — the curved arm's data fit is `min_B ‖y_resp − diag(a)Φ B‖²`,
604 // its constrained minimum over the decoder. Because the linear lane is a member
605 // of the curved family (the straight columns lie in `span(Φ)` for the eligible
606 // charts), this min-curved RSS is `≤ linear_rss` up to solver tolerance — the
607 // "curved match-or-beats linear" floor. Fall back to the realized curve's RSS
608 // only when Φ is absent or the refit is degenerate.
609 let curved_rss = match curved_phi {
610 Some(phi) if phi.nrows() == n => {
611 curved_refit_rss(phi, assign, target_resid).unwrap_or(realized_curved_rss)
612 }
613 _ => realized_curved_rss,
614 };
615
616 // Gaussian-reconstruction deviance: the residual objective `½ RSS` the
617 // Laplace normalizer is added to. The curved arm pays `½·curved_rss` (how
618 // well its REALIZED curve explains the residual) plus its larger `M·p`
619 // parameter price; the linear arm pays `½·linear_rss` plus a `2·p` price.
620 // `curved_rss` is the realized (not re-optimized) curve's misfit, so it is NOT
621 // guaranteed `≤ linear_rss`: when the realized curve underperforms its own best
622 // straight projection the cheaper line simply wins. The argmin trades whatever
623 // data-fit the realized curve buys against the curvature parameter price — a
624 // post-fit compression decision, not a nested match-or-beat floor.
625 let curved_residual_objective = 0.5 * curved_rss;
626 let linear_residual_objective = 0.5 * linear_rss;
627
628 // Linear candidate parameter price: intercept + slope per output channel.
629 let linear_num_params = 2 * p;
630
631 // Laplace logdet of the (weighted) design Gram for the LINEAR candidate.
632 //
633 // For the centered weighted line fit `a_k·(b₀ + (t − t̄)·b₁)`, the per-output-
634 // channel design column is `a_k·[1, (t − t̄)]`, whose Gram is DIAGONAL in the
635 // centered basis: `diag(Σ a_k², Σ a_k²(t − t̄)²) = diag(w_sum, s_tt)`. Its log
636 // determinant is `log(w_sum) + log(s_tt)` PER output channel, i.e.
637 //
638 // log|H_linear| = p · ( log(w_sum) + log(s_tt) ).
639 //
640 // The `log(s_tt)` term is the slope direction's information: a line through a
641 // wide, heavily-massed coordinate spread is better-determined than one through
642 // a tiny spread, and the Laplace evidence must reflect that (#1203).
643 //
644 // The curved arm's Laplace determinant is now the genuine weighted-design
645 // Gram log-determinant `p · log|ΦᵀWΦ|_+` (#1223): the SAME quantity the
646 // linear arm reports (`p·(log w_sum + log s_tt) = p·log|XᵀWX|`), assembled
647 // from the curved basis `Φ` on the atom's assigned rows under the same
648 // assignment-mass design weight `wᵢ = a_k²`. Both arms omit the smoothing
649 // `λS` normalizer, so the complexity price is computed on a symmetric
650 // footing — no parameter-count proxy. Only when `Φ` is unavailable (the
651 // caller could not evaluate the basis) or its Gram is fully rank-deficient do
652 // we fall back to the historical `curved_num_params · log(w_sum)` proxy, so
653 // the comparison degrades gracefully rather than fabricating a determinant.
654 if !(w_sum > 0.0 && w_sum.is_finite() && s_tt.is_finite()) {
655 return None;
656 }
657 let linear_log_det_h = (p as f64) * (w_sum.ln() + s_tt.ln());
658 let curved_log_det_h = curved_phi
659 .and_then(|phi| {
660 if phi.nrows() == n {
661 curved_design_gram_logdet(phi, assign, p)
662 } else {
663 None
664 }
665 })
666 .unwrap_or_else(|| (curved_num_params as f64) * w_sum.ln());
667
668 // Reduced Laplace NLE `residual_objective + ½ log|H|`. Both omit an explicit
669 // smoothing-penalty logdet (the intrinsic smoothness penalty is
670 // reparameterization-invariant and identical in expectation across the two
671 // parameterizations of the same image).
672 let linear_nle = reduced_laplace_nle(linear_residual_objective, linear_log_det_h);
673 let curved_nle = reduced_laplace_nle(curved_residual_objective, curved_log_det_h);
674 if !(linear_nle.is_finite() && curved_nle.is_finite()) {
675 return None;
676 }
677
678 let linear = HybridAtomCandidate::linear(linear_nle, linear_num_params);
679 let curved = HybridAtomCandidate::curved(1, curved_nle, curved_num_params, fitted_turning);
680 Some((linear, curved, (t_bar, b0, b1)))
681}
682
683/// #1026 collapse rescue. When a `d = 1` atom's own coordinate has collapsed to a
684/// single point (`build_atom_candidates` refuses because `s_tt ≈ 0`), the atom is
685/// stuck in the degenerate "chord-through-the-arc" fixed point and its curved
686/// decode is a constant — the rank-1 dictionary co-collapse (real-OLMo held-out EV
687/// ≈ 0.13 vs the rank-K linear ceiling ≈ 0.74). The hybrid-split was DESIGNED to
688/// let such a linear-tail atom decode as a straight line; the only reason it can't
689/// here is that its own codes carry no spread to fit a slope against.
690///
691/// Recover FRESH per-row codes from the data instead: `uᵢ = yᵢ·v`, the projection
692/// of the leave-this-atom-out residual onto its top mass-weighted output direction
693/// `v` (the rank-1 of `Σᵢ wᵢ yᵢyᵢᵀ`, `wᵢ = a_k²` — the SAME design weight the line
694/// fit uses). These codes span the residual's strongest linear axis by
695/// construction, so the straight image `b₀ + (uᵢ − ū)·b₁` fit against them
696/// reconstructs that axis at LINEAR quality — exactly the linear-tail reach the
697/// split owes. Returns the forced-LINEAR candidate plus the image carrying `uᵢ`,
698/// or `None` when the residual itself carries no usable direction (a genuine zero
699/// atom the mass/decoder guards own).
700fn build_collapse_rescue_linear_image(
701 atom_idx: usize,
702 assign: ArrayView1<'_, f64>,
703 target_resid: ArrayView2<'_, f64>,
704) -> Option<(HybridAtomCandidate, AtomLinearImage)> {
705 let n = assign.len();
706 let p = target_resid.ncols();
707 if n < MIN_ROWS_FOR_LINEAR_FIT || target_resid.nrows() != n || p == 0 {
708 return None;
709 }
710 let mut w_sum = 0.0_f64;
711 for i in 0..n {
712 let a = assign[i];
713 if !(a.is_finite() && a >= 0.0) {
714 return None;
715 }
716 w_sum += a * a;
717 }
718 if !(w_sum > 0.0) {
719 return None;
720 }
721 // Top mass-weighted output direction `v` of the residual via power iteration on
722 // `M = Σᵢ wᵢ yᵢyᵢᵀ` (p×p, never materialized): `v ← normalize(Σᵢ wᵢ yᵢ (yᵢ·v))`.
723 // Seed from the per-channel weighted energy so a rank-1 residual converges in
724 // one step and the seed is deterministic (no RNG).
725 let mut v = Array1::<f64>::zeros(p);
726 for j in 0..p {
727 let mut e = 0.0_f64;
728 for i in 0..n {
729 let a = assign[i];
730 let y = target_resid[[i, j]];
731 e += a * a * y * y;
732 }
733 v[j] = e;
734 }
735 let mut vnorm = v.dot(&v).sqrt();
736 if !(vnorm > 0.0) {
737 return None;
738 }
739 v.mapv_inplace(|x| x / vnorm);
740 for _ in 0..32 {
741 let mut mv = Array1::<f64>::zeros(p);
742 for i in 0..n {
743 let a = assign[i];
744 let w = a * a;
745 let mut proj = 0.0_f64;
746 for j in 0..p {
747 proj += target_resid[[i, j]] * v[j];
748 }
749 let wp = w * proj;
750 for j in 0..p {
751 mv[j] += wp * target_resid[[i, j]];
752 }
753 }
754 vnorm = mv.dot(&mv).sqrt();
755 if !(vnorm > 0.0) {
756 return None;
757 }
758 mv.mapv_inplace(|x| x / vnorm);
759 let cos = mv.dot(&v).abs();
760 v = mv;
761 if cos > 1.0 - 1e-12 {
762 break;
763 }
764 }
765 // Fresh per-row codes `uᵢ = yᵢ·v` and the weighted line fit against them.
766 let mut u = Array1::<f64>::zeros(n);
767 let mut t_bar = 0.0_f64;
768 for i in 0..n {
769 let mut proj = 0.0_f64;
770 for j in 0..p {
771 proj += target_resid[[i, j]] * v[j];
772 }
773 u[i] = proj;
774 t_bar += assign[i] * assign[i] * proj;
775 }
776 t_bar /= w_sum;
777 let mut s_tt = 0.0_f64;
778 for i in 0..n {
779 let dt = u[i] - t_bar;
780 s_tt += assign[i] * assign[i] * dt * dt;
781 }
782 if !(s_tt > 1e-12 * (1.0 + t_bar * t_bar)) {
783 return None;
784 }
785 let mut b0 = Array1::<f64>::zeros(p);
786 let mut b1 = Array1::<f64>::zeros(p);
787 let mut linear_rss = 0.0_f64;
788 for j in 0..p {
789 let mut s_1y = 0.0_f64;
790 let mut s_ty = 0.0_f64;
791 for i in 0..n {
792 let a = assign[i];
793 let dt = u[i] - t_bar;
794 let y = target_resid[[i, j]];
795 s_1y += a * y;
796 s_ty += a * dt * y;
797 }
798 b0[j] = s_1y / w_sum;
799 b1[j] = s_ty / s_tt;
800 }
801 for i in 0..n {
802 let a = assign[i];
803 let dt = u[i] - t_bar;
804 for j in 0..p {
805 let r = target_resid[[i, j]] - a * (b0[j] + dt * b1[j]);
806 linear_rss += r * r;
807 }
808 }
809 let linear_log_det_h = (p as f64) * (w_sum.ln() + s_tt.ln());
810 let linear_nle = reduced_laplace_nle(0.5 * linear_rss, linear_log_det_h);
811 if !linear_nle.is_finite() {
812 return None;
813 }
814 let linear = HybridAtomCandidate::linear(linear_nle, 2 * p);
815 let image = AtomLinearImage {
816 atom_idx,
817 t_bar,
818 b0,
819 b1,
820 row_codes: Some(u),
821 // #1777 — persist the projection direction so an OOS row's coordinate can
822 // be recomputed as ⟨residual, v⟩ (identical to the train `row_codes`),
823 // rather than falling back to the atom's collapsed own coordinate.
824 v: Some(v),
825 };
826 Some((linear, image))
827}
828
829/// #1026 item-2 — one collapsed slot's TRUE dictionary-level reconstruction
830/// change `δ_k[i,j] = a_k·(γ_k(t_i) − line_k(t_i))` over ALL globally-aligned rows
831/// (the caller presents `coords`/`assign`/`decoded`/`image` on the same `n` rows,
832/// so these δ vectors ARE cross-atom aligned and their inner products are the
833/// genuine cross terms). `line_k` is evaluated at the image's own coordinate
834/// (collapse-rescue slots evaluate at their fresh per-row codes via
835/// [`AtomLinearImage::coordinate_for_row`], exactly as the collapsed reconstruction
836/// does). Collapsing atom `k` shifts the full reconstruction residual by `+δ_k`.
837fn slot_delta(
838 coords: &Array1<f64>,
839 assign: &Array1<f64>,
840 decoded: &Array2<f64>,
841 image: &AtomLinearImage,
842) -> Array2<f64> {
843 let n = decoded.nrows();
844 let p = decoded.ncols();
845 let mut d = Array2::<f64>::zeros((n, p));
846 for i in 0..n {
847 let a = assign[i];
848 let coord = image.coordinate_for_row(i, coords[i]);
849 let dt = coord - image.t_bar;
850 for j in 0..p {
851 let line = image.b0[j] + dt * image.b1[j];
852 d[[i, j]] = a * (decoded[[i, j]] - line);
853 }
854 }
855 d
856}
857
858/// #1026 item-2 — the ALL-CURVED global reconstruction residual
859/// `R0 = target − Σ_all a·γ` recovered from any single atom's leave-this-atom-out
860/// residual: `R0 = y_resp_k − a_k·γ_k = target_resid − a_k·decoded`. Identical for
861/// every atom (each `target_resid` adds back exactly that atom's own contribution),
862/// so the caller computes it once from the first slot.
863fn slot_r0(assign: &Array1<f64>, decoded: &Array2<f64>, target_resid: &Array2<f64>) -> Array2<f64> {
864 let n = decoded.nrows();
865 let p = decoded.ncols();
866 let mut r0 = Array2::<f64>::zeros((n, p));
867 for i in 0..n {
868 let a = assign[i];
869 for j in 0..p {
870 r0[[i, j]] = target_resid[[i, j]] - a * decoded[[i, j]];
871 }
872 }
873 r0
874}
875
876/// #1026 item-2 — GLOBAL cross-term collapse guard. Given the collapsed slots'
877/// dictionary-level reconstruction-change vectors `δ_k` (globally row-aligned,
878/// `n × p`) and the all-curved global residual `R0 = target − Σ_all a·γ`, decide
879/// which collapses must revert to curved so the TRUE global reconstruction SSR
880/// increase
881///
882/// ```text
883/// ΔRSS(S) = ‖R0 + Σ_{k∈S} δ_k‖² − ‖R0‖²
884/// = 2⟨R0, Σ_{k∈S} δ_k⟩ + ‖Σ_{k∈S} δ_k‖²
885/// = Σ_{k∈S} Δ_k + 2 Σ_{j<k∈S} ⟨δ_j, δ_k⟩
886/// ```
887///
888/// stays within `global_tol`. Unlike the per-atom / summed-loss guard this
889/// INCLUDES the cross terms `2 Σ_{j<k} ⟨δ_j, δ_k⟩` between simultaneously-collapsed
890/// atoms (correlated collapse errors), which the aggregate `Σ max(Δ_k, 0)` bound is
891/// blind to. `forced` are the always-collapsed δ (euclidean / collapse-rescue slots
892/// with no curved alternative): they stay in the reconstruction but cannot be
893/// rolled back. `eligible` are `(slot, curved_evidence_margin, δ)` for
894/// rollback-eligible collapses. When `ΔRSS` over ALL collapses exceeds tolerance,
895/// the least-justified eligible collapses (largest margin — the most marginal
896/// linear win) are reverted one at a time, RECOMPUTING the true global increase
897/// (cross terms and all) after each revert, until within tolerance or none remain.
898/// Returns the slot indices to revert to curved.
899fn global_collapse_rollback(
900 r0: ArrayView2<'_, f64>,
901 eligible: &[(usize, f64, &Array2<f64>)],
902 forced: &[&Array2<f64>],
903 global_tol: f64,
904) -> Vec<usize> {
905 let (n, p) = r0.dim();
906 // True global SSR increase for the active δ set: 2⟨R0, Σδ⟩ + ‖Σδ‖².
907 let increase = |active: &[&Array2<f64>]| -> f64 {
908 let mut cross = 0.0_f64;
909 let mut self_sq = 0.0_f64;
910 for i in 0..n {
911 for j in 0..p {
912 let mut s = 0.0_f64;
913 for d in active {
914 s += d[[i, j]];
915 }
916 cross += r0[[i, j]] * s;
917 self_sq += s * s;
918 }
919 }
920 2.0 * cross + self_sq
921 };
922 let mut kept = vec![true; eligible.len()];
923 let build_active = |kept: &[bool]| -> Vec<&Array2<f64>> {
924 let mut v: Vec<&Array2<f64>> = forced.to_vec();
925 for (idx, &(_, _, d)) in eligible.iter().enumerate() {
926 if kept[idx] {
927 v.push(d);
928 }
929 }
930 v
931 };
932 if increase(&build_active(&kept)) <= global_tol {
933 return Vec::new();
934 }
935 // Revert least-justified first: largest curved_evidence_margin (the most
936 // marginal linear win is the cheapest to give back to curved).
937 let mut order: Vec<usize> = (0..eligible.len()).collect();
938 order.sort_by(|&a, &b| {
939 eligible[b]
940 .1
941 .partial_cmp(&eligible[a].1)
942 .unwrap_or(std::cmp::Ordering::Equal)
943 });
944 let mut reverted = Vec::new();
945 for idx in order {
946 kept[idx] = false;
947 reverted.push(eligible[idx].0);
948 if increase(&build_active(&kept)) <= global_tol {
949 break;
950 }
951 }
952 reverted
953}
954
955/// Assemble the per-atom candidate slots for [`select_hybrid_split`] from the
956/// fitted `d = 1` atoms, run the adjudication, and return the report.
957///
958/// `atoms` are the fitted dictionary atoms; `coords_for` yields the on-atom
959/// coordinate column for a slot, `assign_for` the per-row assignment mass `a_k`,
960/// `decoded_for` the fitted decoded image rows `γ_k`, and `target_resid_for` the
961/// atom's leave-this-atom-out response residual `y_resp` (the data both
962/// candidates are scored against, #1202). `manifold_for` yields the atom's chart
963/// manifold (a flat / Euclidean chart can present only the linear candidate,
964/// enforced inside the selector).
965///
966/// Returns `None` (no report) when no atom is eligible — there is nothing to
967/// adjudicate.
968pub fn build_hybrid_split_report<'a, C, W, D, R, M, E>(
969 atoms: &'a [SaeManifoldAtom],
970 eligible_d1: impl Iterator<Item = usize>,
971 mut coords_for: C,
972 mut assign_for: W,
973 mut decoded_for: D,
974 mut target_resid_for: R,
975 mut manifold_for: M,
976 mut delta_ev_for: E,
977 // #1026 — the full target's total (column-centered) variance `SST_full`, the
978 // fixed denominator of the EV-preservation gate. `≤ 0` / non-finite disables
979 // the gate (a degenerate, varianceless target has no EV to preserve).
980 total_centered_variance: f64,
981) -> Result<Option<SaeHybridSplitReport>, String>
982where
983 C: FnMut(usize) -> Array1<f64>,
984 W: FnMut(usize) -> Array1<f64>,
985 D: FnMut(usize) -> Array2<f64>,
986 R: FnMut(usize) -> Array2<f64>,
987 M: FnMut(usize) -> LatentManifold,
988 // The atom's held-out LOAO `ΔEV_k`, keyed by atom index. `None` when LOAO EV
989 // is unavailable (e.g. the caller has no target to measure against).
990 E: FnMut(usize) -> Option<f64>,
991{
992 let mut slots: Vec<Vec<HybridAtomCandidate>> = Vec::new();
993 let mut names: Vec<String> = Vec::new();
994 let mut manifolds: Vec<LatentManifold> = Vec::new();
995 // Per-slot fitted straight sub-model `(atom_idx, t̄, b₀, b₁)`, surfaced onto
996 // the verdict iff the slot selects LINEAR so the collapsed reconstruction can
997 // substitute it for the curved decoded image.
998 let mut linear_images: Vec<AtomLinearImage> = Vec::new();
999 // Per-slot `(Θ, ΔEV)` — the #1026 frontier point — carried onto each verdict
1000 // so the geometry/EV pairing is structured report data, not a log line.
1001 let mut turnings: Vec<Option<f64>> = Vec::new();
1002 let mut delta_evs: Vec<Option<f64>> = Vec::new();
1003 // #1026 item-2 — per-slot collapse loss `Δ_k = linear_rss − curved_rss` for the
1004 // GLOBAL EV-preservation guard below. `Some(Δ_k)` for a curveable slot that
1005 // retained a curved alternative (so a chosen collapse there can be rolled back);
1006 // `None` for euclidean slots (no curved option) and collapse-rescue slots (the
1007 // curve was already degenerate — collapsing recovers EV rather than losing it).
1008 let mut collapse_loss: Vec<Option<f64>> = Vec::new();
1009 // #1026 item-2 — per-slot dictionary-level reconstruction-change vector δ_k
1010 // (globally row-aligned n×p), and the all-curved global residual R0 (computed
1011 // once from the first slot). These feed the GLOBAL cross-term collapse guard,
1012 // which reconstructs the full dictionary with the selected collapses applied
1013 // and measures the TRUE global EV degradation (cross terms and all).
1014 let mut deltas: Vec<Array2<f64>> = Vec::new();
1015 let mut r0: Option<Array2<f64>> = None;
1016
1017 for atom_idx in eligible_d1 {
1018 let atom = &atoms[atom_idx];
1019 let coords = coords_for(atom_idx);
1020 let assign = assign_for(atom_idx);
1021 let decoded = decoded_for(atom_idx);
1022 let target_resid = target_resid_for(atom_idx);
1023 // Curved parameter price = the decoder's `M · p` coefficients.
1024 let curved_num_params = atom.decoder_coefficients.len();
1025 let fitted_turning = atom.basis_evaluator.as_ref().and_then(|evaluator| {
1026 d1_atom_fitted_turning(
1027 evaluator.as_ref(),
1028 atom.decoder_coefficients.view(),
1029 coords.view(),
1030 )
1031 .ok()
1032 .flatten()
1033 });
1034 // Evaluate the curved design `Φ(t)` on this atom's assigned rows so the
1035 // curved arm's Laplace complexity is the real weighted-design Gram
1036 // log-determinant rather than a parameter-count proxy (#1223). A `d = 1`
1037 // atom's coordinate column is presented as an `n × 1` design input. If
1038 // the evaluator is absent or refuses, `curved_phi` stays `None` and
1039 // `build_atom_candidates` falls back to the proxy.
1040 let coords_col = coords
1041 .view()
1042 .into_shape_with_order((coords.len(), 1))
1043 .ok()
1044 .map(|v| v.to_owned());
1045 let curved_phi = match (atom.basis_evaluator.as_ref(), coords_col.as_ref()) {
1046 (Some(evaluator), Some(col)) => {
1047 evaluator.evaluate(col.view()).ok().map(|(phi, _jet)| phi)
1048 }
1049 _ => None,
1050 };
1051 // A flat (Euclidean) chart cannot honestly present a curved candidate;
1052 // the selector drops it. Present both for curveable charts.
1053 let manifold = manifold_for(atom_idx);
1054 match build_atom_candidates(
1055 coords.view(),
1056 assign.view(),
1057 decoded.view(),
1058 target_resid.view(),
1059 curved_num_params,
1060 curved_phi.as_ref().map(|phi| phi.view()),
1061 fitted_turning,
1062 ) {
1063 Some((linear, curved, (t_bar, b0, b1))) => {
1064 // #1026 PER-ATOM EV-PRESERVATION gate. Collapsing this slot raises
1065 // the full reconstruction SSR by `linear_rss − curved_rss`; if that
1066 // is more than `SAE_HYBRID_COLLAPSE_MAX_EV_LOSS` of the fixed total
1067 // target variance the collapse would DROP this ONE atom's EV
1068 // materially, so veto it by presenting only the curved candidate (the
1069 // selector must keep curved). A lossless / improving collapse (`≤ 0`)
1070 // and a negligible one stay free to collapse — EV-neutral cases (the
1071 // top-k / birth-topology lines) are untouched. Only curveable charts
1072 // are gated; a euclidean chart never had a curved option. NOTE: this
1073 // gate is PER-ATOM only — the accumulation of many small collapses
1074 // and the dictionary-level cross terms are handled by the aggregate
1075 // guard after selection.
1076 let loss = collapse_ssr_increase(
1077 coords.view(),
1078 assign.view(),
1079 decoded.view(),
1080 target_resid.view(),
1081 t_bar,
1082 &b0,
1083 &b1,
1084 );
1085 let collapse_loses_ev = total_centered_variance.is_finite()
1086 && total_centered_variance > 0.0
1087 && loss > SAE_HYBRID_COLLAPSE_MAX_EV_LOSS * total_centered_variance;
1088 let euclidean = manifold.is_euclidean();
1089 let slot = if euclidean {
1090 vec![linear]
1091 } else if collapse_loses_ev {
1092 vec![curved]
1093 } else {
1094 vec![linear, curved]
1095 };
1096 // Build the straight image, then its globally-aligned δ_k for the
1097 // GLOBAL cross-term guard, before moving it into the report.
1098 let image = AtomLinearImage {
1099 atom_idx,
1100 t_bar,
1101 b0,
1102 b1,
1103 row_codes: None,
1104 // Ordinary straight image: decoded at the atom's own
1105 // coordinate, so it carries no residual-projection direction.
1106 v: None,
1107 };
1108 let delta = slot_delta(&coords, &assign, &decoded, &image);
1109 if r0.is_none() {
1110 r0 = Some(slot_r0(&assign, &decoded, &target_resid));
1111 }
1112 slots.push(slot);
1113 // A euclidean slot never had a curved alternative, so its collapse
1114 // carries no recoverable EV loss for the global guard; record `None`.
1115 collapse_loss.push(if euclidean { None } else { Some(loss) });
1116 names.push(atom.name.clone());
1117 manifolds.push(manifold);
1118 turnings.push(fitted_turning);
1119 delta_evs.push(delta_ev_for(atom_idx));
1120 deltas.push(delta);
1121 linear_images.push(image);
1122 }
1123 // #1026 collapse rescue: `build_atom_candidates` refused because the
1124 // atom's own coordinate collapsed (`s_tt ≈ 0`) — the rank-1 co-collapse
1125 // fixed point. Recover a FRESH linear image from the residual's top
1126 // direction (fresh per-row codes) and force the LINEAR verdict (a
1127 // single-option slot the selector must take) so the slot reconstructs
1128 // its residual's best linear axis at linear quality instead of the
1129 // collapsed-curve constant. `None` only when the residual itself is
1130 // degenerate — then there is genuinely nothing to recover and we skip.
1131 None => match build_collapse_rescue_linear_image(
1132 atom_idx,
1133 assign.view(),
1134 target_resid.view(),
1135 ) {
1136 Some((linear, image)) => {
1137 let delta = slot_delta(&coords, &assign, &decoded, &image);
1138 if r0.is_none() {
1139 r0 = Some(slot_r0(&assign, &decoded, &target_resid));
1140 }
1141 slots.push(vec![linear]);
1142 // Forced-linear rescue: the curve was degenerate, so there is no
1143 // curved alternative to roll back to and no recoverable EV loss.
1144 collapse_loss.push(None);
1145 names.push(atom.name.clone());
1146 manifolds.push(manifold);
1147 turnings.push(fitted_turning);
1148 delta_evs.push(delta_ev_for(atom_idx));
1149 deltas.push(delta);
1150 linear_images.push(image);
1151 }
1152 None => continue,
1153 },
1154 }
1155 }
1156
1157 if slots.is_empty() {
1158 return Ok(None);
1159 }
1160
1161 let mut selection = select_hybrid_split(&slots)?;
1162
1163 // #1026 item-2 — GLOBAL CROSS-TERM EV-preservation guard over the SELECTED
1164 // collapses.
1165 //
1166 // The per-atom gate above bounds each atom's individual EV loss, but the TRUE
1167 // dictionary-level RSS increase from collapsing a SET of atoms is
1168 // ΔRSS = Σ_k Δ_k + 2 Σ_{j<k} ⟨δ_j, δ_k⟩,
1169 // so two effects escape any per-atom or summed-loss bound: (1) the accumulation
1170 // of many individually-tolerable `Δ_k`, and (2) the CROSS TERMS between
1171 // simultaneously-collapsed atoms. The old aggregate `Σ max(Δ_k, 0)` guard bounded
1172 // (1) but was blind to (2): correlated collapse errors whose per-atom losses each
1173 // sit under tolerance can still push the true global loss over it.
1174 //
1175 // Every slot now carries its exact dictionary-level reconstruction-change vector
1176 // δ_k (globally row-aligned — the caller presents all per-atom arrays on the same
1177 // n rows), so we reconstruct the full dictionary WITH the selected collapse set
1178 // applied and measure the real global increase `ΔRSS` DIRECTLY (cross terms
1179 // captured), reverting the least-justified collapses until the degradation is
1180 // within the same EV tolerance and re-adjudicating so `selection` stays consistent.
1181 if total_centered_variance.is_finite() && total_centered_variance > 0.0 {
1182 if let Some(r0) = r0.as_ref() {
1183 let global_tol = SAE_HYBRID_COLLAPSE_MAX_EV_LOSS * total_centered_variance;
1184 // Partition the SELECTED-collapsed slots: rollback-eligible ones (a curved
1185 // alternative still present and a finite per-atom loss) vs forced ones
1186 // (euclidean / collapse-rescue — no curved fallback, stay collapsed but
1187 // still enter the global reconstruction so their cross terms are counted).
1188 let mut eligible: Vec<(usize, f64, &Array2<f64>)> = Vec::new();
1189 let mut forced: Vec<&Array2<f64>> = Vec::new();
1190 for (i, choice) in selection.atoms.iter().enumerate() {
1191 if !choice.param.is_linear() {
1192 continue;
1193 }
1194 let has_curved_alt = slots[i].iter().any(|c| !c.param.is_linear());
1195 let loss_finite = collapse_loss[i].map(|l| l.is_finite()).unwrap_or(false);
1196 if has_curved_alt && loss_finite {
1197 eligible.push((i, choice.curved_evidence_margin, &deltas[i]));
1198 } else {
1199 forced.push(&deltas[i]);
1200 }
1201 }
1202 let reverted = global_collapse_rollback(r0.view(), &eligible, &forced, global_tol);
1203 if !reverted.is_empty() {
1204 for slot in reverted {
1205 if let Some(curved) =
1206 slots[slot].iter().find(|c| !c.param.is_linear()).copied()
1207 {
1208 slots[slot] = vec![curved];
1209 }
1210 }
1211 selection = select_hybrid_split(&slots)?;
1212 }
1213 }
1214 }
1215
1216 let verdicts: Vec<AtomHybridVerdict> = names
1217 .into_iter()
1218 .zip(selection.atoms.iter().copied())
1219 .zip(linear_images.into_iter())
1220 .zip(turnings.into_iter())
1221 .zip(delta_evs.into_iter())
1222 .map(
1223 |((((atom_name, choice), linear_image), fitted_turning), train_loao_delta_ev)| {
1224 let kept_curved = !choice.param.is_linear();
1225 AtomHybridVerdict {
1226 atom_name,
1227 choice,
1228 kept_curved,
1229 fitted_turning,
1230 train_loao_delta_ev,
1231 // Carry the straight sub-model only when the verdict collapses
1232 // this slot to linear — the curved slots keep their fitted image.
1233 linear_image: if kept_curved {
1234 None
1235 } else {
1236 Some(linear_image)
1237 },
1238 }
1239 },
1240 )
1241 .collect();
1242
1243 Ok(Some(SaeHybridSplitReport {
1244 verdicts,
1245 selection,
1246 }))
1247}
1248
1249#[cfg(test)]
1250mod tests {
1251 use super::*;
1252 use std::f64::consts::PI;
1253
1254 /// A straight RESPONSE residual (the atom's data is a line) is explained
1255 /// equally well by both candidates, so the cheaper linear special case wins.
1256 /// With `a_k = 1` the curved decoded image is straight too (Θ = 0), so both
1257 /// the dominance floor and the evidence argmin select linear. This is the
1258 /// common-data nested comparison (#1202): linear is the curved family's
1259 /// `Θ = 0` member, so it cannot lose when a line already explains the data.
1260 #[test]
1261 fn straight_residual_selects_linear() {
1262 let n = 40;
1263 let coords = Array1::from_iter((0..n).map(|i| -1.0 + 2.0 * (i as f64) / ((n - 1) as f64)));
1264 let assign = Array1::<f64>::ones(n);
1265 // The data the atom must explain is a straight line in ℝ²; the curved
1266 // decoded image equals that same line (a Θ = 0 curved fit).
1267 let mut data = Array2::<f64>::zeros((n, 2));
1268 let mut decoded = Array2::<f64>::zeros((n, 2));
1269 for i in 0..n {
1270 data[[i, 0]] = coords[i];
1271 data[[i, 1]] = 0.6 * coords[i];
1272 decoded[[i, 0]] = coords[i];
1273 decoded[[i, 1]] = 0.6 * coords[i];
1274 }
1275 let (linear, curved, _) = build_atom_candidates(
1276 coords.view(),
1277 assign.view(),
1278 decoded.view(),
1279 data.view(),
1280 // a generous curved parameter price (M·p)
1281 10,
1282 None,
1283 Some(0.0),
1284 )
1285 .expect("straight residual yields a candidate pair");
1286 let choice =
1287 gam_solve::evidence::select_hybrid_atom(&[linear, curved]).expect("non-empty slot");
1288 assert!(
1289 choice.param.is_linear(),
1290 "a straight response residual must keep the linear special case"
1291 );
1292 }
1293
1294 /// A turning RESPONSE residual (the atom's data traces a full circle) is fit
1295 /// well by the curved decoded image (curved_rss ≈ 0) but poorly by any
1296 /// straight line (large linear_rss), so the curved candidate wins the common
1297 /// evidence comparison once its data-fit gain exceeds its extra parameter
1298 /// price (#1202).
1299 #[test]
1300 fn turning_residual_selects_curved_on_evidence() {
1301 let n = 60;
1302 let coords = Array1::from_iter((0..n).map(|i| (i as f64) / ((n - 1) as f64)));
1303 let assign = Array1::<f64>::ones(n);
1304 // The data is a full circle; the curved decoded image is that same
1305 // circle (the curved atom reconstructs its assigned residual), so the
1306 // curved candidate has ≈ zero data-fit residual while a straight line
1307 // cannot follow the loop.
1308 let mut data = Array2::<f64>::zeros((n, 2));
1309 let mut decoded = Array2::<f64>::zeros((n, 2));
1310 for i in 0..n {
1311 let theta = 2.0 * PI * coords[i];
1312 data[[i, 0]] = theta.cos();
1313 data[[i, 1]] = theta.sin();
1314 decoded[[i, 0]] = theta.cos();
1315 decoded[[i, 1]] = theta.sin();
1316 }
1317 // The curved atom has 5 parameters (just above the 4 = 2·p linear budget);
1318 // the full-circle linear residual exceeds the extra-parameter overhead, so
1319 // curved wins on evidence.
1320 let (linear, curved, _) = build_atom_candidates(
1321 coords.view(),
1322 assign.view(),
1323 decoded.view(),
1324 data.view(),
1325 5,
1326 None,
1327 Some(2.0 * PI),
1328 )
1329 .expect("turning residual yields a candidate pair");
1330 assert!(
1331 linear.negative_log_evidence > curved.negative_log_evidence,
1332 "the line must misfit the circular residual worse than the curve does \
1333 (linear NLE {} should exceed curved NLE {})",
1334 linear.negative_log_evidence,
1335 curved.negative_log_evidence
1336 );
1337 let choice =
1338 gam_solve::evidence::select_hybrid_atom(&[linear, curved]).expect("non-empty slot");
1339 assert_eq!(
1340 choice.param,
1341 gam_solve::evidence::HybridAtomParam::Curved { latent_dim: 1 },
1342 "a full-circle response residual must keep the curved parameterization"
1343 );
1344 assert!(
1345 choice.curved_evidence_margin > 0.0,
1346 "curved must win a positive evidence margin over the linear secant"
1347 );
1348 }
1349
1350 /// The nested-dominance floor on common data (#1202): when the curved decoded
1351 /// image is a WORSE fit to the response residual than its own best straight
1352 /// projection, linear must win — the curved family cannot be charged extra
1353 /// parameters to fit the residual no better than its `Θ = 0` member. Here the
1354 /// data is a line but the curved image bends away from it, so curved_rss >
1355 /// linear_rss and the cheaper, better-fitting line is selected.
1356 #[test]
1357 fn linear_beats_curved_when_curve_misfits_residual() {
1358 let n = 50;
1359 let coords = Array1::from_iter((0..n).map(|i| (i as f64) / ((n - 1) as f64)));
1360 let assign = Array1::<f64>::ones(n);
1361 // Data is a straight line; the curved decoded image is a parabola that
1362 // departs from it, so a straight line fits the data strictly better.
1363 let mut data = Array2::<f64>::zeros((n, 2));
1364 let mut decoded = Array2::<f64>::zeros((n, 2));
1365 for i in 0..n {
1366 let t = coords[i];
1367 data[[i, 0]] = t;
1368 data[[i, 1]] = 0.5 * t;
1369 decoded[[i, 0]] = t;
1370 decoded[[i, 1]] = t * t; // bends away from the linear data
1371 }
1372 let (linear, curved, _) = build_atom_candidates(
1373 coords.view(),
1374 assign.view(),
1375 decoded.view(),
1376 data.view(),
1377 // a real curved Θ above the floor so the dominance floor does not fire
1378 6,
1379 None,
1380 Some(1.0),
1381 )
1382 .expect("candidate pair");
1383 let choice =
1384 gam_solve::evidence::select_hybrid_atom(&[linear, curved]).expect("non-empty slot");
1385 assert!(
1386 choice.param.is_linear(),
1387 "a curved image that fits the data worse than its own line must yield \
1388 to the linear special case on common-data evidence (#1202)"
1389 );
1390 }
1391
1392 /// The LINEAR candidate's Laplace logdet is the genuine weighted-design Gram
1393 /// determinant `p·(log w_sum + log s_tt)` with `w_sum = Σ a_k²`, `s_tt =
1394 /// Σ a_k²(t − t̄)²` — it INCLUDES the coordinate-spread term `log(s_tt)`
1395 /// (#1203). Verify both contributions are present by reading the logdet off a
1396 /// candidate whose linear residual is exactly zero (response residual = the
1397 /// fitted line), so `NLE_linear = ½·logdet`. Doubling the coordinate spread
1398 /// (at fixed assignment mass) scales `s_tt` by 4 → logdet += `p·log(4)`;
1399 /// doubling all assignment masses scales BOTH `w_sum` and `s_tt` by 4 (they
1400 /// are quadratic in `a_k`) → logdet += `2p·log(4)`.
1401 #[test]
1402 fn linear_logdet_includes_weighted_coordinate_spread() {
1403 let n = 40;
1404 let p = 2usize;
1405 // Read the logdet back off a candidate with zero linear residual: the
1406 // response residual is exactly `a_k·(line)`, so the WLS line recovers it
1407 // with RSS == 0 and `NLE_linear = ½·logdet`.
1408 let logdet = |coords: &Array1<f64>, assign: &Array1<f64>| -> f64 {
1409 // A straight image; the response residual is the same line scaled by
1410 // the per-row assignment mass `a_k`, so the prediction `a_k·(b₀+dt·b₁)`
1411 // matches it exactly and linear_rss == 0.
1412 let line = |t: f64| -> [f64; 2] { [t, 0.6 * t] };
1413 let mut decoded = Array2::<f64>::zeros((n, p));
1414 let mut data = Array2::<f64>::zeros((n, p));
1415 for i in 0..n {
1416 let l = line(coords[i]);
1417 decoded[[i, 0]] = l[0];
1418 decoded[[i, 1]] = l[1];
1419 data[[i, 0]] = assign[i] * l[0];
1420 data[[i, 1]] = assign[i] * l[1];
1421 }
1422 let (linear, _curved, _) = build_atom_candidates(
1423 coords.view(),
1424 assign.view(),
1425 decoded.view(),
1426 data.view(),
1427 10,
1428 None,
1429 Some(0.0),
1430 )
1431 .expect("straight residual yields a pair");
1432 2.0 * linear.negative_log_evidence // = logdet (linear_rss == 0)
1433 };
1434
1435 let base_coords =
1436 Array1::from_iter((0..n).map(|i| -1.0 + 2.0 * (i as f64) / ((n - 1) as f64)));
1437 let ones = Array1::<f64>::ones(n);
1438
1439 // Doubling the coordinate spread → s_tt ×4, w_sum fixed → logdet += p·log(4).
1440 let wide_coords = base_coords.mapv(|t| 2.0 * t);
1441 let d_spread = logdet(&wide_coords, &ones) - logdet(&base_coords, &ones);
1442 assert!(
1443 (d_spread - (p as f64) * 4.0_f64.ln()).abs() < 1e-9,
1444 "linear logdet must move by p·log(4) when coordinate spread doubles \
1445 (got {d_spread}); the spread term log(s_tt) must be present"
1446 );
1447
1448 // Doubling all assignment masses → w_sum ×4 AND s_tt ×4 (quadratic in a_k)
1449 // → logdet += 2p·log(4).
1450 let twos = Array1::<f64>::from_elem(n, 2.0);
1451 let d_weight = logdet(&base_coords, &twos) - logdet(&base_coords, &ones);
1452 assert!(
1453 (d_weight - 2.0 * (p as f64) * 4.0_f64.ln()).abs() < 1e-9,
1454 "linear logdet must move by 2p·log(4) when all assignment masses double \
1455 (got {d_weight})"
1456 );
1457 }
1458
1459 /// #1223 — the curved arm's Laplace complexity is the REAL weighted-design
1460 /// Gram log-determinant `p·log|ΦᵀWΦ|_+`, not a parameter-count proxy. Build a
1461 /// curved design whose columns are the constant and the centered coordinate
1462 /// (a 2-column basis), so `ΦᵀWΦ = diag(w_sum, s_tt)` exactly matches the
1463 /// linear arm's data Gram, and assert `curved_design_gram_logdet` returns
1464 /// `p·(log w_sum + log s_tt)` — the same determinant the linear arm reports
1465 /// on the same design weight. A proxy `M·log(w_sum)` would instead omit the
1466 /// `log(s_tt)` spread term, so this pins the genuine determinant.
1467 #[test]
1468 fn curved_gram_logdet_is_real_weighted_design_determinant() {
1469 let n = 40;
1470 let p = 3usize;
1471 let coords = Array1::from_iter((0..n).map(|i| -1.0 + 2.0 * (i as f64) / ((n - 1) as f64)));
1472 let assign = Array1::<f64>::from_iter((0..n).map(|i| 0.5 + 0.01 * (i as f64)));
1473
1474 // Mass-weighted coordinate mean and spread under wᵢ = a_k².
1475 let mut w_sum = 0.0;
1476 let mut t_bar = 0.0;
1477 for i in 0..n {
1478 let w = assign[i] * assign[i];
1479 w_sum += w;
1480 t_bar += w * coords[i];
1481 }
1482 t_bar /= w_sum;
1483 let mut s_tt = 0.0;
1484 for i in 0..n {
1485 let dt = coords[i] - t_bar;
1486 s_tt += assign[i] * assign[i] * dt * dt;
1487 }
1488
1489 // Curved design columns: [1, (t − t̄)]. Its weighted Gram is exactly
1490 // diag(w_sum, s_tt) (the cross term Σ w·(t−t̄) vanishes by construction),
1491 // so log|ΦᵀWΦ| = log(w_sum) + log(s_tt).
1492 let mut phi = Array2::<f64>::zeros((n, 2));
1493 for i in 0..n {
1494 phi[[i, 0]] = 1.0;
1495 phi[[i, 1]] = coords[i] - t_bar;
1496 }
1497 let got = curved_design_gram_logdet(phi.view(), assign.view(), p)
1498 .expect("non-degenerate curved design has a determinant");
1499 let want = (p as f64) * (w_sum.ln() + s_tt.ln());
1500 assert!(
1501 (got - want).abs() < 1e-9,
1502 "curved Gram logdet must be the real p·log|ΦᵀWΦ| = {want}, got {got}"
1503 );
1504
1505 // A rank-deficient design (a duplicated column) drops the null direction:
1506 // its determinant equals that of the single retained constant column,
1507 // p·log(w_sum), NOT a 2-column proxy.
1508 let mut phi_dup = Array2::<f64>::zeros((n, 2));
1509 for i in 0..n {
1510 phi_dup[[i, 0]] = 1.0;
1511 phi_dup[[i, 1]] = 1.0;
1512 }
1513 let got_dup = curved_design_gram_logdet(phi_dup.view(), assign.view(), p)
1514 .expect("rank-1 design still has a positive determinant");
1515 let want_dup = (p as f64) * (2.0 * w_sum).ln();
1516 assert!(
1517 (got_dup - want_dup).abs() < 1e-9,
1518 "rank-deficient curved Gram must report only its positive direction \
1519 (p·log(2·w_sum) = {want_dup}), got {got_dup}"
1520 );
1521 }
1522
1523 /// #1051 NESTED MIN — the curved arm re-fit on the residual match-or-beats the
1524 /// best straight line: `curved_refit_rss ≤ best_line_rss` up to solver tolerance
1525 /// on a basis whose span contains the straight lane (here `Φ = [1, t, t²]`). A
1526 /// genuinely curved (quadratic) signal is STRICTLY preferred by the curved arm,
1527 /// while an exactly-straight signal ties near zero (so the cheaper linear lane
1528 /// wins downstream). This is the property the realized-curve heuristic could not
1529 /// establish: comparing a possibly-collapsed realized curve against min-over-lines
1530 /// did NOT guarantee `curved ≤ linear`; re-fitting the curved decoder does.
1531 #[test]
1532 fn refit_curved_rss_matches_or_beats_best_line_nested() {
1533 let n = 40usize;
1534 let p = 2usize;
1535 let coords =
1536 Array1::from_iter((0..n).map(|i| -1.0 + 2.0 * (i as f64) / ((n - 1) as f64)));
1537 let assign = Array1::<f64>::ones(n);
1538 // Φ = [1, t, t²]: its column span contains the straight lane [1, t], so the
1539 // decoder-only curved refit is a proper superset of the line fit.
1540 let mut phi = Array2::<f64>::zeros((n, 3));
1541 for i in 0..n {
1542 phi[[i, 0]] = 1.0;
1543 phi[[i, 1]] = coords[i];
1544 phi[[i, 2]] = coords[i] * coords[i];
1545 }
1546 // Best mass-weighted line RSS on `y` (assign = 1): fit design [1, (t − t̄)].
1547 let t_bar = coords.iter().sum::<f64>() / n as f64;
1548 let best_line_rss = |y: &Array2<f64>| -> f64 {
1549 let mut design = Array2::<f64>::zeros((n, 2));
1550 for i in 0..n {
1551 design[[i, 0]] = 1.0;
1552 design[[i, 1]] = coords[i] - t_bar;
1553 }
1554 let b = solve_design_least_squares(design.view(), y.view()).unwrap();
1555 let pred = design.dot(&b);
1556 let mut rss = 0.0_f64;
1557 for i in 0..n {
1558 for j in 0..p {
1559 let r = y[[i, j]] - pred[[i, j]];
1560 rss += r * r;
1561 }
1562 }
1563 rss
1564 };
1565 // (a) exactly straight, (b) quadratic curve, (c) noisy line.
1566 let mut y_line = Array2::<f64>::zeros((n, p));
1567 let mut y_curve = Array2::<f64>::zeros((n, p));
1568 let mut y_noisy = Array2::<f64>::zeros((n, p));
1569 for i in 0..n {
1570 let t = coords[i];
1571 y_line[[i, 0]] = 0.4 + 0.6 * t;
1572 y_line[[i, 1]] = -0.2 + 1.1 * t;
1573 y_curve[[i, 0]] = t * t;
1574 y_curve[[i, 1]] = 0.5 - t * t;
1575 y_noisy[[i, 0]] = 0.3 + 0.7 * t + 0.05 * (3.0 * t).sin();
1576 y_noisy[[i, 1]] = 0.9 * t;
1577 }
1578 for y in [&y_line, &y_curve, &y_noisy] {
1579 let curved = curved_refit_rss(phi.view(), assign.view(), y.view())
1580 .expect("non-degenerate refit");
1581 let line = best_line_rss(y);
1582 assert!(
1583 curved <= line + 1e-9 * (1.0 + line),
1584 "nested dominance: refit-curved RSS {curved} must be ≤ best-line RSS {line}"
1585 );
1586 }
1587 // A genuinely curved (quadratic) signal is STRICTLY preferred by the curve.
1588 let curved_c = curved_refit_rss(phi.view(), assign.view(), y_curve.view()).unwrap();
1589 let line_c = best_line_rss(&y_curve);
1590 assert!(
1591 curved_c < 0.5 * line_c,
1592 "a quadratic signal must be far better fit by the curve ({curved_c}) than \
1593 by the best line ({line_c})"
1594 );
1595 // A straight signal ties near zero — collapses to the cheaper linear lane.
1596 let curved_l = curved_refit_rss(phi.view(), assign.view(), y_line.view()).unwrap();
1597 assert!(
1598 curved_l < 1e-18 && best_line_rss(&y_line) < 1e-18,
1599 "an exactly-straight signal ties the two arms near zero (curved {curved_l})"
1600 );
1601 }
1602
1603 /// #1026 item-2 GLOBAL CROSS-TERM guard: two collapses with CORRELATED
1604 /// (parallel) reconstruction-change errors whose per-atom losses each sit under
1605 /// tolerance — and whose SUM `Σ Δ_k` is also under tolerance (so the OLD
1606 /// aggregate `Σ max(Δ_k,0)` guard would ACCEPT) — but whose cross term
1607 /// `2⟨δ_1,δ_2⟩` pushes the TRUE global loss over tolerance. The global guard must
1608 /// roll back the least-justified collapse. An ORTHOGONAL control (disjoint
1609 /// support) with the same per-atom losses is accepted, isolating the cross term.
1610 #[test]
1611 fn global_guard_rejects_correlated_collapses_the_aggregate_would_accept() {
1612 let n = 4usize;
1613 let p = 1usize;
1614 // R0 = 0 ⇒ Δ_k = ‖δ_k‖² and ΔRSS(S) = ‖Σ_S δ_k‖² exactly.
1615 let r0 = Array2::<f64>::zeros((n, p));
1616 // Parallel δ_1 = δ_2 = 1 on all rows: ‖δ_k‖² = 4 each, Σ Δ_k = 8,
1617 // ΔRSS_global = ‖δ_1+δ_2‖² = 16.
1618 let mut d1 = Array2::<f64>::zeros((n, p));
1619 let mut d2 = Array2::<f64>::zeros((n, p));
1620 for i in 0..n {
1621 d1[[i, 0]] = 1.0;
1622 d2[[i, 0]] = 1.0;
1623 }
1624 // tol = 10: each per-atom loss 4 ≤ 10, Σ Δ_k = 8 ≤ 10 (aggregate accepts),
1625 // but global 16 > 10 (cross term rejects).
1626 let global_tol = 10.0_f64;
1627 let eligible = vec![(0usize, 0.1_f64, &d1), (1usize, 0.2_f64, &d2)];
1628 let forced: Vec<&Array2<f64>> = Vec::new();
1629 let reverted = global_collapse_rollback(r0.view(), &eligible, &forced, global_tol);
1630 assert_eq!(
1631 reverted,
1632 vec![1usize],
1633 "the global cross-term guard must roll back the least-justified collapse \
1634 (largest margin = slot 1) that the summed-loss aggregate (8 ≤ 10) accepts"
1635 );
1636
1637 // ORTHOGONAL control: same per-atom losses (δ on disjoint rows) ⇒ cross term
1638 // 0 ⇒ ΔRSS_global = 8 ≤ 10 ⇒ no rollback. Isolates the cross term as the cause.
1639 let mut o1 = Array2::<f64>::zeros((n, p));
1640 let mut o2 = Array2::<f64>::zeros((n, p));
1641 o1[[0, 0]] = 2.0; // ‖o1‖² = 4
1642 o2[[2, 0]] = 2.0; // ‖o2‖² = 4, disjoint support
1643 let eligible_o = vec![(0usize, 0.1_f64, &o1), (1usize, 0.2_f64, &o2)];
1644 let reverted_o = global_collapse_rollback(r0.view(), &eligible_o, &forced, global_tol);
1645 assert!(
1646 reverted_o.is_empty(),
1647 "uncorrelated collapses (cross term 0, global loss 8 ≤ 10) must be accepted"
1648 );
1649 }
1650
1651 /// A degenerate (single-point-mass) coordinate has no slope direction and is
1652 /// refused rather than adjudicated on a fabricated deviance.
1653 #[test]
1654 fn degenerate_coordinate_is_refused() {
1655 let n = 5;
1656 let coords = Array1::<f64>::from_elem(n, 0.5); // no spread
1657 let assign = Array1::<f64>::ones(n);
1658 let decoded = Array2::<f64>::zeros((n, 2));
1659 let data = Array2::<f64>::zeros((n, 2));
1660 assert!(
1661 build_atom_candidates(
1662 coords.view(),
1663 assign.view(),
1664 decoded.view(),
1665 data.view(),
1666 6,
1667 None,
1668 Some(0.0)
1669 )
1670 .is_none(),
1671 "a degenerate coordinate span must be refused"
1672 );
1673 }
1674}