gam_sae/hybrid_split.rs
1//! #1026 — load-bearing curved-vs-linear hybrid split for the fitted SAE
2//! dictionary.
3//!
4//! The selection machinery ([`gam_solve::evidence::select_hybrid_atom`],
5//! [`gam_solve::evidence::select_hybrid_split`]) and the per-atom
6//! integration helper
7//! ([`crate::assignment::select_hybrid_atom_parameterization`]) are
8//! correct and tested, but until now were called nowhere in the fitter: the
9//! post-fit pass only *logged* each `d = 1` atom's fitted turning `Θ`. This
10//! module makes the split LOAD-BEARING by building, per fitted `d = 1` atom, the
11//! two already-realized candidates and adjudicating them by the common
12//! evidence criterion.
13//!
14//! ## The common-evidence comparison on the data (#1202)
15//!
16//! Both candidates are scored against the SAME data: the portion of the
17//! response matrix the atom is responsible for reconstructing, namely its
18//! **leave-this-atom-out residual**
19//!
20//! y_resp[i] = target[i] − ( Σ_j a[i,j]·γ_j(t_{ij}) − a[i,k]·γ_k(t_{ik}) )
21//! = target[i] − without_k[i],
22//!
23//! the response with every OTHER atom's contribution subtracted. Over the rows
24//! assigned to atom `k` (assignment mass `a[i,k] = a_k`), the two candidates
25//! predict that residual:
26//!
27//! * the CURVED candidate predicts `a_k · γ_k(t)` — the atom's actual
28//! already-fitted contribution; its data-fit deviance is the weighted RSS of
29//! that contribution against `y_resp`, no longer zero by construction.
30//! * the LINEAR candidate predicts `a_k · (b₀ + (t − t̄)·b₁)`, the best
31//! weighted least-squares straight line fit to `y_resp` (design column
32//! scaled by the same assignment mass `a_k`), so its data-fit deviance is the
33//! weighted RSS of the best line against the SAME residual.
34//!
35//! Because the curved family's `Θ = 0` member reproduces exactly the linear
36//! prediction `a_k·(b₀ + (t − t̄)·b₁)` on this data, linear IS the nested `Θ = 0`
37//! sub-model of the curved family on common data — so the per-slot evidence
38//! argmin is a genuine "match-or-beat" comparison: the curved candidate is
39//! preferred only when its extra curvature lowers the data-fit deviance by more
40//! than its extra Laplace parameter price, and the linear special case wins
41//! whenever a straight line already explains the residual.
42//!
43//! This replaces the earlier post-hoc curve-simplification diagnostic, in which
44//! both candidates targeted the atom's already-fitted decoded image `γ_k(t)`
45//! (giving the curved arm a free zero residual against itself) rather than the
46//! response data. That version could not nest linear in curved on common data
47//! and so carried no real dominance guarantee (#1202); it is removed. The
48//! comparison here re-fits nothing in the (broken under #1051) euclidean /
49//! multi-atom outer continuation — it scores the already-realized curved
50//! contribution and the closed-form linear lane against the realized residual,
51//! both on the data, with no joint Hessian or continuation spine.
52
53use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
54
55use gam_linalg::faer_ndarray::FaerEigh;
56use gam_solve::evidence::{
57 HybridAtomCandidate, HybridAtomChoice, HybridSplitSelection, select_hybrid_split,
58};
59use gam_terms::latent::LatentManifold;
60use crate::chart_canonicalization::d1_atom_fitted_turning;
61use crate::manifold::SaeManifoldAtom;
62
63/// The rank-aware Laplace negative-log-evidence of a reduced per-atom Gaussian
64/// reconstruction sub-model: `residual_objective + ½ log|H|` with no smoothing
65/// penalty logdet and a full-rank design (no null space), which is the form
66/// [`gam_solve::evidence::laplace_evidence`] reduces to on this comparison.
67/// Kept inline (rather than routed through `EvidenceLogDetSource`) because both
68/// candidates' Hessian logdets are already the closed-form scalar moments of
69/// their shared design — no factor cache or HVP callback to assemble.
70fn reduced_laplace_nle(residual_objective: f64, log_det_h: f64) -> f64 {
71 residual_objective + 0.5 * log_det_h
72}
73
74/// Rank-aware `log|ΦᵀWΦ|_+` of the curved atom's weighted design Gram over its
75/// `M` decoder basis columns, with per-row weight `wᵢ = a_k²` (the same
76/// assignment-mass design weight the linear arm uses), summed over the
77/// eigenvalues above a relative spectral floor (#1223). This is the genuine
78/// weighted-design determinant the linear arm already reports — `log|XᵀWX|` —
79/// assembled for the curved basis so the two arms' Laplace complexity prices are
80/// computed on the SAME footing instead of pricing the curved arm with a
81/// parameter-count proxy `M·log(Σw)`.
82///
83/// Mirrors the linear arm exactly in what it does NOT include: no smoothing-
84/// penalty `λS` normalizer (the linear arm's Gram is the bare data Gram
85/// `diag(w_sum, s_tt)` too), so the comparison stays symmetric. The Gram is the
86/// design's outer Gram over its basis columns; it is identical across the `p`
87/// output channels (every channel shares the design `Φ`), so the per-channel
88/// `log|G|_+` is multiplied by `p` — matching the linear arm's `p·(…)` form.
89///
90/// `phi` is the curved design `Φ(t)` evaluated on the atom's assigned rows
91/// (`n × M`); `assign` is the per-row assignment mass `a_k` (NOT squared).
92/// Returns `None` when `Φ` is missing rows, the Gram is non-finite, or it has no
93/// positive eigenvalues (a fully rank-deficient design carries no determinant);
94/// the caller then falls back to the parameter-count proxy rather than fabricate
95/// a determinant.
96fn curved_design_gram_logdet(
97 phi: ArrayView2<'_, f64>,
98 assign: ArrayView1<'_, f64>,
99 p: usize,
100) -> Option<f64> {
101 let n = phi.nrows();
102 let m = phi.ncols();
103 if m == 0 || assign.len() != n || n == 0 {
104 return None;
105 }
106 // G = Φᵀ diag(a²) Φ (M×M, symmetric PSD).
107 let mut gram = Array2::<f64>::zeros((m, m));
108 for i in 0..n {
109 let w = assign[i] * assign[i];
110 if !(w.is_finite() && w >= 0.0) {
111 return None;
112 }
113 if w == 0.0 {
114 continue;
115 }
116 let row = phi.row(i);
117 for a in 0..m {
118 let wa = w * row[a];
119 for b in a..m {
120 gram[[a, b]] += wa * row[b];
121 }
122 }
123 }
124 // Symmetrize the lower triangle (we only filled the upper).
125 for a in 0..m {
126 for b in 0..a {
127 gram[[a, b]] = gram[[b, a]];
128 }
129 }
130 if gram.iter().any(|v| !v.is_finite()) {
131 return None;
132 }
133 let (vals, _vecs) = gram.eigh(faer::Side::Lower).ok()?;
134 // Rank-aware log-determinant: sum log of eigenvalues above a relative floor
135 // tied to the largest eigenvalue, dropping the numerically-null directions
136 // (the curved design's null space, analogous to the linear arm's full-rank
137 // 2-D Gram). A design with no positive eigenvalue carries no determinant.
138 let lambda_max = vals.iter().cloned().fold(0.0_f64, f64::max);
139 if !(lambda_max > 0.0 && lambda_max.is_finite()) {
140 return None;
141 }
142 let floor = lambda_max * 1e-12;
143 let mut log_det = 0.0_f64;
144 let mut rank = 0usize;
145 for &lambda in vals.iter() {
146 if lambda > floor {
147 log_det += lambda.ln();
148 rank += 1;
149 }
150 }
151 if rank == 0 || !log_det.is_finite() {
152 return None;
153 }
154 // The design Gram is shared across the p output channels.
155 Some((p as f64) * log_det)
156}
157
158/// The fitted straight sub-model `γ̃(t) = b₀ + (t − t̄)·b₁` of one `d = 1` atom:
159/// the exact assignment-mass-weighted least-squares line fit to the atom's
160/// leave-this-atom-out RESPONSE residual `y_resp` over its assigned rows (the
161/// curved family's nested `Θ = 0` sub-model on common data, #1202). Carried on a
162/// verdict that selects LINEAR so the collapsed reconstruction can replace the
163/// curved decoded row with this straight image at any coordinate WITHOUT
164/// re-entering the (broken, #1051) outer fit — the coefficients are already
165/// realized inside the adjudication.
166#[derive(Clone, Debug)]
167pub struct AtomLinearImage {
168 /// The atom's slot index in the dictionary (so the collapsed assembly knows
169 /// which atom's decoded row to substitute).
170 pub atom_idx: usize,
171 /// The mass-weighted coordinate mean `t̄` the line is centered on.
172 pub t_bar: f64,
173 /// Per-output-channel centered intercept `b₀ = γ̄` at `t̄` (length `p`).
174 pub b0: Array1<f64>,
175 /// Per-output-channel slope `b₁` (length `p`).
176 pub b1: Array1<f64>,
177 /// #1026 collapse-rescue per-row coordinates. `None` for the ordinary path:
178 /// the line is evaluated at the atom's OWN realized coordinate `t`. `Some(u)`
179 /// only when the atom's circle codes had collapsed to a single point
180 /// (`s_tt ≈ 0`) so its own coordinate carries no spread — then the line is fit
181 /// against, and reconstruct evaluates it at, these FRESH per-row codes `uᵢ`
182 /// (the projection of the leave-this-atom-out residual onto its top
183 /// mass-weighted output direction). This is what lets a circle atom that the
184 /// joint fit drove into the degenerate "chord-through-the-arc" fixed point
185 /// still reconstruct its residual's best linear direction — recovering the
186 /// linear-tail reach the hybrid-split was designed to deliver — instead of a
187 /// constant (its collapsed curve), which is the real-OLMo rank-1 co-collapse
188 /// (held-out EV ≈ 0.13 vs the linear ceiling ≈ 0.74). Length `n` (one per
189 /// reconstructed row); unassigned rows are gated to zero by `a_k` anyway.
190 pub row_codes: Option<Array1<f64>>,
191}
192
193impl AtomLinearImage {
194 /// Evaluate the straight sub-model `b₀ + (t − t̄)·b₁` into `out` (length `p`).
195 pub fn fill_row(&self, t: f64, out: &mut [f64]) {
196 let dt = t - self.t_bar;
197 for (j, slot) in out.iter_mut().enumerate() {
198 *slot = self.b0[j] + dt * self.b1[j];
199 }
200 }
201
202 /// The coordinate at which row `row` should evaluate this image: the
203 /// collapse-rescue fresh code `uᵢ` when present (#1026), else the atom's own
204 /// realized coordinate `own_t` passed by the caller.
205 pub fn coordinate_for_row(&self, row: usize, own_t: f64) -> f64 {
206 match &self.row_codes {
207 Some(u) if row < u.len() => u[row],
208 _ => own_t,
209 }
210 }
211}
212
213/// One fitted `d = 1` atom's hybrid-split verdict, surfaced in the model output.
214#[derive(Clone, Debug)]
215pub struct AtomHybridVerdict {
216 /// The atom's name (slot identity in the dictionary).
217 pub atom_name: String,
218 /// The evidence-selected parameterization choice for this slot.
219 pub choice: HybridAtomChoice,
220 /// `true` iff the slot kept the CURVED parameterization (the fitted atom);
221 /// `false` iff it yielded to the LINEAR special case (the straight tail).
222 pub kept_curved: bool,
223 /// The atom's fitted turning `Θ = ∫|κ| ds` (radians), the novel geometric
224 /// quantity #1026 pairs against reconstruction EV: `Θ ≈ 0` is a linear-tail
225 /// direction wearing a curved basis, `Θ ≈ 2π` is a full curved loop. `None`
226 /// iff the evaluator has no analytic second jet or the curve is degenerate.
227 /// Captured here (not just logged) so the EV-vs-Θ frontier is queryable
228 /// structured data on the persisted report rather than a transient log line.
229 pub fitted_turning: Option<f64>,
230 /// The atom's training leave-one-atom-out explained-variance contribution
231 /// `ΔEV_k = EV(full) − EV(full∖{k})` — how much reconstruction EV this single
232 /// atom earns. Paired with [`Self::fitted_turning`] this is the `(Θ, ΔEV)`
233 /// point the #1026 frontier reports: a `Θ ≈ 0` atom with large `ΔEV` is a
234 /// genuine linear-tail direction; a high-`Θ` atom with large `ΔEV` is a
235 /// genuine curved family. `None` iff the caller did not supply LOAO EV.
236 pub train_loao_delta_ev: Option<f64>,
237 /// The fitted straight sub-model for this slot, present iff the verdict
238 /// selected LINEAR (`kept_curved == false`). The collapsed reconstruction
239 /// substitutes this for the atom's curved decoded image, making the verdict
240 /// load-bearing on the reconstruction rather than a passive diagnostic.
241 pub linear_image: Option<AtomLinearImage>,
242}
243
244/// The whole dictionary's hybrid-split report: one verdict per eligible `d = 1`
245/// atom, plus the dictionary-level aggregates the EV-vs-Θ frontier reports
246/// against.
247#[derive(Clone, Debug)]
248pub struct SaeHybridSplitReport {
249 /// One adjudicated verdict per eligible `d = 1` atom, in slot order. Atoms
250 /// that are not eligible (wrong dim, no evaluator, mid-homotopy) are absent
251 /// — they carry no curved/linear adjudication.
252 pub verdicts: Vec<AtomHybridVerdict>,
253 /// The dictionary-level rolled-up selection (summed NLE, total parameters,
254 /// curved/linear counts) over the eligible atoms.
255 pub selection: HybridSplitSelection,
256}
257
258/// Below this many assigned rows a `d = 1` atom cannot support a two-parameter
259/// straight-line fit with a residual estimate, so the linear candidate's
260/// deviance is undefined. Such atoms are skipped (absent from the report),
261/// never adjudicated on a fabricated deviance.
262const MIN_ROWS_FOR_LINEAR_FIT: usize = 3;
263
264/// #1610/#1026 — EV-PRESERVATION gate tolerance: a `d = 1` slot may collapse to
265/// its linear tail only if doing so costs at most this fraction of the target's
266/// total (centered) variance in full-reconstruction explained variance. The
267/// evidence argmin trades data-fit against the curved arm's parameter price in
268/// `NLE` units, but on a small / low-amplitude fixture that trade can prefer the
269/// cheaper line even when the curve carries real reconstruction signal — the
270/// collapse then DROPS EV (the observed 1.0 → 0.748 over-collapse). This gate is
271/// a direct guard on the quantity that actually regressed: the per-atom collapse
272/// EV impact equals `(linear_rss − curved_rss)/SST_full` exactly (collapsing
273/// atom `k` raises the full reconstruction SSR by `linear_rss − curved_rss` and
274/// the full target variance `SST_full` is fixed), so a collapse is vetoed when it
275/// would lose more than this fraction. EV is a dimensionless quantity in `[0,1]`,
276/// so an absolute EV-loss tolerance is itself scale-invariant — it does not
277/// reintroduce the scale-incommensurability that sank the evidence-reformulation
278/// attempt. A genuinely straight atom (the curved fit IS its own line) loses
279/// `≈ 0` EV and still collapses losslessly; only a curve doing real
280/// reconstruction work (loss `≫ 1e-3`) is kept. `1e-3` = 0.1% of total variance:
281/// comfortably above f64 round-off on an exact-line collapse yet far below any
282/// material EV loss, so it separates the lossless and load-bearing regimes
283/// without tuning.
284const SAE_HYBRID_COLLAPSE_MAX_EV_LOSS: f64 = 1.0e-3;
285
286/// The full-reconstruction SSR INCREASE from collapsing one `d = 1` atom to its
287/// fitted straight sub-model: `linear_rss − curved_rss`, where both arms are
288/// scored against the atom's leave-this-atom-out response residual `y_resp`
289/// (`target_resid`) exactly as [`build_atom_candidates`] scores them. Because the
290/// full reconstruction differs from the collapsed one ONLY in this atom's
291/// contribution (`a_k·γ_k` → `a_k·line`) on its assigned rows, this scalar is the
292/// exact amount the full reconstruction's SSR rises when the slot is collapsed;
293/// dividing by the fixed total target variance gives the full-EV loss the
294/// EV-preservation gate keys on. Positive ⇒ the curve out-fits its straight
295/// projection (collapsing hurts); `≤ 0` ⇒ the line is at least as good (collapse
296/// is lossless or improving, never gated). Mirrors the `curved_rss` / `linear_rss`
297/// accumulation in [`build_atom_candidates`] bit-for-bit so the gate and the
298/// evidence comparison see the same residuals.
299fn collapse_ssr_increase(
300 coords: ArrayView1<'_, f64>,
301 assign: ArrayView1<'_, f64>,
302 decoded: ArrayView2<'_, f64>,
303 target_resid: ArrayView2<'_, f64>,
304 t_bar: f64,
305 b0: &Array1<f64>,
306 b1: &Array1<f64>,
307) -> f64 {
308 let n = assign.len();
309 let p = target_resid.ncols();
310 let mut curved_rss = 0.0_f64;
311 let mut linear_rss = 0.0_f64;
312 for i in 0..n {
313 let a = assign[i];
314 let dt = coords[i] - t_bar;
315 for j in 0..p {
316 let y = target_resid[[i, j]];
317 let r_curved = y - a * decoded[[i, j]];
318 curved_rss += r_curved * r_curved;
319 let r_linear = y - a * (b0[j] + dt * b1[j]);
320 linear_rss += r_linear * r_linear;
321 }
322 }
323 linear_rss - curved_rss
324}
325
326/// Build the curved + linear candidates for ONE fitted `d = 1` atom and return
327/// them as `(linear, curved, (t̄, b₀, b₁))`, or `None` if the atom cannot present
328/// an honest pair (too few rows, degenerate coordinate span, or non-finite
329/// numbers). Both candidates are scored against the SAME data — the atom's
330/// leave-this-atom-out response residual `y_resp` — so the comparison is a
331/// genuine common-evidence one with linear nested as the curved family's `Θ = 0`
332/// sub-model (#1202).
333///
334/// Inputs over the atom's assigned rows:
335/// * `coords` — the fitted on-atom coordinate `t`.
336/// * `assign` — the per-row assignment mass `a_k` (NOT squared; this routine
337/// squares it where the design weight `a_k²` is needed).
338/// * `decoded` — the atom's fitted decoded image `γ_k(t) = Φ(t) B_k` (`p` cols),
339/// whose mass-scaled value `a_k·γ_k` is the curved candidate's PREDICTION.
340/// * `target_resid` — the atom's leave-this-atom-out response residual `y_resp`
341/// (`p` cols): the response with every OTHER atom's contribution removed.
342/// This is the data both candidates fit.
343///
344/// The curved candidate's data-fit deviance is `½ Σ ‖y_resp − a_k·γ_k‖²` (the
345/// plain reconstruction SSE, matching the joint loss — the mass already lives in
346/// the prediction `a_k·γ_k`); the linear candidate fits the best mass-weighted
347/// straight line to `y_resp` and pays `½ Σ ‖y_resp − a_k·(b₀ + (t − t̄)·b₁)‖²`.
348/// Because the curved family's `Θ = 0` member reproduces the linear prediction
349/// exactly, linear is the nested sub-model and the argmin is the honest
350/// match-or-beat criterion.
351fn build_atom_candidates(
352 coords: ArrayView1<'_, f64>,
353 assign: ArrayView1<'_, f64>,
354 decoded: ArrayView2<'_, f64>,
355 target_resid: ArrayView2<'_, f64>,
356 curved_num_params: usize,
357 curved_phi: Option<ArrayView2<'_, f64>>,
358 fitted_turning: Option<f64>,
359) -> Option<(
360 HybridAtomCandidate,
361 HybridAtomCandidate,
362 (f64, Array1<f64>, Array1<f64>),
363)> {
364 let n = coords.len();
365 let p = decoded.ncols();
366 if n < MIN_ROWS_FOR_LINEAR_FIT
367 || decoded.nrows() != n
368 || assign.len() != n
369 || target_resid.nrows() != n
370 || target_resid.ncols() != p
371 || p == 0
372 {
373 return None;
374 }
375
376 // The LINEAR candidate fits `a_k·(b₀ + (t − t̄)·b₁)` to the residual `y_resp`,
377 // so the natural design column is `a_k·[1, (t − t̄)]` and the per-row Gram
378 // weight is `wᵢ = a_k²`. We accumulate the mass-weighted coordinate mean `t̄`
379 // and spread `s_tt` under that weight; a row that barely belongs to the atom
380 // (`a_k ≈ 0`) contributes ≈ nothing, exactly as in the joint loss.
381 let mut w_sum = 0.0_f64;
382 let mut t_bar = 0.0_f64;
383 for i in 0..n {
384 let a = assign[i];
385 if !(a.is_finite() && a >= 0.0) {
386 return None;
387 }
388 let w = a * a;
389 w_sum += w;
390 t_bar += w * coords[i];
391 }
392 if !(w_sum > 0.0) {
393 return None;
394 }
395 t_bar /= w_sum;
396
397 // Weighted Σ wᵢ·(t − t̄)² with `wᵢ = a_k²` — the coordinate spread under the
398 // line's design weight. A degenerate (single-point mass) coordinate has no
399 // slope direction; refuse rather than divide by ~0.
400 let mut s_tt = 0.0_f64;
401 for i in 0..n {
402 let dt = coords[i] - t_bar;
403 s_tt += assign[i] * assign[i] * dt * dt;
404 }
405 if !(s_tt > 1e-12 * (1.0 + t_bar * t_bar)) {
406 return None;
407 }
408
409 // Per-output-channel mass-weighted least squares for the line fit to the
410 // RESIDUAL `y_resp`. Minimizing `Σᵢ ‖y_resp[i] − a_k·(b₀ + (t − t̄)·b₁)‖²` in
411 // the centered basis has the diagonal normal equations
412 // b₀[j] = (Σ a_k·y_resp[i,j]) / w_sum, (recall the design intercept is a_k)
413 // b₁[j] = (Σ a_k·(t − t̄)·y_resp[i,j]) / s_tt.
414 let mut b0 = Array1::<f64>::zeros(p);
415 let mut b1 = Array1::<f64>::zeros(p);
416 for j in 0..p {
417 let mut s_1y = 0.0_f64;
418 let mut s_ty = 0.0_f64;
419 for i in 0..n {
420 let a = assign[i];
421 let dt = coords[i] - t_bar;
422 let y = target_resid[[i, j]];
423 s_1y += a * y;
424 s_ty += a * dt * y;
425 }
426 b0[j] = s_1y / w_sum;
427 b1[j] = s_ty / s_tt;
428 }
429
430 // Data-fit residual sums of squares of BOTH candidates against `y_resp`, the
431 // common data. The curved candidate predicts the atom's actual mass-scaled
432 // contribution `a_k·γ_k`; the linear candidate predicts the best line
433 // `a_k·(b₀ + (t − t̄)·b₁)`. These are no longer trivially zero for the curved
434 // arm — both are real misfits to the response residual, so the argmin is a
435 // genuine common-evidence comparison (#1202).
436 let mut curved_rss = 0.0_f64;
437 let mut linear_rss = 0.0_f64;
438 for i in 0..n {
439 let a = assign[i];
440 let dt = coords[i] - t_bar;
441 for j in 0..p {
442 let y = target_resid[[i, j]];
443 let r_curved = y - a * decoded[[i, j]];
444 curved_rss += r_curved * r_curved;
445 let r_linear = y - a * (b0[j] + dt * b1[j]);
446 linear_rss += r_linear * r_linear;
447 }
448 }
449
450 // Gaussian-reconstruction deviance: the residual objective `½ RSS` the
451 // Laplace normalizer is added to. The curved arm pays `½·curved_rss` (how
452 // well its fitted curve explains the residual) plus its larger `M·p`
453 // parameter price; the linear arm pays `½·linear_rss` plus a `2·p` price.
454 // Because the curved family's `Θ = 0` member equals the linear prediction,
455 // `curved_rss ≤ linear_rss` whenever the fitted curve is at least as good a
456 // residual fit as its own straight projection — the match-or-beat floor — and
457 // the argmin trades that data-fit gain against the curvature parameter price.
458 let curved_residual_objective = 0.5 * curved_rss;
459 let linear_residual_objective = 0.5 * linear_rss;
460
461 // Linear candidate parameter price: intercept + slope per output channel.
462 let linear_num_params = 2 * p;
463
464 // Laplace logdet of the (weighted) design Gram for the LINEAR candidate.
465 //
466 // For the centered weighted line fit `a_k·(b₀ + (t − t̄)·b₁)`, the per-output-
467 // channel design column is `a_k·[1, (t − t̄)]`, whose Gram is DIAGONAL in the
468 // centered basis: `diag(Σ a_k², Σ a_k²(t − t̄)²) = diag(w_sum, s_tt)`. Its log
469 // determinant is `log(w_sum) + log(s_tt)` PER output channel, i.e.
470 //
471 // log|H_linear| = p · ( log(w_sum) + log(s_tt) ).
472 //
473 // The `log(s_tt)` term is the slope direction's information: a line through a
474 // wide, heavily-massed coordinate spread is better-determined than one through
475 // a tiny spread, and the Laplace evidence must reflect that (#1203).
476 //
477 // The curved arm's Laplace determinant is now the genuine weighted-design
478 // Gram log-determinant `p · log|ΦᵀWΦ|_+` (#1223): the SAME quantity the
479 // linear arm reports (`p·(log w_sum + log s_tt) = p·log|XᵀWX|`), assembled
480 // from the curved basis `Φ` on the atom's assigned rows under the same
481 // assignment-mass design weight `wᵢ = a_k²`. Both arms omit the smoothing
482 // `λS` normalizer, so the complexity price is computed on a symmetric
483 // footing — no parameter-count proxy. Only when `Φ` is unavailable (the
484 // caller could not evaluate the basis) or its Gram is fully rank-deficient do
485 // we fall back to the historical `curved_num_params · log(w_sum)` proxy, so
486 // the comparison degrades gracefully rather than fabricating a determinant.
487 if !(w_sum > 0.0 && w_sum.is_finite() && s_tt.is_finite()) {
488 return None;
489 }
490 let linear_log_det_h = (p as f64) * (w_sum.ln() + s_tt.ln());
491 let curved_log_det_h = curved_phi
492 .and_then(|phi| {
493 if phi.nrows() == n {
494 curved_design_gram_logdet(phi, assign, p)
495 } else {
496 None
497 }
498 })
499 .unwrap_or_else(|| (curved_num_params as f64) * w_sum.ln());
500
501 // Reduced Laplace NLE `residual_objective + ½ log|H|`. Both omit an explicit
502 // smoothing-penalty logdet (the intrinsic smoothness penalty is
503 // reparameterization-invariant and identical in expectation across the two
504 // parameterizations of the same image).
505 let linear_nle = reduced_laplace_nle(linear_residual_objective, linear_log_det_h);
506 let curved_nle = reduced_laplace_nle(curved_residual_objective, curved_log_det_h);
507 if !(linear_nle.is_finite() && curved_nle.is_finite()) {
508 return None;
509 }
510
511 let linear = HybridAtomCandidate::linear(linear_nle, linear_num_params);
512 let curved = HybridAtomCandidate::curved(1, curved_nle, curved_num_params, fitted_turning);
513 Some((linear, curved, (t_bar, b0, b1)))
514}
515
516/// #1026 collapse rescue. When a `d = 1` atom's own coordinate has collapsed to a
517/// single point (`build_atom_candidates` refuses because `s_tt ≈ 0`), the atom is
518/// stuck in the degenerate "chord-through-the-arc" fixed point and its curved
519/// decode is a constant — the rank-1 dictionary co-collapse (real-OLMo held-out EV
520/// ≈ 0.13 vs the rank-K linear ceiling ≈ 0.74). The hybrid-split was DESIGNED to
521/// let such a linear-tail atom decode as a straight line; the only reason it can't
522/// here is that its own codes carry no spread to fit a slope against.
523///
524/// Recover FRESH per-row codes from the data instead: `uᵢ = yᵢ·v`, the projection
525/// of the leave-this-atom-out residual onto its top mass-weighted output direction
526/// `v` (the rank-1 of `Σᵢ wᵢ yᵢyᵢᵀ`, `wᵢ = a_k²` — the SAME design weight the line
527/// fit uses). These codes span the residual's strongest linear axis by
528/// construction, so the straight image `b₀ + (uᵢ − ū)·b₁` fit against them
529/// reconstructs that axis at LINEAR quality — exactly the linear-tail reach the
530/// split owes. Returns the forced-LINEAR candidate plus the image carrying `uᵢ`,
531/// or `None` when the residual itself carries no usable direction (a genuine zero
532/// atom the mass/decoder guards own).
533fn build_collapse_rescue_linear_image(
534 atom_idx: usize,
535 assign: ArrayView1<'_, f64>,
536 target_resid: ArrayView2<'_, f64>,
537) -> Option<(HybridAtomCandidate, AtomLinearImage)> {
538 let n = assign.len();
539 let p = target_resid.ncols();
540 if n < MIN_ROWS_FOR_LINEAR_FIT || target_resid.nrows() != n || p == 0 {
541 return None;
542 }
543 let mut w_sum = 0.0_f64;
544 for i in 0..n {
545 let a = assign[i];
546 if !(a.is_finite() && a >= 0.0) {
547 return None;
548 }
549 w_sum += a * a;
550 }
551 if !(w_sum > 0.0) {
552 return None;
553 }
554 // Top mass-weighted output direction `v` of the residual via power iteration on
555 // `M = Σᵢ wᵢ yᵢyᵢᵀ` (p×p, never materialized): `v ← normalize(Σᵢ wᵢ yᵢ (yᵢ·v))`.
556 // Seed from the per-channel weighted energy so a rank-1 residual converges in
557 // one step and the seed is deterministic (no RNG).
558 let mut v = Array1::<f64>::zeros(p);
559 for j in 0..p {
560 let mut e = 0.0_f64;
561 for i in 0..n {
562 let a = assign[i];
563 let y = target_resid[[i, j]];
564 e += a * a * y * y;
565 }
566 v[j] = e;
567 }
568 let mut vnorm = v.dot(&v).sqrt();
569 if !(vnorm > 0.0) {
570 return None;
571 }
572 v.mapv_inplace(|x| x / vnorm);
573 for _ in 0..32 {
574 let mut mv = Array1::<f64>::zeros(p);
575 for i in 0..n {
576 let a = assign[i];
577 let w = a * a;
578 let mut proj = 0.0_f64;
579 for j in 0..p {
580 proj += target_resid[[i, j]] * v[j];
581 }
582 let wp = w * proj;
583 for j in 0..p {
584 mv[j] += wp * target_resid[[i, j]];
585 }
586 }
587 vnorm = mv.dot(&mv).sqrt();
588 if !(vnorm > 0.0) {
589 return None;
590 }
591 mv.mapv_inplace(|x| x / vnorm);
592 let cos = mv.dot(&v).abs();
593 v = mv;
594 if cos > 1.0 - 1e-12 {
595 break;
596 }
597 }
598 // Fresh per-row codes `uᵢ = yᵢ·v` and the weighted line fit against them.
599 let mut u = Array1::<f64>::zeros(n);
600 let mut t_bar = 0.0_f64;
601 for i in 0..n {
602 let mut proj = 0.0_f64;
603 for j in 0..p {
604 proj += target_resid[[i, j]] * v[j];
605 }
606 u[i] = proj;
607 t_bar += assign[i] * assign[i] * proj;
608 }
609 t_bar /= w_sum;
610 let mut s_tt = 0.0_f64;
611 for i in 0..n {
612 let dt = u[i] - t_bar;
613 s_tt += assign[i] * assign[i] * dt * dt;
614 }
615 if !(s_tt > 1e-12 * (1.0 + t_bar * t_bar)) {
616 return None;
617 }
618 let mut b0 = Array1::<f64>::zeros(p);
619 let mut b1 = Array1::<f64>::zeros(p);
620 let mut linear_rss = 0.0_f64;
621 for j in 0..p {
622 let mut s_1y = 0.0_f64;
623 let mut s_ty = 0.0_f64;
624 for i in 0..n {
625 let a = assign[i];
626 let dt = u[i] - t_bar;
627 let y = target_resid[[i, j]];
628 s_1y += a * y;
629 s_ty += a * dt * y;
630 }
631 b0[j] = s_1y / w_sum;
632 b1[j] = s_ty / s_tt;
633 }
634 for i in 0..n {
635 let a = assign[i];
636 let dt = u[i] - t_bar;
637 for j in 0..p {
638 let r = target_resid[[i, j]] - a * (b0[j] + dt * b1[j]);
639 linear_rss += r * r;
640 }
641 }
642 let linear_log_det_h = (p as f64) * (w_sum.ln() + s_tt.ln());
643 let linear_nle = reduced_laplace_nle(0.5 * linear_rss, linear_log_det_h);
644 if !linear_nle.is_finite() {
645 return None;
646 }
647 let linear = HybridAtomCandidate::linear(linear_nle, 2 * p);
648 let image = AtomLinearImage {
649 atom_idx,
650 t_bar,
651 b0,
652 b1,
653 row_codes: Some(u),
654 };
655 Some((linear, image))
656}
657
658/// Assemble the per-atom candidate slots for [`select_hybrid_split`] from the
659/// fitted `d = 1` atoms, run the adjudication, and return the report.
660///
661/// `atoms` are the fitted dictionary atoms; `coords_for` yields the on-atom
662/// coordinate column for a slot, `assign_for` the per-row assignment mass `a_k`,
663/// `decoded_for` the fitted decoded image rows `γ_k`, and `target_resid_for` the
664/// atom's leave-this-atom-out response residual `y_resp` (the data both
665/// candidates are scored against, #1202). `manifold_for` yields the atom's chart
666/// manifold (a flat / Euclidean chart can present only the linear candidate,
667/// enforced inside the selector).
668///
669/// Returns `None` (no report) when no atom is eligible — there is nothing to
670/// adjudicate.
671pub fn build_hybrid_split_report<'a, C, W, D, R, M, E>(
672 atoms: &'a [SaeManifoldAtom],
673 eligible_d1: impl Iterator<Item = usize>,
674 mut coords_for: C,
675 mut assign_for: W,
676 mut decoded_for: D,
677 mut target_resid_for: R,
678 mut manifold_for: M,
679 mut delta_ev_for: E,
680 // #1026 — the full target's total (column-centered) variance `SST_full`, the
681 // fixed denominator of the EV-preservation gate. `≤ 0` / non-finite disables
682 // the gate (a degenerate, varianceless target has no EV to preserve).
683 total_centered_variance: f64,
684) -> Result<Option<SaeHybridSplitReport>, String>
685where
686 C: FnMut(usize) -> Array1<f64>,
687 W: FnMut(usize) -> Array1<f64>,
688 D: FnMut(usize) -> Array2<f64>,
689 R: FnMut(usize) -> Array2<f64>,
690 M: FnMut(usize) -> LatentManifold,
691 // The atom's held-out LOAO `ΔEV_k`, keyed by atom index. `None` when LOAO EV
692 // is unavailable (e.g. the caller has no target to measure against).
693 E: FnMut(usize) -> Option<f64>,
694{
695 let mut slots: Vec<Vec<HybridAtomCandidate>> = Vec::new();
696 let mut names: Vec<String> = Vec::new();
697 let mut manifolds: Vec<LatentManifold> = Vec::new();
698 // Per-slot fitted straight sub-model `(atom_idx, t̄, b₀, b₁)`, surfaced onto
699 // the verdict iff the slot selects LINEAR so the collapsed reconstruction can
700 // substitute it for the curved decoded image.
701 let mut linear_images: Vec<AtomLinearImage> = Vec::new();
702 // Per-slot `(Θ, ΔEV)` — the #1026 frontier point — carried onto each verdict
703 // so the geometry/EV pairing is structured report data, not a log line.
704 let mut turnings: Vec<Option<f64>> = Vec::new();
705 let mut delta_evs: Vec<Option<f64>> = Vec::new();
706
707 for atom_idx in eligible_d1 {
708 let atom = &atoms[atom_idx];
709 let coords = coords_for(atom_idx);
710 let assign = assign_for(atom_idx);
711 let decoded = decoded_for(atom_idx);
712 let target_resid = target_resid_for(atom_idx);
713 // Curved parameter price = the decoder's `M · p` coefficients.
714 let curved_num_params = atom.decoder_coefficients.len();
715 let fitted_turning = atom.basis_evaluator.as_ref().and_then(|evaluator| {
716 d1_atom_fitted_turning(
717 evaluator.as_ref(),
718 atom.decoder_coefficients.view(),
719 coords.view(),
720 )
721 .ok()
722 .flatten()
723 });
724 // Evaluate the curved design `Φ(t)` on this atom's assigned rows so the
725 // curved arm's Laplace complexity is the real weighted-design Gram
726 // log-determinant rather than a parameter-count proxy (#1223). A `d = 1`
727 // atom's coordinate column is presented as an `n × 1` design input. If
728 // the evaluator is absent or refuses, `curved_phi` stays `None` and
729 // `build_atom_candidates` falls back to the proxy.
730 let coords_col = coords
731 .view()
732 .into_shape_with_order((coords.len(), 1))
733 .ok()
734 .map(|v| v.to_owned());
735 let curved_phi = match (atom.basis_evaluator.as_ref(), coords_col.as_ref()) {
736 (Some(evaluator), Some(col)) => {
737 evaluator.evaluate(col.view()).ok().map(|(phi, _jet)| phi)
738 }
739 _ => None,
740 };
741 // A flat (Euclidean) chart cannot honestly present a curved candidate;
742 // the selector drops it. Present both for curveable charts.
743 let manifold = manifold_for(atom_idx);
744 match build_atom_candidates(
745 coords.view(),
746 assign.view(),
747 decoded.view(),
748 target_resid.view(),
749 curved_num_params,
750 curved_phi.as_ref().map(|phi| phi.view()),
751 fitted_turning,
752 ) {
753 Some((linear, curved, (t_bar, b0, b1))) => {
754 // #1026 EV-PRESERVATION gate. Collapsing this slot raises the full
755 // reconstruction SSR by `linear_rss − curved_rss`; if that is more
756 // than `SAE_HYBRID_COLLAPSE_MAX_EV_LOSS` of the fixed total target
757 // variance the collapse would DROP reconstruction EV materially, so
758 // veto it by presenting only the curved candidate (the selector must
759 // keep curved). A lossless / improving collapse (`≤ 0`) and a
760 // negligible one stay free to collapse — EV-neutral cases (the
761 // top-k / birth-topology lines) are untouched. Only curveable charts
762 // are gated; a euclidean chart never had a curved option.
763 let collapse_loses_ev = total_centered_variance.is_finite()
764 && total_centered_variance > 0.0
765 && collapse_ssr_increase(
766 coords.view(),
767 assign.view(),
768 decoded.view(),
769 target_resid.view(),
770 t_bar,
771 &b0,
772 &b1,
773 ) > SAE_HYBRID_COLLAPSE_MAX_EV_LOSS * total_centered_variance;
774 let slot = if manifold.is_euclidean() {
775 vec![linear]
776 } else if collapse_loses_ev {
777 vec![curved]
778 } else {
779 vec![linear, curved]
780 };
781 slots.push(slot);
782 names.push(atom.name.clone());
783 manifolds.push(manifold);
784 turnings.push(fitted_turning);
785 delta_evs.push(delta_ev_for(atom_idx));
786 linear_images.push(AtomLinearImage {
787 atom_idx,
788 t_bar,
789 b0,
790 b1,
791 row_codes: None,
792 });
793 }
794 // #1026 collapse rescue: `build_atom_candidates` refused because the
795 // atom's own coordinate collapsed (`s_tt ≈ 0`) — the rank-1 co-collapse
796 // fixed point. Recover a FRESH linear image from the residual's top
797 // direction (fresh per-row codes) and force the LINEAR verdict (a
798 // single-option slot the selector must take) so the slot reconstructs
799 // its residual's best linear axis at linear quality instead of the
800 // collapsed-curve constant. `None` only when the residual itself is
801 // degenerate — then there is genuinely nothing to recover and we skip.
802 None => match build_collapse_rescue_linear_image(
803 atom_idx,
804 assign.view(),
805 target_resid.view(),
806 ) {
807 Some((linear, image)) => {
808 slots.push(vec![linear]);
809 names.push(atom.name.clone());
810 manifolds.push(manifold);
811 turnings.push(fitted_turning);
812 delta_evs.push(delta_ev_for(atom_idx));
813 linear_images.push(image);
814 }
815 None => continue,
816 },
817 }
818 }
819
820 if slots.is_empty() {
821 return Ok(None);
822 }
823
824 let selection = select_hybrid_split(&slots)?;
825 let verdicts: Vec<AtomHybridVerdict> = names
826 .into_iter()
827 .zip(selection.atoms.iter().copied())
828 .zip(linear_images.into_iter())
829 .zip(turnings.into_iter())
830 .zip(delta_evs.into_iter())
831 .map(
832 |((((atom_name, choice), linear_image), fitted_turning), train_loao_delta_ev)| {
833 let kept_curved = !choice.param.is_linear();
834 AtomHybridVerdict {
835 atom_name,
836 choice,
837 kept_curved,
838 fitted_turning,
839 train_loao_delta_ev,
840 // Carry the straight sub-model only when the verdict collapses
841 // this slot to linear — the curved slots keep their fitted image.
842 linear_image: if kept_curved {
843 None
844 } else {
845 Some(linear_image)
846 },
847 }
848 },
849 )
850 .collect();
851
852 Ok(Some(SaeHybridSplitReport {
853 verdicts,
854 selection,
855 }))
856}
857
858#[cfg(test)]
859mod tests {
860 use super::*;
861 use std::f64::consts::PI;
862
863 /// A straight RESPONSE residual (the atom's data is a line) is explained
864 /// equally well by both candidates, so the cheaper linear special case wins.
865 /// With `a_k = 1` the curved decoded image is straight too (Θ = 0), so both
866 /// the dominance floor and the evidence argmin select linear. This is the
867 /// common-data nested comparison (#1202): linear is the curved family's
868 /// `Θ = 0` member, so it cannot lose when a line already explains the data.
869 #[test]
870 fn straight_residual_selects_linear() {
871 let n = 40;
872 let coords = Array1::from_iter((0..n).map(|i| -1.0 + 2.0 * (i as f64) / ((n - 1) as f64)));
873 let assign = Array1::<f64>::ones(n);
874 // The data the atom must explain is a straight line in ℝ²; the curved
875 // decoded image equals that same line (a Θ = 0 curved fit).
876 let mut data = Array2::<f64>::zeros((n, 2));
877 let mut decoded = Array2::<f64>::zeros((n, 2));
878 for i in 0..n {
879 data[[i, 0]] = coords[i];
880 data[[i, 1]] = 0.6 * coords[i];
881 decoded[[i, 0]] = coords[i];
882 decoded[[i, 1]] = 0.6 * coords[i];
883 }
884 let (linear, curved, _) = build_atom_candidates(
885 coords.view(),
886 assign.view(),
887 decoded.view(),
888 data.view(),
889 // a generous curved parameter price (M·p)
890 10,
891 None,
892 Some(0.0),
893 )
894 .expect("straight residual yields a candidate pair");
895 let choice =
896 gam_solve::evidence::select_hybrid_atom(&[linear, curved]).expect("non-empty slot");
897 assert!(
898 choice.param.is_linear(),
899 "a straight response residual must keep the linear special case"
900 );
901 }
902
903 /// A turning RESPONSE residual (the atom's data traces a full circle) is fit
904 /// well by the curved decoded image (curved_rss ≈ 0) but poorly by any
905 /// straight line (large linear_rss), so the curved candidate wins the common
906 /// evidence comparison once its data-fit gain exceeds its extra parameter
907 /// price (#1202).
908 #[test]
909 fn turning_residual_selects_curved_on_evidence() {
910 let n = 60;
911 let coords = Array1::from_iter((0..n).map(|i| (i as f64) / ((n - 1) as f64)));
912 let assign = Array1::<f64>::ones(n);
913 // The data is a full circle; the curved decoded image is that same
914 // circle (the curved atom reconstructs its assigned residual), so the
915 // curved candidate has ≈ zero data-fit residual while a straight line
916 // cannot follow the loop.
917 let mut data = Array2::<f64>::zeros((n, 2));
918 let mut decoded = Array2::<f64>::zeros((n, 2));
919 for i in 0..n {
920 let theta = 2.0 * PI * coords[i];
921 data[[i, 0]] = theta.cos();
922 data[[i, 1]] = theta.sin();
923 decoded[[i, 0]] = theta.cos();
924 decoded[[i, 1]] = theta.sin();
925 }
926 // The curved atom has 5 parameters (just above the 4 = 2·p linear budget);
927 // the full-circle linear residual exceeds the extra-parameter overhead, so
928 // curved wins on evidence.
929 let (linear, curved, _) = build_atom_candidates(
930 coords.view(),
931 assign.view(),
932 decoded.view(),
933 data.view(),
934 5,
935 None,
936 Some(2.0 * PI),
937 )
938 .expect("turning residual yields a candidate pair");
939 assert!(
940 linear.negative_log_evidence > curved.negative_log_evidence,
941 "the line must misfit the circular residual worse than the curve does \
942 (linear NLE {} should exceed curved NLE {})",
943 linear.negative_log_evidence,
944 curved.negative_log_evidence
945 );
946 let choice =
947 gam_solve::evidence::select_hybrid_atom(&[linear, curved]).expect("non-empty slot");
948 assert_eq!(
949 choice.param,
950 gam_solve::evidence::HybridAtomParam::Curved { latent_dim: 1 },
951 "a full-circle response residual must keep the curved parameterization"
952 );
953 assert!(
954 choice.curved_evidence_margin > 0.0,
955 "curved must win a positive evidence margin over the linear secant"
956 );
957 }
958
959 /// The nested-dominance floor on common data (#1202): when the curved decoded
960 /// image is a WORSE fit to the response residual than its own best straight
961 /// projection, linear must win — the curved family cannot be charged extra
962 /// parameters to fit the residual no better than its `Θ = 0` member. Here the
963 /// data is a line but the curved image bends away from it, so curved_rss >
964 /// linear_rss and the cheaper, better-fitting line is selected.
965 #[test]
966 fn linear_beats_curved_when_curve_misfits_residual() {
967 let n = 50;
968 let coords = Array1::from_iter((0..n).map(|i| (i as f64) / ((n - 1) as f64)));
969 let assign = Array1::<f64>::ones(n);
970 // Data is a straight line; the curved decoded image is a parabola that
971 // departs from it, so a straight line fits the data strictly better.
972 let mut data = Array2::<f64>::zeros((n, 2));
973 let mut decoded = Array2::<f64>::zeros((n, 2));
974 for i in 0..n {
975 let t = coords[i];
976 data[[i, 0]] = t;
977 data[[i, 1]] = 0.5 * t;
978 decoded[[i, 0]] = t;
979 decoded[[i, 1]] = t * t; // bends away from the linear data
980 }
981 let (linear, curved, _) = build_atom_candidates(
982 coords.view(),
983 assign.view(),
984 decoded.view(),
985 data.view(),
986 // a real curved Θ above the floor so the dominance floor does not fire
987 6,
988 None,
989 Some(1.0),
990 )
991 .expect("candidate pair");
992 let choice =
993 gam_solve::evidence::select_hybrid_atom(&[linear, curved]).expect("non-empty slot");
994 assert!(
995 choice.param.is_linear(),
996 "a curved image that fits the data worse than its own line must yield \
997 to the linear special case on common-data evidence (#1202)"
998 );
999 }
1000
1001 /// The LINEAR candidate's Laplace logdet is the genuine weighted-design Gram
1002 /// determinant `p·(log w_sum + log s_tt)` with `w_sum = Σ a_k²`, `s_tt =
1003 /// Σ a_k²(t − t̄)²` — it INCLUDES the coordinate-spread term `log(s_tt)`
1004 /// (#1203). Verify both contributions are present by reading the logdet off a
1005 /// candidate whose linear residual is exactly zero (response residual = the
1006 /// fitted line), so `NLE_linear = ½·logdet`. Doubling the coordinate spread
1007 /// (at fixed assignment mass) scales `s_tt` by 4 → logdet += `p·log(4)`;
1008 /// doubling all assignment masses scales BOTH `w_sum` and `s_tt` by 4 (they
1009 /// are quadratic in `a_k`) → logdet += `2p·log(4)`.
1010 #[test]
1011 fn linear_logdet_includes_weighted_coordinate_spread() {
1012 let n = 40;
1013 let p = 2usize;
1014 // Read the logdet back off a candidate with zero linear residual: the
1015 // response residual is exactly `a_k·(line)`, so the WLS line recovers it
1016 // with RSS == 0 and `NLE_linear = ½·logdet`.
1017 let logdet = |coords: &Array1<f64>, assign: &Array1<f64>| -> f64 {
1018 // A straight image; the response residual is the same line scaled by
1019 // the per-row assignment mass `a_k`, so the prediction `a_k·(b₀+dt·b₁)`
1020 // matches it exactly and linear_rss == 0.
1021 let line = |t: f64| -> [f64; 2] { [t, 0.6 * t] };
1022 let mut decoded = Array2::<f64>::zeros((n, p));
1023 let mut data = Array2::<f64>::zeros((n, p));
1024 for i in 0..n {
1025 let l = line(coords[i]);
1026 decoded[[i, 0]] = l[0];
1027 decoded[[i, 1]] = l[1];
1028 data[[i, 0]] = assign[i] * l[0];
1029 data[[i, 1]] = assign[i] * l[1];
1030 }
1031 let (linear, _curved, _) = build_atom_candidates(
1032 coords.view(),
1033 assign.view(),
1034 decoded.view(),
1035 data.view(),
1036 10,
1037 None,
1038 Some(0.0),
1039 )
1040 .expect("straight residual yields a pair");
1041 2.0 * linear.negative_log_evidence // = logdet (linear_rss == 0)
1042 };
1043
1044 let base_coords =
1045 Array1::from_iter((0..n).map(|i| -1.0 + 2.0 * (i as f64) / ((n - 1) as f64)));
1046 let ones = Array1::<f64>::ones(n);
1047
1048 // Doubling the coordinate spread → s_tt ×4, w_sum fixed → logdet += p·log(4).
1049 let wide_coords = base_coords.mapv(|t| 2.0 * t);
1050 let d_spread = logdet(&wide_coords, &ones) - logdet(&base_coords, &ones);
1051 assert!(
1052 (d_spread - (p as f64) * 4.0_f64.ln()).abs() < 1e-9,
1053 "linear logdet must move by p·log(4) when coordinate spread doubles \
1054 (got {d_spread}); the spread term log(s_tt) must be present"
1055 );
1056
1057 // Doubling all assignment masses → w_sum ×4 AND s_tt ×4 (quadratic in a_k)
1058 // → logdet += 2p·log(4).
1059 let twos = Array1::<f64>::from_elem(n, 2.0);
1060 let d_weight = logdet(&base_coords, &twos) - logdet(&base_coords, &ones);
1061 assert!(
1062 (d_weight - 2.0 * (p as f64) * 4.0_f64.ln()).abs() < 1e-9,
1063 "linear logdet must move by 2p·log(4) when all assignment masses double \
1064 (got {d_weight})"
1065 );
1066 }
1067
1068 /// #1223 — the curved arm's Laplace complexity is the REAL weighted-design
1069 /// Gram log-determinant `p·log|ΦᵀWΦ|_+`, not a parameter-count proxy. Build a
1070 /// curved design whose columns are the constant and the centered coordinate
1071 /// (a 2-column basis), so `ΦᵀWΦ = diag(w_sum, s_tt)` exactly matches the
1072 /// linear arm's data Gram, and assert `curved_design_gram_logdet` returns
1073 /// `p·(log w_sum + log s_tt)` — the same determinant the linear arm reports
1074 /// on the same design weight. A proxy `M·log(w_sum)` would instead omit the
1075 /// `log(s_tt)` spread term, so this pins the genuine determinant.
1076 #[test]
1077 fn curved_gram_logdet_is_real_weighted_design_determinant() {
1078 let n = 40;
1079 let p = 3usize;
1080 let coords = Array1::from_iter((0..n).map(|i| -1.0 + 2.0 * (i as f64) / ((n - 1) as f64)));
1081 let assign = Array1::<f64>::from_iter((0..n).map(|i| 0.5 + 0.01 * (i as f64)));
1082
1083 // Mass-weighted coordinate mean and spread under wᵢ = a_k².
1084 let mut w_sum = 0.0;
1085 let mut t_bar = 0.0;
1086 for i in 0..n {
1087 let w = assign[i] * assign[i];
1088 w_sum += w;
1089 t_bar += w * coords[i];
1090 }
1091 t_bar /= w_sum;
1092 let mut s_tt = 0.0;
1093 for i in 0..n {
1094 let dt = coords[i] - t_bar;
1095 s_tt += assign[i] * assign[i] * dt * dt;
1096 }
1097
1098 // Curved design columns: [1, (t − t̄)]. Its weighted Gram is exactly
1099 // diag(w_sum, s_tt) (the cross term Σ w·(t−t̄) vanishes by construction),
1100 // so log|ΦᵀWΦ| = log(w_sum) + log(s_tt).
1101 let mut phi = Array2::<f64>::zeros((n, 2));
1102 for i in 0..n {
1103 phi[[i, 0]] = 1.0;
1104 phi[[i, 1]] = coords[i] - t_bar;
1105 }
1106 let got = curved_design_gram_logdet(phi.view(), assign.view(), p)
1107 .expect("non-degenerate curved design has a determinant");
1108 let want = (p as f64) * (w_sum.ln() + s_tt.ln());
1109 assert!(
1110 (got - want).abs() < 1e-9,
1111 "curved Gram logdet must be the real p·log|ΦᵀWΦ| = {want}, got {got}"
1112 );
1113
1114 // A rank-deficient design (a duplicated column) drops the null direction:
1115 // its determinant equals that of the single retained constant column,
1116 // p·log(w_sum), NOT a 2-column proxy.
1117 let mut phi_dup = Array2::<f64>::zeros((n, 2));
1118 for i in 0..n {
1119 phi_dup[[i, 0]] = 1.0;
1120 phi_dup[[i, 1]] = 1.0;
1121 }
1122 let got_dup = curved_design_gram_logdet(phi_dup.view(), assign.view(), p)
1123 .expect("rank-1 design still has a positive determinant");
1124 let want_dup = (p as f64) * (2.0 * w_sum).ln();
1125 assert!(
1126 (got_dup - want_dup).abs() < 1e-9,
1127 "rank-deficient curved Gram must report only its positive direction \
1128 (p·log(2·w_sum) = {want_dup}), got {got_dup}"
1129 );
1130 }
1131
1132 /// A degenerate (single-point-mass) coordinate has no slope direction and is
1133 /// refused rather than adjudicated on a fabricated deviance.
1134 #[test]
1135 fn degenerate_coordinate_is_refused() {
1136 let n = 5;
1137 let coords = Array1::<f64>::from_elem(n, 0.5); // no spread
1138 let assign = Array1::<f64>::ones(n);
1139 let decoded = Array2::<f64>::zeros((n, 2));
1140 let data = Array2::<f64>::zeros((n, 2));
1141 assert!(
1142 build_atom_candidates(
1143 coords.view(),
1144 assign.view(),
1145 decoded.view(),
1146 data.view(),
1147 6,
1148 None,
1149 Some(0.0)
1150 )
1151 .is_none(),
1152 "a degenerate coordinate span must be refused"
1153 );
1154 }
1155}