Skip to main content

gam_solve/
evidence.rs

1//! Canonical Laplace evidence, IFT cascade, and topology selection.
2//!
3//! This module is the single canonical entry point for:
4//!
5//!   1. Laplace evidence `V(ρ, T) = F + (1/2) log|H| - (1/2) log|S(ρ)|+
6//!      - ((dim(H)-rank(S))/2) log(2π)`
7//!      evaluated at the arrow-Schur inner-loop fixed point.
8//!   2. The full IFT cascade `∂u*/∂β → ∂β*/∂ρ → ∂u*/∂ρ` through the three
9//!      continuous tiers `(u, β, ρ)`, per §2.2 / §2.4 / §2.6.
10//!   3. The per-`ρ` evidence gradient `∂V/∂ρ` via the arrow trace formula,
11//!      per §3.5 / §3.7 / §3.8.
12//!   4. Discrete topology selection across `{periodic, flat, sphere, torus}`,
13//!      per §4 (4.1 / 4.5 / 4.6).
14//!
15//! ## Crucial numerical invariants (proposal §1.7, §6.4, §6.5)
16//!
17//!   * Evidence log-determinants use **undamped** factors. The cached
18//!     `ArrowFactorCache::htt_factors_undamped` Cholesky factors of
19//!     `H_uu_i` (no `ridge_u`) are the ones that must enter
20//!     `Σ_i log|H_uu_i|`. Likewise a factored Schur log-det must be of
21//!     `A(0, 0) = H_ββ - Σ_i H_uβ_iᵀ H_uu_i⁻¹ H_uβ_i`, not the LM-damped
22//!     surrogate. Matrix-free evidence callers must provide the matching
23//!     undamped HVP so the same log-det is estimated by SLQ.
24//!   * IFT solves invert `H_uu`, not `H_uu + ridge_u I` (proposal §1.7,
25//!     §6.6). The evidence-side IFT predictor loop here uses the undamped
26//!     `htt_factors_undamped` factors for exactly this reason.
27//!   * Penalty pseudo-logdet `log|S(ρ)|+` is the prior penalty, distinct
28//!     from the arrow Schur complement (proposal §3.1, §3.6). The variable
29//!     names below preserve that distinction:
30//!       `arrow_schur_log_det`   = `log|A|` where `A` is the arrow Schur.
31//!       `penalty_log_det`       = `log|S_pen(ρ)|+` where `S_pen` is the
32//!                                 prior penalty matrix pseudo-logdet.
33//!
34//! ## Sign discipline (proposal §3.1, §4.3)
35//!
36//! `V` as written is the *negative log evidence* when `F` is the
37//! penalized negative log posterior. The maximizer of evidence is the
38//! minimizer of `V`. For the public API we expose **negative log
39//! evidence** under `laplace_evidence` and rank topologies by the
40//! **minimum** of the configured per-row or per-effective-dimension
41//! normalization (see `select_topology` below); equivalently the caller can
42//! negate and `argmax`.
43
44use faer::Side;
45use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
46
47use gam_linalg::faer_ndarray::FaerEigh;
48use gam_linalg::lanczos::{
49    SymmetricLanczosOptions, symmetric_lanczos_eigenpairs, symmetric_lanczos_log_quadrature,
50};
51use gam_linalg::triangular::cholesky_solve_vector;
52use crate::arrow_schur::{ArrowFactorCache, ArrowSchurSystem};
53use crate::priority_selection::{PriorityCandidate, rank_priority_candidates};
54
55pub const ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD: usize = 1024;
56const EVIDENCE_LOGDET_SLQ_PROBES: usize = 16;
57const EVIDENCE_LOGDET_LANCZOS_STEPS: usize = 32;
58const EVIDENCE_HVP_SYMMETRY_REL_TOL: f64 = 1e-8;
59const EVIDENCE_HVP_SYMMETRY_PROBES: usize = 4;
60
61/// Matrix-free SPD Hessian logdet source used when the arrow Schur factor is
62/// not materialized. The callback must apply the same undamped Hessian whose
63/// determinant enters the Laplace evidence.
64#[derive(Clone, Copy)]
65pub struct EvidenceHvpLogDet<'a> {
66    pub dim: usize,
67    pub apply: &'a dyn Fn(&[f64]) -> Vec<f64>,
68}
69
70/// Source for the Hessian log determinant in [`laplace_evidence`].
71#[derive(Clone, Copy)]
72pub enum EvidenceLogDetSource<'a> {
73    /// Use the exact arrow Cholesky factors, falling back to `fallback_hvp`
74    /// when the Schur factor is absent on a matrix-free solve.
75    FactoredArrow {
76        cache: &'a ArrowFactorCache,
77        fallback_hvp: Option<EvidenceHvpLogDet<'a>>,
78    },
79    /// Use an HVP callback directly. Dimensions at or below
80    /// [`ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD`] are materialized exactly;
81    /// larger operators use the same Rademacher-Lanczos SLQ constants as
82    /// `FrozenAnalyticPenaltyOp`.
83    Hvp(EvidenceHvpLogDet<'a>),
84}
85
86// ---------------------------------------------------------------------------
87// Topology candidate enum and selection result
88// ---------------------------------------------------------------------------
89
90/// Discrete topology choice for the latent coordinate domain.
91///
92/// Maps directly to the set `{periodic, flat, sphere, torus}`. No additional
93/// variants — unused candidate variants are deliberately not carried
94/// alongside the four-way selector.
95#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
96pub enum TopologyKind {
97    /// `S¹` or periodic interval (cyclic B-spline / periodic Duchon).
98    Periodic,
99    /// `Rᵈ` Euclidean Duchon / Matérn / thin-plate patch.
100    Flat,
101    /// `S²` embedded in `R³`, spherical Wahba/Sobolev basis.
102    Sphere,
103    /// `S¹ × S¹` mixed-periodicity Duchon.
104    Torus,
105}
106
107impl TopologyKind {
108    /// Tie-break priority — smaller wins. Per §4.6: `flat < periodic <
109    /// sphere < torus`.
110    pub fn complexity_rank(self) -> u8 {
111        match self {
112            TopologyKind::Flat => 0,
113            TopologyKind::Periodic => 1,
114            TopologyKind::Sphere => 2,
115            TopologyKind::Torus => 3,
116        }
117    }
118}
119
120/// One topology candidate together with the evidence ingredients it
121/// produced at its own fitted optimum.
122#[derive(Debug, Clone)]
123pub struct TopologyCandidate {
124    pub kind: TopologyKind,
125    /// Negative-log-evidence `V(ρ_T*, T)` evaluated at the candidate's own
126    /// fitted `(ρ_T*, β_T*, u_T*)`.
127    pub negative_log_evidence: f64,
128    /// Effective integrated dimension after rank/nullspace accounting. This
129    /// is the dimension used for per-complexity topology normalization.
130    pub effective_dim: f64,
131    /// Number of response rows used to fit this topology candidate. This is
132    /// the dimension used for per-observation topology normalization.
133    pub n_obs: usize,
134    /// `True` iff the candidate's continuous inner+outer fit converged
135    /// cleanly. Failed candidates are excluded from ranking (proposal
136    /// §4.4 item 7 and §6.11).
137    pub converged: bool,
138    /// Optional rationale string for excluded candidates (proposal
139    /// §6.11): `"sphere input not on S²"`, `"torus periods missing"`, etc.
140    pub exclusion_reason: Option<String>,
141}
142
143/// Outcome of [`select_topology`].
144#[derive(Debug, Clone)]
145pub struct SelectedTopology {
146    pub winner: TopologyKind,
147    /// All candidates sorted from best (lowest negative log evidence)
148    /// to worst, with excluded candidates appended last.
149    pub ranking: Vec<TopologyCandidate>,
150    /// `True` iff the top two finite scores fall within `tie_tolerance`.
151    /// Per §4.6 we still pick one — the simpler topology — but expose
152    /// the tie so callers can warn.
153    pub tie: bool,
154}
155
156/// Tolerance options for the topology comparator.
157#[derive(Debug, Clone, Copy)]
158pub struct TopologySelectOptions {
159    /// Maximum `|V_a - V_b|` for which two candidates are treated as
160    /// numerically tied after [`TopologyScoreScale`] normalization. Default
161    /// `1e-3` per proposal §4.6 examples.
162    pub tie_tolerance: f64,
163    /// Score scale used for discrete topology comparison. Raw evidence is
164    /// intentionally not a selector because candidates may have different
165    /// row counts and basis/nullspace dimensions.
166    pub score_scale: TopologyScoreScale,
167}
168
169/// Normalization applied before ranking topology candidates.
170#[derive(Debug, Clone, Copy, PartialEq, Eq)]
171pub enum TopologyScoreScale {
172    /// Compare negative log evidence per observation row.
173    PerObservation,
174    /// Compare negative log evidence per effective integrated dimension.
175    PerEffectiveDim,
176}
177
178/// Convergence controls for stacking retained topology predictive densities.
179#[derive(Debug, Clone, Copy)]
180pub struct StackingConfig {
181    pub max_iter: usize,
182    pub weight_tol: f64,
183}
184
185impl Default for StackingConfig {
186    fn default() -> Self {
187        Self {
188            max_iter: 1000,
189            weight_tol: 1e-10,
190        }
191    }
192}
193
194/// Simplex weights for retained topology candidates plus the achieved held-out
195/// mean log-score.
196#[derive(Debug, Clone)]
197pub struct StackingWeights {
198    pub weights: Array1<f64>,
199    pub mean_log_score: f64,
200    pub iterations: usize,
201}
202
203/// Solve the stacking-of-predictive-distributions weight problem from a
204/// per-observation held-out log-density table `log_density[i, k] = log p_k(y_i)`.
205///
206/// This belongs on the evidence surface rather than in a separate solver: it is
207/// the topology/evidence consumer that replaces winner-take-all only when the
208/// caller has retained candidate fits and per-point held-out densities.
209pub fn solve_stacking_weights(
210    log_density: ArrayView2<'_, f64>,
211    config: StackingConfig,
212) -> Result<StackingWeights, String> {
213    let n_obs = log_density.nrows();
214    let n_cand = log_density.ncols();
215    if n_cand == 0 {
216        return Err("stacking requires at least one candidate column".to_string());
217    }
218    if n_obs == 0 {
219        return Err("stacking requires at least one held-out observation row".to_string());
220    }
221
222    let kept_cols: Vec<usize> = (0..n_cand)
223        .filter(|&k| (0..n_obs).any(|i| log_density[[i, k]].is_finite()))
224        .collect();
225    if kept_cols.is_empty() {
226        return Err("stacking found no candidate with any finite held-out density".to_string());
227    }
228    let rows: Vec<usize> = (0..n_obs)
229        .filter(|&i| kept_cols.iter().any(|&k| log_density[[i, k]].is_finite()))
230        .collect();
231    if rows.is_empty() {
232        return Err("stacking found no held-out row with a finite density".to_string());
233    }
234
235    let kept = kept_cols.len();
236    let mut weights = Array1::<f64>::from_elem(kept, 1.0 / kept as f64);
237    let mut next = Array1::<f64>::zeros(kept);
238    let mut iterations = 0usize;
239    for _ in 0..config.max_iter {
240        iterations += 1;
241        next.fill(0.0);
242        let mut active_rows = 0usize;
243        for &row in &rows {
244            let mut row_max = f64::NEG_INFINITY;
245            for (local_col, &source_col) in kept_cols.iter().enumerate() {
246                let log_p = log_density[[row, source_col]];
247                if log_p.is_finite() && weights[local_col] > 0.0 {
248                    row_max = row_max.max(weights[local_col].ln() + log_p);
249                }
250            }
251            if !row_max.is_finite() {
252                continue;
253            }
254            let mut denom = 0.0_f64;
255            for (local_col, &source_col) in kept_cols.iter().enumerate() {
256                let log_p = log_density[[row, source_col]];
257                if log_p.is_finite() && weights[local_col] > 0.0 {
258                    denom += (weights[local_col].ln() + log_p - row_max).exp();
259                }
260            }
261            if denom <= 0.0 {
262                continue;
263            }
264            active_rows += 1;
265            let log_mix = row_max + denom.ln();
266            for (local_col, &source_col) in kept_cols.iter().enumerate() {
267                let log_p = log_density[[row, source_col]];
268                if log_p.is_finite() && weights[local_col] > 0.0 {
269                    next[local_col] += (weights[local_col].ln() + log_p - log_mix).exp();
270                }
271            }
272        }
273        if active_rows == 0 {
274            break;
275        }
276        next.mapv_inplace(|value| value / active_rows as f64);
277        let total = next.sum();
278        if total > 0.0 {
279            next.mapv_inplace(|value| value / total);
280        }
281        let delta = next
282            .iter()
283            .zip(weights.iter())
284            .fold(0.0_f64, |acc, (a, b)| acc.max((a - b).abs()));
285        weights.assign(&next);
286        if delta <= config.weight_tol {
287            break;
288        }
289    }
290
291    let mean_log_score = stacking_mean_log_score(log_density, &rows, &kept_cols, weights.view());
292    let mut full = Array1::<f64>::zeros(n_cand);
293    for (local_col, &source_col) in kept_cols.iter().enumerate() {
294        full[source_col] = weights[local_col];
295    }
296    Ok(StackingWeights {
297        weights: full,
298        mean_log_score,
299        iterations,
300    })
301}
302
303fn stacking_mean_log_score(
304    log_density: ArrayView2<'_, f64>,
305    rows: &[usize],
306    kept_cols: &[usize],
307    weights: ArrayView1<'_, f64>,
308) -> f64 {
309    let mut score_sum = 0.0_f64;
310    let mut counted = 0usize;
311    for &row in rows {
312        let mut row_max = f64::NEG_INFINITY;
313        for (local_col, &source_col) in kept_cols.iter().enumerate() {
314            let log_p = log_density[[row, source_col]];
315            if log_p.is_finite() && weights[local_col] > 0.0 {
316                row_max = row_max.max(weights[local_col].ln() + log_p);
317            }
318        }
319        if !row_max.is_finite() {
320            continue;
321        }
322        let mut denom = 0.0_f64;
323        for (local_col, &source_col) in kept_cols.iter().enumerate() {
324            let log_p = log_density[[row, source_col]];
325            if log_p.is_finite() && weights[local_col] > 0.0 {
326                denom += (weights[local_col].ln() + log_p - row_max).exp();
327            }
328        }
329        if denom > 0.0 {
330            score_sum += row_max + denom.ln();
331            counted += 1;
332        }
333    }
334    if counted == 0 {
335        f64::NEG_INFINITY
336    } else {
337        score_sum / counted as f64
338    }
339}
340
341/// Combine retained candidate response-scale means with stacking weights.
342pub fn stacked_predictive_mean(
343    weights: &Array1<f64>,
344    candidate_means: &[Array1<f64>],
345) -> Result<Array1<f64>, String> {
346    if candidate_means.len() != weights.len() {
347        return Err(format!(
348            "stacked_predictive_mean: {} weights but {} candidate mean vectors",
349            weights.len(),
350            candidate_means.len()
351        ));
352    }
353    let Some(first) = candidate_means.first() else {
354        return Err("stacked_predictive_mean requires at least one candidate".to_string());
355    };
356    let n_rows = first.len();
357    if candidate_means.iter().any(|means| means.len() != n_rows) {
358        return Err(
359            "stacked_predictive_mean: candidate mean vectors disagree on row count".to_string(),
360        );
361    }
362    let mut out = Array1::<f64>::zeros(n_rows);
363    for (weight, means) in weights.iter().zip(candidate_means) {
364        if *weight != 0.0 {
365            out.scaled_add(*weight, means);
366        }
367    }
368    Ok(out)
369}
370
371// ---------------------------------------------------------------------------
372// Discrete mixture rung (Object 3a / WP-C)
373// ---------------------------------------------------------------------------
374//
375// A `k`-component full-covariance Gaussian mixture fitted by deterministic
376// k-means++-style seeding (reusing `terms::basis` farthest-point k-means) plus
377// EM to a tolerance. It is priced by its free-parameter count and scored
378// through the SAME rank-aware Laplace/Tierney-Kadane normalizer as the smooth
379// topology candidates: `−V = loglik − ½ log|H| + ½ P log(2π)` with the
380// `−½ (dim(H) − rank(S)) log(2π)` normalizer evaluated at `dim(H) = P`,
381// `rank(S) = 0` (a fully likelihood-identified, unpenalized parametric model,
382// so every free parameter is unpenalized null-space). The Hessian log-det
383// `log|H|` is the observed (empirical-Fisher / BHHH) information
384// `H = Σ_i s_i s_iᵀ`, the exact, finite, SPD observed-information surrogate at
385// the EM optimum, fed through the same `laplace_evidence` entry point used by
386// the smooth rungs so the two model classes are comparable on the evidence
387// scale.
388
389/// Convergence + ladder controls for the discrete-mixture rung. All fields are
390/// fixed (no clock randomness, no env): deterministic seeding makes the fitted
391/// mixture a pure function of the data and `k`.
392#[derive(Debug, Clone, Copy)]
393pub struct GaussianMixtureConfig {
394    /// Maximum EM iterations.
395    pub max_iter: usize,
396    /// Relative mean-log-likelihood improvement tolerance for EM stopping.
397    pub loglik_tol: f64,
398    /// Ridge added to each component covariance for numerical SPD safety
399    /// (variance floor). A small fixed value, not a tuned knob.
400    pub covariance_floor: f64,
401    /// Maximum iterations for the deterministic k-means seeding pass.
402    pub kmeans_max_iter: usize,
403}
404
405impl Default for GaussianMixtureConfig {
406    fn default() -> Self {
407        Self {
408            max_iter: 200,
409            loglik_tol: 1e-7,
410            covariance_floor: 1e-6,
411            kmeans_max_iter: 25,
412        }
413    }
414}
415
416/// A fitted `k`-component full-covariance Gaussian mixture.
417#[derive(Debug, Clone)]
418pub struct GaussianMixtureFit {
419    /// Mixing weights, length `k`, on the simplex.
420    pub weights: Array1<f64>,
421    /// Component means, `k × d`.
422    pub means: Array2<f64>,
423    /// Component covariances, `k` matrices of shape `d × d` (SPD).
424    pub covariances: Vec<Array2<f64>>,
425    /// Number of mixture components.
426    pub k: usize,
427    /// Data dimension.
428    pub d: usize,
429    /// Number of rows used to fit.
430    pub n_obs: usize,
431    /// Maximised total log-likelihood `Σ_i log Σ_j w_j N(y_i; μ_j, Σ_j)`.
432    pub loglik: f64,
433    /// EM iterations taken.
434    pub iterations: usize,
435}
436
437impl GaussianMixtureFit {
438    /// Free-parameter count `P` of a `k`-component full-covariance mixture in
439    /// `d` dimensions: `(k − 1)` mixing weights on the simplex, `k·d` mean
440    /// coordinates, and `k · d(d+1)/2` covariance entries. This is the exact
441    /// quantity that enters the rank-aware normalizer as `dim(H) − rank(S)`.
442    pub fn num_free_parameters(&self) -> usize {
443        let cov_per = self.d * (self.d + 1) / 2;
444        (self.k - 1) + self.k * self.d + self.k * cov_per
445    }
446
447    /// Per-observation log predictive density `log p(y_i)` under the fitted
448    /// mixture, length `n`. This is the held-out-density column source for
449    /// cross-class stacking when the mixture is evaluated on a held-out fold.
450    pub fn per_point_log_density(&self, data: ArrayView2<'_, f64>) -> Result<Array1<f64>, String> {
451        if data.ncols() != self.d {
452            return Err(format!(
453                "mixture log-density expects {} columns, got {}",
454                self.d,
455                data.ncols()
456            ));
457        }
458        let n = data.nrows();
459        let mut comp = vec![GaussianComponentEval::new(self.d); self.k];
460        for j in 0..self.k {
461            comp[j] = GaussianComponentEval::factor(self.means.row(j), &self.covariances[j])?;
462        }
463        let mut out = Array1::<f64>::zeros(n);
464        let log_w: Vec<f64> = self
465            .weights
466            .iter()
467            .map(|w| w.max(f64::MIN_POSITIVE).ln())
468            .collect();
469        for i in 0..n {
470            let row = data.row(i);
471            let mut log_terms = vec![f64::NEG_INFINITY; self.k];
472            let mut max_term = f64::NEG_INFINITY;
473            for j in 0..self.k {
474                let lt = log_w[j] + comp[j].log_density(row);
475                log_terms[j] = lt;
476                if lt > max_term {
477                    max_term = lt;
478                }
479            }
480            out[i] = log_sum_exp(&log_terms, max_term);
481        }
482        Ok(out)
483    }
484
485    /// Rank-aware Laplace **negative** log evidence on the SAME scale as the
486    /// smooth topology rungs. `−V = loglik − ½ log|H| + ½ P log(2π)`, realised
487    /// by calling [`laplace_evidence`] with `residual_objective = −loglik`,
488    /// `penalty_log_det = 0`, `penalty_rank = 0`, `effective_dim = P`, and
489    /// `log|H|` the observed empirical-Fisher information at the optimum.
490    pub fn laplace_negative_log_evidence(&self, data: ArrayView2<'_, f64>) -> Result<f64, String> {
491        let p = self.num_free_parameters();
492        let information = self.empirical_fisher_information(data)?;
493        if information.nrows() != p {
494            return Err(format!(
495                "mixture empirical-Fisher information has dim {} but expected free-parameter count {p}",
496                information.nrows()
497            ));
498        }
499        let apply_info = |x: &[f64]| -> Vec<f64> {
500            let mut out = vec![0.0_f64; p];
501            for r in 0..p {
502                let mut acc = 0.0_f64;
503                for c in 0..p {
504                    acc += information[[r, c]] * x[c];
505                }
506                out[r] = acc;
507            }
508            out
509        };
510        let hvp = EvidenceHvpLogDet {
511            dim: p,
512            apply: &apply_info,
513        };
514        let v = laplace_evidence(
515            EvidenceLogDetSource::Hvp(hvp),
516            0.0,
517            -self.loglik,
518            p as f64,
519            0.0,
520        );
521        if !v.is_finite() {
522            return Err("mixture Laplace evidence is not finite".to_string());
523        }
524        Ok(v)
525    }
526
527    /// Observed empirical-Fisher (BHHH) information `H = Σ_i s_i s_iᵀ`, where
528    /// `s_i = ∇_θ log p(y_i)` is the per-observation score in the
529    /// free-parameter coordinates: softmax-logit mixing weights (`k − 1`),
530    /// component means (`k·d`), and the lower-triangular covariance entries
531    /// (`k · d(d+1)/2`) of each component, in that block order. This SPD matrix
532    /// is the genuine observed-information surrogate evaluated at the EM
533    /// optimum — its dimension is exactly `P`, which is what enters the
534    /// rank-aware normalizer.
535    fn empirical_fisher_information(
536        &self,
537        data: ArrayView2<'_, f64>,
538    ) -> Result<Array2<f64>, String> {
539        if data.ncols() != self.d {
540            return Err(format!(
541                "mixture information expects {} columns, got {}",
542                self.d,
543                data.ncols()
544            ));
545        }
546        let n = data.nrows();
547        let p = self.num_free_parameters();
548        let cov_per = self.d * (self.d + 1) / 2;
549        // Precompute per-component evaluators (mean, precision = Σ⁻¹).
550        let mut comp = Vec::with_capacity(self.k);
551        for j in 0..self.k {
552            comp.push(GaussianComponentEval::factor(
553                self.means.row(j),
554                &self.covariances[j],
555            )?);
556        }
557        let log_w: Vec<f64> = self
558            .weights
559            .iter()
560            .map(|w| w.max(f64::MIN_POSITIVE).ln())
561            .collect();
562
563        let mean_base = self.k - 1;
564        let cov_base = mean_base + self.k * self.d;
565
566        let mut info = Array2::<f64>::zeros((p, p));
567        let mut score = vec![0.0_f64; p];
568        for i in 0..n {
569            let row = data.row(i);
570            // Responsibilities r_j = w_j N_j / Σ.
571            let mut log_terms = vec![0.0_f64; self.k];
572            let mut max_term = f64::NEG_INFINITY;
573            for j in 0..self.k {
574                let lt = log_w[j] + comp[j].log_density(row);
575                log_terms[j] = lt;
576                if lt > max_term {
577                    max_term = lt;
578                }
579            }
580            let log_mix = log_sum_exp(&log_terms, max_term);
581            let resp: Vec<f64> = log_terms.iter().map(|lt| (lt - log_mix).exp()).collect();
582
583            for s in score.iter_mut() {
584                *s = 0.0;
585            }
586            // Softmax-logit mixing score: ∂/∂α_j log p = r_j − w_j for the free
587            // logits j = 1..k-1 (component 0 is the reference / pinned logit).
588            for j in 1..self.k {
589                score[j - 1] = resp[j] - self.weights[j];
590            }
591            // Mean score: ∂/∂μ_j log p = r_j · Σ_j⁻¹ (y − μ_j).
592            // Covariance score (lower-tri entries): ∂/∂Σ_j contracted through
593            // the symmetric chain rule, r_j · ½ (Σ⁻¹ v vᵀ Σ⁻¹ − Σ⁻¹) with
594            // off-diagonal entries doubled for the symmetric parameterization.
595            for j in 0..self.k {
596                let prec_v = comp[j].precision_times_residual(row); // Σ⁻¹ (y − μ_j)
597                let mbo = mean_base + j * self.d;
598                for c in 0..self.d {
599                    score[mbo + c] = resp[j] * prec_v[c];
600                }
601                let cbo = cov_base + j * cov_per;
602                let mut idx = 0usize;
603                for a in 0..self.d {
604                    for b in 0..=a {
605                        let outer = prec_v[a] * prec_v[b];
606                        let prec_ab = comp[j].precision[[a, b]];
607                        let mut g = 0.5 * (outer - prec_ab);
608                        if a != b {
609                            // Off-diagonal entry appears twice in the symmetric
610                            // matrix, so its free-parameter derivative doubles.
611                            g *= 2.0;
612                        }
613                        score[cbo + idx] = resp[j] * g;
614                        idx += 1;
615                    }
616                }
617            }
618            // Accumulate outer product s_i s_iᵀ.
619            for r in 0..p {
620                let sr = score[r];
621                if sr == 0.0 {
622                    continue;
623                }
624                for c in 0..p {
625                    info[[r, c]] += sr * score[c];
626                }
627            }
628        }
629        // Symmetrize and add a unit-precision ridge `I`. This is the Hessian
630        // contribution of a standard-normal prior on the (natural) parameters,
631        // making the object a proper MAP observed-information `H = I_prior +
632        // Σ_i s_i s_iᵀ`. It guarantees SPD (well-defined `log|H|`) for any `n`
633        // and is a fixed prior, not a tuned knob — the same unit-information
634        // prior the rank-aware normalizer assumes when it credits each free
635        // parameter one `log(2π)` of integration volume.
636        for r in 0..p {
637            for c in (r + 1)..p {
638                let avg = 0.5 * (info[[r, c]] + info[[c, r]]);
639                info[[r, c]] = avg;
640                info[[c, r]] = avg;
641            }
642            info[[r, r]] += 1.0;
643        }
644        Ok(info)
645    }
646}
647
648/// Cached per-component Gaussian evaluator: mean, precision `Σ⁻¹`, and the
649/// log-normalizing constant `−½(d log 2π + log|Σ|)`.
650#[derive(Debug, Clone)]
651struct GaussianComponentEval {
652    mean: Array1<f64>,
653    precision: Array2<f64>,
654    log_norm: f64,
655    d: usize,
656}
657
658impl GaussianComponentEval {
659    fn new(d: usize) -> Self {
660        Self {
661            mean: Array1::zeros(d),
662            precision: Array2::eye(d),
663            log_norm: 0.0,
664            d,
665        }
666    }
667
668    fn factor(mean: ArrayView1<'_, f64>, cov: &Array2<f64>) -> Result<Self, String> {
669        let d = mean.len();
670        if cov.nrows() != d || cov.ncols() != d {
671            return Err(format!(
672                "mixture component covariance must be {d}x{d}, got {}x{}",
673                cov.nrows(),
674                cov.ncols()
675            ));
676        }
677        let (evals, evecs) = cov
678            .eigh(Side::Lower)
679            .map_err(|e| format!("mixture component covariance eigendecomposition failed: {e}"))?;
680        let mut log_det = 0.0_f64;
681        let mut inv_evals = Array1::<f64>::zeros(d);
682        for (idx, &ev) in evals.iter().enumerate() {
683            if !ev.is_finite() || ev <= 0.0 {
684                return Err(format!(
685                    "mixture component covariance is not SPD: eigenvalue {idx} is {ev:.3e}"
686                ));
687            }
688            log_det += ev.ln();
689            inv_evals[idx] = 1.0 / ev;
690        }
691        // Σ⁻¹ = V diag(1/λ) Vᵀ.
692        let mut precision = Array2::<f64>::zeros((d, d));
693        for a in 0..d {
694            for b in 0..d {
695                let mut acc = 0.0_f64;
696                for m in 0..d {
697                    acc += evecs[[a, m]] * inv_evals[m] * evecs[[b, m]];
698                }
699                precision[[a, b]] = acc;
700            }
701        }
702        let log_norm = -0.5 * (d as f64 * (2.0 * std::f64::consts::PI).ln() + log_det);
703        Ok(Self {
704            mean: mean.to_owned(),
705            precision,
706            log_norm,
707            d,
708        })
709    }
710
711    #[inline]
712    fn log_density(&self, y: ArrayView1<'_, f64>) -> f64 {
713        let pv = self.precision_times_residual(y);
714        let mut quad = 0.0_f64;
715        for c in 0..self.d {
716            quad += (y[c] - self.mean[c]) * pv[c];
717        }
718        self.log_norm - 0.5 * quad
719    }
720
721    /// `Σ⁻¹ (y − μ)`.
722    #[inline]
723    fn precision_times_residual(&self, y: ArrayView1<'_, f64>) -> Vec<f64> {
724        let mut out = vec![0.0_f64; self.d];
725        for a in 0..self.d {
726            let mut acc = 0.0_f64;
727            for b in 0..self.d {
728                acc += self.precision[[a, b]] * (y[b] - self.mean[b]);
729            }
730            out[a] = acc;
731        }
732        out
733    }
734}
735
736#[inline]
737fn log_sum_exp(terms: &[f64], max_term: f64) -> f64 {
738    if !max_term.is_finite() {
739        return f64::NEG_INFINITY;
740    }
741    let mut acc = 0.0_f64;
742    for &t in terms {
743        acc += (t - max_term).exp();
744    }
745    max_term + acc.ln()
746}
747
748/// Fit a `k`-component full-covariance Gaussian mixture by deterministic
749/// k-means++-style seeding (reusing the `terms::basis` farthest-point k-means,
750/// a pure function of the data — no clock randomness) followed by EM to the
751/// configured tolerance.
752///
753/// The fit is deterministic given `(data, k, config)`: the seed is the
754/// farthest-point/k-means center selection, EM is a deterministic map, so
755/// re-running yields the identical mixture.
756pub fn fit_gaussian_mixture(
757    data: ArrayView2<'_, f64>,
758    k: usize,
759    config: GaussianMixtureConfig,
760) -> Result<GaussianMixtureFit, String> {
761    let n = data.nrows();
762    let d = data.ncols();
763    if k == 0 {
764        return Err("gaussian mixture requires k >= 1".to_string());
765    }
766    if d == 0 {
767        return Err("gaussian mixture requires at least one column".to_string());
768    }
769    if k > n {
770        return Err(format!(
771            "gaussian mixture requested {k} components but data has {n} rows"
772        ));
773    }
774    // Deterministic k-means++-style seeding via the shared basis k-means
775    // (farthest-point init + Lloyd iterations). Fixed by construction.
776    let centers = gam_terms::basis::select_centers_by_strategy(
777        data,
778        &gam_terms::basis::CenterStrategy::KMeans {
779            num_centers: k,
780            max_iter: config.kmeans_max_iter,
781        },
782    )
783    .map_err(|e| format!("gaussian mixture k-means seeding failed: {e}"))?;
784    if centers.nrows() != k || centers.ncols() != d {
785        return Err(format!(
786            "gaussian mixture seeding returned {}x{} centers, expected {k}x{d}",
787            centers.nrows(),
788            centers.ncols()
789        ));
790    }
791
792    let mut means = centers;
793    // Seed covariances from the global data covariance (shared start).
794    let global_cov = data_covariance(data, config.covariance_floor);
795    let mut covariances = vec![global_cov; k];
796    let mut weights = Array1::<f64>::from_elem(k, 1.0 / k as f64);
797
798    let mut resp = Array2::<f64>::zeros((n, k));
799    let mut prev_mean_ll = f64::NEG_INFINITY;
800    let mut total_loglik = f64::NEG_INFINITY;
801    let mut iterations = 0usize;
802
803    for iter in 0..config.max_iter.max(1) {
804        iterations = iter + 1;
805        // E-step: responsibilities and total log-likelihood.
806        let mut comp = Vec::with_capacity(k);
807        for j in 0..k {
808            comp.push(GaussianComponentEval::factor(
809                means.row(j),
810                &covariances[j],
811            )?);
812        }
813        let log_w: Vec<f64> = weights
814            .iter()
815            .map(|w| w.max(f64::MIN_POSITIVE).ln())
816            .collect();
817        total_loglik = 0.0;
818        for i in 0..n {
819            let yrow = data.row(i);
820            let mut log_terms = vec![0.0_f64; k];
821            let mut max_term = f64::NEG_INFINITY;
822            for j in 0..k {
823                let lt = log_w[j] + comp[j].log_density(yrow);
824                log_terms[j] = lt;
825                if lt > max_term {
826                    max_term = lt;
827                }
828            }
829            let log_mix = log_sum_exp(&log_terms, max_term);
830            total_loglik += log_mix;
831            for j in 0..k {
832                resp[[i, j]] = (log_terms[j] - log_mix).exp();
833            }
834        }
835        let mean_ll = total_loglik / n as f64;
836        if iter > 0 {
837            let denom = prev_mean_ll.abs().max(1.0);
838            if (mean_ll - prev_mean_ll).abs() / denom <= config.loglik_tol {
839                break;
840            }
841        }
842        prev_mean_ll = mean_ll;
843
844        // M-step.
845        let mut nk = vec![0.0_f64; k];
846        for j in 0..k {
847            let mut sum = 0.0_f64;
848            for i in 0..n {
849                sum += resp[[i, j]];
850            }
851            nk[j] = sum.max(f64::MIN_POSITIVE);
852        }
853        for j in 0..k {
854            weights[j] = nk[j] / n as f64;
855            // Means.
856            let mut mu = Array1::<f64>::zeros(d);
857            for i in 0..n {
858                let r = resp[[i, j]];
859                if r == 0.0 {
860                    continue;
861                }
862                for c in 0..d {
863                    mu[c] += r * data[[i, c]];
864                }
865            }
866            mu.mapv_inplace(|v| v / nk[j]);
867            for c in 0..d {
868                means[[j, c]] = mu[c];
869            }
870            // Covariance with a fixed diagonal floor for SPD safety.
871            let mut cov = Array2::<f64>::zeros((d, d));
872            for i in 0..n {
873                let r = resp[[i, j]];
874                if r == 0.0 {
875                    continue;
876                }
877                for a in 0..d {
878                    let da = data[[i, a]] - mu[a];
879                    for b in 0..d {
880                        cov[[a, b]] += r * da * (data[[i, b]] - mu[b]);
881                    }
882                }
883            }
884            cov.mapv_inplace(|v| v / nk[j]);
885            for a in 0..d {
886                cov[[a, a]] += config.covariance_floor;
887            }
888            covariances[j] = cov;
889        }
890    }
891
892    Ok(GaussianMixtureFit {
893        weights,
894        means,
895        covariances,
896        k,
897        d,
898        n_obs: n,
899        loglik: total_loglik,
900        iterations,
901    })
902}
903
904/// Global data covariance with a fixed diagonal floor (used to seed EM).
905fn data_covariance(data: ArrayView2<'_, f64>, floor: f64) -> Array2<f64> {
906    let n = data.nrows();
907    let d = data.ncols();
908    let mut mean = Array1::<f64>::zeros(d);
909    for i in 0..n {
910        for c in 0..d {
911            mean[c] += data[[i, c]];
912        }
913    }
914    mean.mapv_inplace(|v| v / n.max(1) as f64);
915    let mut cov = Array2::<f64>::zeros((d, d));
916    for i in 0..n {
917        for a in 0..d {
918            let da = data[[i, a]] - mean[a];
919            for b in 0..d {
920                cov[[a, b]] += da * (data[[i, b]] - mean[b]);
921            }
922        }
923    }
924    let inv = 1.0 / (n.max(1) as f64);
925    cov.mapv_inplace(|v| v * inv);
926    for a in 0..d {
927        cov[[a, a]] += floor;
928    }
929    cov
930}
931
932// ---------------------------------------------------------------------------
933// Structured-union candidates (#907)
934// ---------------------------------------------------------------------------
935//
936// A *union* candidate is a small FIXED composite of named component structures
937// joined by a hard row-responsibility split. Unlike the discrete-mixture rung
938// (which is one free k-component Gaussian density), a union pins each component
939// to a specific generative STRUCTURE (a circle, a line, a point cluster) and
940// asks whether the data is better explained as the disjoint sum of those
941// structures than by any single pure rung.
942//
943// Each component is fit on its responsibility group as its own parametric
944// generative density and scored through the SAME rank-aware Laplace /
945// Tierney-Kadane normalizer used by the smooth rungs and the mixture rung:
946// `−V_c = loglik_c − ½ log|H_c| + ½ P_c log(2π)` with `H_c` the observed
947// empirical-Fisher (BHHH) information `I + Σ s_i s_iᵀ` at the component optimum
948// (`rank(S)=0`, fully likelihood-identified). The union's evidence is the SUM
949// `V = Σ_c V_c` (the components partition the rows, so their log-likelihoods
950// add and their Hessians are block-diagonal — `log|H| = Σ_c log|H_c|`). The
951// complexity price is the TOTAL free-parameter count across all components,
952// which is exactly what the summed `+ ½ Σ_c P_c log(2π)` normalizer charges.
953// A union is therefore strictly more expensive than either pure component, so
954// it can only win when the structured split buys enough likelihood to pay for
955// its extra parameters — the negative-control discipline of #907.
956
957/// The fixed ladder of structured-union composites. Deterministic and closed:
958/// open-ended structure search stays owned by #976's move set; these three are
959/// the only composites the topology race may select.
960#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
961pub enum UnionStructure {
962    /// Two circles (two well-separated periodic loops).
963    CircleCircle,
964    /// One circle plus one isolated point cluster (a loop with an outlier blob).
965    CirclePointCluster,
966    /// One line (anisotropic cluster) plus one isolated point cluster.
967    LineCluster,
968}
969
970/// The fixed structured-union ladder, in stable order.
971pub const UNION_STRUCTURE_LADDER: &[UnionStructure] = &[
972    UnionStructure::CircleCircle,
973    UnionStructure::CirclePointCluster,
974    UnionStructure::LineCluster,
975];
976
977/// The per-component generative structure a union pins each responsibility group
978/// to. `Line` and `PointCluster` share the full-covariance Gaussian density
979/// (a line is an anisotropic Gaussian — the covariance, not a different
980/// parameterization, is what distinguishes them); `Circle` is a genuinely
981/// different density on `(radius, angle)`.
982#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
983pub enum UnionComponentKind {
984    Circle,
985    Line,
986    PointCluster,
987}
988
989impl UnionStructure {
990    /// Stable display name, e.g. `"union_circle+circle"`.
991    pub const fn as_str(self) -> &'static str {
992        match self {
993            UnionStructure::CircleCircle => "union_circle+circle",
994            UnionStructure::CirclePointCluster => "union_circle+cluster",
995            UnionStructure::LineCluster => "union_line+cluster",
996        }
997    }
998
999    /// The fixed ordered component structures of this union.
1000    pub const fn components(self) -> &'static [UnionComponentKind] {
1001        match self {
1002            UnionStructure::CircleCircle => {
1003                &[UnionComponentKind::Circle, UnionComponentKind::Circle]
1004            }
1005            UnionStructure::CirclePointCluster => {
1006                &[UnionComponentKind::Circle, UnionComponentKind::PointCluster]
1007            }
1008            UnionStructure::LineCluster => {
1009                &[UnionComponentKind::Line, UnionComponentKind::PointCluster]
1010            }
1011        }
1012    }
1013
1014    /// Number of components (= the responsibility-split order `m`).
1015    pub const fn num_components(self) -> usize {
1016        self.components().len()
1017    }
1018}
1019
1020/// One fitted component of a union: its pinned structure, the rows it owns
1021/// (after the hard responsibility split), its free-parameter count, and its
1022/// rank-aware Laplace negative-log-evidence on the common scale.
1023#[derive(Debug, Clone)]
1024pub struct UnionComponentFit {
1025    pub kind: UnionComponentKind,
1026    pub row_count: usize,
1027    pub num_parameters: usize,
1028    pub negative_log_evidence: f64,
1029}
1030
1031/// A fitted structured-union candidate: the composite kind, the per-component
1032/// fits, the SUMMED rank-aware Laplace negative-log-evidence, and the TOTAL
1033/// free-parameter count across components (the complexity price).
1034#[derive(Debug, Clone)]
1035pub struct UnionStructureFit {
1036    pub structure: UnionStructure,
1037    pub components: Vec<UnionComponentFit>,
1038    /// `Σ_c V_c` — summed rank-aware Laplace negative-log-evidence (lower wins).
1039    pub negative_log_evidence: f64,
1040    /// `Σ_c P_c` — total free-parameter count across components.
1041    pub total_parameters: usize,
1042}
1043
1044/// Hard responsibility split of `0..n` into `m` groups by argmax of the
1045/// deterministic `m`-component Gaussian-mixture responsibilities. Reuses the
1046/// mixture rung's seeding + EM so the split is a pure function of the data and
1047/// `m` (no clock). Returns one row-index vector per component.
1048pub fn union_responsibility_split(
1049    data: ArrayView2<'_, f64>,
1050    m: usize,
1051    config: GaussianMixtureConfig,
1052) -> Result<Vec<Vec<usize>>, String> {
1053    let n = data.nrows();
1054    if m == 0 {
1055        return Err("union split requires at least one component".to_string());
1056    }
1057    if m > n {
1058        return Err(format!(
1059            "union split requested {m} groups but data has {n} rows"
1060        ));
1061    }
1062    if m == 1 {
1063        return Ok(vec![(0..n).collect()]);
1064    }
1065    let fit = fit_gaussian_mixture(data, m, config)?;
1066    let mut groups: Vec<Vec<usize>> = vec![Vec::new(); m];
1067    // Hard assignment by argmax per-component log responsibility.
1068    let mut comp = Vec::with_capacity(m);
1069    for j in 0..m {
1070        comp.push(GaussianComponentEval::factor(
1071            fit.means.row(j),
1072            &fit.covariances[j],
1073        )?);
1074    }
1075    let log_w: Vec<f64> = fit
1076        .weights
1077        .iter()
1078        .map(|w| w.max(f64::MIN_POSITIVE).ln())
1079        .collect();
1080    for i in 0..n {
1081        let row = data.row(i);
1082        let mut best_j = 0usize;
1083        let mut best_lt = f64::NEG_INFINITY;
1084        for j in 0..m {
1085            let lt = log_w[j] + comp[j].log_density(row);
1086            if lt > best_lt {
1087                best_lt = lt;
1088                best_j = j;
1089            }
1090        }
1091        groups[best_j].push(i);
1092    }
1093    Ok(groups)
1094}
1095
1096/// Fit one structured-union candidate: hard-split the rows into one group per
1097/// component, fit each component's pinned density, and SUM the rank-aware
1098/// Laplace negative-log-evidence. The complexity price is the total
1099/// free-parameter count across components.
1100///
1101/// Returns an error if any component group is too small to identify its
1102/// structure (so an over-priced or non-identifiable composite simply does not
1103/// enter the race rather than scoring spuriously well).
1104pub fn fit_union_structure(
1105    data: ArrayView2<'_, f64>,
1106    structure: UnionStructure,
1107    config: GaussianMixtureConfig,
1108) -> Result<UnionStructureFit, String> {
1109    let comps = structure.components();
1110    let m = comps.len();
1111    let groups = union_responsibility_split(data, m, config)?;
1112    let mut fits = Vec::with_capacity(m);
1113    let mut total_nle = 0.0_f64;
1114    let mut total_parameters = 0usize;
1115    for (kind, rows) in comps.iter().zip(groups.iter()) {
1116        let group = gather_union_rows(data, rows);
1117        let (nle, p) = fit_union_component(group.view(), *kind, config)?;
1118        if !nle.is_finite() {
1119            return Err(format!(
1120                "union {} component {:?} produced non-finite evidence",
1121                structure.as_str(),
1122                kind
1123            ));
1124        }
1125        total_nle += nle;
1126        total_parameters += p;
1127        fits.push(UnionComponentFit {
1128            kind: *kind,
1129            row_count: rows.len(),
1130            num_parameters: p,
1131            negative_log_evidence: nle,
1132        });
1133    }
1134    Ok(UnionStructureFit {
1135        structure,
1136        components: fits,
1137        negative_log_evidence: total_nle,
1138        total_parameters,
1139    })
1140}
1141
1142/// Fit the whole fixed union ladder and rank in-class by summed rank-aware
1143/// Laplace evidence (lower wins). Composites that fail to fit (e.g. a group too
1144/// small to identify a circle) are skipped. Returns the fitted ladder sorted
1145/// best-first.
1146pub fn fit_union_ladder(
1147    data: ArrayView2<'_, f64>,
1148    config: GaussianMixtureConfig,
1149) -> Result<Vec<UnionStructureFit>, String> {
1150    let mut fits = Vec::new();
1151    let mut errors = Vec::new();
1152    for &structure in UNION_STRUCTURE_LADDER {
1153        match fit_union_structure(data, structure, config) {
1154            Ok(fit) => fits.push(fit),
1155            Err(e) => errors.push(format!("{}: {e}", structure.as_str())),
1156        }
1157    }
1158    if fits.is_empty() {
1159        return Err(format!(
1160            "union ladder produced no fittable composites{}",
1161            if errors.is_empty() {
1162                String::new()
1163            } else {
1164                format!(" ({})", errors.join("; "))
1165            }
1166        ));
1167    }
1168    let ranked = rank_priority_candidates(
1169        fits.into_iter()
1170            .enumerate()
1171            .map(|(idx, row)| {
1172                let score = row.negative_log_evidence;
1173                let tie = row.total_parameters; // cheaper composite wins ties
1174                PriorityCandidate::new(row, idx, score, tie)
1175            })
1176            .collect(),
1177    )
1178    .into_iter()
1179    .map(|row| row.item)
1180    .collect::<Vec<_>>();
1181    Ok(ranked)
1182}
1183
1184fn gather_union_rows(data: ArrayView2<'_, f64>, idx: &[usize]) -> Array2<f64> {
1185    let d = data.ncols();
1186    let mut out = Array2::<f64>::zeros((idx.len(), d));
1187    for (r, &i) in idx.iter().enumerate() {
1188        for c in 0..d {
1189            out[[r, c]] = data[[i, c]];
1190        }
1191    }
1192    out
1193}
1194
1195/// Fit a single union component density on its responsibility group and return
1196/// `(rank_aware_negative_log_evidence, free_parameter_count)`. `Line` and
1197/// `PointCluster` use the full-covariance Gaussian density (a single mixture
1198/// component); `Circle` uses the radius/angle generative density below.
1199fn fit_union_component(
1200    group: ArrayView2<'_, f64>,
1201    kind: UnionComponentKind,
1202    config: GaussianMixtureConfig,
1203) -> Result<(f64, usize), String> {
1204    match kind {
1205        UnionComponentKind::Line | UnionComponentKind::PointCluster => {
1206            // A single full-covariance Gaussian is the k=1 mixture: reuse its
1207            // exact rank-aware Laplace evidence so a union component is on the
1208            // identical scale as a mixture component.
1209            if group.nrows() < group.ncols() + 1 {
1210                return Err(format!(
1211                    "union gaussian component needs >= {} rows, got {}",
1212                    group.ncols() + 1,
1213                    group.nrows()
1214                ));
1215            }
1216            let fit = fit_gaussian_mixture(group, 1, config)?;
1217            let nle = fit.laplace_negative_log_evidence(group)?;
1218            Ok((nle, fit.num_free_parameters()))
1219        }
1220        UnionComponentKind::Circle => fit_circle_component_evidence(group, config),
1221    }
1222}
1223
1224/// Rank-aware Laplace negative-log-evidence of a 2-D *circle* component: data is
1225/// modelled as `(r, θ)` with `r ~ N(ρ, σ_r²)` around a fitted center+radius and
1226/// `θ` uniform on the circle. Free parameters: center `(cx, cy)`, radius `ρ`,
1227/// radial variance `σ_r²` — `P = 4`. The angle is an ancillary uniform with no
1228/// free parameter (it carries `−log(2π r)` of density). The Hessian is the
1229/// observed empirical-Fisher `I + Σ s_i s_iᵀ` in `(cx, cy, ρ, log σ_r²)`
1230/// coordinates, fed through the SAME [`laplace_evidence`] entry point.
1231fn fit_circle_component_evidence(
1232    group: ArrayView2<'_, f64>,
1233    config: GaussianMixtureConfig,
1234) -> Result<(f64, usize), String> {
1235    let d = group.ncols();
1236    if d != 2 {
1237        return Err(format!(
1238            "union circle component requires 2-D data, got {d} columns"
1239        ));
1240    }
1241    let n = group.nrows();
1242    let p = 4usize; // cx, cy, radius, radial-variance
1243    if n < p + 1 {
1244        return Err(format!(
1245            "union circle component needs >= {} rows, got {n}",
1246            p + 1
1247        ));
1248    }
1249    // Center = data centroid; radius = mean distance to centroid; radial
1250    // variance = mean squared radial residual (floored). This is the algebraic
1251    // circle-fit optimum for the isotropic radial-Gaussian model and is a pure
1252    // function of the data.
1253    let mut cx = 0.0_f64;
1254    let mut cy = 0.0_f64;
1255    for i in 0..n {
1256        cx += group[[i, 0]];
1257        cy += group[[i, 1]];
1258    }
1259    cx /= n as f64;
1260    cy /= n as f64;
1261    let mut radii = vec![0.0_f64; n];
1262    let mut radius = 0.0_f64;
1263    for i in 0..n {
1264        let dx = group[[i, 0]] - cx;
1265        let dy = group[[i, 1]] - cy;
1266        let r = (dx * dx + dy * dy).sqrt();
1267        radii[i] = r;
1268        radius += r;
1269    }
1270    radius /= n as f64;
1271    let mut var_r = 0.0_f64;
1272    for &r in &radii {
1273        let e = r - radius;
1274        var_r += e * e;
1275    }
1276    var_r = (var_r / n as f64).max(config.covariance_floor);
1277    let inv_var = 1.0 / var_r;
1278    // Total log-likelihood: Σ_i [ −½ log(2π σ_r²) − (r_i−ρ)²/(2σ_r²)
1279    //                             − log(2π r_i) ]  (radial Gaussian × uniform θ).
1280    let mut loglik = 0.0_f64;
1281    let log_2pi = (2.0 * std::f64::consts::PI).ln();
1282    for &r in &radii {
1283        let e = r - radius;
1284        let radial = -0.5 * (log_2pi + var_r.ln()) - 0.5 * e * e * inv_var;
1285        let angular = -(log_2pi + r.max(f64::MIN_POSITIVE).ln());
1286        loglik += radial + angular;
1287    }
1288    // Observed empirical-Fisher in (cx, cy, ρ, s) with s = log σ_r².
1289    // Per-row scores:
1290    //   ∂/∂cx log = (e/σ_r²) · (−dx/r)            (r decreases as center moves +x toward point)
1291    //   ∂/∂cy log = (e/σ_r²) · (−dy/r)
1292    //   ∂/∂ρ  log = e/σ_r²
1293    //   ∂/∂s  log = −½ + e²/(2σ_r²)               (s = log σ_r²)
1294    let mut info = Array2::<f64>::zeros((p, p));
1295    let mut score = [0.0_f64; 4];
1296    for i in 0..n {
1297        let dx = group[[i, 0]] - cx;
1298        let dy = group[[i, 1]] - cy;
1299        let r = radii[i].max(f64::MIN_POSITIVE);
1300        let e = radii[i] - radius;
1301        let ee = e * inv_var;
1302        score[0] = ee * (-dx / r);
1303        score[1] = ee * (-dy / r);
1304        score[2] = ee;
1305        score[3] = -0.5 + 0.5 * e * e * inv_var;
1306        for a in 0..p {
1307            let sa = score[a];
1308            if sa == 0.0 {
1309                continue;
1310            }
1311            for b in 0..p {
1312                info[[a, b]] += sa * score[b];
1313            }
1314        }
1315    }
1316    // Symmetrize and add the unit-information prior ridge `I` (same fixed prior
1317    // as the mixture path) so `log|H|` is well-defined for any `n`.
1318    for a in 0..p {
1319        for b in (a + 1)..p {
1320            let avg = 0.5 * (info[[a, b]] + info[[b, a]]);
1321            info[[a, b]] = avg;
1322            info[[b, a]] = avg;
1323        }
1324        info[[a, a]] += 1.0;
1325    }
1326    let apply_info = |x: &[f64]| -> Vec<f64> {
1327        let mut out = vec![0.0_f64; p];
1328        for r in 0..p {
1329            let mut acc = 0.0_f64;
1330            for c in 0..p {
1331                acc += info[[r, c]] * x[c];
1332            }
1333            out[r] = acc;
1334        }
1335        out
1336    };
1337    let hvp = EvidenceHvpLogDet {
1338        dim: p,
1339        apply: &apply_info,
1340    };
1341    let v = laplace_evidence(EvidenceLogDetSource::Hvp(hvp), 0.0, -loglik, p as f64, 0.0);
1342    if !v.is_finite() {
1343        return Err("union circle component Laplace evidence is not finite".to_string());
1344    }
1345    Ok((v, p))
1346}
1347
1348/// A fitted union component as a *predictive density* (not just an evidence
1349/// scalar): either a full-covariance Gaussian (`Line`/`PointCluster`) or the
1350/// radial-Gaussian×uniform-angle circle density. Carries the mixing weight
1351/// `π_c = row_count_c / n_train` so a union can be evaluated as the soft mixture
1352/// `Σ_c π_c p_c(y)` at held-out rows for cross-class stacking.
1353#[derive(Debug, Clone)]
1354enum UnionComponentDensity {
1355    Gaussian {
1356        log_weight: f64,
1357        eval: GaussianComponentEval,
1358    },
1359    Circle {
1360        log_weight: f64,
1361        center: [f64; 2],
1362        radius: f64,
1363        var_r: f64,
1364    },
1365}
1366
1367impl UnionComponentDensity {
1368    /// `log π_c + log p_c(y)` for one eval row.
1369    fn weighted_log_density(&self, y: ArrayView1<'_, f64>) -> f64 {
1370        match self {
1371            UnionComponentDensity::Gaussian { log_weight, eval } => {
1372                log_weight + eval.log_density(y)
1373            }
1374            UnionComponentDensity::Circle {
1375                log_weight,
1376                center,
1377                radius,
1378                var_r,
1379            } => {
1380                let dx = y[0] - center[0];
1381                let dy = y[1] - center[1];
1382                let r = (dx * dx + dy * dy).sqrt();
1383                let log_2pi = (2.0 * std::f64::consts::PI).ln();
1384                let e = r - radius;
1385                let radial = -0.5 * (log_2pi + var_r.ln()) - 0.5 * e * e / var_r;
1386                let angular = -(log_2pi + r.max(f64::MIN_POSITIVE).ln());
1387                log_weight + radial + angular
1388            }
1389        }
1390    }
1391}
1392
1393/// Fit each union component's *density* on the training rows (hard
1394/// responsibility split) so the composite can be evaluated as the soft mixture
1395/// `Σ_c π_c p_c(y)` at new rows. Mixing weights are the training row shares.
1396fn fit_union_component_densities(
1397    train: ArrayView2<'_, f64>,
1398    structure: UnionStructure,
1399    config: GaussianMixtureConfig,
1400) -> Result<Vec<UnionComponentDensity>, String> {
1401    let comps = structure.components();
1402    let m = comps.len();
1403    let groups = union_responsibility_split(train, m, config)?;
1404    let n_train = train.nrows().max(1) as f64;
1405    let mut out = Vec::with_capacity(m);
1406    for (kind, rows) in comps.iter().zip(groups.iter()) {
1407        if rows.is_empty() {
1408            return Err(format!(
1409                "union {} held-out density: empty component group",
1410                structure.as_str()
1411            ));
1412        }
1413        let log_weight = (rows.len() as f64 / n_train).max(f64::MIN_POSITIVE).ln();
1414        let group = gather_union_rows(train, rows);
1415        match kind {
1416            UnionComponentKind::Line | UnionComponentKind::PointCluster => {
1417                if group.nrows() < group.ncols() + 1 {
1418                    return Err(format!(
1419                        "union gaussian component density needs >= {} rows, got {}",
1420                        group.ncols() + 1,
1421                        group.nrows()
1422                    ));
1423                }
1424                let fit = fit_gaussian_mixture(group.view(), 1, config)?;
1425                let eval = GaussianComponentEval::factor(fit.means.row(0), &fit.covariances[0])?;
1426                out.push(UnionComponentDensity::Gaussian { log_weight, eval });
1427            }
1428            UnionComponentKind::Circle => {
1429                let d = group.ncols();
1430                if d != 2 {
1431                    return Err(format!(
1432                        "union circle component density requires 2-D data, got {d} columns"
1433                    ));
1434                }
1435                let n = group.nrows();
1436                if n < 5 {
1437                    return Err(format!(
1438                        "union circle component density needs >= 5 rows, got {n}"
1439                    ));
1440                }
1441                let mut cx = 0.0_f64;
1442                let mut cy = 0.0_f64;
1443                for i in 0..n {
1444                    cx += group[[i, 0]];
1445                    cy += group[[i, 1]];
1446                }
1447                cx /= n as f64;
1448                cy /= n as f64;
1449                let mut radius = 0.0_f64;
1450                let mut radii = vec![0.0_f64; n];
1451                for i in 0..n {
1452                    let dx = group[[i, 0]] - cx;
1453                    let dy = group[[i, 1]] - cy;
1454                    let r = (dx * dx + dy * dy).sqrt();
1455                    radii[i] = r;
1456                    radius += r;
1457                }
1458                radius /= n as f64;
1459                let mut var_r = 0.0_f64;
1460                for &r in &radii {
1461                    let e = r - radius;
1462                    var_r += e * e;
1463                }
1464                var_r = (var_r / n as f64).max(config.covariance_floor);
1465                out.push(UnionComponentDensity::Circle {
1466                    log_weight,
1467                    center: [cx, cy],
1468                    radius,
1469                    var_r,
1470                });
1471            }
1472        }
1473    }
1474    Ok(out)
1475}
1476
1477/// Per-point held-out log predictive density of a structured-union candidate:
1478/// fit the component densities on `train` and score each row of `eval` as the
1479/// soft mixture `log Σ_c π_c p_c(y)`. This is the cross-class stacking column
1480/// source for a union (the analogue of [`GaussianMixtureFit::per_point_log_density`]).
1481pub fn union_per_point_log_density(
1482    train: ArrayView2<'_, f64>,
1483    eval: ArrayView2<'_, f64>,
1484    structure: UnionStructure,
1485    config: GaussianMixtureConfig,
1486) -> Result<Array1<f64>, String> {
1487    if train.ncols() != eval.ncols() {
1488        return Err(format!(
1489            "union held-out density: train has {} columns, eval has {}",
1490            train.ncols(),
1491            eval.ncols()
1492        ));
1493    }
1494    let densities = fit_union_component_densities(train, structure, config)?;
1495    let mut out = Array1::<f64>::zeros(eval.nrows());
1496    let mut terms = vec![f64::NEG_INFINITY; densities.len()];
1497    for i in 0..eval.nrows() {
1498        let row = eval.row(i);
1499        let mut max_term = f64::NEG_INFINITY;
1500        for (c, dens) in densities.iter().enumerate() {
1501            let lt = dens.weighted_log_density(row);
1502            terms[c] = lt;
1503            if lt > max_term {
1504                max_term = lt;
1505            }
1506        }
1507        out[i] = log_sum_exp(&terms, max_term);
1508    }
1509    Ok(out)
1510}
1511
1512/// One fitted model in a REML/LAML evidence comparison.
1513#[derive(Clone, Debug)]
1514pub struct RemlCandidate {
1515    pub index: usize,
1516    pub name: String,
1517    /// Minimised REML/LAML cost. Lower is better. This is the model's reported
1518    /// evidence headline (`Model.evidence`), kept verbatim in the score table.
1519    pub score: f64,
1520    pub edf: Option<f64>,
1521    /// Log-likelihood at the converged mode, on the engine's
1522    /// constants-omitted scale (same as `gam_inference::model_comparison`).
1523    /// Present when the fit carries it; `None` for legacy payloads.
1524    pub log_lik: Option<f64>,
1525    /// Response-family tag (e.g. "gaussian", "gamma", "binomial"). Carried so
1526    /// `compare_reml_fits` can REFUSE to rank fits whose REML/LAML scores are on
1527    /// incomparable base measures (a cross-family comparison is meaningless;
1528    /// #1384). `None` for legacy payloads that did not record it — those are not
1529    /// guarded (back-compatible), but every current FFI candidate carries it.
1530    pub family: Option<String>,
1531}
1532
1533impl RemlCandidate {
1534    /// Cost used to RANK candidates and pick the winner.
1535    ///
1536    /// The REML/LAML marginal-likelihood evidence headline (`score`) does NOT
1537    /// reliably Occam-penalise an added pure-noise smooth: on `y ~ s(x)` vs
1538    /// `y ~ s(x) + s(z)` with `z ⟂ y`, the augmented model's evidence is
1539    /// *lower* (apparently better) by a few nats on essentially every dataset,
1540    /// because the Gaussian REML Occam pair `½(log|H| − log|S|₊)` collapses
1541    /// toward zero for a finite-`λ̂` null term while that term still spends a
1542    /// few effective degrees of freedom fitting noise (issue #1362).
1543    ///
1544    /// The conditional AIC `−2ℓ + 2·edf` prices exactly those spent degrees of
1545    /// freedom and discriminates correctly: it penalises the noise smooth
1546    /// (Δ ≈ +15 nats) yet rewards a genuinely relevant smooth (Δ ≈ −650),
1547    /// preserving power. We therefore rank on the conditional AIC whenever both
1548    /// the log-likelihood and the effective degrees of freedom are available,
1549    /// and fall back to the raw evidence headline otherwise. The reported
1550    /// `score_table` still carries the unaltered evidence (`reml_score`), so
1551    /// `Model.evidence` / `bayes_factor_vs` stay consistent with the table.
1552    pub fn ranking_score(&self) -> f64 {
1553        match (self.log_lik, self.edf) {
1554            (Some(log_lik), Some(edf)) if log_lik.is_finite() && edf.is_finite() => {
1555                -2.0 * log_lik + 2.0 * edf
1556            }
1557            _ => self.score,
1558        }
1559    }
1560}
1561
1562#[derive(Clone, Debug)]
1563pub struct RemlComparison {
1564    pub ranking: Vec<RankedRow>,
1565    pub winner: String,
1566    pub evidence_summary: String,
1567    pub score_table: Vec<ScoreRow>,
1568}
1569
1570#[derive(Clone, Debug)]
1571pub struct RankedRow {
1572    pub name: String,
1573    pub score: f64,
1574    /// Cost gap from the winning model on the SAME scale used to order the
1575    /// ranking (`ranking_score`, the Occam-penalised conditional AIC where
1576    /// available, issue #1362). The winner is `argmin ranking_score`, so this
1577    /// is `>= 0` for every row by construction — it never contradicts the
1578    /// declared winner (issue #1465). `score` still carries the raw REML/LAML
1579    /// evidence so it stays consistent with `Model.evidence`.
1580    pub delta: f64,
1581    /// Bayes factor of the winner over this row on the ranking scale,
1582    /// `exp(delta) >= 1` (issue #1465).
1583    pub bayes_factor: f64,
1584    pub edf: Option<f64>,
1585}
1586
1587#[derive(Clone, Debug)]
1588pub struct ScoreRow {
1589    pub name: String,
1590    pub reml_score: f64,
1591    pub delta_reml: f64,
1592    pub bayes_factor_best_over_model: f64,
1593    pub effective_dof: Option<f64>,
1594}
1595
1596/// Log Bayes factor of model `a` over model `b` from minimised REML/LAML costs.
1597#[inline]
1598pub fn log_bayes_factor(reml_score_a: f64, reml_score_b: f64) -> f64 {
1599    reml_score_b - reml_score_a
1600}
1601
1602/// Compare fitted models by the single evidence ordering contract used by
1603/// topology ranking and seed screening: lower finite cost wins, with stable
1604/// original-order tie handling.
1605pub fn compare_reml_fits(mut candidates: Vec<RemlCandidate>) -> Result<RemlComparison, String> {
1606    if candidates.is_empty() {
1607        return Err("compare_models requires at least one fit".to_string());
1608    }
1609    // Fail-loud comparability guard (#1384): REML/LAML evidence scores are only
1610    // comparable across fits of the SAME response family — a Gaussian score and
1611    // a Gamma score live on different log-density base measures, so their
1612    // difference is not a Bayes factor. Ranking them anyway returns a confident
1613    // but meaningless winner. Refuse when two candidates carry DIFFERENT family
1614    // tags. Candidates with no family tag (`None`, legacy payloads) are not
1615    // constrained, so this never spuriously rejects an older saved model.
1616    {
1617        let mut seen_family: Option<&str> = None;
1618        for cand in &candidates {
1619            if let Some(fam) = cand.family.as_deref() {
1620                match seen_family {
1621                    None => seen_family = Some(fam),
1622                    Some(prev) if prev != fam => {
1623                        return Err(format!(
1624                            "compare_models: cannot compare fits of different response families                              ('{prev}' vs '{fam}'); their REML/LAML evidence scores are on                              incomparable base measures. Compare models fit to the same response                              under the same family."
1625                        ));
1626                    }
1627                    Some(_) => {}
1628                }
1629            }
1630        }
1631    }
1632    candidates = rank_priority_candidates(
1633        candidates
1634            .into_iter()
1635            .enumerate()
1636            .map(|(idx, row)| {
1637                // Rank/winner on the Occam-penalised conditional AIC where it is
1638                // available (issue #1362); falls back to the raw evidence score.
1639                let ranking = row.ranking_score();
1640                PriorityCandidate::new(row, idx, ranking, 0)
1641            })
1642            .collect(),
1643    )
1644    .into_iter()
1645    .map(|row| row.item)
1646    .collect();
1647
1648    let winner = candidates[0].name.clone();
1649    // The ranking `delta` / `bayes_factor` must be measured on the SAME scale
1650    // that orders the table — the `ranking_score` (Occam-penalised conditional
1651    // AIC where available, issue #1362). `candidates[0]` is the winner =
1652    // `argmin ranking_score`, so its ranking score IS the minimum; every row's
1653    // ranking-scale gap is then `>= 0` and its Bayes factor `>= 1`, never
1654    // contradicting the declared winner (issue #1465). Computing these against
1655    // the AIC winner's *raw REML* — which is not the minimum raw REML once AIC
1656    // and REML disagree — produced negative deltas and Bayes factors < 1 for
1657    // non-winner rows.
1658    let best_ranking_score = candidates[0].ranking_score();
1659    // The raw-REML `score_table` stays on the raw evidence scale (consistent
1660    // with `Model.evidence` / `bayes_factor_vs`), but is referenced to the
1661    // genuine minimum raw REML so its best-over-model Bayes factors are also
1662    // coherent (`>= 1`), rather than to whichever row happens to sit at index 0.
1663    let best_raw_score = candidates
1664        .iter()
1665        .map(|c| c.score)
1666        .fold(f64::INFINITY, f64::min);
1667    let mut ranking = Vec::with_capacity(candidates.len());
1668    let mut score_table = Vec::with_capacity(candidates.len());
1669    for row in &candidates {
1670        let delta = log_bayes_factor(best_ranking_score, row.ranking_score());
1671        let bayes_factor = delta.exp();
1672        let delta_reml = log_bayes_factor(best_raw_score, row.score);
1673        ranking.push(RankedRow {
1674            name: row.name.clone(),
1675            score: row.score,
1676            delta,
1677            bayes_factor,
1678            edf: row.edf,
1679        });
1680        score_table.push(ScoreRow {
1681            name: row.name.clone(),
1682            reml_score: row.score,
1683            delta_reml,
1684            bayes_factor_best_over_model: delta_reml.exp(),
1685            effective_dof: row.edf,
1686        });
1687    }
1688    // The winner is decided by `ranking_score` (the Occam-penalised conditional
1689    // AIC where available, issue #1362), which can disagree in sign with the raw
1690    // evidence Bayes factor for a noise-augmented model. Summarise the actual
1691    // decision margin so the headline never contradicts the chosen winner.
1692    let evidence_summary = if let Some(runner_up) = candidates.get(1) {
1693        let margin = runner_up.ranking_score() - candidates[0].ranking_score();
1694        format!(
1695            "{} wins by Bayes factor {} over {}",
1696            winner,
1697            format_bayes_factor(margin),
1698            runner_up.name
1699        )
1700    } else {
1701        format!("{winner} (single fit; no comparison)")
1702    };
1703    Ok(RemlComparison {
1704        ranking,
1705        winner,
1706        evidence_summary,
1707        score_table,
1708    })
1709}
1710
1711pub fn format_bayes_factor(log_bf: f64) -> String {
1712    if !log_bf.is_finite() {
1713        return "inf".to_string();
1714    }
1715    if log_bf.abs() >= std::f64::consts::LN_10 * 3.0 {
1716        return format!("1e{:+.1}", log_bf / std::f64::consts::LN_10);
1717    }
1718    format_three_significant(log_bf.exp())
1719}
1720
1721pub fn format_three_significant(value: f64) -> String {
1722    if value == 0.0 {
1723        return "0".to_string();
1724    }
1725    if !value.is_finite() {
1726        return format!("{value}");
1727    }
1728    let exponent = value.abs().log10().floor() as i32;
1729    if exponent >= 3 {
1730        return format!("{value:.2e}");
1731    }
1732    let decimals = (2 - exponent).max(0) as usize;
1733    let scale = 10f64.powi(decimals as i32);
1734    let rounded = (value * scale).abs().round() / scale * value.signum();
1735    format!("{rounded:.decimals$}")
1736}
1737
1738impl Default for TopologySelectOptions {
1739    fn default() -> Self {
1740        Self {
1741            tie_tolerance: 1e-3,
1742            score_scale: TopologyScoreScale::PerObservation,
1743        }
1744    }
1745}
1746
1747// ---------------------------------------------------------------------------
1748// Laplace evidence
1749// ---------------------------------------------------------------------------
1750
1751/// Single canonical Laplace evidence at the inner-loop fixed point.
1752///
1753/// Returns negative log evidence:
1754///
1755/// ```text
1756/// V(ρ, T) = F(β*, u*; ρ, T)
1757///         + 0.5 log|H|
1758///         - 0.5 log|S_pen(ρ)|+
1759///         - 0.5 (dim(H) - rank(S_pen)) log(2π).
1760/// ```
1761///
1762/// The last term is the rank-aware Tierney-Kadane normalizer:
1763/// `log p(y|T) ≈ -V`, with `0.5 log|2πH⁻¹| - 0.5 log|2πS⁻¹|`.
1764///
1765/// The `H` log-determinant is computed from the arrow factorization
1766///
1767/// ```text
1768/// log|H| = Σ_i log|H_uu_i| + log|A|
1769/// ```
1770///
1771/// (proposal §3.4 / §7) using the **undamped** per-row Cholesky factors
1772/// `cache.htt_factors_undamped` and the **undamped** Schur factor.
1773///
1774/// `penalty_log_det` is `log|S_pen(ρ)|+` — the prior penalty
1775/// pseudo-logdet from `crate::reml::penalty_logdet` (proposal
1776/// §3.6). It must NOT be confused with the arrow Schur log-det, which
1777/// this function recomputes internally from `logdet_source`.
1778///
1779/// `residual_objective` is `F(β*, u*; ρ, T)` at the inner optimum. The
1780/// envelope theorem (proposal §3.2) makes this the only `F`-related
1781/// contribution.
1782///
1783/// `effective_dim` is `dim(H)` after constraints/projections and
1784/// `penalty_rank` is `rank(S_pen)`. Their difference is the unpenalized
1785/// nullspace dimension that remains in the Laplace integral.
1786///
1787/// # Errors
1788///
1789/// Returns `f64::NAN` if the exact factor path is incoherent and no HVP
1790/// fallback is supplied, or if the supplied dimensions are non-finite.
1791pub fn laplace_evidence(
1792    logdet_source: EvidenceLogDetSource<'_>,
1793    penalty_log_det: f64,
1794    residual_objective: f64,
1795    effective_dim: f64,
1796    penalty_rank: f64,
1797) -> f64 {
1798    if !(effective_dim.is_finite() && penalty_rank.is_finite()) {
1799        return f64::NAN;
1800    }
1801    let log_det_h = match evidence_hessian_log_det(logdet_source) {
1802        Ok(v) => v,
1803        Err(_) => return f64::NAN,
1804    };
1805    let null_dim = effective_dim - penalty_rank;
1806    if !null_dim.is_finite() || null_dim < -1e-9 {
1807        return f64::NAN;
1808    }
1809    residual_objective + 0.5 * log_det_h
1810        - 0.5 * penalty_log_det
1811        - 0.5 * null_dim.max(0.0) * (2.0 * std::f64::consts::PI).ln()
1812}
1813
1814/// Compute the Hessian logdet from exact arrow factors or an HVP fallback.
1815pub fn evidence_hessian_log_det(source: EvidenceLogDetSource<'_>) -> Result<f64, String> {
1816    match source {
1817        EvidenceLogDetSource::FactoredArrow {
1818            cache,
1819            fallback_hvp,
1820        } => match arrow_log_det_from_cache(cache) {
1821            Some(v) => Ok(v),
1822            None => match fallback_hvp {
1823                Some(hvp) => hessian_log_det_from_hvp(hvp),
1824                None => {
1825                    Err("evidence Hessian logdet requires exact factors or HVP fallback".into())
1826                }
1827            },
1828        },
1829        EvidenceLogDetSource::Hvp(hvp) => hessian_log_det_from_hvp(hvp),
1830    }
1831}
1832
1833/// Log determinant of an SPD operator supplied by HVP callback.
1834///
1835/// The dispatch boundary intentionally matches
1836/// `ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD` in `terms::analytic_penalties`:
1837/// small operators are materialized and diagonalized exactly; larger ones use
1838/// Rademacher stochastic Lanczos quadrature.
1839pub fn hessian_log_det_from_hvp(hvp: EvidenceHvpLogDet<'_>) -> Result<f64, String> {
1840    if hvp.dim == 0 {
1841        return Ok(0.0);
1842    }
1843    if hvp.dim <= ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD {
1844        let mut dense = Array2::<f64>::zeros((hvp.dim, hvp.dim));
1845        let mut basis = vec![0.0_f64; hvp.dim];
1846        for j in 0..hvp.dim {
1847            basis[j] = 1.0;
1848            let col = (hvp.apply)(&basis);
1849            basis[j] = 0.0;
1850            if col.len() != hvp.dim || col.iter().any(|v| !v.is_finite()) {
1851                return Err(format!(
1852                    "evidence HVP logdet expected finite column of length {}, got {}",
1853                    hvp.dim,
1854                    col.len()
1855                ));
1856            }
1857            for i in 0..hvp.dim {
1858                dense[[i, j]] = col[i];
1859            }
1860        }
1861        validate_dense_hvp_symmetry(&dense)?;
1862        for i in 0..hvp.dim {
1863            for j in (i + 1)..hvp.dim {
1864                let avg = 0.5 * (dense[[i, j]] + dense[[j, i]]);
1865                dense[[i, j]] = avg;
1866                dense[[j, i]] = avg;
1867            }
1868        }
1869        dense_spd_log_det(&dense)
1870    } else {
1871        stochastic_hvp_log_det(hvp)
1872    }
1873}
1874
1875fn dense_spd_log_det(matrix: &Array2<f64>) -> Result<f64, String> {
1876    if matrix.nrows() != matrix.ncols() {
1877        return Err(format!(
1878            "evidence dense logdet requires square matrix, got {}x{}",
1879            matrix.nrows(),
1880            matrix.ncols()
1881        ));
1882    }
1883    if gam_gpu::cuda_selected() {
1884        return crate::gpu::reml_gpu::evidence_derivatives_gpu(
1885            crate::gpu::reml_gpu::RemlGpuInput {
1886                penalized_hessian: matrix.view(),
1887                derivative_hessians: Vec::new(),
1888            },
1889        )
1890        .map(|evidence| evidence.logdet_hessian);
1891    }
1892    let (evals, _) = matrix
1893        .eigh(Side::Lower)
1894        .map_err(|e| format!("evidence dense logdet eigendecomposition failed: {e}"))?;
1895    let mut logdet = 0.0_f64;
1896    for (idx, &ev) in evals.iter().enumerate() {
1897        if !ev.is_finite() || ev <= 0.0 {
1898            return Err(format!(
1899                "evidence dense logdet expected SPD Hessian, eigenvalue {idx} is {ev:.3e}"
1900            ));
1901        }
1902        logdet += ev.ln();
1903    }
1904    Ok(logdet)
1905}
1906
1907fn validate_dense_hvp_symmetry(matrix: &Array2<f64>) -> Result<(), String> {
1908    let n = matrix.nrows();
1909    let mut norm_sq = 0.0_f64;
1910    for &value in matrix.iter() {
1911        norm_sq += value * value;
1912    }
1913
1914    let mut skew_sq = 0.0_f64;
1915    for i in 0..n {
1916        for j in (i + 1)..n {
1917            let skew = matrix[[i, j]] - matrix[[j, i]];
1918            skew_sq += 2.0 * skew * skew;
1919        }
1920    }
1921
1922    let rel_skew = skew_sq.sqrt() / norm_sq.sqrt().max(1.0);
1923    if !rel_skew.is_finite() || rel_skew > EVIDENCE_HVP_SYMMETRY_REL_TOL {
1924        return Err(format!(
1925            "evidence HVP logdet requires symmetric operator, relative skew norm is {rel_skew:.3e}"
1926        ));
1927    }
1928    Ok(())
1929}
1930
1931fn validate_hvp_randomized_symmetry(hvp: EvidenceHvpLogDet<'_>) -> Result<(), String> {
1932    let inv_norm = 1.0 / (hvp.dim as f64).sqrt();
1933    for probe in 0..EVIDENCE_HVP_SYMMETRY_PROBES.max(1) {
1934        let mut x = vec![0.0_f64; hvp.dim];
1935        let mut y = vec![0.0_f64; hvp.dim];
1936        rademacher_unit_probe_into_slice(&mut x, (2 * probe) as u64, inv_norm);
1937        rademacher_unit_probe_into_slice(&mut y, (2 * probe + 1) as u64, inv_norm);
1938
1939        let hx = (hvp.apply)(&x);
1940        let hy = (hvp.apply)(&y);
1941        if hx.len() != hvp.dim || hx.iter().any(|v| !v.is_finite()) {
1942            return Err(format!(
1943                "evidence HVP symmetry check expected finite vector of length {}, got {}",
1944                hvp.dim,
1945                hx.len()
1946            ));
1947        }
1948        if hy.len() != hvp.dim || hy.iter().any(|v| !v.is_finite()) {
1949            return Err(format!(
1950                "evidence HVP symmetry check expected finite vector of length {}, got {}",
1951                hvp.dim,
1952                hy.len()
1953            ));
1954        }
1955
1956        let lhs = dot_slice(&x, &hy);
1957        let rhs = dot_slice(&hx, &y);
1958        let scale = (norm2_slice(&hx) * norm2_slice(&y))
1959            .max(norm2_slice(&hy) * norm2_slice(&x))
1960            .max(lhs.abs())
1961            .max(rhs.abs())
1962            .max(1.0);
1963        let rel = (lhs - rhs).abs() / scale;
1964        if !rel.is_finite() || rel > EVIDENCE_HVP_SYMMETRY_REL_TOL {
1965            return Err(format!(
1966                "evidence HVP logdet requires symmetric operator, randomized symmetry probe {probe} has relative bilinear mismatch {rel:.3e}"
1967            ));
1968        }
1969    }
1970    Ok(())
1971}
1972
1973fn stochastic_hvp_log_det(hvp: EvidenceHvpLogDet<'_>) -> Result<f64, String> {
1974    validate_hvp_randomized_symmetry(hvp)?;
1975    let probes = EVIDENCE_LOGDET_SLQ_PROBES.max(1);
1976    let steps = EVIDENCE_LOGDET_LANCZOS_STEPS.min(hvp.dim).max(1);
1977    let inv_norm = 1.0 / (hvp.dim as f64).sqrt();
1978    let mut estimate = 0.0_f64;
1979    for probe in 0..probes {
1980        let mut q0 = vec![0.0_f64; hvp.dim];
1981        rademacher_unit_probe_into_slice(&mut q0, probe as u64, inv_norm);
1982        let quad = lanczos_log_quadrature_hvp(hvp, q0, steps)?;
1983        estimate += hvp.dim as f64 * quad;
1984    }
1985    Ok(estimate / probes as f64)
1986}
1987
1988fn lanczos_log_quadrature_hvp(
1989    hvp: EvidenceHvpLogDet<'_>,
1990    q: Vec<f64>,
1991    max_steps: usize,
1992) -> Result<f64, String> {
1993    let n = hvp.dim;
1994    let eigen = symmetric_lanczos_eigenpairs(
1995        n,
1996        &q,
1997        SymmetricLanczosOptions {
1998            max_steps,
1999            residual_tol: 1e-12,
2000            local_reorthogonalize: false,
2001            full_reorthogonalize: false,
2002        },
2003        |q, out| {
2004            let applied = (hvp.apply)(q);
2005            if applied.len() != n || applied.iter().any(|v| !v.is_finite()) {
2006                return Err(format!(
2007                    "evidence HVP SLQ expected finite vector of length {n}, got {}",
2008                    applied.len()
2009                ));
2010            }
2011            out.copy_from_slice(&applied);
2012            Ok(())
2013        },
2014    )
2015    .map_err(|e| format!("evidence HVP SLQ Lanczos failed: {e}"))?;
2016    symmetric_lanczos_log_quadrature(&eigen, "evidence HVP SLQ expected SPD Hessian")
2017}
2018
2019#[inline]
2020fn dot_slice(a: &[f64], b: &[f64]) -> f64 {
2021    assert_eq!(a.len(), b.len());
2022    let mut s = 0.0_f64;
2023    for i in 0..a.len() {
2024        s += a[i] * b[i];
2025    }
2026    s
2027}
2028
2029#[inline]
2030fn norm2_slice(a: &[f64]) -> f64 {
2031    dot_slice(a, a).sqrt()
2032}
2033
2034fn rademacher_unit_probe_into_slice(z: &mut [f64], probe: u64, scale: f64) {
2035    let mut state = 0x6A09E667F3BCC909_u64 ^ probe.wrapping_mul(0xD1B54A32D192ED03);
2036    let mut bits = 0_u64;
2037    let mut remaining_bits = 0_u32;
2038    for value in z.iter_mut() {
2039        if remaining_bits == 0 {
2040            bits = splitmix64(&mut state);
2041            remaining_bits = 64;
2042        }
2043        *value = if bits & 1 == 0 { scale } else { -scale };
2044        bits >>= 1;
2045        remaining_bits -= 1;
2046    }
2047}
2048
2049#[inline]
2050const fn splitmix64(state: &mut u64) -> u64 {
2051    gam_linalg::utils::splitmix64(state)
2052}
2053
2054/// Sum of per-row arrow log-determinants plus the Schur log-det.
2055///
2056/// `log|H| = Σ_i log|H_uu_i| + log|A|` using the undamped Cholesky
2057/// factors of `H_uu_i` and the cached Schur Cholesky factor.
2058///
2059/// Returns `None` if `cache.schur_factor` is absent (InexactPCG path) or
2060/// if a damped/incoherent cache is supplied. [`evidence_hessian_log_det`]
2061/// routes such matrix-free cases to an explicit HVP fallback.
2062pub fn arrow_log_det_from_cache(cache: &ArrowFactorCache) -> Option<f64> {
2063    if cache.ridge_t != 0.0 || cache.ridge_beta != 0.0 {
2064        // Per proposal §6.4 / §6.5 — evidence must use the undamped
2065        // operator. The cache's Schur factor here was assembled under
2066        // ridge damping, which is a different operator. Reject loudly.
2067        return None;
2068    }
2069    if let Some(log_det) = cache.joint_hessian_log_det {
2070        return log_det.is_finite().then_some(log_det);
2071    }
2072    // A `k == 0` cache has no shared β block, so the dense Direct path forms no
2073    // reduced Schur complement and `schur_factor` is legitimately `None` (the
2074    // joint Hessian is block-diagonal in the latent rows). Its log-det is the
2075    // per-row sum with no Schur term. Only reject when `k > 0` and the factor
2076    // is absent — the InexactPCG case that never built the dense `K×K` factor.
2077    // (#1132 euclidean K=4: a β-profiled atom reaches here with `k == 0`.)
2078    let schur = match cache.schur_factor.as_ref() {
2079        Some(schur) => Some(schur),
2080        None if cache.k == 0 => None,
2081        None => return None,
2082    };
2083
2084    let mut acc = 0.0_f64;
2085    // Per-row arrow blocks: log|H_uu_i| = 2 Σ log diag(L_i).
2086    for l in cache.undamped_factors_iter() {
2087        acc += 2.0 * log_det_from_chol_lower(l);
2088    }
2089    // Schur block: log|A| = 2 Σ log diag(L_schur). Empty for the `k == 0` case.
2090    if let Some(schur) = schur {
2091        acc += 2.0 * log_det_from_chol_lower(schur.view());
2092    }
2093    // #1038 cross-row IBP: when the cache carries an exact rank-`R` Woodbury,
2094    // the per-row + Schur factors above are of the NO-SELF base `H₀'`, so the
2095    // exact `log det H_full = log det H₀' + log det(I_R + D Uᵀ H₀'⁻¹ U)`. The
2096    // correction is zero (no-op) for every non-IBP cache.
2097    let woodbury_correction = cache.cross_row_woodbury_log_det();
2098    if !woodbury_correction.is_finite() {
2099        // A non-PD capacitance (negative determinant) is a value↔gradient
2100        // desync the evidence must reject loudly, not paper over.
2101        return None;
2102    }
2103    acc += woodbury_correction;
2104    Some(acc)
2105}
2106
2107/// Twice-the-diagonal-log sum for a lower-triangular Cholesky factor.
2108fn log_det_from_chol_lower(l: ArrayView2<'_, f64>) -> f64 {
2109    let n = l.nrows();
2110    let mut acc = 0.0_f64;
2111    for i in 0..n {
2112        let d = l[[i, i]];
2113        if d > 0.0 {
2114            acc += d.ln();
2115        } else {
2116            // SAFETY: a valid lower-triangular Cholesky factor has a strictly
2117            // positive diagonal by construction. A non-positive diagonal means
2118            // the caller passed a corrupted / non-SPD factor — surface it loudly
2119            // rather than papering over with a corrupting NaN that silently
2120            // poisons the evidence log-det (callers do not check is_nan).
2121            panic!(
2122                "log_det_from_chol_lower: non-positive Cholesky diagonal {d} at index {i}; \
2123                 caller passed a corrupted or non-SPD factor"
2124            );
2125        }
2126    }
2127    acc
2128}
2129
2130// ---------------------------------------------------------------------------
2131// IFT cascade: ∂u*/∂β → ∂β*/∂ρ → ∂u*/∂ρ
2132// ---------------------------------------------------------------------------
2133
2134/// Tier-1 IFT sensitivity `∂u_i*/∂β = -H_uu_i⁻¹ H_uβ_i`.
2135///
2136/// Concatenated row-major to a single `(N·d) × K` dense matrix. Each
2137/// row block is solved with the **undamped** Cholesky factor. Proposal
2138/// §2.2 / §7.
2139pub fn ift_du_dbeta(cache: &ArrowFactorCache) -> Array2<f64> {
2140    let n = cache.undamped_factor_count();
2141    let total_len = cache.delta_t_len();
2142    let k = cache.k;
2143    if !cache.htbeta_available() {
2144        return Array2::<f64>::from_elem((total_len, k), f64::NAN);
2145    }
2146    let mut out = Array2::<f64>::zeros((total_len, k));
2147    let mut beta_basis = Array1::<f64>::zeros(k);
2148    // Allocate scratch at max_d; per-row slice is ..di.
2149    let mut rhs = Array1::<f64>::zeros(cache.d);
2150    for i in 0..n {
2151        let di = cache.row_dims[i];
2152        let row_base = cache.row_offsets[i];
2153        let factor = cache.undamped_factor(i);
2154        // Solve H_uu_i Y = H_uβ_i column by column.
2155        for col in 0..k {
2156            beta_basis.fill(0.0);
2157            beta_basis[col] = 1.0;
2158            let mut rhs_i = rhs.slice_mut(ndarray::s![..di]).to_owned();
2159            // The Tier-2 IFT assembler is built only when the family's
2160            // capability surface promises cached `H_tβ` row products.
2161            if !cache.apply_htbeta_row(i, beta_basis.view(), &mut rhs_i) {
2162                // SAFETY: reaching `false` means a family declared the cache
2163                // available but failed to populate it — contract violation.
2164                return Array2::<f64>::from_elem((total_len, k), f64::NAN);
2165            }
2166            let y = cholesky_solve_vector(factor, &rhs_i);
2167            for c in 0..di {
2168                out[[row_base + c, col]] = -y[c];
2169            }
2170        }
2171    }
2172    out
2173}
2174
2175/// Coupling components of a symmetric coefficient Hessian: the connected
2176/// components of the graph whose vertices are coefficient indices `0..p` and
2177/// whose edges are the structurally nonzero off-diagonal entries of `H` (#779).
2178///
2179/// Returns a length-`p` vector of component labels in `0..num_components`,
2180/// where two indices share a label iff they are connected through a chain of
2181/// nonzero `H[i,j]` couplings. This is the exact structural partition the
2182/// cone-of-influence sensitivity reuse is keyed on: a smoothing-parameter move
2183/// whose stationarity-gradient derivative `∂g/∂ρ` is supported only inside one
2184/// component can change `β = -H⁻¹ ∂g/∂ρ` only inside that same component, so
2185/// the sensitivity of every *other* component is provably unchanged and may be
2186/// reused unrecomputed (lazy/local propagation).
2187///
2188/// The nonzero test is exact (`!= 0.0`), matching the structural-coupling gate
2189/// used elsewhere for the joint inner Hessian: a tolerance would risk dropping a
2190/// genuine (small) coupling edge and silently biasing the propagated sensitivity
2191/// — the failure mode #779/#740 explicitly guard against. A block-diagonal `H`
2192/// yields the all-singletons partition (one component per block-decoupled
2193/// coordinate); a fully coupled `H` yields a single component (no shortcut, the
2194/// full joint solve is required — and is what the non-coned path performs).
2195pub fn coupling_components(hessian: ArrayView2<'_, f64>) -> Vec<usize> {
2196    let p = hessian.nrows();
2197    if p == 0 || hessian.ncols() != p {
2198        return Vec::new();
2199    }
2200    // Union-find with path compression and union by size.
2201    let mut parent: Vec<usize> = (0..p).collect();
2202    let mut size: Vec<usize> = vec![1; p];
2203
2204    fn find(parent: &mut [usize], mut x: usize) -> usize {
2205        while parent[x] != x {
2206            parent[x] = parent[parent[x]];
2207            x = parent[x];
2208        }
2209        x
2210    }
2211
2212    for i in 0..p {
2213        for j in (i + 1)..p {
2214            // Symmetric structure: an edge exists if either triangle is nonzero,
2215            // so a numerically one-sided fill still couples the two indices.
2216            if hessian[[i, j]] != 0.0 || hessian[[j, i]] != 0.0 {
2217                let (ri, rj) = (find(&mut parent, i), find(&mut parent, j));
2218                if ri != rj {
2219                    let (small, large) = if size[ri] < size[rj] {
2220                        (ri, rj)
2221                    } else {
2222                        (rj, ri)
2223                    };
2224                    parent[small] = large;
2225                    size[large] += size[small];
2226                }
2227            }
2228        }
2229    }
2230
2231    // Relabel roots to a dense `0..num_components` range, preserving
2232    // first-seen order so labels are deterministic.
2233    let mut label_of_root: Vec<Option<usize>> = vec![None; p];
2234    let mut next_label = 0usize;
2235    let mut labels = vec![0usize; p];
2236    for idx in 0..p {
2237        let root = find(&mut parent, idx);
2238        let label = match label_of_root[root] {
2239            Some(l) => l,
2240            None => {
2241                let l = next_label;
2242                label_of_root[root] = Some(l);
2243                next_label += 1;
2244                l
2245            }
2246        };
2247        labels[idx] = label;
2248    }
2249    labels
2250}
2251
2252/// The cone of influence of a single stationarity-gradient derivative column
2253/// whose support (the coefficient indices where `∂g/∂ρ_k` is nonzero) lies in
2254/// `support`: the set of coefficient indices in the same coupling component(s)
2255/// as that support, given precomputed `labels` from [`coupling_components`].
2256///
2257/// `β_k = -H⁻¹ ∂g/∂ρ_k` is exactly zero outside this cone, so a confined solve
2258/// (or reuse of a cached zero) is exact, not an approximation. An empty support
2259/// (a structurally inactive `ρ_k`, e.g. a rank-0 or out-of-range penalty block)
2260/// yields an empty cone: the sensitivity is identically zero and no solve is
2261/// needed at all.
2262pub fn cone_of_influence(labels: &[usize], support: &[usize]) -> Vec<usize> {
2263    if support.is_empty() {
2264        return Vec::new();
2265    }
2266    let mut in_cone_labels: Vec<usize> = support
2267        .iter()
2268        .filter_map(|&idx| labels.get(idx).copied())
2269        .collect();
2270    in_cone_labels.sort_unstable();
2271    in_cone_labels.dedup();
2272    if in_cone_labels.is_empty() {
2273        return Vec::new();
2274    }
2275    (0..labels.len())
2276        .filter(|idx| in_cone_labels.binary_search(&labels[*idx]).is_ok())
2277        .collect()
2278}
2279
2280/// Tier-2 IFT sensitivity `∂β*/∂ρ = -A⁻¹ ∂g_red/∂ρ` (proposal §2.4 /
2281/// §7).
2282///
2283/// `dg_red_drho` is the `K × R` matrix whose `a`-th column is `q_a =
2284/// ∂g_red/∂ρ_a`. Returns the `K × R` matrix `β_ρ`.
2285///
2286/// Returns `None` if the Schur factor is unavailable (PCG mode) or was
2287/// built from a damped operator, or if any solved entry is non-finite;
2288/// callers must not silently substitute an approximation. The solve is
2289/// the one sensitivity operator (#935) — this site holds no private H⁻¹
2290/// convention of its own.
2291pub fn ift_dbeta_drho(
2292    cache: &ArrowFactorCache,
2293    dg_red_drho: ArrayView2<'_, f64>,
2294) -> Option<Array2<f64>> {
2295    if cache.ridge_t != 0.0 || cache.ridge_beta != 0.0 {
2296        return None;
2297    }
2298    let schur = cache.schur_factor.as_ref()?;
2299    if dg_red_drho.nrows() != cache.k || schur.nrows() != cache.k {
2300        return None;
2301    }
2302    crate::sensitivity::FitSensitivity::from_lower_triangular(schur)
2303        .mode_response(dg_red_drho)
2304}
2305
2306
2307// ---------------------------------------------------------------------------
2308// ∂V/∂ρ — analytic optimized-evidence gradient via IFT mode response
2309// ---------------------------------------------------------------------------
2310
2311/// IFT terms needed to differentiate the optimized Laplace evidence through
2312/// the fitted mode `(β*(ρ), u*(ρ))`.
2313///
2314/// For each hyperparameter `ρ_a`, the correction added to the direct trace is
2315///
2316/// ```text
2317/// F_β · β_a + F_u · u_a
2318/// + 0.5 (∂_β log|H| · β_a + ∂_u log|H| · u_a).
2319/// ```
2320///
2321/// At an exact KKT point the value-gradient pieces are zero, but they are
2322/// explicit here so the exported gradient matches the optimized objective
2323/// whenever callers carry a certified nonzero residual correction.
2324#[derive(Clone)]
2325pub struct EvidenceIftGradientTerms<'a> {
2326    pub dbeta_drho: ArrayView2<'a, f64>,
2327    pub du_drho: ArrayView2<'a, f64>,
2328    pub value_beta: ArrayView1<'a, f64>,
2329    pub value_u: ArrayView1<'a, f64>,
2330    pub logdet_h_beta: ArrayView1<'a, f64>,
2331    pub logdet_h_u: ArrayView1<'a, f64>,
2332}
2333
2334/// Contract the IFT mode-response columns into the optimized-evidence
2335/// gradient correction.
2336pub fn evidence_ift_gradient_correction(terms: EvidenceIftGradientTerms<'_>) -> Array1<f64> {
2337    let k = terms.dbeta_drho.nrows();
2338    let nd = terms.du_drho.nrows();
2339    let r = terms.dbeta_drho.ncols();
2340    if terms.du_drho.ncols() != r
2341        || terms.value_beta.len() != k
2342        || terms.logdet_h_beta.len() != k
2343        || terms.value_u.len() != nd
2344        || terms.logdet_h_u.len() != nd
2345    {
2346        return Array1::<f64>::from_elem(r, f64::NAN);
2347    }
2348
2349    let mut out = Array1::<f64>::zeros(r);
2350    for a in 0..r {
2351        let mut acc = 0.0_f64;
2352        for j in 0..k {
2353            let mode = terms.dbeta_drho[[j, a]];
2354            acc += terms.value_beta[j] * mode;
2355            acc += 0.5 * terms.logdet_h_beta[j] * mode;
2356        }
2357        for j in 0..nd {
2358            let mode = terms.du_drho[[j, a]];
2359            acc += terms.value_u[j] * mode;
2360            acc += 0.5 * terms.logdet_h_u[j] * mode;
2361        }
2362        out[a] = acc;
2363    }
2364    out
2365}
2366
2367/// Per-`ρ` optimized-evidence gradient (proposal §3.7 / §3.8 split):
2368///
2369/// ```text
2370/// ∂V/∂ρ_a =
2371///       F_{ρ_a}                                  (value part)
2372///   + 0.5 tr(H⁻¹ H_{ρ_a})                        (direct Hessian)
2373///   + F_x · x_{ρ_a}
2374///   + 0.5 (∂_x log|H|) · x_{ρ_a}                 (IFT mode response)
2375///   - 0.5 tr(S_pen⁺ S_{pen,ρ_a})                 (penalty pseudo-logdet)
2376/// ```
2377/// where `x = (β, u)`.
2378///
2379/// The `tr(H⁻¹ H_{ρ_a})` trace is computed via the arrow structure
2380/// (proposal §3.5 / §3.10):
2381///
2382/// ```text
2383/// tr(H⁻¹ H_{ρ_a}) = Σ_i tr(H_uu_i⁻¹ ∂_{ρ_a} H_uu_i) + tr(A⁻¹ ∂_{ρ_a} A).
2384/// ```
2385///
2386/// `value_rho[a] = F_{ρ_a}` (envelope theorem, proposal §3.2).
2387/// `huu_drho[i][a]` is `∂H_uu_i/∂ρ_a` as a `d × d` matrix.
2388/// `hbb_drho[a]` is `∂H_ββ/∂ρ_a` as a `K × K` matrix.
2389/// `htbeta_drho[i][a]` is `∂H_uβ_i/∂ρ_a` as a `d × K` matrix.
2390/// `pen_logdet_drho[a]` is `∂_{ρ_a} log|S_pen|+`.
2391/// `ift_terms` carries `∂β*/∂ρ`, `∂u*/∂ρ`, and the already-contracted
2392/// mode derivatives of `F` and `log|H|`.
2393///
2394/// Returns the per-`ρ` gradient. Returns a NaN-filled vector when the
2395/// cache has no undamped Schur factor (PCG mode).
2396pub fn evidence_grad_rho(
2397    cache: &ArrowFactorCache,
2398    value_rho: ArrayView1<'_, f64>,
2399    huu_drho: &[Vec<Array2<f64>>],
2400    htbeta_drho: &[Vec<Array2<f64>>],
2401    hbb_drho: &[Array2<f64>],
2402    pen_logdet_drho: ArrayView1<'_, f64>,
2403    ift_terms: EvidenceIftGradientTerms<'_>,
2404) -> Array1<f64> {
2405    let r = value_rho.len();
2406    let n = cache.undamped_factor_count();
2407    let k = cache.k;
2408    let mut out = Array1::<f64>::zeros(r);
2409    if !cache.htbeta_available()
2410        || pen_logdet_drho.len() != r
2411        || huu_drho.len() != n
2412        || htbeta_drho.len() != n
2413        || hbb_drho.len() != r
2414        || huu_drho.iter().any(|row| row.len() != r)
2415        || htbeta_drho.iter().any(|row| row.len() != r)
2416        || hbb_drho.iter().any(|m| m.nrows() != k || m.ncols() != k)
2417        || huu_drho.iter().enumerate().any(|(i, row)| {
2418            let di = cache.row_dims[i];
2419            row.iter().any(|m| m.nrows() != di || m.ncols() != di)
2420        })
2421        || htbeta_drho.iter().enumerate().any(|(i, row)| {
2422            let di = cache.row_dims[i];
2423            row.iter().any(|m| m.nrows() != di || m.ncols() != k)
2424        })
2425    {
2426        out.fill(f64::NAN);
2427        return out;
2428    }
2429    let ift_correction = evidence_ift_gradient_correction(ift_terms);
2430    if ift_correction.len() != r || ift_correction.iter().any(|v| v.is_nan()) {
2431        out.fill(f64::NAN);
2432        return out;
2433    }
2434
2435    let schur = match cache.schur_factor.as_ref() {
2436        Some(s) => s,
2437        None => {
2438            for a in 0..r {
2439                out[a] = f64::NAN;
2440            }
2441            return out;
2442        }
2443    };
2444
2445    // Precompute Y_i = H_uu_i⁻¹ H_uβ_i (di × K). Used by both the Schur
2446    // derivative formula (§3.5) and the row trace `tr(H_uu_i⁻¹ ∂H_uu_i)`.
2447    let mut y_blocks: Vec<Array2<f64>> = Vec::with_capacity(n);
2448    let mut beta_basis = Array1::<f64>::zeros(k);
2449    // Scratch sized to max_d; per-row slice is ..di.
2450    let mut rhs = Array1::<f64>::zeros(cache.d);
2451    for i in 0..n {
2452        let di = cache.row_dims[i];
2453        let factor = cache.undamped_factor(i);
2454        let mut yi = Array2::<f64>::zeros((di, k));
2455        for col in 0..k {
2456            beta_basis.fill(0.0);
2457            beta_basis[col] = 1.0;
2458            let mut rhs_i = rhs.slice_mut(ndarray::s![..di]).to_owned();
2459            // Same H_tβ cache contract as the IFT du/dβ and du/dρ paths.
2460            if !cache.apply_htbeta_row(i, beta_basis.view(), &mut rhs_i) {
2461                // SAFETY: `false` means the family declared the cache
2462                // available but did not populate it — contract violation.
2463                out.fill(f64::NAN);
2464                return out;
2465            }
2466            let v = cholesky_solve_vector(factor, &rhs_i);
2467            for c in 0..di {
2468                yi[[c, col]] = v[c];
2469            }
2470        }
2471        y_blocks.push(yi);
2472    }
2473
2474    // Outer-hoisted scratch reused across all (a, i) iterations.
2475    // Sized to max_d for trace_rhs and da_tmp; per-row slices used below.
2476    let mut trace_rhs = Array1::<f64>::zeros(cache.d);
2477    let mut da_tmp = Array2::<f64>::zeros((cache.d, k));
2478    let mut col_scratch = Array1::<f64>::zeros(k);
2479    for a in 0..r {
2480        // Part 1: F_{ρ_a} envelope contribution.
2481        let mut grad = value_rho[a];
2482
2483        // Part 2a: Σ_i tr(H_uu_i⁻¹ ∂H_uu_i).
2484        // tr(H_uu_i⁻¹ M_i) = tr(L_iᵀ⁻¹ L_i⁻¹ M_i). Compute as the sum
2485        // over columns: solve L_i Lᵀ x = e_c for the c-th column of
2486        // M_i, then take its c-th component. Equivalently and more
2487        // cheaply, build (H_uu_i⁻¹ M_i) by solving column-by-column
2488        // and take its diagonal sum.
2489        let mut row_trace_acc = 0.0_f64;
2490        for i in 0..n {
2491            let di = cache.row_dims[i];
2492            let m_i = &huu_drho[i][a];
2493            assert_eq!(m_i.shape(), &[di, di]);
2494            for col in 0..di {
2495                let mut tr_rhs_i = trace_rhs.slice_mut(ndarray::s![..di]).to_owned();
2496                for r0 in 0..di {
2497                    tr_rhs_i[r0] = m_i[[r0, col]];
2498                }
2499                let v = cholesky_solve_vector(cache.undamped_factor(i), &tr_rhs_i);
2500                row_trace_acc += v[col];
2501            }
2502        }
2503
2504        // Part 2b: tr(A⁻¹ ∂A) where (proposal §3.5)
2505        //     ∂A = ∂H_ββ
2506        //          - Σ_i (∂H_uβ_i)ᵀ Y_i
2507        //          - Σ_i Y_iᵀ (∂H_uβ_i)
2508        //          + Σ_i Y_iᵀ (∂H_uu_i) Y_i.
2509        // We accumulate ∂A as a dense `K × K` matrix, then evaluate
2510        // tr(A⁻¹ ∂A) by `Σ_j (A⁻¹ ∂A)[j, j]` via column solves of the
2511        // Schur Cholesky.
2512        let mut da = hbb_drho[a].clone();
2513        assert_eq!(da.shape(), &[k, k]);
2514        for i in 0..n {
2515            let di = cache.row_dims[i];
2516            let dhtb = &htbeta_drho[i][a]; // di × K
2517            let yi = &y_blocks[i]; // di × K
2518            // - (∂H_uβ_i)ᵀ Y_i
2519            for r0 in 0..k {
2520                for c0 in 0..k {
2521                    let mut acc = 0.0;
2522                    for cc in 0..di {
2523                        acc += dhtb[[cc, r0]] * yi[[cc, c0]];
2524                    }
2525                    da[[r0, c0]] -= acc;
2526                }
2527            }
2528            // - Y_iᵀ (∂H_uβ_i)
2529            for r0 in 0..k {
2530                for c0 in 0..k {
2531                    let mut acc = 0.0;
2532                    for cc in 0..di {
2533                        acc += yi[[cc, r0]] * dhtb[[cc, c0]];
2534                    }
2535                    da[[r0, c0]] -= acc;
2536                }
2537            }
2538            // + Y_iᵀ (∂H_uu_i) Y_i
2539            let dhuu = &huu_drho[i][a];
2540            // tmp = (∂H_uu_i) Y_i  (di × K) — use a slice of the hoisted buffer.
2541            let mut da_tmp_i = da_tmp.slice_mut(ndarray::s![..di, ..]).to_owned();
2542            for r0 in 0..di {
2543                for c0 in 0..k {
2544                    let mut acc = 0.0;
2545                    for cc in 0..di {
2546                        acc += dhuu[[r0, cc]] * yi[[cc, c0]];
2547                    }
2548                    da_tmp_i[[r0, c0]] = acc;
2549                }
2550            }
2551            // da += Y_iᵀ tmp
2552            for r0 in 0..k {
2553                for c0 in 0..k {
2554                    let mut acc = 0.0;
2555                    for cc in 0..di {
2556                        acc += yi[[cc, r0]] * da_tmp_i[[cc, c0]];
2557                    }
2558                    da[[r0, c0]] += acc;
2559                }
2560            }
2561        }
2562
2563        // tr(A⁻¹ ∂A) via column solves.
2564        let mut schur_trace_acc = 0.0_f64;
2565        for j in 0..k {
2566            for r0 in 0..k {
2567                col_scratch[r0] = da[[r0, j]];
2568            }
2569            let v = cholesky_solve_vector(schur, &col_scratch);
2570            schur_trace_acc += v[j];
2571        }
2572
2573        grad += 0.5 * (row_trace_acc + schur_trace_acc);
2574        grad += ift_correction[a];
2575
2576        // Part 3: -0.5 ∂_{ρ_a} log|S_pen|+.
2577        grad -= 0.5 * pen_logdet_drho[a];
2578
2579        out[a] = grad;
2580    }
2581    out
2582}
2583
2584// ---------------------------------------------------------------------------
2585// Topology selection
2586// ---------------------------------------------------------------------------
2587
2588/// Enumerate the candidate topologies, rank by normalized negative log
2589/// evidence, and return the winner. Failed/excluded candidates (proposal
2590/// §6.11) are appended at the end of `ranking` and are never the winner.
2591///
2592/// The caller fits each topology separately (proposal §4.2) and supplies
2593/// the resulting `TopologyCandidate` records. This function is purely
2594/// the discrete comparator + tie breaker.
2595///
2596/// # Tie-breaking
2597///
2598/// Per proposal §4.6: if normalized `|score_a - score_b| <= tie_tolerance`,
2599/// prefer the simpler topology by `TopologyKind::complexity_rank` (flat <
2600/// periodic < sphere < torus). The `tie` flag in the result records whether
2601/// such a tie occurred at the top of the ranking.
2602///
2603/// # Panics
2604///
2605/// Panics if `candidates` is empty after filtering out non-finite
2606/// scores. Proposal §6.11 explicitly forbids silent fallback to a
2607/// default topology; callers must handle the empty-candidate case
2608/// before invocation.
2609pub fn select_topology(
2610    candidates: &[TopologyCandidate],
2611    options: TopologySelectOptions,
2612) -> SelectedTopology {
2613    // Split valid and excluded.
2614    let mut valid: Vec<TopologyCandidate> = candidates
2615        .iter()
2616        .filter(|c| {
2617            c.converged
2618                && c.exclusion_reason.is_none()
2619                && c.negative_log_evidence.is_finite()
2620                && topology_selection_score(c, options.score_scale).is_finite()
2621        })
2622        .cloned()
2623        .collect();
2624    let mut excluded: Vec<TopologyCandidate> = candidates
2625        .iter()
2626        .filter(|c| {
2627            !(c.converged && c.exclusion_reason.is_none() && c.negative_log_evidence.is_finite())
2628                || !topology_selection_score(c, options.score_scale).is_finite()
2629        })
2630        .cloned()
2631        .collect();
2632
2633    assert!(
2634        !valid.is_empty(),
2635        "select_topology: no finite valid candidates; proposal §6.11 forbids silent fallback"
2636    );
2637
2638    // Sort by normalized negative log evidence (ascending = best first),
2639    // breaking ties by complexity_rank (smaller wins). The shared selector is
2640    // the single lower-is-better ordering contract used by topology ranking,
2641    // seed screening, and REML model comparison (#782).
2642    valid = rank_priority_candidates(
2643        valid
2644            .into_iter()
2645            .enumerate()
2646            .map(|(idx, row)| {
2647                let score = topology_selection_score(&row, options.score_scale);
2648                let tie_break = usize::from(row.kind.complexity_rank());
2649                PriorityCandidate::new(row, idx, score, tie_break)
2650            })
2651            .collect(),
2652    )
2653    .into_iter()
2654    .map(|row| row.item)
2655    .collect();
2656
2657    // Detect numerical tie at the top.
2658    let tie = if valid.len() >= 2 {
2659        let top = topology_selection_score(&valid[0], options.score_scale);
2660        let next = topology_selection_score(&valid[1], options.score_scale);
2661        (next - top).abs() <= options.tie_tolerance
2662    } else {
2663        false
2664    };
2665
2666    // If tied, prefer simpler topology among the tied prefix.
2667    if tie {
2668        let top_score = topology_selection_score(&valid[0], options.score_scale);
2669        // Find the tied prefix range.
2670        let tied_end = valid
2671            .iter()
2672            .position(|c| {
2673                (topology_selection_score(c, options.score_scale) - top_score).abs()
2674                    > options.tie_tolerance
2675            })
2676            .unwrap_or(valid.len());
2677        // Sort the tied prefix by complexity_rank ascending.
2678        valid[..tied_end].sort_by_key(|c| c.kind.complexity_rank());
2679    }
2680
2681    let winner = valid[0].kind;
2682    valid.append(&mut excluded);
2683    SelectedTopology {
2684        winner,
2685        ranking: valid,
2686        tie,
2687    }
2688}
2689
2690fn topology_selection_score(candidate: &TopologyCandidate, scale: TopologyScoreScale) -> f64 {
2691    match scale {
2692        TopologyScoreScale::PerObservation => {
2693            if candidate.n_obs == 0 {
2694                f64::NAN
2695            } else {
2696                candidate.negative_log_evidence / candidate.n_obs as f64
2697            }
2698        }
2699        TopologyScoreScale::PerEffectiveDim => {
2700            if !(candidate.effective_dim.is_finite() && candidate.effective_dim > 0.0) {
2701                f64::NAN
2702            } else {
2703                candidate.negative_log_evidence / candidate.effective_dim
2704            }
2705        }
2706    }
2707}
2708
2709// ---------------------------------------------------------------------------
2710// Cache verification helpers
2711// ---------------------------------------------------------------------------
2712
2713
2714/// Verifies the `ArrowSchurSystem` dimensions match the cache. Used as
2715/// a debug-time precondition; never silently masks shape errors
2716/// (proposal §6.9 — sign and shape errors must be loud).
2717pub fn cache_matches_system(cache: &ArrowFactorCache, sys: &ArrowSchurSystem) -> bool {
2718    cache.d == sys.d
2719        && cache.k == sys.k
2720        && cache.n_rows() == sys.rows.len()
2721        && cache.undamped_factor_count() == sys.rows.len()
2722        && cache.manifold_mode_fingerprint == sys.manifold_mode_fingerprint
2723        && cache.row_hessian_fingerprint == sys.current_row_hessian_fingerprint()
2724}
2725
2726// ---------------------------------------------------------------------------
2727// #1026 hybrid curved + linear-tail dictionary split-selection
2728// ---------------------------------------------------------------------------
2729//
2730// COMMON-EVIDENCE NOTE (#1202): the candidates BOTH fit the same data — the
2731// atom's leave-this-atom-out response residual `y_resp` (the response with every
2732// other atom's contribution removed). The curved candidate predicts the atom's
2733// actual mass-scaled contribution `a_k·γ_k`, the linear candidate the best
2734// mass-weighted straight line fit to `y_resp`. Because the curved family's
2735// `Θ = 0` member reproduces the linear prediction exactly, linear IS the nested
2736// `Θ = 0` sub-model on common data, so the "match-or-beat" statements below are a
2737// genuine data-level comparison: the curved candidate wins only when fitting the
2738// response residual better than its own straight projection pays for its extra
2739// parameters. See `crate::terms::sae::hybrid_split` for the residual assembly.
2740//
2741// The per-slot adjudication uses the SAME rank-aware Laplace evidence criterion
2742// the union/mixture rungs use (`−V = NLE`, lower wins), comparing the data-fit +
2743// complexity cost of the curved contribution against that of the straight line.
2744//
2745// ## The turning floor (Θ → 0) and the curved ceiling (Θ large)
2746//
2747// Per slot, the curved candidate fits the response residual with its actual
2748// mass-scaled contribution `a_k·γ_k` (data-fit `½·curved_rss`) and pays a larger
2749// free-parameter price `P_curved > P_linear`; the linear candidate fits the same
2750// residual with its best straight line (data-fit `½·linear_rss ≥ ½·curved_rss`
2751// whenever the curve beats its own straight projection) at a smaller price,
2752// charged with its genuine weighted Gram logdet `p·(log w_sum + log s_tt)`
2753// (#1203). Hence:
2754//
2755//   * Θ → 0 (the residual is straight): the curve and the line fit it equally, so
2756//     the cheaper LINEAR candidate wins — the turning floor / nested dominance. A
2757//     curved parameterization "buys nothing" on an already-straight residual.
2758//   * Θ large (a genuinely turning residual): the line's data-fit residual
2759//     exceeds the curved atom's extra parameter price, so CURVED wins. (Whether
2760//     curved wins also depends on the coordinate spread `s_tt` and amplitude, via
2761//     the honest logdet — a tightly-spread, mildly-curved residual can still
2762//     prefer the cheaper line.)
2763//
2764// The crossover is governed by the documented shatter law: a linear SAE shatters
2765// a feature of total turning Θ into `N(ε) ≈ Θ/(2√(2ε))` rank-1 directions at
2766// relative reconstruction error ε, so the curved advantage scales as `Θ/√ε`. We
2767// use the fitted turning Θ (`sae::chart_canonicalization::d1_atom_fitted_turning`)
2768// as the decision FEATURE: it both (a) sharpens the evidence comparison into a
2769// falsifiable per-atom prediction and (b) provides the exact-zero dominance
2770// guard — when an atom's fitted turning is identically zero, the curved fit has
2771// no curvature to price and the linear special case is selected by construction,
2772// independent of finite-sample evidence noise.
2773
2774/// Which atom parameterization a hybrid-dictionary slot selects: a CURVED atom
2775/// (a `latent_dim ≥ 1` curved basis whose decoded image may turn) or its LINEAR
2776/// special case (the euclidean-d=1-linear atom — one straight decoder direction,
2777/// `γ(t) = t·b`, fitted turning `Θ = 0`).
2778#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
2779pub enum HybridAtomParam {
2780    /// The curved atom (`latent_dim ≥ 1`), priced at its full coefficient count.
2781    Curved { latent_dim: usize },
2782    /// The linear special case: one decoder direction, zero turning.
2783    Linear,
2784}
2785
2786impl HybridAtomParam {
2787    /// Stable display name for logs and tests.
2788    pub const fn as_str(self) -> &'static str {
2789        match self {
2790            HybridAtomParam::Curved { .. } => "curved",
2791            HybridAtomParam::Linear => "linear",
2792        }
2793    }
2794
2795    /// `true` iff this is the linear special case (the linear tail).
2796    pub const fn is_linear(self) -> bool {
2797        matches!(self, HybridAtomParam::Linear)
2798    }
2799}
2800
2801/// One fitted candidate parameterization for a single hybrid-dictionary atom
2802/// slot, scored on the COMMON rank-aware Laplace scale (`−V = NLE`, lower wins,
2803/// identical to the union/mixture rungs). The curved and linear candidates for
2804/// the SAME slot are fit on the same rows AND the same data (the atom's response
2805/// residual, #1202), so their NLEs are directly comparable; the structural
2806/// difference is the curved candidate's larger free-parameter price and whatever
2807/// data-fit it buys with its curvature.
2808#[derive(Debug, Clone, Copy)]
2809pub struct HybridAtomCandidate {
2810    pub param: HybridAtomParam,
2811    /// Rank-aware Laplace negative-log-evidence on the common scale (lower wins).
2812    pub negative_log_evidence: f64,
2813    /// Free-parameter count this candidate is charged for (the complexity price).
2814    pub num_parameters: usize,
2815    /// The candidate's fitted total turning `Θ = ∫κ ds` of its decoded curve, if
2816    /// the basis admits an analytic second jet. `Some(0.0)` for a linear atom (a
2817    /// straight image has no turning); `None` when the turning is honestly
2818    /// unavailable (no second jet / degenerate curve) — never fabricated.
2819    pub fitted_turning: Option<f64>,
2820}
2821
2822impl HybridAtomCandidate {
2823    /// A linear special-case candidate: exact zero turning by construction.
2824    pub fn linear(negative_log_evidence: f64, num_parameters: usize) -> Self {
2825        Self {
2826            param: HybridAtomParam::Linear,
2827            negative_log_evidence,
2828            num_parameters,
2829            fitted_turning: Some(0.0),
2830        }
2831    }
2832
2833    /// A curved candidate of the given latent dimension, with its fitted turning.
2834    pub fn curved(
2835        latent_dim: usize,
2836        negative_log_evidence: f64,
2837        num_parameters: usize,
2838        fitted_turning: Option<f64>,
2839    ) -> Self {
2840        Self {
2841            param: HybridAtomParam::Curved { latent_dim },
2842            negative_log_evidence,
2843            num_parameters,
2844            fitted_turning,
2845        }
2846    }
2847}
2848
2849/// The evidence-selected parameterization for one hybrid-dictionary atom slot:
2850/// the winning candidate, plus the curved/linear NLEs that decided it (for the
2851/// EV-vs-Θ diagnostic and the tie-break audit trail).
2852#[derive(Debug, Clone, Copy)]
2853pub struct HybridAtomChoice {
2854    pub param: HybridAtomParam,
2855    /// The winning candidate's NLE.
2856    pub negative_log_evidence: f64,
2857    /// The winning candidate's free-parameter price.
2858    pub num_parameters: usize,
2859    /// The curved candidate's fitted turning `Θ` (the decision feature). `None`
2860    /// when no curved candidate offered an analytic turning.
2861    pub curved_turning: Option<f64>,
2862    /// `NLE_linear − NLE_curved`: the evidence margin the curved fit won (or lost,
2863    /// if negative) over the linear special case at this slot. Positive ⇒ curved
2864    /// bought more evidence than its parameter price; ≤ 0 ⇒ the dominance floor
2865    /// keeps the linear tail.
2866    pub curved_evidence_margin: f64,
2867}
2868
2869/// Below this fitted turning the curved candidate is treated as straight: its
2870/// curvature is numerically indistinguishable from zero, so the dominance floor
2871/// (the linear special case is cheaper at equal likelihood) is enforced by
2872/// construction rather than left to finite-sample evidence noise. This is the
2873/// exact-zero guard from the `Θ → 0 ⇒ N(ε) → 0` limit of the shatter law, not a
2874/// tunable knob: it is the curvature scale below which `‖γ' ∧ γ''‖` is at the
2875/// floor of the Simpson quadrature for a genuinely straight image.
2876pub const HYBRID_LINEAR_TURNING_FLOOR: f64 = 1e-9;
2877
2878/// Adjudicate the curved-vs-linear parameterization for ONE hybrid-dictionary
2879/// atom slot by the common rank-aware Laplace evidence criterion.
2880///
2881/// Selection rule (all on the single `NLE = −V` scale, lower wins):
2882///
2883///  1. **Dominance floor (Θ → 0).** If the curved candidate's fitted turning is
2884///     `Some(Θ)` with `Θ ≤ HYBRID_LINEAR_TURNING_FLOOR` and a linear candidate
2885///     exists, select LINEAR. A straight curved fit recovers no likelihood the
2886///     linear special case does not, and the linear atom is strictly cheaper, so
2887///     it cannot lose — we enforce that exactly instead of trusting evidence
2888///     noise at the floor.
2889///  2. **Evidence comparison.** Otherwise select the candidate with the smaller
2890///     `NLE`. The curved candidate wins only when its extra curvature lowers the
2891///     NLE by MORE than its extra parameter price — the `Θ/√ε` crossover, decided
2892///     here by the evidence numbers themselves, not by fiat. This is a
2893///     common-data comparison (both candidates fit the atom's response residual,
2894///     see `crate::terms::sae::hybrid_split`) in which linear is the curved
2895///     family's nested `Θ = 0` sub-model (#1202): the curved candidate cannot be
2896///     charged its extra parameters to fit the residual no better than its own
2897///     straight projection, and a tightly-spread, mildly-curved residual can
2898///     still prefer the cheaper line.
2899///  3. **Tie-break.** Exact NLE ties go to the cheaper (fewer-parameter)
2900///     candidate — i.e. linear — preserving the strict-generalization guarantee
2901///     that the hybrid never pays for curvature it does not need.
2902///
2903/// `candidates` must contain at most one linear and at most one curved candidate
2904/// for the slot; returns `None` only if `candidates` is empty.
2905pub fn select_hybrid_atom(candidates: &[HybridAtomCandidate]) -> Option<HybridAtomChoice> {
2906    if candidates.is_empty() {
2907        return None;
2908    }
2909    let linear = candidates.iter().find(|c| c.param.is_linear());
2910    let curved = candidates.iter().find(|c| !c.param.is_linear());
2911    let curved_turning = curved.and_then(|c| c.fitted_turning);
2912    let curved_evidence_margin = match (linear, curved) {
2913        (Some(l), Some(c)) => l.negative_log_evidence - c.negative_log_evidence,
2914        _ => 0.0,
2915    };
2916
2917    // (1) Exact-zero dominance floor: a straight curved fit yields to the linear
2918    // special case by construction.
2919    if let (Some(l), Some(turning)) = (linear, curved_turning)
2920        && turning <= HYBRID_LINEAR_TURNING_FLOOR
2921    {
2922        return Some(HybridAtomChoice {
2923            param: l.param,
2924            negative_log_evidence: l.negative_log_evidence,
2925            num_parameters: l.num_parameters,
2926            curved_turning,
2927            curved_evidence_margin,
2928        });
2929    }
2930
2931    // (2)+(3) Evidence argmin with the cheaper candidate winning exact ties.
2932    let mut best = candidates[0];
2933    for cand in &candidates[1..] {
2934        let better_evidence = cand.negative_log_evidence < best.negative_log_evidence;
2935        let tied = cand.negative_log_evidence == best.negative_log_evidence;
2936        let cheaper_on_tie = tied && cand.num_parameters < best.num_parameters;
2937        if better_evidence || cheaper_on_tie {
2938            best = *cand;
2939        }
2940    }
2941    Some(HybridAtomChoice {
2942        param: best.param,
2943        negative_log_evidence: best.negative_log_evidence,
2944        num_parameters: best.num_parameters,
2945        curved_turning,
2946        curved_evidence_margin,
2947    })
2948}
2949
2950/// The evidence-selected split for a whole hybrid dictionary: the per-atom
2951/// curved-vs-linear choices and the dictionary-level aggregates the EV-vs-Θ
2952/// frontier reports against.
2953#[derive(Debug, Clone)]
2954pub struct HybridSplitSelection {
2955    /// One adjudicated choice per atom slot, in slot order.
2956    pub atoms: Vec<HybridAtomChoice>,
2957    /// `Σ NLE` across the selected per-atom parameterizations — the dictionary's
2958    /// summed rank-aware Laplace negative-log-evidence (lower wins). Because each
2959    /// slot picks the argmin over {curved contribution, best straight line to the
2960    /// response residual}, this is ≤ the sum of the per-slot LINEAR-candidate
2961    /// NLEs. The linear baseline is the best straight line fit to each atom's
2962    /// leave-this-atom-out RESPONSE residual (#1202), the curved family's nested
2963    /// `Θ = 0` member on common data — so this is a genuine data-level
2964    /// match-or-beat dominance, not a post-hoc curve-simplification one.
2965    pub total_negative_log_evidence: f64,
2966    /// `Σ P` across the selected parameterizations — the dictionary's total
2967    /// free-parameter price (the matched-active-budget accounting).
2968    pub total_parameters: usize,
2969    /// Count of slots that selected the curved parameterization.
2970    pub curved_atom_count: usize,
2971}
2972
2973impl HybridSplitSelection {
2974    /// Count of slots that selected the linear special case (the linear tail).
2975    pub fn linear_atom_count(&self) -> usize {
2976        self.atoms.len() - self.curved_atom_count
2977    }
2978
2979    /// `true` iff every slot selected linear — the pure-linear limit, reached
2980    /// when every feature is straight (all `Θ → 0`).
2981    pub fn is_pure_linear(&self) -> bool {
2982        self.curved_atom_count == 0 && !self.atoms.is_empty()
2983    }
2984
2985    /// `true` iff every slot selected curved — the pure-curved limit, reached
2986    /// when every feature turns enough to pay for curvature.
2987    pub fn is_pure_curved(&self) -> bool {
2988        self.curved_atom_count == self.atoms.len() && !self.atoms.is_empty()
2989    }
2990}
2991
2992/// Adjudicate the curved-vs-linear split across a whole hybrid dictionary by the
2993/// common evidence criterion. `slots[i]` holds the curved/linear candidates for
2994/// atom slot `i` (each scored on the same rows, on the common Laplace scale).
2995///
2996/// The result reduces EXACTLY to pure-linear when every slot's curved candidate
2997/// has `Θ → 0` (the turning floor fires everywhere) and to pure-curved when
2998/// every slot's curved candidate wins the evidence comparison. (Common-data
2999/// criterion, #1202 — both candidates fit the atom's response residual, with
3000/// linear nested as the curved family's `Θ = 0` sub-model; see the module header
3001/// above and `crate::terms::sae::hybrid_split`.)
3002///
3003/// Returns an error only if some slot has no candidates to adjudicate (an empty
3004/// dictionary slot is a caller bug, not a silent skip).
3005pub fn select_hybrid_split(
3006    slots: &[Vec<HybridAtomCandidate>],
3007) -> Result<HybridSplitSelection, String> {
3008    let mut atoms = Vec::with_capacity(slots.len());
3009    let mut total_nle = 0.0_f64;
3010    let mut total_parameters = 0usize;
3011    let mut curved_atom_count = 0usize;
3012    for (i, slot) in slots.iter().enumerate() {
3013        let choice = select_hybrid_atom(slot)
3014            .ok_or_else(|| format!("hybrid split slot {i} has no candidate parameterizations"))?;
3015        if !choice.negative_log_evidence.is_finite() {
3016            return Err(format!(
3017                "hybrid split slot {i} selected a non-finite evidence ({})",
3018                choice.negative_log_evidence
3019            ));
3020        }
3021        if !choice.param.is_linear() {
3022            curved_atom_count += 1;
3023        }
3024        total_nle += choice.negative_log_evidence;
3025        total_parameters += choice.num_parameters;
3026        atoms.push(choice);
3027    }
3028    Ok(HybridSplitSelection {
3029        atoms,
3030        total_negative_log_evidence: total_nle,
3031        total_parameters,
3032        curved_atom_count,
3033    })
3034}
3035
3036// ---------------------------------------------------------------------------
3037// Tests
3038//
3039// These are type-level / structural tests: per the task contract we do
3040// not compile or run them in this session. They document the expected
3041// shapes and degenerate-case behavior so a future maintainer running
3042// `cargo test` sees the contract written down.
3043// ---------------------------------------------------------------------------
3044
3045#[cfg(test)]
3046mod tests {
3047    use super::*;
3048    use crate::arrow_schur::ArrowFactorSlab;
3049
3050    // Dense `H⁻¹` apply via explicit inverse (test-only reference solver).
3051    fn dense_inverse(h: &Array2<f64>) -> Array2<f64> {
3052        let p = h.nrows();
3053        let mut aug = Array2::<f64>::zeros((p, 2 * p));
3054        for i in 0..p {
3055            for j in 0..p {
3056                aug[[i, j]] = h[[i, j]];
3057            }
3058            aug[[i, p + i]] = 1.0;
3059        }
3060        for col in 0..p {
3061            let mut pivot = col;
3062            for row in (col + 1)..p {
3063                if aug[[row, col]].abs() > aug[[pivot, col]].abs() {
3064                    pivot = row;
3065                }
3066            }
3067            if pivot != col {
3068                for j in 0..(2 * p) {
3069                    aug.swap([col, j], [pivot, j]);
3070                }
3071            }
3072            let d = aug[[col, col]];
3073            for j in 0..(2 * p) {
3074                aug[[col, j]] /= d;
3075            }
3076            for row in 0..p {
3077                if row == col {
3078                    continue;
3079                }
3080                let f = aug[[row, col]];
3081                if f != 0.0 {
3082                    for j in 0..(2 * p) {
3083                        aug[[row, j]] -= f * aug[[col, j]];
3084                    }
3085                }
3086            }
3087        }
3088        let mut inv = Array2::<f64>::zeros((p, p));
3089        for i in 0..p {
3090            for j in 0..p {
3091                inv[[i, j]] = aug[[i, p + j]];
3092            }
3093        }
3094        inv
3095    }
3096
3097    #[test]
3098    fn coupling_components_block_diagonal_is_all_singletons_by_block() {
3099        // Two decoupled 2x2 blocks: {0,1} and {2,3}.
3100        let mut h = Array2::<f64>::eye(4);
3101        h[[0, 1]] = 0.3;
3102        h[[1, 0]] = 0.3;
3103        h[[2, 3]] = 0.7;
3104        h[[3, 2]] = 0.7;
3105        let labels = coupling_components(h.view());
3106        assert_eq!(labels[0], labels[1]);
3107        assert_eq!(labels[2], labels[3]);
3108        assert_ne!(labels[0], labels[2]);
3109        // Exactly two components.
3110        let mut uniq = labels.clone();
3111        uniq.sort_unstable();
3112        uniq.dedup();
3113        assert_eq!(uniq.len(), 2);
3114    }
3115
3116    #[test]
3117    fn coupling_components_fully_coupled_is_one_component() {
3118        let mut h = Array2::<f64>::eye(3);
3119        for i in 0..3 {
3120            for j in 0..3 {
3121                if i != j {
3122                    h[[i, j]] = 0.1;
3123                }
3124            }
3125        }
3126        let labels = coupling_components(h.view());
3127        assert!(labels.iter().all(|&l| l == labels[0]));
3128    }
3129
3130    #[test]
3131    fn coupling_components_transitive_chain_merges() {
3132        // 0-1 and 1-2 coupled (but no direct 0-2 edge) must form one component.
3133        let mut h = Array2::<f64>::eye(3);
3134        h[[0, 1]] = 0.5;
3135        h[[1, 0]] = 0.5;
3136        h[[1, 2]] = 0.5;
3137        h[[2, 1]] = 0.5;
3138        let labels = coupling_components(h.view());
3139        assert_eq!(labels[0], labels[1]);
3140        assert_eq!(labels[1], labels[2]);
3141    }
3142
3143    #[test]
3144    fn compare_reml_fits_delta_and_bayes_factor_never_contradict_winner_gh1465() {
3145        // Regression for #1465: the ranking `delta` / `bayes_factor` must be
3146        // measured on the SAME scale that orders the table (the Occam-penalised
3147        // conditional AIC `ranking_score`), so every row's delta is >= 0 and its
3148        // Bayes factor >= 1 — the table must never claim a non-winner beats the
3149        // declared winner. The scenario is exactly the case the comparison
3150        // exists to handle: AIC and raw REML DISAGREE. `m1` is the AIC winner
3151        // but does NOT carry the minimum raw REML (`m2` does) — the noise
3152        // extra-term case from the issue.
3153        //
3154        // `ranking_score` = -2*log_lik + 2*edf; with log_lik = 0 it is `2*edf`,
3155        // so the AIC order is m1 < m2 < m3 while the raw-REML order has m2 lowest.
3156        let cand = |name: &str, score: f64, edf: f64| RemlCandidate {
3157            index: 0,
3158            name: name.to_string(),
3159            score,
3160            edf: Some(edf),
3161            log_lik: Some(0.0),
3162            family: Some("gaussian".to_string()),
3163        };
3164        // raw REML : m2 (41.605) < m1 (53.748) < m3 (120.011)
3165        // AIC=2*edf: m1 (100)    < m2 (102)    < m3 (130)
3166        let candidates = vec![
3167            cand("m1", 53.748, 50.0),
3168            cand("m2", 41.605, 51.0),
3169            cand("m3", 120.011, 65.0),
3170        ];
3171        let cmp = compare_reml_fits(candidates).expect("comparison");
3172
3173        assert_eq!(cmp.winner, "m1", "AIC winner");
3174        // No ranking row may contradict the declared winner.
3175        for row in &cmp.ranking {
3176            assert!(
3177                row.delta >= 0.0,
3178                "ranking delta for {} must be >= 0, got {}",
3179                row.name,
3180                row.delta
3181            );
3182            assert!(
3183                row.bayes_factor >= 1.0 - 1e-12,
3184                "ranking bayes_factor for {} must be >= 1, got {}",
3185                row.name,
3186                row.bayes_factor
3187            );
3188        }
3189        let winner_row = cmp.ranking.iter().find(|r| r.name == "m1").unwrap();
3190        assert!(winner_row.delta.abs() < 1e-12, "winner delta == 0");
3191        assert!(
3192            (winner_row.bayes_factor - 1.0).abs() < 1e-9,
3193            "winner bayes_factor == 1"
3194        );
3195
3196        // The raw-REML score table is referenced to the genuine minimum raw REML
3197        // (m2), so its best-over-model Bayes factors are also coherent (>= 1).
3198        for row in &cmp.score_table {
3199            assert!(
3200                row.delta_reml >= 0.0,
3201                "score-table delta_reml for {} must be >= 0, got {}",
3202                row.name,
3203                row.delta_reml
3204            );
3205            assert!(
3206                row.bayes_factor_best_over_model >= 1.0 - 1e-12,
3207                "score-table bayes_factor for {} must be >= 1, got {}",
3208                row.name,
3209                row.bayes_factor_best_over_model
3210            );
3211        }
3212        // m2 carries the minimum raw REML, so its raw delta is exactly 0.
3213        let m2 = cmp.score_table.iter().find(|r| r.name == "m2").unwrap();
3214        assert!(
3215            m2.delta_reml.abs() < 1e-12,
3216            "the minimum-raw-REML row has delta_reml 0"
3217        );
3218    }
3219
3220    #[test]
3221    fn cone_of_influence_empty_support_is_empty() {
3222        let labels = vec![0usize, 0, 1, 1];
3223        assert!(cone_of_influence(&labels, &[]).is_empty());
3224    }
3225
3226    #[test]
3227    fn cone_of_influence_returns_full_component() {
3228        let labels = vec![0usize, 0, 1, 1];
3229        // Support in component 0 -> cone is {0,1}.
3230        assert_eq!(cone_of_influence(&labels, &[0]), vec![0, 1]);
3231        // Support spanning both -> cone is everything.
3232        assert_eq!(cone_of_influence(&labels, &[1, 2]), vec![0, 1, 2, 3]);
3233    }
3234
3235    #[test]
3236    fn coned_matches_full_solve_on_fully_coupled_hessian() {
3237        // Fully coupled SPD H: cone is the whole space, result must equal the
3238        // unconfined sensitivity-operator mode response bit-for-bit.
3239        let h = Array2::from_shape_vec((3, 3), vec![4.0, 1.0, 0.5, 1.0, 3.0, 0.8, 0.5, 0.8, 2.5])
3240            .unwrap();
3241        let inv = dense_inverse(&h);
3242        // Two ρ-columns, each supported on a single coefficient.
3243        let mut dg = Array2::<f64>::zeros((3, 2));
3244        dg[[0, 0]] = 1.3;
3245        dg[[2, 1]] = -0.7;
3246        let supports = vec![0..1usize, 2..3usize];
3247
3248        let eye: Array2<f64> = Array2::eye(3);
3249        let op = crate::sensitivity::FitSensitivity::from_projected(&eye, &inv);
3250        let full = op.mode_response(dg.view()).unwrap();
3251        let coned = op
3252            .mode_response_coned(h.view(), dg.view(), &supports)
3253            .unwrap();
3254        for i in 0..3 {
3255            for a in 0..2 {
3256                assert!(
3257                    (full[[i, a]] - coned[[i, a]]).abs() < 1e-12,
3258                    "fully-coupled mismatch at ({i},{a}): {} vs {}",
3259                    full[[i, a]],
3260                    coned[[i, a]]
3261                );
3262            }
3263        }
3264    }
3265
3266    #[test]
3267    fn coned_confines_to_component_on_decoupled_hessian() {
3268        // Block-decoupled H: blocks {0,1} and {2,3}. A column supported only in
3269        // block {0,1} must produce sensitivity zero in block {2,3}, and match
3270        // the exact solution within its own block.
3271        let mut h = Array2::<f64>::zeros((4, 4));
3272        // Block A.
3273        h[[0, 0]] = 4.0;
3274        h[[1, 1]] = 3.0;
3275        h[[0, 1]] = 1.0;
3276        h[[1, 0]] = 1.0;
3277        // Block B.
3278        h[[2, 2]] = 2.0;
3279        h[[3, 3]] = 5.0;
3280        h[[2, 3]] = 0.6;
3281        h[[3, 2]] = 0.6;
3282        let inv = dense_inverse(&h);
3283
3284        let mut dg = Array2::<f64>::zeros((4, 1));
3285        dg[[0, 0]] = 0.9;
3286        dg[[1, 0]] = -0.4;
3287        let support_range = 0..2usize;
3288        let supports = std::slice::from_ref(&support_range);
3289
3290        let eye: Array2<f64> = Array2::eye(4);
3291        let coned = crate::sensitivity::FitSensitivity::from_projected(&eye, &inv)
3292            .mode_response_coned(h.view(), dg.view(), supports)
3293            .unwrap();
3294        // Exact reference: -H⁻¹ q. Off-block entries are exactly zero already
3295        // (decoupled inverse), and the cone must preserve the in-block ones.
3296        let q = dg.column(0).to_owned();
3297        let exact = inv.dot(&q).mapv(|v| -v);
3298        for i in 0..4 {
3299            assert!(
3300                (coned[[i, 0]] - exact[[i]]).abs() < 1e-12,
3301                "decoupled mismatch at {i}: {} vs {}",
3302                coned[[i, 0]],
3303                exact[[i]]
3304            );
3305        }
3306        // Block B is outside the cone -> exactly zero.
3307        assert_eq!(coned[[2, 0]], 0.0);
3308        assert_eq!(coned[[3, 0]], 0.0);
3309    }
3310
3311    #[test]
3312    fn coned_skips_inactive_column_with_empty_support() {
3313        let h = Array2::<f64>::eye(2);
3314        let dg = Array2::<f64>::zeros((2, 1));
3315        // Inactive ρ: empty support, must be skipped without solving.
3316        let empty_support = 0..0usize;
3317        let supports = std::slice::from_ref(&empty_support);
3318        // A NaN inverse: an empty-support column must be skipped WITHOUT
3319        // solving, so the operator's finite-check never sees the NaN and the
3320        // result is `Some(zeros)`. Were the inactive column ever solved, the
3321        // NaN would propagate and `mode_response_coned` would return `None`.
3322        let eye: Array2<f64> = Array2::eye(2);
3323        let nan_inv = Array2::<f64>::from_elem((2, 2), f64::NAN);
3324        let coned = crate::sensitivity::FitSensitivity::from_projected(&eye, &nan_inv)
3325            .mode_response_coned(h.view(), dg.view(), supports)
3326            .unwrap();
3327        assert_eq!(coned[[0, 0]], 0.0);
3328        assert_eq!(coned[[1, 0]], 0.0);
3329    }
3330
3331    fn make_minimal_cache() -> ArrowFactorCache {
3332        // d = 1, k = 1, n = 1, H_uu_1 = [[2.0]] => L = [[sqrt(2)]],
3333        // H_uβ_1 = [[0.5]], A = 2 - 0.5 * 0.5 / 2 = 1.875.
3334        let l_huu = Array2::from_shape_vec((1, 1), vec![std::f64::consts::SQRT_2]).unwrap();
3335        let l_schur = Array2::from_shape_vec((1, 1), vec![(1.875_f64).sqrt()]).unwrap();
3336        let htbeta = Array2::from_shape_vec((1, 1), vec![0.5]).unwrap();
3337        ArrowFactorCache {
3338            htt_factors: ArrowFactorSlab::from_blocks(vec![l_huu]),
3339            htt_factors_undamped: crate::arrow_schur::ArrowUndampedFactors::SameAsDamped,
3340            schur_factor: Some(l_schur),
3341            joint_hessian_log_det: None,
3342            solver_mode: crate::arrow_schur::ArrowSolverMode::Direct,
3343            ridge_t: 0.0,
3344            ridge_beta: 0.0,
3345            htbeta: crate::arrow_schur::ArrowHtbetaCache::Dense {
3346                blocks: std::sync::Arc::from(vec![htbeta]),
3347                estimated_bytes: std::mem::size_of::<f64>(),
3348            },
3349            d: 1,
3350            row_dims: std::sync::Arc::from(vec![1usize]),
3351            row_offsets: std::sync::Arc::from(vec![0usize, 1usize]),
3352            k: 1,
3353            manifold_mode_fingerprint: 0,
3354            row_hessian_fingerprint: 0,
3355            pcg_diagnostics: crate::arrow_schur::PcgDiagnostics::default(),
3356            gauge_deflated_directions: 0,
3357            deflated_row_directions: std::sync::Arc::from(Vec::new()),
3358            deflation_row_spectra: std::sync::Arc::from(Vec::new()),
3359            cross_row_woodbury: None,
3360        }
3361    }
3362
3363    #[test]
3364    fn laplace_evidence_returns_finite_for_minimal_cache() {
3365        let cache = make_minimal_cache();
3366        // log|H| = log(2) + log(1.875). With dim(H)=2 and rank(S)=1,
3367        // V includes the rank-aware TK nullspace normalizer.
3368        let v = laplace_evidence(
3369            EvidenceLogDetSource::FactoredArrow {
3370                cache: &cache,
3371                fallback_hvp: None,
3372            },
3373            0.0,
3374            0.0,
3375            2.0,
3376            1.0,
3377        );
3378        assert!(v.is_finite());
3379        let expected =
3380            0.5 * (2.0_f64.ln() + 1.875_f64.ln()) - 0.5 * (2.0 * std::f64::consts::PI).ln();
3381        assert!((v - expected).abs() < 1e-12);
3382    }
3383
3384    /// #1132 bug 2: a β-profiled atom (no shared `β` block, `k == 0`) reaches
3385    /// `arrow_log_det_from_cache` in the dense Direct path with
3386    /// `schur_factor = None` — there is no reduced Schur complement to form. The
3387    /// joint Hessian is then block-diagonal in the latent rows, so its log-det
3388    /// is exactly the per-row sum with NO Schur term. Before the fix this
3389    /// returned `None` (the `schur_factor.as_ref()?` bail), starving the REML
3390    /// Laplace normaliser and erroring "arrow_log_det_from_cache returned None
3391    /// at ridge=0 Direct mode". Now it returns `Some(Σ_i log|H_tt^(i)|)`.
3392    fn k0_direct_cache_no_schur(latent_diag: f64) -> ArrowFactorCache {
3393        let l_huu = Array2::from_shape_vec((1, 1), vec![latent_diag.sqrt()]).unwrap();
3394        ArrowFactorCache {
3395            htt_factors: ArrowFactorSlab::from_blocks(vec![l_huu]),
3396            htt_factors_undamped: crate::arrow_schur::ArrowUndampedFactors::SameAsDamped,
3397            schur_factor: None,
3398            joint_hessian_log_det: None,
3399            solver_mode: crate::arrow_schur::ArrowSolverMode::Direct,
3400            ridge_t: 0.0,
3401            ridge_beta: 0.0,
3402            htbeta: crate::arrow_schur::ArrowHtbetaCache::Disabled { estimated_bytes: 0 },
3403            d: 1,
3404            row_dims: std::sync::Arc::from(vec![1usize]),
3405            row_offsets: std::sync::Arc::from(vec![0usize, 1usize]),
3406            k: 0,
3407            manifold_mode_fingerprint: 0,
3408            row_hessian_fingerprint: 0,
3409            pcg_diagnostics: crate::arrow_schur::PcgDiagnostics::default(),
3410            gauge_deflated_directions: 0,
3411            deflated_row_directions: std::sync::Arc::from(Vec::new()),
3412            deflation_row_spectra: std::sync::Arc::from(Vec::new()),
3413            cross_row_woodbury: None,
3414        }
3415    }
3416
3417    #[test]
3418    fn arrow_log_det_some_for_k0_direct_cache_without_schur() {
3419        let cache = k0_direct_cache_no_schur(3.0);
3420        let log_det = arrow_log_det_from_cache(&cache)
3421            .expect("k==0 Direct cache must yield Some(per-row sum), not None (#1132)");
3422        // Single latent block H_tt = [[3.0]]; no Schur term for k == 0.
3423        assert!(
3424            (log_det - 3.0_f64.ln()).abs() < 1e-12,
3425            "log_det = {log_det}"
3426        );
3427        // The cache's own computation must agree bit-for-bit.
3428        let cached = cache
3429            .compute_undamped_arrow_log_det()
3430            .expect("compute_undamped_arrow_log_det must be Some for k==0");
3431        assert!((cached - 3.0_f64.ln()).abs() < 1e-12, "cached = {cached}");
3432    }
3433
3434    #[test]
3435    fn arrow_log_det_none_for_kpos_cache_without_schur() {
3436        // k > 0 but no dense Schur factor is the genuine InexactPCG case and
3437        // must still reject (the guard must not over-broaden to all `None`).
3438        let mut cache = k0_direct_cache_no_schur(3.0);
3439        cache.k = 1;
3440        cache.solver_mode = crate::arrow_schur::ArrowSolverMode::InexactPCG;
3441        assert!(arrow_log_det_from_cache(&cache).is_none());
3442        assert!(cache.compute_undamped_arrow_log_det().is_none());
3443    }
3444
3445    #[test]
3446    fn laplace_evidence_nan_when_ridge_is_nonzero() {
3447        let mut cache = make_minimal_cache();
3448        cache.ridge_t = 1e-3;
3449        assert!(
3450            laplace_evidence(
3451                EvidenceLogDetSource::FactoredArrow {
3452                    cache: &cache,
3453                    fallback_hvp: None,
3454                },
3455                0.0,
3456                0.0,
3457                2.0,
3458                1.0,
3459            )
3460            .is_nan()
3461        );
3462    }
3463
3464    #[test]
3465    fn laplace_evidence_uses_hvp_fallback_without_schur_factor() {
3466        let mut cache = make_minimal_cache();
3467        cache.schur_factor = None;
3468        let hvp = |x: &[f64]| -> Vec<f64> { vec![2.0 * x[0], 1.875 * x[1]] };
3469        let v = laplace_evidence(
3470            EvidenceLogDetSource::FactoredArrow {
3471                cache: &cache,
3472                fallback_hvp: Some(EvidenceHvpLogDet {
3473                    dim: 2,
3474                    apply: &hvp,
3475                }),
3476            },
3477            0.0,
3478            0.0,
3479            2.0,
3480            1.0,
3481        );
3482        let expected =
3483            0.5 * (2.0_f64.ln() + 1.875_f64.ln()) - 0.5 * (2.0 * std::f64::consts::PI).ln();
3484        assert!((v - expected).abs() < 1e-12);
3485    }
3486
3487    #[test]
3488    fn ift_du_dbeta_has_expected_shape() {
3489        let cache = make_minimal_cache();
3490        let du_db = ift_du_dbeta(&cache);
3491        assert_eq!(du_db.shape(), &[1, 1]);
3492        // ∂u/∂β = -H_uu⁻¹ H_uβ = -0.5 / 2 = -0.25.
3493        assert!((du_db[[0, 0]] - (-0.25)).abs() < 1e-12);
3494    }
3495
3496    #[test]
3497    fn ift_dbeta_drho_returns_some_for_direct_cache() {
3498        let cache = make_minimal_cache();
3499        let q = Array2::from_shape_vec((1, 1), vec![1.0]).unwrap();
3500        let out = ift_dbeta_drho(&cache, q.view()).unwrap();
3501        assert_eq!(out.shape(), &[1, 1]);
3502        // ∂β/∂ρ = -A⁻¹ · 1 = -1/1.875.
3503        assert!((out[[0, 0]] + 1.0 / 1.875).abs() < 1e-12);
3504    }
3505
3506    #[test]
3507    fn topology_select_picks_lowest_negative_log_evidence() {
3508        let candidates = vec![
3509            TopologyCandidate {
3510                kind: TopologyKind::Flat,
3511                negative_log_evidence: 10.0,
3512                effective_dim: 4.0,
3513                n_obs: 100,
3514                converged: true,
3515                exclusion_reason: None,
3516            },
3517            TopologyCandidate {
3518                kind: TopologyKind::Sphere,
3519                negative_log_evidence: 8.0,
3520                effective_dim: 5.0,
3521                n_obs: 100,
3522                converged: true,
3523                exclusion_reason: None,
3524            },
3525            TopologyCandidate {
3526                kind: TopologyKind::Torus,
3527                negative_log_evidence: f64::NAN,
3528                effective_dim: 6.0,
3529                n_obs: 100,
3530                converged: false,
3531                exclusion_reason: Some("torus periods missing".to_string()),
3532            },
3533        ];
3534        let sel = select_topology(&candidates, TopologySelectOptions::default());
3535        assert_eq!(sel.winner, TopologyKind::Sphere);
3536        assert!(!sel.tie);
3537    }
3538
3539    #[test]
3540    fn topology_select_tie_breaks_to_simpler() {
3541        let candidates = vec![
3542            TopologyCandidate {
3543                kind: TopologyKind::Sphere,
3544                negative_log_evidence: 5.0,
3545                effective_dim: 5.0,
3546                n_obs: 100,
3547                converged: true,
3548                exclusion_reason: None,
3549            },
3550            TopologyCandidate {
3551                kind: TopologyKind::Flat,
3552                negative_log_evidence: 5.0 + 1e-6,
3553                effective_dim: 4.0,
3554                n_obs: 100,
3555                converged: true,
3556                exclusion_reason: None,
3557            },
3558        ];
3559        let sel = select_topology(&candidates, TopologySelectOptions::default());
3560        assert_eq!(sel.winner, TopologyKind::Flat);
3561        assert!(sel.tie);
3562    }
3563
3564    fn gaussian_logpdf(y: f64, mean: f64, sd: f64) -> f64 {
3565        let z = (y - mean) / sd;
3566        -0.5 * (2.0 * std::f64::consts::PI).ln() - sd.ln() - 0.5 * z * z
3567    }
3568
3569    #[test]
3570    fn stacking_single_candidate_gets_full_weight() {
3571        let log_density = Array2::from_shape_vec((3, 1), vec![-1.0, -2.0, -0.5]).unwrap();
3572        let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
3573        assert!((out.weights[0] - 1.0).abs() < 1e-12);
3574        assert_eq!(out.weights.len(), 1);
3575    }
3576
3577    #[test]
3578    fn stacking_dominant_candidate_attracts_nearly_all_weight() {
3579        let mut log_density = Array2::<f64>::zeros((50, 2));
3580        for i in 0..50 {
3581            log_density[[i, 0]] = -0.1;
3582            log_density[[i, 1]] = -5.0;
3583        }
3584        let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
3585        assert!(out.weights[0] > 0.99, "w0 = {}", out.weights[0]);
3586        assert!(out.weights[1] < 0.01, "w1 = {}", out.weights[1]);
3587    }
3588
3589    #[test]
3590    fn stacking_complementary_candidates_share_weight() {
3591        // Each candidate is the better predictor on its own half of the data;
3592        // stacking keeps both, unlike winner-take-all.
3593        let n = 40;
3594        let mut log_density = Array2::<f64>::zeros((n, 2));
3595        for i in 0..n {
3596            if i < n / 2 {
3597                log_density[[i, 0]] = gaussian_logpdf(0.0, 0.0, 0.5);
3598                log_density[[i, 1]] = gaussian_logpdf(0.0, 1.5, 0.5);
3599            } else {
3600                log_density[[i, 0]] = gaussian_logpdf(0.0, 1.5, 0.5);
3601                log_density[[i, 1]] = gaussian_logpdf(0.0, 0.0, 0.5);
3602            }
3603        }
3604        let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
3605        assert!(
3606            out.weights[0] > 0.2 && out.weights[0] < 0.8,
3607            "w0 = {}",
3608            out.weights[0]
3609        );
3610        assert!((out.weights.sum() - 1.0).abs() < 1e-9);
3611    }
3612
3613    #[test]
3614    fn stacking_weights_stay_on_the_simplex() {
3615        let log_density = Array2::from_shape_vec(
3616            (3, 3),
3617            vec![-1.0, -2.0, -3.0, -2.5, -1.0, -2.0, -3.0, -2.0, -1.0],
3618        )
3619        .unwrap();
3620        let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
3621        assert!((out.weights.sum() - 1.0).abs() < 1e-9);
3622        assert!(out.weights.iter().all(|&w| w >= -1e-12));
3623    }
3624
3625    #[test]
3626    fn stacking_mean_log_score_is_monotone_under_more_iterations() {
3627        // The EM ascent is monotone in the held-out mean log-score, so allowing
3628        // more iterations never lowers it.
3629        let log_density =
3630            Array2::from_shape_vec((4, 2), vec![-0.2, -3.0, -3.0, -0.2, -0.5, -1.5, -1.5, -0.5])
3631                .unwrap();
3632        let mut prev = f64::NEG_INFINITY;
3633        for max_iter in [1usize, 2, 4, 8, 32] {
3634            let out = solve_stacking_weights(
3635                log_density.view(),
3636                StackingConfig {
3637                    max_iter,
3638                    weight_tol: 0.0,
3639                },
3640            )
3641            .unwrap();
3642            assert!(
3643                out.mean_log_score >= prev - 1e-12,
3644                "log-score decreased at max_iter={max_iter}: {prev} -> {}",
3645                out.mean_log_score
3646            );
3647            prev = out.mean_log_score;
3648        }
3649    }
3650
3651    #[test]
3652    fn stacking_dead_candidate_column_is_rejected_and_zero_weighted() {
3653        let log_density = Array2::from_shape_vec(
3654            (3, 2),
3655            vec![
3656                -1.0,
3657                f64::NEG_INFINITY,
3658                -2.0,
3659                f64::NAN,
3660                -0.5,
3661                f64::NEG_INFINITY,
3662            ],
3663        )
3664        .unwrap();
3665        let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
3666        assert_eq!(out.weights[1], 0.0);
3667        assert!((out.weights[0] - 1.0).abs() < 1e-12);
3668    }
3669
3670    #[test]
3671    fn stacking_rows_with_no_finite_density_are_dropped() {
3672        let log_density = Array2::from_shape_vec(
3673            (3, 2),
3674            vec![-1.0, -2.0, f64::NAN, f64::NEG_INFINITY, -2.0, -1.0],
3675        )
3676        .unwrap();
3677        let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
3678        assert!((out.weights.sum() - 1.0).abs() < 1e-9);
3679        assert!(out.mean_log_score.is_finite());
3680    }
3681
3682    #[test]
3683    fn stacking_all_dead_table_errors() {
3684        let log_density = Array2::from_elem((2, 2), f64::NEG_INFINITY);
3685        assert!(solve_stacking_weights(log_density.view(), StackingConfig::default()).is_err());
3686    }
3687
3688    #[test]
3689    fn stacked_mean_is_weighted_combination() {
3690        let weights = Array1::from_vec(vec![0.25, 0.75]);
3691        let means = vec![
3692            Array1::from_vec(vec![1.0, 2.0, 3.0]),
3693            Array1::from_vec(vec![5.0, 6.0, 7.0]),
3694        ];
3695        let out = stacked_predictive_mean(&weights, &means).unwrap();
3696        assert!((out[0] - (0.25 * 1.0 + 0.75 * 5.0)).abs() < 1e-12);
3697        assert!((out[2] - (0.25 * 3.0 + 0.75 * 7.0)).abs() < 1e-12);
3698    }
3699
3700    #[test]
3701    fn stacked_mean_rejects_shape_mismatch() {
3702        let weights = Array1::from_vec(vec![0.5, 0.5]);
3703        let means = vec![
3704            Array1::from_vec(vec![1.0, 2.0]),
3705            Array1::from_vec(vec![3.0]),
3706        ];
3707        assert!(stacked_predictive_mean(&weights, &means).is_err());
3708    }
3709
3710    // -----------------------------------------------------------------------
3711    // #1026 hybrid curved + linear-tail split-selection
3712    // -----------------------------------------------------------------------
3713
3714    /// Build the two candidate parameterizations for one atom slot the way the
3715    /// fit would: the linear special case (one decoder direction, `Θ = 0`,
3716    /// `P_linear` params) and the curved candidate (`latent_dim` ≥ 1, more
3717    /// params, fitted turning `theta`). The curved candidate's likelihood is the
3718    /// linear likelihood MINUS `curved_loglik_gain` of NLE (curvature it captures
3719    /// the secant cannot), so the nesting invariant `curved_loglik ≥ linear` is
3720    /// honored: a straight feature has zero gain, a turning feature a positive
3721    /// gain that grows with Θ. The rank-aware Laplace normalizer charges the
3722    /// extra `½(P_curved − P_linear)·log(2π)` for the curved parameters, so the
3723    /// evidence comparison is the real `Θ/√ε` crossover.
3724    fn hybrid_slot(
3725        linear_nle: f64,
3726        p_linear: usize,
3727        latent_dim: usize,
3728        p_curved: usize,
3729        theta: f64,
3730        curved_loglik_gain: f64,
3731    ) -> Vec<HybridAtomCandidate> {
3732        let param_price =
3733            0.5 * (p_curved as f64 - p_linear as f64) * (2.0 * std::f64::consts::PI).ln();
3734        let curved_nle = linear_nle - curved_loglik_gain + param_price;
3735        vec![
3736            HybridAtomCandidate::linear(linear_nle, p_linear),
3737            HybridAtomCandidate::curved(latent_dim, curved_nle, p_curved, Some(theta)),
3738        ]
3739    }
3740
3741    #[test]
3742    fn hybrid_dominance_floor_selects_linear_when_turning_is_zero() {
3743        // A perfectly straight curved fit (Θ = 0) gains no likelihood over its
3744        // linear sub-model but pays more parameters → linear must win, by
3745        // construction, even if finite-sample evidence noise nudged the curved
3746        // NLE slightly below linear.
3747        let slot = hybrid_slot(100.0, 2, 1, 5, 0.0, 0.0);
3748        let choice = select_hybrid_atom(&slot).unwrap();
3749        assert!(choice.param.is_linear());
3750        assert_eq!(choice.param, HybridAtomParam::Linear);
3751        // The exact-zero guard fires regardless of the evidence margin sign.
3752        assert!(choice.curved_turning.unwrap() <= HYBRID_LINEAR_TURNING_FLOOR);
3753    }
3754
3755    #[test]
3756    fn hybrid_selects_curved_when_turning_pays_for_itself() {
3757        // A genuinely turning feature (Θ = 2π, a full loop): the curved fit
3758        // captures enough curvature that, even charged the extra-parameter price,
3759        // its NLE drops below the linear secant's → curved wins.
3760        let slot = hybrid_slot(100.0, 2, 1, 5, 2.0 * std::f64::consts::PI, 30.0);
3761        let choice = select_hybrid_atom(&slot).unwrap();
3762        assert_eq!(choice.param, HybridAtomParam::Curved { latent_dim: 1 });
3763        // The curved fit won a strictly positive evidence margin.
3764        assert!(choice.curved_evidence_margin > 0.0);
3765    }
3766
3767    #[test]
3768    fn hybrid_keeps_linear_when_curvature_doesnt_pay_its_price() {
3769        // A barely-curved feature (small Θ): the curved fit recovers only a sliver
3770        // of likelihood, not enough to cover the extra-parameter price → the
3771        // dominance floor keeps the linear tail.
3772        let slot = hybrid_slot(100.0, 2, 1, 5, 0.05, 0.1);
3773        let choice = select_hybrid_atom(&slot).unwrap();
3774        assert!(choice.param.is_linear());
3775        assert!(choice.curved_evidence_margin <= 0.0);
3776    }
3777
3778    #[test]
3779    fn hybrid_tie_breaks_to_the_cheaper_linear_atom() {
3780        // Exact NLE tie (above the turning floor so the evidence path decides):
3781        // the cheaper linear atom wins, preserving strict generalization — the
3782        // hybrid never pays for curvature it does not need.
3783        let theta = 0.5; // above the floor → evidence path, not the exact guard
3784        let nle = 42.0;
3785        let slot = vec![
3786            HybridAtomCandidate::linear(nle, 2),
3787            HybridAtomCandidate::curved(1, nle, 5, Some(theta)),
3788        ];
3789        let choice = select_hybrid_atom(&slot).unwrap();
3790        assert!(choice.param.is_linear());
3791        assert_eq!(choice.num_parameters, 2);
3792    }
3793
3794    #[test]
3795    fn hybrid_split_reduces_to_pure_linear_when_all_features_are_straight() {
3796        // Every slot's curved candidate has Θ → 0 (flat features everywhere): the
3797        // dominance floor fires at every slot → the hybrid recovers the pure-
3798        // linear dictionary exactly. This is the `all Θ → 0` limit (3).
3799        let slots: Vec<Vec<HybridAtomCandidate>> = (0..6)
3800            .map(|i| hybrid_slot(50.0 + i as f64, 2, 1, 5, 0.0, 0.0))
3801            .collect();
3802        let split = select_hybrid_split(&slots).unwrap();
3803        assert!(split.is_pure_linear());
3804        assert_eq!(split.curved_atom_count, 0);
3805        assert_eq!(split.linear_atom_count(), 6);
3806        // Summed NLE equals the pure-linear baseline (every slot chose linear).
3807        let pure_linear: f64 = (0..6).map(|i| 50.0 + i as f64).sum();
3808        assert!((split.total_negative_log_evidence - pure_linear).abs() < 1e-12);
3809    }
3810
3811    #[test]
3812    fn hybrid_split_reduces_to_pure_curved_when_every_feature_curves() {
3813        // Every slot's feature turns enough (Θ = 2π, large likelihood gain) that
3814        // curved beats linear everywhere → the pure-curved limit (3).
3815        let slots: Vec<Vec<HybridAtomCandidate>> = (0..5)
3816            .map(|i| hybrid_slot(80.0 + i as f64, 2, 1, 5, 2.0 * std::f64::consts::PI, 40.0))
3817            .collect();
3818        let split = select_hybrid_split(&slots).unwrap();
3819        assert!(split.is_pure_curved());
3820        assert_eq!(split.curved_atom_count, 5);
3821        assert_eq!(split.linear_atom_count(), 0);
3822    }
3823
3824    #[test]
3825    fn hybrid_split_on_mixed_dictionary_picks_curved_for_circles_linear_for_directions() {
3826        // Mixed synthetic: slots 0..3 are CIRCLE features (high turning Θ = 2π,
3827        // the curved fit captures the loop), slots 3..7 are LINEAR DIRECTIONS
3828        // (straight, Θ = 0). The evidence split must select curved for the
3829        // circles and linear for the directions — and the hybrid's summed
3830        // evidence must be ≤ the summed per-slot LINEAR-candidate NLE (each
3831        // slot's best straight line fit to its response residual). This is a
3832        // data-level match-or-beat dominance (#1202: linear is the curved
3833        // family's nested Θ = 0 sub-model on common data), and holds because each
3834        // slot picks the argmin of its two common-data candidates.
3835        let mut slots: Vec<Vec<HybridAtomCandidate>> = Vec::new();
3836        let mut pure_linear_baseline = 0.0_f64;
3837        // Three circle features: a curved atom replaces ~10-30 linear secants, so
3838        // the curved fit buys a large likelihood gain that dwarfs its param price.
3839        for i in 0..3 {
3840            let linear_nle = 120.0 + 3.0 * i as f64;
3841            pure_linear_baseline += linear_nle;
3842            slots.push(hybrid_slot(
3843                linear_nle,
3844                2,
3845                1,
3846                5,
3847                2.0 * std::f64::consts::PI,
3848                35.0,
3849            ));
3850        }
3851        // Four straight linear directions: zero turning, the linear special case
3852        // is optimal — a curved atom buys nothing and only costs parameters.
3853        for i in 0..4 {
3854            let linear_nle = 90.0 + 2.0 * i as f64;
3855            pure_linear_baseline += linear_nle;
3856            slots.push(hybrid_slot(linear_nle, 2, 1, 5, 0.0, 0.0));
3857        }
3858
3859        let split = select_hybrid_split(&slots).unwrap();
3860
3861        // The first three (circles) chose curved; the last four (directions) chose
3862        // linear.
3863        for (idx, choice) in split.atoms.iter().enumerate() {
3864            if idx < 3 {
3865                assert_eq!(
3866                    choice.param,
3867                    HybridAtomParam::Curved { latent_dim: 1 },
3868                    "circle slot {idx} should select curved"
3869                );
3870            } else {
3871                assert!(
3872                    choice.param.is_linear(),
3873                    "direction slot {idx} should select linear"
3874                );
3875            }
3876        }
3877        assert_eq!(split.curved_atom_count, 3);
3878        assert_eq!(split.linear_atom_count(), 4);
3879
3880        // The hybrid's summed negative-log-evidence is ≤ the summed per-slot
3881        // LINEAR-candidate NLE (each slot's best straight line fit to its response
3882        // residual): the per-slot argmin can only lower the sum. This is a
3883        // data-level match-or-beat dominance (#1202): linear is the curved
3884        // family's nested Θ = 0 sub-model on common data.
3885        assert!(
3886            split.total_negative_log_evidence <= pure_linear_baseline + 1e-9,
3887            "hybrid NLE {} must be <= summed linear-candidate NLE {}",
3888            split.total_negative_log_evidence,
3889            pure_linear_baseline
3890        );
3891        // And strictly better, because the curved circle slots paid off.
3892        assert!(split.total_negative_log_evidence < pure_linear_baseline);
3893    }
3894
3895    #[test]
3896    fn hybrid_split_rejects_empty_slot() {
3897        let slots = vec![hybrid_slot(10.0, 2, 1, 5, 0.0, 0.0), Vec::new()];
3898        assert!(select_hybrid_split(&slots).is_err());
3899    }
3900
3901    // ── #1362: compare_models must Occam-penalise a pure-noise smooth ────────
3902    //
3903    // These tests pin the ranking contract directly on `compare_reml_fits` with
3904    // controlled (score, edf, log_lik) inputs taken from the actual #1362
3905    // reproduction (Rust `reml_score` of `y ~ s(x)` vs `y ~ s(x) + s(z)` at
3906    // n=700). They do not need a fitted GAM or a Python wheel.
3907
3908    fn cand(name: &str, score: f64, edf: f64, log_lik: f64) -> RemlCandidate {
3909        RemlCandidate {
3910            index: 0,
3911            name: name.to_string(),
3912            score,
3913            edf: Some(edf),
3914            log_lik: Some(log_lik),
3915            family: None,
3916        }
3917    }
3918
3919    #[test]
3920    fn ranking_score_is_conditional_aic_when_loglik_and_edf_present() {
3921        // AIC = -2ℓ + 2·edf.
3922        let c = cand("m", /*score (ignored)*/ 999.0, 6.748, -32.0866);
3923        let expected = -2.0 * -32.0866 + 2.0 * 6.748;
3924        assert!((c.ranking_score() - expected).abs() < 1e-9);
3925    }
3926
3927    #[test]
3928    fn ranking_score_falls_back_to_evidence_without_loglik() {
3929        let c = RemlCandidate {
3930            index: 0,
3931            name: "m".to_string(),
3932            score: 151.28,
3933            edf: Some(6.0),
3934            log_lik: None,
3935            family: None,
3936        };
3937        assert_eq!(c.ranking_score(), 151.28);
3938    }
3939
3940    #[test]
3941    fn compare_models_rejects_pure_noise_smooth_despite_lower_evidence() {
3942        // Seed-3000 numbers from the #1362 Rust reproduction:
3943        //   small (y ~ s(x)):      reml=180.526, edf=6.748,  loglik=-32.0866
3944        //   big   (y ~ s(x)+s(z)): reml=177.404, edf=14.250, loglik=-32.1212
3945        // The big (noise-augmented) model has the LOWER (apparently better) raw
3946        // REML evidence, yet it spends ~7.5 extra EDF fitting noise without
3947        // improving the likelihood. The winner must be the SMALL model.
3948        let small = cand("small", 180.526, 6.748, -32.0866);
3949        let big = cand("big", 177.404, 14.250, -32.1212);
3950
3951        // Sanity: raw evidence (the broken headline) prefers big.
3952        assert!(big.score < small.score);
3953
3954        let cmp = compare_reml_fits(vec![small, big]).expect("compare");
3955        assert_eq!(
3956            cmp.winner, "small",
3957            "compare_models must Occam-penalise the pure-noise smooth and pick the smaller model"
3958        );
3959        // The score table still reports the raw evidence headline unchanged, so
3960        // Model.evidence / bayes_factor_vs stay consistent with the table.
3961        let small_row = cmp
3962            .score_table
3963            .iter()
3964            .find(|r| r.name == "small")
3965            .expect("small row");
3966        let big_row = cmp
3967            .score_table
3968            .iter()
3969            .find(|r| r.name == "big")
3970            .expect("big row");
3971        assert!((small_row.reml_score - 180.526).abs() < 1e-9);
3972        assert!((big_row.reml_score - 177.404).abs() < 1e-9);
3973    }
3974
3975    #[test]
3976    fn compare_models_keeps_power_for_a_relevant_smooth() {
3977        // Seed-3000 relevant-z numbers from the same reproduction:
3978        //   small: reml=1025.067, edf≈6.75,  loglik≈-368.99 (aic≈751.5)
3979        //   big:   reml=199.509,  edf≈14.25, loglik≈-33.16  (aic≈94.8)
3980        // A genuinely relevant smooth lowers BOTH the evidence and the AIC, so
3981        // the bigger model must still win — a fix cannot just always pick small.
3982        let small = cand("small", 1025.067, 6.75, -368.985);
3983        let big = cand("big", 199.509, 14.25, -33.165);
3984        let cmp = compare_reml_fits(vec![small, big]).expect("compare");
3985        assert_eq!(
3986            cmp.winner, "big",
3987            "compare_models must retain power: the relevant smooth's model must win"
3988        );
3989    }
3990}