Skip to main content

gam_solve/
evidence.rs

1//! Canonical Laplace evidence, IFT cascade, and topology selection.
2//!
3//! This module is the single canonical entry point for:
4//!
5//!   1. Laplace evidence `V(ρ, T) = F + (1/2) log|H| - (1/2) log|S(ρ)|+
6//!      - ((dim(H)-rank(S))/2) log(2π)`
7//!      evaluated at the arrow-Schur inner-loop fixed point.
8//!   2. The full IFT cascade `∂u*/∂β → ∂β*/∂ρ → ∂u*/∂ρ` through the three
9//!      continuous tiers `(u, β, ρ)`, per §2.2 / §2.4 / §2.6.
10//!   3. The per-`ρ` evidence gradient `∂V/∂ρ` via the arrow trace formula,
11//!      per §3.5 / §3.7 / §3.8.
12//!   4. Discrete topology selection across `{periodic, flat, sphere, torus}`,
13//!      per §4 (4.1 / 4.5 / 4.6).
14//!
15//! ## Crucial numerical invariants (proposal §1.7, §6.4, §6.5)
16//!
17//!   * Evidence log-determinants use **undamped** factors. The cached
18//!     `ArrowFactorCache::htt_factors_undamped` Cholesky factors of
19//!     `H_uu_i` (no `ridge_u`) are the ones that must enter
20//!     `Σ_i log|H_uu_i|`. Likewise a factored Schur log-det must be of
21//!     `A(0, 0) = H_ββ - Σ_i H_uβ_iᵀ H_uu_i⁻¹ H_uβ_i`, not the LM-damped
22//!     surrogate. Matrix-free evidence callers must provide the matching
23//!     undamped HVP so the same log-det is estimated by SLQ.
24//!   * IFT solves invert `H_uu`, not `H_uu + ridge_u I` (proposal §1.7,
25//!     §6.6). The evidence-side IFT predictor loop here uses the undamped
26//!     `htt_factors_undamped` factors for exactly this reason.
27//!   * Penalty pseudo-logdet `log|S(ρ)|+` is the prior penalty, distinct
28//!     from the arrow Schur complement (proposal §3.1, §3.6). The variable
29//!     names below preserve that distinction:
30//!       `arrow_schur_log_det`   = `log|A|` where `A` is the arrow Schur.
31//!       `penalty_log_det`       = `log|S_pen(ρ)|+` where `S_pen` is the
32//!                                 prior penalty matrix pseudo-logdet.
33//!
34//! ## Sign discipline (proposal §3.1, §4.3)
35//!
36//! `V` as written is the *negative log evidence* when `F` is the
37//! penalized negative log posterior. The maximizer of evidence is the
38//! minimizer of `V`. For the public API we expose **negative log
39//! evidence** under `laplace_evidence` and rank topologies by the
40//! **minimum** of the configured per-row or per-effective-dimension
41//! normalization (see `select_topology` below); equivalently the caller can
42//! negate and `argmax`.
43
44use faer::Side;
45use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
46
47use gam_linalg::faer_ndarray::FaerEigh;
48use gam_linalg::lanczos::{
49    SymmetricLanczosOptions, symmetric_lanczos_eigenpairs, symmetric_lanczos_log_quadrature,
50};
51use gam_linalg::triangular::cholesky_solve_vector;
52use crate::arrow_schur::{ArrowFactorCache, ArrowSchurSystem};
53use crate::priority_selection::{PriorityCandidate, rank_priority_candidates};
54
55pub const ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD: usize = 1024;
56const EVIDENCE_LOGDET_SLQ_PROBES: usize = 16;
57const EVIDENCE_LOGDET_LANCZOS_STEPS: usize = 32;
58const EVIDENCE_HVP_SYMMETRY_REL_TOL: f64 = 1e-8;
59const EVIDENCE_HVP_SYMMETRY_PROBES: usize = 4;
60
61/// Matrix-free SPD Hessian logdet source used when the arrow Schur factor is
62/// not materialized. The callback must apply the same undamped Hessian whose
63/// determinant enters the Laplace evidence.
64#[derive(Clone, Copy)]
65pub struct EvidenceHvpLogDet<'a> {
66    pub dim: usize,
67    pub apply: &'a dyn Fn(&[f64]) -> Vec<f64>,
68}
69
70/// Source for the Hessian log determinant in [`laplace_evidence`].
71#[derive(Clone, Copy)]
72pub enum EvidenceLogDetSource<'a> {
73    /// Use the exact arrow Cholesky factors, falling back to `fallback_hvp`
74    /// when the Schur factor is absent on a matrix-free solve.
75    FactoredArrow {
76        cache: &'a ArrowFactorCache,
77        fallback_hvp: Option<EvidenceHvpLogDet<'a>>,
78    },
79    /// Use an HVP callback directly. Dimensions at or below
80    /// [`ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD`] are materialized exactly;
81    /// larger operators use the same Rademacher-Lanczos SLQ constants as
82    /// `FrozenAnalyticPenaltyOp`.
83    Hvp(EvidenceHvpLogDet<'a>),
84}
85
86// ---------------------------------------------------------------------------
87// Topology candidate enum and selection result
88// ---------------------------------------------------------------------------
89
90/// Discrete topology choice for the latent coordinate domain.
91///
92/// Maps directly to the set `{periodic, flat, sphere, torus}`. No additional
93/// variants — unused candidate variants are deliberately not carried
94/// alongside the four-way selector.
95#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
96pub enum TopologyKind {
97    /// `S¹` or periodic interval (cyclic B-spline / periodic Duchon).
98    Periodic,
99    /// `Rᵈ` Euclidean Duchon / Matérn / thin-plate patch.
100    Flat,
101    /// `S²` embedded in `R³`, spherical Wahba/Sobolev basis.
102    Sphere,
103    /// `S¹ × S¹` mixed-periodicity Duchon.
104    Torus,
105}
106
107impl TopologyKind {
108    /// Tie-break priority — smaller wins. Per §4.6: `flat < periodic <
109    /// sphere < torus`.
110    pub fn complexity_rank(self) -> u8 {
111        match self {
112            TopologyKind::Flat => 0,
113            TopologyKind::Periodic => 1,
114            TopologyKind::Sphere => 2,
115            TopologyKind::Torus => 3,
116        }
117    }
118}
119
120/// One topology candidate together with the evidence ingredients it
121/// produced at its own fitted optimum.
122#[derive(Debug, Clone)]
123pub struct TopologyCandidate {
124    pub kind: TopologyKind,
125    /// Negative-log-evidence `V(ρ_T*, T)` evaluated at the candidate's own
126    /// fitted `(ρ_T*, β_T*, u_T*)`.
127    pub negative_log_evidence: f64,
128    /// Effective integrated dimension after rank/nullspace accounting. This
129    /// is the dimension used for per-complexity topology normalization.
130    pub effective_dim: f64,
131    /// Number of response rows used to fit this topology candidate. This is
132    /// the dimension used for per-observation topology normalization.
133    pub n_obs: usize,
134    /// `True` iff the candidate's continuous inner+outer fit converged
135    /// cleanly. Failed candidates are excluded from ranking (proposal
136    /// §4.4 item 7 and §6.11).
137    pub converged: bool,
138    /// Optional rationale string for excluded candidates (proposal
139    /// §6.11): `"sphere input not on S²"`, `"torus periods missing"`, etc.
140    pub exclusion_reason: Option<String>,
141}
142
143/// Outcome of [`select_topology`].
144#[derive(Debug, Clone)]
145pub struct SelectedTopology {
146    pub winner: TopologyKind,
147    /// All candidates sorted from best (lowest negative log evidence)
148    /// to worst, with excluded candidates appended last.
149    pub ranking: Vec<TopologyCandidate>,
150    /// `True` iff the top two finite scores fall within `tie_tolerance`.
151    /// Per §4.6 we still pick one — the simpler topology — but expose
152    /// the tie so callers can warn.
153    pub tie: bool,
154}
155
156/// Tolerance options for the topology comparator.
157#[derive(Debug, Clone, Copy)]
158pub struct TopologySelectOptions {
159    /// Maximum `|V_a - V_b|` for which two candidates are treated as
160    /// numerically tied after [`TopologyScoreScale`] normalization. Default
161    /// `1e-3` per proposal §4.6 examples.
162    pub tie_tolerance: f64,
163    /// Score scale used for discrete topology comparison. Raw evidence is
164    /// intentionally not a selector because candidates may have different
165    /// row counts and basis/nullspace dimensions.
166    pub score_scale: TopologyScoreScale,
167}
168
169/// Normalization applied before ranking topology candidates.
170#[derive(Debug, Clone, Copy, PartialEq, Eq)]
171pub enum TopologyScoreScale {
172    /// Compare negative log evidence per observation row.
173    PerObservation,
174    /// Compare negative log evidence per effective integrated dimension.
175    PerEffectiveDim,
176}
177
178/// Convergence controls for stacking retained topology predictive densities.
179#[derive(Debug, Clone, Copy)]
180pub struct StackingConfig {
181    pub max_iter: usize,
182    pub weight_tol: f64,
183}
184
185impl Default for StackingConfig {
186    fn default() -> Self {
187        Self {
188            max_iter: 1000,
189            weight_tol: 1e-10,
190        }
191    }
192}
193
194/// Simplex weights for retained topology candidates plus the achieved held-out
195/// mean log-score.
196#[derive(Debug, Clone)]
197pub struct StackingWeights {
198    pub weights: Array1<f64>,
199    pub mean_log_score: f64,
200    pub iterations: usize,
201}
202
203/// Solve the stacking-of-predictive-distributions weight problem from a
204/// per-observation held-out log-density table `log_density[i, k] = log p_k(y_i)`.
205///
206/// This belongs on the evidence surface rather than in a separate solver: it is
207/// the topology/evidence consumer that replaces winner-take-all only when the
208/// caller has retained candidate fits and per-point held-out densities.
209pub fn solve_stacking_weights(
210    log_density: ArrayView2<'_, f64>,
211    config: StackingConfig,
212) -> Result<StackingWeights, String> {
213    let n_obs = log_density.nrows();
214    let n_cand = log_density.ncols();
215    if n_cand == 0 {
216        return Err("stacking requires at least one candidate column".to_string());
217    }
218    if n_obs == 0 {
219        return Err("stacking requires at least one held-out observation row".to_string());
220    }
221
222    let kept_cols: Vec<usize> = (0..n_cand)
223        .filter(|&k| (0..n_obs).any(|i| log_density[[i, k]].is_finite()))
224        .collect();
225    if kept_cols.is_empty() {
226        return Err("stacking found no candidate with any finite held-out density".to_string());
227    }
228    let rows: Vec<usize> = (0..n_obs)
229        .filter(|&i| kept_cols.iter().any(|&k| log_density[[i, k]].is_finite()))
230        .collect();
231    if rows.is_empty() {
232        return Err("stacking found no held-out row with a finite density".to_string());
233    }
234
235    let kept = kept_cols.len();
236    let mut weights = Array1::<f64>::from_elem(kept, 1.0 / kept as f64);
237    let mut next = Array1::<f64>::zeros(kept);
238    let mut iterations = 0usize;
239    for _ in 0..config.max_iter {
240        iterations += 1;
241        next.fill(0.0);
242        let mut active_rows = 0usize;
243        for &row in &rows {
244            let mut row_max = f64::NEG_INFINITY;
245            for (local_col, &source_col) in kept_cols.iter().enumerate() {
246                let log_p = log_density[[row, source_col]];
247                if log_p.is_finite() && weights[local_col] > 0.0 {
248                    row_max = row_max.max(weights[local_col].ln() + log_p);
249                }
250            }
251            if !row_max.is_finite() {
252                continue;
253            }
254            let mut denom = 0.0_f64;
255            for (local_col, &source_col) in kept_cols.iter().enumerate() {
256                let log_p = log_density[[row, source_col]];
257                if log_p.is_finite() && weights[local_col] > 0.0 {
258                    denom += (weights[local_col].ln() + log_p - row_max).exp();
259                }
260            }
261            if denom <= 0.0 {
262                continue;
263            }
264            active_rows += 1;
265            let log_mix = row_max + denom.ln();
266            for (local_col, &source_col) in kept_cols.iter().enumerate() {
267                let log_p = log_density[[row, source_col]];
268                if log_p.is_finite() && weights[local_col] > 0.0 {
269                    next[local_col] += (weights[local_col].ln() + log_p - log_mix).exp();
270                }
271            }
272        }
273        if active_rows == 0 {
274            break;
275        }
276        next.mapv_inplace(|value| value / active_rows as f64);
277        let total = next.sum();
278        if total > 0.0 {
279            next.mapv_inplace(|value| value / total);
280        }
281        let delta = next
282            .iter()
283            .zip(weights.iter())
284            .fold(0.0_f64, |acc, (a, b)| acc.max((a - b).abs()));
285        weights.assign(&next);
286        if delta <= config.weight_tol {
287            break;
288        }
289    }
290
291    let mean_log_score = stacking_mean_log_score(log_density, &rows, &kept_cols, weights.view());
292    let mut full = Array1::<f64>::zeros(n_cand);
293    for (local_col, &source_col) in kept_cols.iter().enumerate() {
294        full[source_col] = weights[local_col];
295    }
296    Ok(StackingWeights {
297        weights: full,
298        mean_log_score,
299        iterations,
300    })
301}
302
303fn stacking_mean_log_score(
304    log_density: ArrayView2<'_, f64>,
305    rows: &[usize],
306    kept_cols: &[usize],
307    weights: ArrayView1<'_, f64>,
308) -> f64 {
309    let mut score_sum = 0.0_f64;
310    let mut counted = 0usize;
311    for &row in rows {
312        let mut row_max = f64::NEG_INFINITY;
313        for (local_col, &source_col) in kept_cols.iter().enumerate() {
314            let log_p = log_density[[row, source_col]];
315            if log_p.is_finite() && weights[local_col] > 0.0 {
316                row_max = row_max.max(weights[local_col].ln() + log_p);
317            }
318        }
319        if !row_max.is_finite() {
320            continue;
321        }
322        let mut denom = 0.0_f64;
323        for (local_col, &source_col) in kept_cols.iter().enumerate() {
324            let log_p = log_density[[row, source_col]];
325            if log_p.is_finite() && weights[local_col] > 0.0 {
326                denom += (weights[local_col].ln() + log_p - row_max).exp();
327            }
328        }
329        if denom > 0.0 {
330            score_sum += row_max + denom.ln();
331            counted += 1;
332        }
333    }
334    if counted == 0 {
335        f64::NEG_INFINITY
336    } else {
337        score_sum / counted as f64
338    }
339}
340
341/// Combine retained candidate response-scale means with stacking weights.
342pub fn stacked_predictive_mean(
343    weights: &Array1<f64>,
344    candidate_means: &[Array1<f64>],
345) -> Result<Array1<f64>, String> {
346    if candidate_means.len() != weights.len() {
347        return Err(format!(
348            "stacked_predictive_mean: {} weights but {} candidate mean vectors",
349            weights.len(),
350            candidate_means.len()
351        ));
352    }
353    let Some(first) = candidate_means.first() else {
354        return Err("stacked_predictive_mean requires at least one candidate".to_string());
355    };
356    let n_rows = first.len();
357    if candidate_means.iter().any(|means| means.len() != n_rows) {
358        return Err(
359            "stacked_predictive_mean: candidate mean vectors disagree on row count".to_string(),
360        );
361    }
362    let mut out = Array1::<f64>::zeros(n_rows);
363    for (weight, means) in weights.iter().zip(candidate_means) {
364        if *weight != 0.0 {
365            out.scaled_add(*weight, means);
366        }
367    }
368    Ok(out)
369}
370
371// ---------------------------------------------------------------------------
372// Discrete mixture rung (Object 3a / WP-C)
373// ---------------------------------------------------------------------------
374//
375// A `k`-component full-covariance Gaussian mixture fitted by deterministic
376// k-means++-style seeding (reusing `terms::basis` farthest-point k-means) plus
377// EM to a tolerance. It is priced by its free-parameter count and scored
378// through the SAME rank-aware Laplace/Tierney-Kadane normalizer as the smooth
379// topology candidates: `−V = loglik − ½ log|H| + ½ P log(2π)` with the
380// `−½ (dim(H) − rank(S)) log(2π)` normalizer evaluated at `dim(H) = P`,
381// `rank(S) = 0` (a fully likelihood-identified, unpenalized parametric model,
382// so every free parameter is unpenalized null-space). The Hessian log-det
383// `log|H|` is the observed (empirical-Fisher / BHHH) information
384// `H = Σ_i s_i s_iᵀ`, the exact, finite, SPD observed-information surrogate at
385// the EM optimum, fed through the same `laplace_evidence` entry point used by
386// the smooth rungs so the two model classes are comparable on the evidence
387// scale.
388
389/// Convergence + ladder controls for the discrete-mixture rung. All fields are
390/// fixed (no clock randomness, no env): deterministic seeding makes the fitted
391/// mixture a pure function of the data and `k`.
392#[derive(Debug, Clone, Copy)]
393pub struct GaussianMixtureConfig {
394    /// Maximum EM iterations.
395    pub max_iter: usize,
396    /// Relative mean-log-likelihood improvement tolerance for EM stopping.
397    pub loglik_tol: f64,
398    /// Ridge added to each component covariance for numerical SPD safety
399    /// (variance floor). A small fixed value, not a tuned knob.
400    pub covariance_floor: f64,
401    /// Maximum iterations for the deterministic k-means seeding pass.
402    pub kmeans_max_iter: usize,
403}
404
405impl Default for GaussianMixtureConfig {
406    fn default() -> Self {
407        Self {
408            max_iter: 200,
409            loglik_tol: 1e-7,
410            covariance_floor: 1e-6,
411            kmeans_max_iter: 25,
412        }
413    }
414}
415
416/// A fitted `k`-component full-covariance Gaussian mixture.
417#[derive(Debug, Clone)]
418pub struct GaussianMixtureFit {
419    /// Mixing weights, length `k`, on the simplex.
420    pub weights: Array1<f64>,
421    /// Component means, `k × d`.
422    pub means: Array2<f64>,
423    /// Component covariances, `k` matrices of shape `d × d` (SPD).
424    pub covariances: Vec<Array2<f64>>,
425    /// Number of mixture components.
426    pub k: usize,
427    /// Data dimension.
428    pub d: usize,
429    /// Number of rows used to fit.
430    pub n_obs: usize,
431    /// Maximised total log-likelihood `Σ_i log Σ_j w_j N(y_i; μ_j, Σ_j)`.
432    pub loglik: f64,
433    /// EM iterations taken.
434    pub iterations: usize,
435}
436
437impl GaussianMixtureFit {
438    /// Free-parameter count `P` of a `k`-component full-covariance mixture in
439    /// `d` dimensions: `(k − 1)` mixing weights on the simplex, `k·d` mean
440    /// coordinates, and `k · d(d+1)/2` covariance entries. This is the exact
441    /// quantity that enters the rank-aware normalizer as `dim(H) − rank(S)`.
442    pub fn num_free_parameters(&self) -> usize {
443        let cov_per = self.d * (self.d + 1) / 2;
444        (self.k - 1) + self.k * self.d + self.k * cov_per
445    }
446
447    /// Per-observation log predictive density `log p(y_i)` under the fitted
448    /// mixture, length `n`. This is the held-out-density column source for
449    /// cross-class stacking when the mixture is evaluated on a held-out fold.
450    pub fn per_point_log_density(&self, data: ArrayView2<'_, f64>) -> Result<Array1<f64>, String> {
451        if data.ncols() != self.d {
452            return Err(format!(
453                "mixture log-density expects {} columns, got {}",
454                self.d,
455                data.ncols()
456            ));
457        }
458        let n = data.nrows();
459        let mut comp = vec![GaussianComponentEval::new(self.d); self.k];
460        for j in 0..self.k {
461            comp[j] = GaussianComponentEval::factor(self.means.row(j), &self.covariances[j])?;
462        }
463        let mut out = Array1::<f64>::zeros(n);
464        let log_w: Vec<f64> = self
465            .weights
466            .iter()
467            .map(|w| w.max(f64::MIN_POSITIVE).ln())
468            .collect();
469        for i in 0..n {
470            let row = data.row(i);
471            let mut log_terms = vec![f64::NEG_INFINITY; self.k];
472            let mut max_term = f64::NEG_INFINITY;
473            for j in 0..self.k {
474                let lt = log_w[j] + comp[j].log_density(row);
475                log_terms[j] = lt;
476                if lt > max_term {
477                    max_term = lt;
478                }
479            }
480            out[i] = log_sum_exp(&log_terms, max_term);
481        }
482        Ok(out)
483    }
484
485    /// Rank-aware Laplace **negative** log evidence on the SAME scale as the
486    /// smooth topology rungs. `−V = loglik − ½ log|H| + ½ P log(2π)`, realised
487    /// by calling [`laplace_evidence`] with `residual_objective = −loglik`,
488    /// `penalty_log_det = 0`, `penalty_rank = 0`, `effective_dim = P`, and
489    /// `log|H|` the observed empirical-Fisher information at the optimum.
490    pub fn laplace_negative_log_evidence(&self, data: ArrayView2<'_, f64>) -> Result<f64, String> {
491        let p = self.num_free_parameters();
492        let information = self.empirical_fisher_information(data)?;
493        if information.nrows() != p {
494            return Err(format!(
495                "mixture empirical-Fisher information has dim {} but expected free-parameter count {p}",
496                information.nrows()
497            ));
498        }
499        let apply_info = |x: &[f64]| -> Vec<f64> {
500            let mut out = vec![0.0_f64; p];
501            for r in 0..p {
502                let mut acc = 0.0_f64;
503                for c in 0..p {
504                    acc += information[[r, c]] * x[c];
505                }
506                out[r] = acc;
507            }
508            out
509        };
510        let hvp = EvidenceHvpLogDet {
511            dim: p,
512            apply: &apply_info,
513        };
514        let v = laplace_evidence(
515            EvidenceLogDetSource::Hvp(hvp),
516            0.0,
517            -self.loglik,
518            p as f64,
519            0.0,
520        );
521        if !v.is_finite() {
522            return Err("mixture Laplace evidence is not finite".to_string());
523        }
524        Ok(v)
525    }
526
527    /// Observed empirical-Fisher (BHHH) information `H = Σ_i s_i s_iᵀ`, where
528    /// `s_i = ∇_θ log p(y_i)` is the per-observation score in the
529    /// free-parameter coordinates: softmax-logit mixing weights (`k − 1`),
530    /// component means (`k·d`), and the lower-triangular covariance entries
531    /// (`k · d(d+1)/2`) of each component, in that block order. This SPD matrix
532    /// is the genuine observed-information surrogate evaluated at the EM
533    /// optimum — its dimension is exactly `P`, which is what enters the
534    /// rank-aware normalizer.
535    fn empirical_fisher_information(
536        &self,
537        data: ArrayView2<'_, f64>,
538    ) -> Result<Array2<f64>, String> {
539        if data.ncols() != self.d {
540            return Err(format!(
541                "mixture information expects {} columns, got {}",
542                self.d,
543                data.ncols()
544            ));
545        }
546        let n = data.nrows();
547        let p = self.num_free_parameters();
548        let cov_per = self.d * (self.d + 1) / 2;
549        // Precompute per-component evaluators (mean, precision = Σ⁻¹).
550        let mut comp = Vec::with_capacity(self.k);
551        for j in 0..self.k {
552            comp.push(GaussianComponentEval::factor(
553                self.means.row(j),
554                &self.covariances[j],
555            )?);
556        }
557        let log_w: Vec<f64> = self
558            .weights
559            .iter()
560            .map(|w| w.max(f64::MIN_POSITIVE).ln())
561            .collect();
562
563        let mean_base = self.k - 1;
564        let cov_base = mean_base + self.k * self.d;
565
566        let mut info = Array2::<f64>::zeros((p, p));
567        let mut score = vec![0.0_f64; p];
568        for i in 0..n {
569            let row = data.row(i);
570            // Responsibilities r_j = w_j N_j / Σ.
571            let mut log_terms = vec![0.0_f64; self.k];
572            let mut max_term = f64::NEG_INFINITY;
573            for j in 0..self.k {
574                let lt = log_w[j] + comp[j].log_density(row);
575                log_terms[j] = lt;
576                if lt > max_term {
577                    max_term = lt;
578                }
579            }
580            let log_mix = log_sum_exp(&log_terms, max_term);
581            let resp: Vec<f64> = log_terms.iter().map(|lt| (lt - log_mix).exp()).collect();
582
583            for s in score.iter_mut() {
584                *s = 0.0;
585            }
586            // Softmax-logit mixing score: ∂/∂α_j log p = r_j − w_j for the free
587            // logits j = 1..k-1 (component 0 is the reference / pinned logit).
588            for j in 1..self.k {
589                score[j - 1] = resp[j] - self.weights[j];
590            }
591            // Mean score: ∂/∂μ_j log p = r_j · Σ_j⁻¹ (y − μ_j).
592            // Covariance score (lower-tri entries): ∂/∂Σ_j contracted through
593            // the symmetric chain rule, r_j · ½ (Σ⁻¹ v vᵀ Σ⁻¹ − Σ⁻¹) with
594            // off-diagonal entries doubled for the symmetric parameterization.
595            for j in 0..self.k {
596                let prec_v = comp[j].precision_times_residual(row); // Σ⁻¹ (y − μ_j)
597                let mbo = mean_base + j * self.d;
598                for c in 0..self.d {
599                    score[mbo + c] = resp[j] * prec_v[c];
600                }
601                let cbo = cov_base + j * cov_per;
602                let mut idx = 0usize;
603                for a in 0..self.d {
604                    for b in 0..=a {
605                        let outer = prec_v[a] * prec_v[b];
606                        let prec_ab = comp[j].precision[[a, b]];
607                        let mut g = 0.5 * (outer - prec_ab);
608                        if a != b {
609                            // Off-diagonal entry appears twice in the symmetric
610                            // matrix, so its free-parameter derivative doubles.
611                            g *= 2.0;
612                        }
613                        score[cbo + idx] = resp[j] * g;
614                        idx += 1;
615                    }
616                }
617            }
618            // Accumulate outer product s_i s_iᵀ.
619            for r in 0..p {
620                let sr = score[r];
621                if sr == 0.0 {
622                    continue;
623                }
624                for c in 0..p {
625                    info[[r, c]] += sr * score[c];
626                }
627            }
628        }
629        // Symmetrize and add a unit-precision ridge `I`. This is the Hessian
630        // contribution of a standard-normal prior on the (natural) parameters,
631        // making the object a proper MAP observed-information `H = I_prior +
632        // Σ_i s_i s_iᵀ`. It guarantees SPD (well-defined `log|H|`) for any `n`
633        // and is a fixed prior, not a tuned knob — the same unit-information
634        // prior the rank-aware normalizer assumes when it credits each free
635        // parameter one `log(2π)` of integration volume.
636        for r in 0..p {
637            for c in (r + 1)..p {
638                let avg = 0.5 * (info[[r, c]] + info[[c, r]]);
639                info[[r, c]] = avg;
640                info[[c, r]] = avg;
641            }
642            info[[r, r]] += 1.0;
643        }
644        Ok(info)
645    }
646}
647
648/// Cached per-component Gaussian evaluator: mean, precision `Σ⁻¹`, and the
649/// log-normalizing constant `−½(d log 2π + log|Σ|)`.
650#[derive(Debug, Clone)]
651struct GaussianComponentEval {
652    mean: Array1<f64>,
653    precision: Array2<f64>,
654    log_norm: f64,
655    d: usize,
656}
657
658impl GaussianComponentEval {
659    fn new(d: usize) -> Self {
660        Self {
661            mean: Array1::zeros(d),
662            precision: Array2::eye(d),
663            log_norm: 0.0,
664            d,
665        }
666    }
667
668    fn factor(mean: ArrayView1<'_, f64>, cov: &Array2<f64>) -> Result<Self, String> {
669        let d = mean.len();
670        if cov.nrows() != d || cov.ncols() != d {
671            return Err(format!(
672                "mixture component covariance must be {d}x{d}, got {}x{}",
673                cov.nrows(),
674                cov.ncols()
675            ));
676        }
677        let (evals, evecs) = cov
678            .eigh(Side::Lower)
679            .map_err(|e| format!("mixture component covariance eigendecomposition failed: {e}"))?;
680        let mut log_det = 0.0_f64;
681        let mut inv_evals = Array1::<f64>::zeros(d);
682        for (idx, &ev) in evals.iter().enumerate() {
683            if !ev.is_finite() || ev <= 0.0 {
684                return Err(format!(
685                    "mixture component covariance is not SPD: eigenvalue {idx} is {ev:.3e}"
686                ));
687            }
688            log_det += ev.ln();
689            inv_evals[idx] = 1.0 / ev;
690        }
691        // Σ⁻¹ = V diag(1/λ) Vᵀ.
692        let mut precision = Array2::<f64>::zeros((d, d));
693        for a in 0..d {
694            for b in 0..d {
695                let mut acc = 0.0_f64;
696                for m in 0..d {
697                    acc += evecs[[a, m]] * inv_evals[m] * evecs[[b, m]];
698                }
699                precision[[a, b]] = acc;
700            }
701        }
702        let log_norm = -0.5 * (d as f64 * (2.0 * std::f64::consts::PI).ln() + log_det);
703        Ok(Self {
704            mean: mean.to_owned(),
705            precision,
706            log_norm,
707            d,
708        })
709    }
710
711    #[inline]
712    fn log_density(&self, y: ArrayView1<'_, f64>) -> f64 {
713        let pv = self.precision_times_residual(y);
714        let mut quad = 0.0_f64;
715        for c in 0..self.d {
716            quad += (y[c] - self.mean[c]) * pv[c];
717        }
718        self.log_norm - 0.5 * quad
719    }
720
721    /// `Σ⁻¹ (y − μ)`.
722    #[inline]
723    fn precision_times_residual(&self, y: ArrayView1<'_, f64>) -> Vec<f64> {
724        let mut out = vec![0.0_f64; self.d];
725        for a in 0..self.d {
726            let mut acc = 0.0_f64;
727            for b in 0..self.d {
728                acc += self.precision[[a, b]] * (y[b] - self.mean[b]);
729            }
730            out[a] = acc;
731        }
732        out
733    }
734}
735
736#[inline]
737fn log_sum_exp(terms: &[f64], max_term: f64) -> f64 {
738    if !max_term.is_finite() {
739        return f64::NEG_INFINITY;
740    }
741    let mut acc = 0.0_f64;
742    for &t in terms {
743        acc += (t - max_term).exp();
744    }
745    max_term + acc.ln()
746}
747
748/// Fit a `k`-component full-covariance Gaussian mixture by deterministic
749/// k-means++-style seeding (reusing the `terms::basis` farthest-point k-means,
750/// a pure function of the data — no clock randomness) followed by EM to the
751/// configured tolerance.
752///
753/// The fit is deterministic given `(data, k, config)`: the seed is the
754/// farthest-point/k-means center selection, EM is a deterministic map, so
755/// re-running yields the identical mixture.
756pub fn fit_gaussian_mixture(
757    data: ArrayView2<'_, f64>,
758    k: usize,
759    config: GaussianMixtureConfig,
760) -> Result<GaussianMixtureFit, String> {
761    let n = data.nrows();
762    let d = data.ncols();
763    if k == 0 {
764        return Err("gaussian mixture requires k >= 1".to_string());
765    }
766    if d == 0 {
767        return Err("gaussian mixture requires at least one column".to_string());
768    }
769    if k > n {
770        return Err(format!(
771            "gaussian mixture requested {k} components but data has {n} rows"
772        ));
773    }
774    // Deterministic k-means++-style seeding via the shared basis k-means
775    // (farthest-point init + Lloyd iterations). Fixed by construction.
776    let centers = gam_terms::basis::select_centers_by_strategy(
777        data,
778        &gam_terms::basis::CenterStrategy::KMeans {
779            num_centers: k,
780            max_iter: config.kmeans_max_iter,
781        },
782    )
783    .map_err(|e| format!("gaussian mixture k-means seeding failed: {e}"))?;
784    if centers.nrows() != k || centers.ncols() != d {
785        return Err(format!(
786            "gaussian mixture seeding returned {}x{} centers, expected {k}x{d}",
787            centers.nrows(),
788            centers.ncols()
789        ));
790    }
791
792    let mut means = centers;
793    // Seed covariances from the global data covariance (shared start).
794    let global_cov = data_covariance(data, config.covariance_floor);
795    let mut covariances = vec![global_cov; k];
796    let mut weights = Array1::<f64>::from_elem(k, 1.0 / k as f64);
797
798    let mut resp = Array2::<f64>::zeros((n, k));
799    let mut prev_mean_ll = f64::NEG_INFINITY;
800    let mut total_loglik = f64::NEG_INFINITY;
801    let mut iterations = 0usize;
802
803    for iter in 0..config.max_iter.max(1) {
804        iterations = iter + 1;
805        // E-step: responsibilities and total log-likelihood.
806        let mut comp = Vec::with_capacity(k);
807        for j in 0..k {
808            comp.push(GaussianComponentEval::factor(
809                means.row(j),
810                &covariances[j],
811            )?);
812        }
813        let log_w: Vec<f64> = weights
814            .iter()
815            .map(|w| w.max(f64::MIN_POSITIVE).ln())
816            .collect();
817        total_loglik = 0.0;
818        for i in 0..n {
819            let yrow = data.row(i);
820            let mut log_terms = vec![0.0_f64; k];
821            let mut max_term = f64::NEG_INFINITY;
822            for j in 0..k {
823                let lt = log_w[j] + comp[j].log_density(yrow);
824                log_terms[j] = lt;
825                if lt > max_term {
826                    max_term = lt;
827                }
828            }
829            let log_mix = log_sum_exp(&log_terms, max_term);
830            total_loglik += log_mix;
831            for j in 0..k {
832                resp[[i, j]] = (log_terms[j] - log_mix).exp();
833            }
834        }
835        let mean_ll = total_loglik / n as f64;
836        if iter > 0 {
837            let denom = prev_mean_ll.abs().max(1.0);
838            if (mean_ll - prev_mean_ll).abs() / denom <= config.loglik_tol {
839                break;
840            }
841        }
842        prev_mean_ll = mean_ll;
843
844        // M-step.
845        let mut nk = vec![0.0_f64; k];
846        for j in 0..k {
847            let mut sum = 0.0_f64;
848            for i in 0..n {
849                sum += resp[[i, j]];
850            }
851            nk[j] = sum.max(f64::MIN_POSITIVE);
852        }
853        for j in 0..k {
854            weights[j] = nk[j] / n as f64;
855            // Means.
856            let mut mu = Array1::<f64>::zeros(d);
857            for i in 0..n {
858                let r = resp[[i, j]];
859                if r == 0.0 {
860                    continue;
861                }
862                for c in 0..d {
863                    mu[c] += r * data[[i, c]];
864                }
865            }
866            mu.mapv_inplace(|v| v / nk[j]);
867            for c in 0..d {
868                means[[j, c]] = mu[c];
869            }
870            // Covariance with a fixed diagonal floor for SPD safety.
871            let mut cov = Array2::<f64>::zeros((d, d));
872            for i in 0..n {
873                let r = resp[[i, j]];
874                if r == 0.0 {
875                    continue;
876                }
877                for a in 0..d {
878                    let da = data[[i, a]] - mu[a];
879                    for b in 0..d {
880                        cov[[a, b]] += r * da * (data[[i, b]] - mu[b]);
881                    }
882                }
883            }
884            cov.mapv_inplace(|v| v / nk[j]);
885            for a in 0..d {
886                cov[[a, a]] += config.covariance_floor;
887            }
888            covariances[j] = cov;
889        }
890    }
891
892    Ok(GaussianMixtureFit {
893        weights,
894        means,
895        covariances,
896        k,
897        d,
898        n_obs: n,
899        loglik: total_loglik,
900        iterations,
901    })
902}
903
904/// Global data covariance with a fixed diagonal floor (used to seed EM).
905fn data_covariance(data: ArrayView2<'_, f64>, floor: f64) -> Array2<f64> {
906    let n = data.nrows();
907    let d = data.ncols();
908    let mut mean = Array1::<f64>::zeros(d);
909    for i in 0..n {
910        for c in 0..d {
911            mean[c] += data[[i, c]];
912        }
913    }
914    mean.mapv_inplace(|v| v / n.max(1) as f64);
915    let mut cov = Array2::<f64>::zeros((d, d));
916    for i in 0..n {
917        for a in 0..d {
918            let da = data[[i, a]] - mean[a];
919            for b in 0..d {
920                cov[[a, b]] += da * (data[[i, b]] - mean[b]);
921            }
922        }
923    }
924    let inv = 1.0 / (n.max(1) as f64);
925    cov.mapv_inplace(|v| v * inv);
926    for a in 0..d {
927        cov[[a, a]] += floor;
928    }
929    cov
930}
931
932// ---------------------------------------------------------------------------
933// Structured-union candidates (#907)
934// ---------------------------------------------------------------------------
935//
936// A *union* candidate is a small FIXED composite of named component structures
937// joined by a hard row-responsibility split. Unlike the discrete-mixture rung
938// (which is one free k-component Gaussian density), a union pins each component
939// to a specific generative STRUCTURE (a circle, a line, a point cluster) and
940// asks whether the data is better explained as the disjoint sum of those
941// structures than by any single pure rung.
942//
943// Each component is fit on its responsibility group as its own parametric
944// generative density and scored through the SAME rank-aware Laplace /
945// Tierney-Kadane normalizer used by the smooth rungs and the mixture rung:
946// `−V_c = loglik_c − ½ log|H_c| + ½ P_c log(2π)` with `H_c` the observed
947// empirical-Fisher (BHHH) information `I + Σ s_i s_iᵀ` at the component optimum
948// (`rank(S)=0`, fully likelihood-identified). The union's evidence is the SUM
949// `V = Σ_c V_c` (the components partition the rows, so their log-likelihoods
950// add and their Hessians are block-diagonal — `log|H| = Σ_c log|H_c|`). The
951// complexity price is the TOTAL free-parameter count across all components,
952// which is exactly what the summed `+ ½ Σ_c P_c log(2π)` normalizer charges.
953// A union is therefore strictly more expensive than either pure component, so
954// it can only win when the structured split buys enough likelihood to pay for
955// its extra parameters — the negative-control discipline of #907.
956
957/// The fixed ladder of structured-union composites. Deterministic and closed:
958/// open-ended structure search stays owned by #976's move set; these three are
959/// the only composites the topology race may select.
960#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
961pub enum UnionStructure {
962    /// Two circles (two well-separated periodic loops).
963    CircleCircle,
964    /// One circle plus one isolated point cluster (a loop with an outlier blob).
965    CirclePointCluster,
966    /// One line (anisotropic cluster) plus one isolated point cluster.
967    LineCluster,
968}
969
970/// The fixed structured-union ladder, in stable order.
971pub const UNION_STRUCTURE_LADDER: &[UnionStructure] = &[
972    UnionStructure::CircleCircle,
973    UnionStructure::CirclePointCluster,
974    UnionStructure::LineCluster,
975];
976
977/// The per-component generative structure a union pins each responsibility group
978/// to. `Line` and `PointCluster` share the full-covariance Gaussian density
979/// (a line is an anisotropic Gaussian — the covariance, not a different
980/// parameterization, is what distinguishes them); `Circle` is a genuinely
981/// different density on `(radius, angle)`.
982#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
983pub enum UnionComponentKind {
984    Circle,
985    Line,
986    PointCluster,
987}
988
989impl UnionStructure {
990    /// Stable display name, e.g. `"union_circle+circle"`.
991    pub const fn as_str(self) -> &'static str {
992        match self {
993            UnionStructure::CircleCircle => "union_circle+circle",
994            UnionStructure::CirclePointCluster => "union_circle+cluster",
995            UnionStructure::LineCluster => "union_line+cluster",
996        }
997    }
998
999    /// The fixed ordered component structures of this union.
1000    pub const fn components(self) -> &'static [UnionComponentKind] {
1001        match self {
1002            UnionStructure::CircleCircle => {
1003                &[UnionComponentKind::Circle, UnionComponentKind::Circle]
1004            }
1005            UnionStructure::CirclePointCluster => {
1006                &[UnionComponentKind::Circle, UnionComponentKind::PointCluster]
1007            }
1008            UnionStructure::LineCluster => {
1009                &[UnionComponentKind::Line, UnionComponentKind::PointCluster]
1010            }
1011        }
1012    }
1013
1014    /// Number of components (= the responsibility-split order `m`).
1015    pub const fn num_components(self) -> usize {
1016        self.components().len()
1017    }
1018}
1019
1020/// One fitted component of a union: its pinned structure, the rows it owns
1021/// (after the hard responsibility split), its free-parameter count, and its
1022/// rank-aware Laplace negative-log-evidence on the common scale.
1023#[derive(Debug, Clone)]
1024pub struct UnionComponentFit {
1025    pub kind: UnionComponentKind,
1026    pub row_count: usize,
1027    pub num_parameters: usize,
1028    pub negative_log_evidence: f64,
1029}
1030
1031/// A fitted structured-union candidate: the composite kind, the per-component
1032/// fits, the SUMMED rank-aware Laplace negative-log-evidence, and the TOTAL
1033/// free-parameter count across components (the complexity price).
1034#[derive(Debug, Clone)]
1035pub struct UnionStructureFit {
1036    pub structure: UnionStructure,
1037    pub components: Vec<UnionComponentFit>,
1038    /// `Σ_c V_c` — summed rank-aware Laplace negative-log-evidence (lower wins).
1039    pub negative_log_evidence: f64,
1040    /// `Σ_c P_c` — total free-parameter count across components.
1041    pub total_parameters: usize,
1042}
1043
1044/// Hard responsibility split of `0..n` into `m` groups by argmax of the
1045/// deterministic `m`-component Gaussian-mixture responsibilities. Reuses the
1046/// mixture rung's seeding + EM so the split is a pure function of the data and
1047/// `m` (no clock). Returns one row-index vector per component.
1048pub fn union_responsibility_split(
1049    data: ArrayView2<'_, f64>,
1050    m: usize,
1051    config: GaussianMixtureConfig,
1052) -> Result<Vec<Vec<usize>>, String> {
1053    let n = data.nrows();
1054    if m == 0 {
1055        return Err("union split requires at least one component".to_string());
1056    }
1057    if m > n {
1058        return Err(format!(
1059            "union split requested {m} groups but data has {n} rows"
1060        ));
1061    }
1062    if m == 1 {
1063        return Ok(vec![(0..n).collect()]);
1064    }
1065    let fit = fit_gaussian_mixture(data, m, config)?;
1066    let mut groups: Vec<Vec<usize>> = vec![Vec::new(); m];
1067    // Hard assignment by argmax per-component log responsibility.
1068    let mut comp = Vec::with_capacity(m);
1069    for j in 0..m {
1070        comp.push(GaussianComponentEval::factor(
1071            fit.means.row(j),
1072            &fit.covariances[j],
1073        )?);
1074    }
1075    let log_w: Vec<f64> = fit
1076        .weights
1077        .iter()
1078        .map(|w| w.max(f64::MIN_POSITIVE).ln())
1079        .collect();
1080    for i in 0..n {
1081        let row = data.row(i);
1082        let mut best_j = 0usize;
1083        let mut best_lt = f64::NEG_INFINITY;
1084        for j in 0..m {
1085            let lt = log_w[j] + comp[j].log_density(row);
1086            if lt > best_lt {
1087                best_lt = lt;
1088                best_j = j;
1089            }
1090        }
1091        groups[best_j].push(i);
1092    }
1093    Ok(groups)
1094}
1095
1096/// Fit one structured-union candidate: hard-split the rows into one group per
1097/// component, fit each component's pinned density, and SUM the rank-aware
1098/// Laplace negative-log-evidence. The complexity price is the total
1099/// free-parameter count across components.
1100///
1101/// Returns an error if any component group is too small to identify its
1102/// structure (so an over-priced or non-identifiable composite simply does not
1103/// enter the race rather than scoring spuriously well).
1104pub fn fit_union_structure(
1105    data: ArrayView2<'_, f64>,
1106    structure: UnionStructure,
1107    config: GaussianMixtureConfig,
1108) -> Result<UnionStructureFit, String> {
1109    let comps = structure.components();
1110    let m = comps.len();
1111    let groups = union_responsibility_split(data, m, config)?;
1112    let mut fits = Vec::with_capacity(m);
1113    let mut total_nle = 0.0_f64;
1114    let mut total_parameters = 0usize;
1115    for (kind, rows) in comps.iter().zip(groups.iter()) {
1116        let group = gather_union_rows(data, rows);
1117        let (nle, p) = fit_union_component(group.view(), *kind, config)?;
1118        if !nle.is_finite() {
1119            return Err(format!(
1120                "union {} component {:?} produced non-finite evidence",
1121                structure.as_str(),
1122                kind
1123            ));
1124        }
1125        total_nle += nle;
1126        total_parameters += p;
1127        fits.push(UnionComponentFit {
1128            kind: *kind,
1129            row_count: rows.len(),
1130            num_parameters: p,
1131            negative_log_evidence: nle,
1132        });
1133    }
1134    Ok(UnionStructureFit {
1135        structure,
1136        components: fits,
1137        negative_log_evidence: total_nle,
1138        total_parameters,
1139    })
1140}
1141
1142/// Fit the whole fixed union ladder and rank in-class by summed rank-aware
1143/// Laplace evidence (lower wins). Composites that fail to fit (e.g. a group too
1144/// small to identify a circle) are skipped. Returns the fitted ladder sorted
1145/// best-first.
1146pub fn fit_union_ladder(
1147    data: ArrayView2<'_, f64>,
1148    config: GaussianMixtureConfig,
1149) -> Result<Vec<UnionStructureFit>, String> {
1150    let mut fits = Vec::new();
1151    let mut errors = Vec::new();
1152    for &structure in UNION_STRUCTURE_LADDER {
1153        match fit_union_structure(data, structure, config) {
1154            Ok(fit) => fits.push(fit),
1155            Err(e) => errors.push(format!("{}: {e}", structure.as_str())),
1156        }
1157    }
1158    if fits.is_empty() {
1159        return Err(format!(
1160            "union ladder produced no fittable composites{}",
1161            if errors.is_empty() {
1162                String::new()
1163            } else {
1164                format!(" ({})", errors.join("; "))
1165            }
1166        ));
1167    }
1168    let ranked = rank_priority_candidates(
1169        fits.into_iter()
1170            .enumerate()
1171            .map(|(idx, row)| {
1172                let score = row.negative_log_evidence;
1173                let tie = row.total_parameters; // cheaper composite wins ties
1174                PriorityCandidate::new(row, idx, score, tie)
1175            })
1176            .collect(),
1177    )
1178    .into_iter()
1179    .map(|row| row.item)
1180    .collect::<Vec<_>>();
1181    Ok(ranked)
1182}
1183
1184fn gather_union_rows(data: ArrayView2<'_, f64>, idx: &[usize]) -> Array2<f64> {
1185    let d = data.ncols();
1186    let mut out = Array2::<f64>::zeros((idx.len(), d));
1187    for (r, &i) in idx.iter().enumerate() {
1188        for c in 0..d {
1189            out[[r, c]] = data[[i, c]];
1190        }
1191    }
1192    out
1193}
1194
1195/// Fit a single union component density on its responsibility group and return
1196/// `(rank_aware_negative_log_evidence, free_parameter_count)`. `Line` and
1197/// `PointCluster` use the full-covariance Gaussian density (a single mixture
1198/// component); `Circle` uses the radius/angle generative density below.
1199fn fit_union_component(
1200    group: ArrayView2<'_, f64>,
1201    kind: UnionComponentKind,
1202    config: GaussianMixtureConfig,
1203) -> Result<(f64, usize), String> {
1204    match kind {
1205        UnionComponentKind::Line | UnionComponentKind::PointCluster => {
1206            // A single full-covariance Gaussian is the k=1 mixture: reuse its
1207            // exact rank-aware Laplace evidence so a union component is on the
1208            // identical scale as a mixture component.
1209            if group.nrows() < group.ncols() + 1 {
1210                return Err(format!(
1211                    "union gaussian component needs >= {} rows, got {}",
1212                    group.ncols() + 1,
1213                    group.nrows()
1214                ));
1215            }
1216            let fit = fit_gaussian_mixture(group, 1, config)?;
1217            let nle = fit.laplace_negative_log_evidence(group)?;
1218            Ok((nle, fit.num_free_parameters()))
1219        }
1220        UnionComponentKind::Circle => fit_circle_component_evidence(group, config),
1221    }
1222}
1223
1224/// Rank-aware Laplace negative-log-evidence of a 2-D *circle* component: data is
1225/// modelled as `(r, θ)` with `r ~ N(ρ, σ_r²)` around a fitted center+radius and
1226/// `θ` uniform on the circle. Free parameters: center `(cx, cy)`, radius `ρ`,
1227/// radial variance `σ_r²` — `P = 4`. The angle is an ancillary uniform with no
1228/// free parameter (it carries `−log(2π r)` of density). The Hessian is the
1229/// observed empirical-Fisher `I + Σ s_i s_iᵀ` in `(cx, cy, ρ, log σ_r²)`
1230/// coordinates, fed through the SAME [`laplace_evidence`] entry point.
1231fn fit_circle_component_evidence(
1232    group: ArrayView2<'_, f64>,
1233    config: GaussianMixtureConfig,
1234) -> Result<(f64, usize), String> {
1235    let d = group.ncols();
1236    if d != 2 {
1237        return Err(format!(
1238            "union circle component requires 2-D data, got {d} columns"
1239        ));
1240    }
1241    let n = group.nrows();
1242    let p = 4usize; // cx, cy, radius, radial-variance
1243    if n < p + 1 {
1244        return Err(format!(
1245            "union circle component needs >= {} rows, got {n}",
1246            p + 1
1247        ));
1248    }
1249    // Center = data centroid; radius = mean distance to centroid; radial
1250    // variance = mean squared radial residual (floored). This is the algebraic
1251    // circle-fit optimum for the isotropic radial-Gaussian model and is a pure
1252    // function of the data.
1253    let mut cx = 0.0_f64;
1254    let mut cy = 0.0_f64;
1255    for i in 0..n {
1256        cx += group[[i, 0]];
1257        cy += group[[i, 1]];
1258    }
1259    cx /= n as f64;
1260    cy /= n as f64;
1261    let mut radii = vec![0.0_f64; n];
1262    let mut radius = 0.0_f64;
1263    for i in 0..n {
1264        let dx = group[[i, 0]] - cx;
1265        let dy = group[[i, 1]] - cy;
1266        let r = (dx * dx + dy * dy).sqrt();
1267        radii[i] = r;
1268        radius += r;
1269    }
1270    radius /= n as f64;
1271    let mut var_r = 0.0_f64;
1272    for &r in &radii {
1273        let e = r - radius;
1274        var_r += e * e;
1275    }
1276    var_r = (var_r / n as f64).max(config.covariance_floor);
1277    let inv_var = 1.0 / var_r;
1278    // Total log-likelihood: Σ_i [ −½ log(2π σ_r²) − (r_i−ρ)²/(2σ_r²)
1279    //                             − log(2π r_i) ]  (radial Gaussian × uniform θ).
1280    let mut loglik = 0.0_f64;
1281    let log_2pi = (2.0 * std::f64::consts::PI).ln();
1282    for &r in &radii {
1283        let e = r - radius;
1284        let radial = -0.5 * (log_2pi + var_r.ln()) - 0.5 * e * e * inv_var;
1285        let angular = -(log_2pi + r.max(f64::MIN_POSITIVE).ln());
1286        loglik += radial + angular;
1287    }
1288    // Observed empirical-Fisher in (cx, cy, ρ, s) with s = log σ_r².
1289    // Per-row scores:
1290    //   ∂/∂cx log = (e/σ_r²) · (−dx/r)            (r decreases as center moves +x toward point)
1291    //   ∂/∂cy log = (e/σ_r²) · (−dy/r)
1292    //   ∂/∂ρ  log = e/σ_r²
1293    //   ∂/∂s  log = −½ + e²/(2σ_r²)               (s = log σ_r²)
1294    let mut info = Array2::<f64>::zeros((p, p));
1295    let mut score = [0.0_f64; 4];
1296    for i in 0..n {
1297        let dx = group[[i, 0]] - cx;
1298        let dy = group[[i, 1]] - cy;
1299        let r = radii[i].max(f64::MIN_POSITIVE);
1300        let e = radii[i] - radius;
1301        let ee = e * inv_var;
1302        score[0] = ee * (-dx / r);
1303        score[1] = ee * (-dy / r);
1304        score[2] = ee;
1305        score[3] = -0.5 + 0.5 * e * e * inv_var;
1306        for a in 0..p {
1307            let sa = score[a];
1308            if sa == 0.0 {
1309                continue;
1310            }
1311            for b in 0..p {
1312                info[[a, b]] += sa * score[b];
1313            }
1314        }
1315    }
1316    // Symmetrize and add the unit-information prior ridge `I` (same fixed prior
1317    // as the mixture path) so `log|H|` is well-defined for any `n`.
1318    for a in 0..p {
1319        for b in (a + 1)..p {
1320            let avg = 0.5 * (info[[a, b]] + info[[b, a]]);
1321            info[[a, b]] = avg;
1322            info[[b, a]] = avg;
1323        }
1324        info[[a, a]] += 1.0;
1325    }
1326    let apply_info = |x: &[f64]| -> Vec<f64> {
1327        let mut out = vec![0.0_f64; p];
1328        for r in 0..p {
1329            let mut acc = 0.0_f64;
1330            for c in 0..p {
1331                acc += info[[r, c]] * x[c];
1332            }
1333            out[r] = acc;
1334        }
1335        out
1336    };
1337    let hvp = EvidenceHvpLogDet {
1338        dim: p,
1339        apply: &apply_info,
1340    };
1341    let v = laplace_evidence(EvidenceLogDetSource::Hvp(hvp), 0.0, -loglik, p as f64, 0.0);
1342    if !v.is_finite() {
1343        return Err("union circle component Laplace evidence is not finite".to_string());
1344    }
1345    Ok((v, p))
1346}
1347
1348/// A fitted union component as a *predictive density* (not just an evidence
1349/// scalar): either a full-covariance Gaussian (`Line`/`PointCluster`) or the
1350/// radial-Gaussian×uniform-angle circle density. Carries the mixing weight
1351/// `π_c = row_count_c / n_train` so a union can be evaluated as the soft mixture
1352/// `Σ_c π_c p_c(y)` at held-out rows for cross-class stacking.
1353#[derive(Debug, Clone)]
1354enum UnionComponentDensity {
1355    Gaussian {
1356        log_weight: f64,
1357        eval: GaussianComponentEval,
1358    },
1359    Circle {
1360        log_weight: f64,
1361        center: [f64; 2],
1362        radius: f64,
1363        var_r: f64,
1364    },
1365}
1366
1367impl UnionComponentDensity {
1368    /// `log π_c + log p_c(y)` for one eval row.
1369    fn weighted_log_density(&self, y: ArrayView1<'_, f64>) -> f64 {
1370        match self {
1371            UnionComponentDensity::Gaussian { log_weight, eval } => {
1372                log_weight + eval.log_density(y)
1373            }
1374            UnionComponentDensity::Circle {
1375                log_weight,
1376                center,
1377                radius,
1378                var_r,
1379            } => {
1380                let dx = y[0] - center[0];
1381                let dy = y[1] - center[1];
1382                let r = (dx * dx + dy * dy).sqrt();
1383                let log_2pi = (2.0 * std::f64::consts::PI).ln();
1384                let e = r - radius;
1385                let radial = -0.5 * (log_2pi + var_r.ln()) - 0.5 * e * e / var_r;
1386                let angular = -(log_2pi + r.max(f64::MIN_POSITIVE).ln());
1387                log_weight + radial + angular
1388            }
1389        }
1390    }
1391}
1392
1393/// Fit each union component's *density* on the training rows (hard
1394/// responsibility split) so the composite can be evaluated as the soft mixture
1395/// `Σ_c π_c p_c(y)` at new rows. Mixing weights are the training row shares.
1396fn fit_union_component_densities(
1397    train: ArrayView2<'_, f64>,
1398    structure: UnionStructure,
1399    config: GaussianMixtureConfig,
1400) -> Result<Vec<UnionComponentDensity>, String> {
1401    let comps = structure.components();
1402    let m = comps.len();
1403    let groups = union_responsibility_split(train, m, config)?;
1404    let n_train = train.nrows().max(1) as f64;
1405    let mut out = Vec::with_capacity(m);
1406    for (kind, rows) in comps.iter().zip(groups.iter()) {
1407        if rows.is_empty() {
1408            return Err(format!(
1409                "union {} held-out density: empty component group",
1410                structure.as_str()
1411            ));
1412        }
1413        let log_weight = (rows.len() as f64 / n_train).max(f64::MIN_POSITIVE).ln();
1414        let group = gather_union_rows(train, rows);
1415        match kind {
1416            UnionComponentKind::Line | UnionComponentKind::PointCluster => {
1417                if group.nrows() < group.ncols() + 1 {
1418                    return Err(format!(
1419                        "union gaussian component density needs >= {} rows, got {}",
1420                        group.ncols() + 1,
1421                        group.nrows()
1422                    ));
1423                }
1424                let fit = fit_gaussian_mixture(group.view(), 1, config)?;
1425                let eval = GaussianComponentEval::factor(fit.means.row(0), &fit.covariances[0])?;
1426                out.push(UnionComponentDensity::Gaussian { log_weight, eval });
1427            }
1428            UnionComponentKind::Circle => {
1429                let d = group.ncols();
1430                if d != 2 {
1431                    return Err(format!(
1432                        "union circle component density requires 2-D data, got {d} columns"
1433                    ));
1434                }
1435                let n = group.nrows();
1436                if n < 5 {
1437                    return Err(format!(
1438                        "union circle component density needs >= 5 rows, got {n}"
1439                    ));
1440                }
1441                let mut cx = 0.0_f64;
1442                let mut cy = 0.0_f64;
1443                for i in 0..n {
1444                    cx += group[[i, 0]];
1445                    cy += group[[i, 1]];
1446                }
1447                cx /= n as f64;
1448                cy /= n as f64;
1449                let mut radius = 0.0_f64;
1450                let mut radii = vec![0.0_f64; n];
1451                for i in 0..n {
1452                    let dx = group[[i, 0]] - cx;
1453                    let dy = group[[i, 1]] - cy;
1454                    let r = (dx * dx + dy * dy).sqrt();
1455                    radii[i] = r;
1456                    radius += r;
1457                }
1458                radius /= n as f64;
1459                let mut var_r = 0.0_f64;
1460                for &r in &radii {
1461                    let e = r - radius;
1462                    var_r += e * e;
1463                }
1464                var_r = (var_r / n as f64).max(config.covariance_floor);
1465                out.push(UnionComponentDensity::Circle {
1466                    log_weight,
1467                    center: [cx, cy],
1468                    radius,
1469                    var_r,
1470                });
1471            }
1472        }
1473    }
1474    Ok(out)
1475}
1476
1477/// Per-point held-out log predictive density of a structured-union candidate:
1478/// fit the component densities on `train` and score each row of `eval` as the
1479/// soft mixture `log Σ_c π_c p_c(y)`. This is the cross-class stacking column
1480/// source for a union (the analogue of [`GaussianMixtureFit::per_point_log_density`]).
1481pub fn union_per_point_log_density(
1482    train: ArrayView2<'_, f64>,
1483    eval: ArrayView2<'_, f64>,
1484    structure: UnionStructure,
1485    config: GaussianMixtureConfig,
1486) -> Result<Array1<f64>, String> {
1487    if train.ncols() != eval.ncols() {
1488        return Err(format!(
1489            "union held-out density: train has {} columns, eval has {}",
1490            train.ncols(),
1491            eval.ncols()
1492        ));
1493    }
1494    let densities = fit_union_component_densities(train, structure, config)?;
1495    let mut out = Array1::<f64>::zeros(eval.nrows());
1496    let mut terms = vec![f64::NEG_INFINITY; densities.len()];
1497    for i in 0..eval.nrows() {
1498        let row = eval.row(i);
1499        let mut max_term = f64::NEG_INFINITY;
1500        for (c, dens) in densities.iter().enumerate() {
1501            let lt = dens.weighted_log_density(row);
1502            terms[c] = lt;
1503            if lt > max_term {
1504                max_term = lt;
1505            }
1506        }
1507        out[i] = log_sum_exp(&terms, max_term);
1508    }
1509    Ok(out)
1510}
1511
1512/// One fitted model in a REML/LAML evidence comparison.
1513#[derive(Clone, Debug)]
1514pub struct RemlCandidate {
1515    pub index: usize,
1516    pub name: String,
1517    /// Minimised REML/LAML cost. Lower is better. This is the model's reported
1518    /// evidence headline (`Model.evidence`), kept verbatim in the score table.
1519    pub score: f64,
1520    pub edf: Option<f64>,
1521    /// Log-likelihood at the converged mode, on the engine's
1522    /// constants-omitted scale (same as `gam_inference::model_comparison`).
1523    /// Present when the fit carries it; `None` for legacy payloads.
1524    pub log_lik: Option<f64>,
1525    /// Response-family tag (e.g. "gaussian", "gamma", "binomial"). Carried so
1526    /// `compare_reml_fits` can REFUSE to rank fits whose REML/LAML scores are on
1527    /// incomparable base measures (a cross-family comparison is meaningless;
1528    /// #1384). `None` for legacy payloads that did not record it — those are not
1529    /// guarded (back-compatible), but every current FFI candidate carries it.
1530    pub family: Option<String>,
1531    /// Number of observations the fit was trained on. Carried so
1532    /// `compare_reml_fits` can REFUSE to rank fits made on a different number of
1533    /// observations (hence different data): `−2·loglik` and the REML/LAML
1534    /// evidence grow with `n`, so a score difference between two fits with
1535    /// different `n` is not a Bayes factor — the same incomparability the family
1536    /// guard already rejects. `None` for payloads that do not record it (legacy /
1537    /// O(n) scan smoothers), which the guard treats as unconstrained.
1538    pub n_obs: Option<usize>,
1539}
1540
1541impl RemlCandidate {
1542    /// Cost used to RANK candidates and pick the winner.
1543    ///
1544    /// The REML/LAML marginal-likelihood evidence headline (`score`) does NOT
1545    /// reliably Occam-penalise an added pure-noise smooth: on `y ~ s(x)` vs
1546    /// `y ~ s(x) + s(z)` with `z ⟂ y`, the augmented model's evidence is
1547    /// *lower* (apparently better) by a few nats on essentially every dataset,
1548    /// because the Gaussian REML Occam pair `½(log|H| − log|S|₊)` collapses
1549    /// toward zero for a finite-`λ̂` null term while that term still spends a
1550    /// few effective degrees of freedom fitting noise (issue #1362).
1551    ///
1552    /// The conditional AIC `−2ℓ + 2·edf` prices exactly those spent degrees of
1553    /// freedom and discriminates correctly: it penalises the noise smooth
1554    /// (Δ ≈ +15 nats) yet rewards a genuinely relevant smooth (Δ ≈ −650),
1555    /// preserving power. We therefore rank on the conditional AIC whenever both
1556    /// the log-likelihood and the effective degrees of freedom are available,
1557    /// and fall back to the raw evidence headline otherwise. The reported
1558    /// `score_table` still carries the unaltered evidence (`reml_score`), so
1559    /// `Model.evidence` / `bayes_factor_vs` stay consistent with the table.
1560    pub fn ranking_score(&self) -> f64 {
1561        match (self.log_lik, self.edf) {
1562            (Some(log_lik), Some(edf)) if log_lik.is_finite() && edf.is_finite() => {
1563                -2.0 * log_lik + 2.0 * edf
1564            }
1565            _ => self.score,
1566        }
1567    }
1568}
1569
1570#[derive(Clone, Debug)]
1571pub struct RemlComparison {
1572    pub ranking: Vec<RankedRow>,
1573    pub winner: String,
1574    pub evidence_summary: String,
1575    pub score_table: Vec<ScoreRow>,
1576}
1577
1578#[derive(Clone, Debug)]
1579pub struct RankedRow {
1580    pub name: String,
1581    pub score: f64,
1582    /// Cost gap from the winning model on the SAME scale used to order the
1583    /// ranking (`ranking_score`, the Occam-penalised conditional AIC where
1584    /// available, issue #1362). The winner is `argmin ranking_score`, so this
1585    /// is `>= 0` for every row by construction — it never contradicts the
1586    /// declared winner (issue #1465). `score` still carries the raw REML/LAML
1587    /// evidence so it stays consistent with `Model.evidence`.
1588    pub delta: f64,
1589    /// Bayes factor of the winner over this row on the ranking scale,
1590    /// `exp(delta) >= 1` (issue #1465).
1591    pub bayes_factor: f64,
1592    pub edf: Option<f64>,
1593}
1594
1595#[derive(Clone, Debug)]
1596pub struct ScoreRow {
1597    pub name: String,
1598    pub reml_score: f64,
1599    pub delta_reml: f64,
1600    pub bayes_factor_best_over_model: f64,
1601    pub effective_dof: Option<f64>,
1602}
1603
1604/// Log Bayes factor of model `a` over model `b` from minimised REML/LAML costs.
1605#[inline]
1606pub fn log_bayes_factor(reml_score_a: f64, reml_score_b: f64) -> f64 {
1607    reml_score_b - reml_score_a
1608}
1609
1610/// Compare fitted models by the single evidence ordering contract used by
1611/// topology ranking and seed screening: lower finite cost wins, with stable
1612/// original-order tie handling.
1613pub fn compare_reml_fits(mut candidates: Vec<RemlCandidate>) -> Result<RemlComparison, String> {
1614    if candidates.is_empty() {
1615        return Err("compare_models requires at least one fit".to_string());
1616    }
1617    // Fail-loud comparability guard (#1384): REML/LAML evidence scores are only
1618    // comparable across fits of the SAME response family — a Gaussian score and
1619    // a Gamma score live on different log-density base measures, so their
1620    // difference is not a Bayes factor. Ranking them anyway returns a confident
1621    // but meaningless winner. Refuse when two candidates carry DIFFERENT family
1622    // tags. Candidates with no family tag (`None`, legacy payloads) are not
1623    // constrained, so this never spuriously rejects an older saved model.
1624    {
1625        let mut seen_family: Option<&str> = None;
1626        for cand in &candidates {
1627            if let Some(fam) = cand.family.as_deref() {
1628                match seen_family {
1629                    None => seen_family = Some(fam),
1630                    Some(prev) if prev != fam => {
1631                        return Err(format!(
1632                            "compare_models: cannot compare fits of different response families                              ('{prev}' vs '{fam}'); their REML/LAML evidence scores are on                              incomparable base measures. Compare models fit to the same response                              under the same family."
1633                        ));
1634                    }
1635                    Some(_) => {}
1636                }
1637            }
1638        }
1639    }
1640    // Fail-loud comparability guard (#1384 sibling): AIC / REML-LAML evidence are
1641    // only comparable across fits of the SAME response on the SAME observations.
1642    // `−2·loglik` (and the marginal-likelihood headline) grow with the number of
1643    // observations `n`, so two fits with different `n` live on incomparable
1644    // scales and their score gap is not a Bayes factor — comparing an n=500 and
1645    // an n=100 fit of the same DGP otherwise declares the n=100 model the winner
1646    // purely because fewer points give a less-negative total log-likelihood.
1647    // Refuse when two candidates carry DIFFERENT observation counts. Candidates
1648    // with no count (`None`, legacy / O(n) scan payloads) are unconstrained, so
1649    // this never spuriously rejects a fit that simply did not record `n`.
1650    {
1651        let mut seen_n: Option<usize> = None;
1652        for cand in &candidates {
1653            if let Some(n) = cand.n_obs {
1654                match seen_n {
1655                    None => seen_n = Some(n),
1656                    Some(prev) if prev != n => {
1657                        return Err(format!(
1658                            "compare_models: cannot compare fits made on a different number of \
1659                             observations (n={prev} vs n={n}); AIC / REML-LAML evidence scales \
1660                             with the sample size, so their score difference is not a Bayes \
1661                             factor. Compare models fit to the same response on the same data."
1662                        ));
1663                    }
1664                    Some(_) => {}
1665                }
1666            }
1667        }
1668    }
1669    candidates = rank_priority_candidates(
1670        candidates
1671            .into_iter()
1672            .enumerate()
1673            .map(|(idx, row)| {
1674                // Rank/winner on the Occam-penalised conditional AIC where it is
1675                // available (issue #1362); falls back to the raw evidence score.
1676                let ranking = row.ranking_score();
1677                PriorityCandidate::new(row, idx, ranking, 0)
1678            })
1679            .collect(),
1680    )
1681    .into_iter()
1682    .map(|row| row.item)
1683    .collect();
1684
1685    let winner = candidates[0].name.clone();
1686    // The ranking `delta` / `bayes_factor` must be measured on the SAME scale
1687    // that orders the table — the `ranking_score` (Occam-penalised conditional
1688    // AIC where available, issue #1362). `candidates[0]` is the winner =
1689    // `argmin ranking_score`, so its ranking score IS the minimum; every row's
1690    // ranking-scale gap is then `>= 0` and its Bayes factor `>= 1`, never
1691    // contradicting the declared winner (issue #1465). Computing these against
1692    // the AIC winner's *raw REML* — which is not the minimum raw REML once AIC
1693    // and REML disagree — produced negative deltas and Bayes factors < 1 for
1694    // non-winner rows.
1695    let best_ranking_score = candidates[0].ranking_score();
1696    // The raw-REML `score_table` stays on the raw evidence scale (consistent
1697    // with `Model.evidence` / `bayes_factor_vs`), but is referenced to the
1698    // genuine minimum raw REML so its best-over-model Bayes factors are also
1699    // coherent (`>= 1`), rather than to whichever row happens to sit at index 0.
1700    let best_raw_score = candidates
1701        .iter()
1702        .map(|c| c.score)
1703        .fold(f64::INFINITY, f64::min);
1704    let mut ranking = Vec::with_capacity(candidates.len());
1705    let mut score_table = Vec::with_capacity(candidates.len());
1706    for row in &candidates {
1707        let delta = log_bayes_factor(best_ranking_score, row.ranking_score());
1708        let bayes_factor = delta.exp();
1709        let delta_reml = log_bayes_factor(best_raw_score, row.score);
1710        ranking.push(RankedRow {
1711            name: row.name.clone(),
1712            score: row.score,
1713            delta,
1714            bayes_factor,
1715            edf: row.edf,
1716        });
1717        score_table.push(ScoreRow {
1718            name: row.name.clone(),
1719            reml_score: row.score,
1720            delta_reml,
1721            bayes_factor_best_over_model: delta_reml.exp(),
1722            effective_dof: row.edf,
1723        });
1724    }
1725    // The winner is decided by `ranking_score` (the Occam-penalised conditional
1726    // AIC where available, issue #1362), which can disagree in sign with the raw
1727    // evidence Bayes factor for a noise-augmented model. Summarise the actual
1728    // decision margin so the headline never contradicts the chosen winner.
1729    let evidence_summary = if let Some(runner_up) = candidates.get(1) {
1730        let margin = runner_up.ranking_score() - candidates[0].ranking_score();
1731        format!(
1732            "{} wins by Bayes factor {} over {}",
1733            winner,
1734            format_bayes_factor(margin),
1735            runner_up.name
1736        )
1737    } else {
1738        format!("{winner} (single fit; no comparison)")
1739    };
1740    Ok(RemlComparison {
1741        ranking,
1742        winner,
1743        evidence_summary,
1744        score_table,
1745    })
1746}
1747
1748pub fn format_bayes_factor(log_bf: f64) -> String {
1749    if !log_bf.is_finite() {
1750        return "inf".to_string();
1751    }
1752    if log_bf.abs() >= std::f64::consts::LN_10 * 3.0 {
1753        return format!("1e{:+.1}", log_bf / std::f64::consts::LN_10);
1754    }
1755    format_three_significant(log_bf.exp())
1756}
1757
1758pub fn format_three_significant(value: f64) -> String {
1759    if value == 0.0 {
1760        return "0".to_string();
1761    }
1762    if !value.is_finite() {
1763        return format!("{value}");
1764    }
1765    let exponent = value.abs().log10().floor() as i32;
1766    if exponent >= 3 {
1767        return format!("{value:.2e}");
1768    }
1769    let decimals = (2 - exponent).max(0) as usize;
1770    let scale = 10f64.powi(decimals as i32);
1771    let rounded = (value * scale).abs().round() / scale * value.signum();
1772    format!("{rounded:.decimals$}")
1773}
1774
1775impl Default for TopologySelectOptions {
1776    fn default() -> Self {
1777        Self {
1778            tie_tolerance: 1e-3,
1779            score_scale: TopologyScoreScale::PerObservation,
1780        }
1781    }
1782}
1783
1784// ---------------------------------------------------------------------------
1785// Laplace evidence
1786// ---------------------------------------------------------------------------
1787
1788/// Single canonical Laplace evidence at the inner-loop fixed point.
1789///
1790/// Returns negative log evidence:
1791///
1792/// ```text
1793/// V(ρ, T) = F(β*, u*; ρ, T)
1794///         + 0.5 log|H|
1795///         - 0.5 log|S_pen(ρ)|+
1796///         - 0.5 (dim(H) - rank(S_pen)) log(2π).
1797/// ```
1798///
1799/// The last term is the rank-aware Tierney-Kadane normalizer:
1800/// `log p(y|T) ≈ -V`, with `0.5 log|2πH⁻¹| - 0.5 log|2πS⁻¹|`.
1801///
1802/// The `H` log-determinant is computed from the arrow factorization
1803///
1804/// ```text
1805/// log|H| = Σ_i log|H_uu_i| + log|A|
1806/// ```
1807///
1808/// (proposal §3.4 / §7) using the **undamped** per-row Cholesky factors
1809/// `cache.htt_factors_undamped` and the **undamped** Schur factor.
1810///
1811/// `penalty_log_det` is `log|S_pen(ρ)|+` — the prior penalty
1812/// pseudo-logdet from `crate::reml::penalty_logdet` (proposal
1813/// §3.6). It must NOT be confused with the arrow Schur log-det, which
1814/// this function recomputes internally from `logdet_source`.
1815///
1816/// `residual_objective` is `F(β*, u*; ρ, T)` at the inner optimum. The
1817/// envelope theorem (proposal §3.2) makes this the only `F`-related
1818/// contribution.
1819///
1820/// `effective_dim` is `dim(H)` after constraints/projections and
1821/// `penalty_rank` is `rank(S_pen)`. Their difference is the unpenalized
1822/// nullspace dimension that remains in the Laplace integral.
1823///
1824/// # Errors
1825///
1826/// Returns `f64::NAN` if the exact factor path is incoherent and no HVP
1827/// fallback is supplied, or if the supplied dimensions are non-finite.
1828pub fn laplace_evidence(
1829    logdet_source: EvidenceLogDetSource<'_>,
1830    penalty_log_det: f64,
1831    residual_objective: f64,
1832    effective_dim: f64,
1833    penalty_rank: f64,
1834) -> f64 {
1835    if !(effective_dim.is_finite() && penalty_rank.is_finite()) {
1836        return f64::NAN;
1837    }
1838    let log_det_h = match evidence_hessian_log_det(logdet_source) {
1839        Ok(v) => v,
1840        Err(_) => return f64::NAN,
1841    };
1842    let null_dim = effective_dim - penalty_rank;
1843    if !null_dim.is_finite() || null_dim < -1e-9 {
1844        return f64::NAN;
1845    }
1846    residual_objective + 0.5 * log_det_h
1847        - 0.5 * penalty_log_det
1848        - 0.5 * null_dim.max(0.0) * (2.0 * std::f64::consts::PI).ln()
1849}
1850
1851/// Compute the Hessian logdet from exact arrow factors or an HVP fallback.
1852pub fn evidence_hessian_log_det(source: EvidenceLogDetSource<'_>) -> Result<f64, String> {
1853    match source {
1854        EvidenceLogDetSource::FactoredArrow {
1855            cache,
1856            fallback_hvp,
1857        } => match arrow_log_det_from_cache(cache) {
1858            Some(v) => Ok(v),
1859            None => match fallback_hvp {
1860                Some(hvp) => hessian_log_det_from_hvp(hvp),
1861                None => {
1862                    Err("evidence Hessian logdet requires exact factors or HVP fallback".into())
1863                }
1864            },
1865        },
1866        EvidenceLogDetSource::Hvp(hvp) => hessian_log_det_from_hvp(hvp),
1867    }
1868}
1869
1870/// Log determinant of an SPD operator supplied by HVP callback.
1871///
1872/// The dispatch boundary intentionally matches
1873/// `ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD` in `terms::analytic_penalties`:
1874/// small operators are materialized and diagonalized exactly; larger ones use
1875/// Rademacher stochastic Lanczos quadrature.
1876pub fn hessian_log_det_from_hvp(hvp: EvidenceHvpLogDet<'_>) -> Result<f64, String> {
1877    if hvp.dim == 0 {
1878        return Ok(0.0);
1879    }
1880    if hvp.dim <= ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD {
1881        let mut dense = Array2::<f64>::zeros((hvp.dim, hvp.dim));
1882        let mut basis = vec![0.0_f64; hvp.dim];
1883        for j in 0..hvp.dim {
1884            basis[j] = 1.0;
1885            let col = (hvp.apply)(&basis);
1886            basis[j] = 0.0;
1887            if col.len() != hvp.dim || col.iter().any(|v| !v.is_finite()) {
1888                return Err(format!(
1889                    "evidence HVP logdet expected finite column of length {}, got {}",
1890                    hvp.dim,
1891                    col.len()
1892                ));
1893            }
1894            for i in 0..hvp.dim {
1895                dense[[i, j]] = col[i];
1896            }
1897        }
1898        validate_dense_hvp_symmetry(&dense)?;
1899        for i in 0..hvp.dim {
1900            for j in (i + 1)..hvp.dim {
1901                let avg = 0.5 * (dense[[i, j]] + dense[[j, i]]);
1902                dense[[i, j]] = avg;
1903                dense[[j, i]] = avg;
1904            }
1905        }
1906        dense_spd_log_det(&dense)
1907    } else {
1908        stochastic_hvp_log_det(hvp)
1909    }
1910}
1911
1912fn dense_spd_log_det(matrix: &Array2<f64>) -> Result<f64, String> {
1913    if matrix.nrows() != matrix.ncols() {
1914        return Err(format!(
1915            "evidence dense logdet requires square matrix, got {}x{}",
1916            matrix.nrows(),
1917            matrix.ncols()
1918        ));
1919    }
1920    if gam_gpu::cuda_selected() {
1921        return crate::gpu::reml_gpu::evidence_derivatives_gpu(
1922            crate::gpu::reml_gpu::RemlGpuInput {
1923                penalized_hessian: matrix.view(),
1924                derivative_hessians: Vec::new(),
1925            },
1926        )
1927        .map(|evidence| evidence.logdet_hessian);
1928    }
1929    let (evals, _) = matrix
1930        .eigh(Side::Lower)
1931        .map_err(|e| format!("evidence dense logdet eigendecomposition failed: {e}"))?;
1932    let mut logdet = 0.0_f64;
1933    for (idx, &ev) in evals.iter().enumerate() {
1934        if !ev.is_finite() || ev <= 0.0 {
1935            return Err(format!(
1936                "evidence dense logdet expected SPD Hessian, eigenvalue {idx} is {ev:.3e}"
1937            ));
1938        }
1939        logdet += ev.ln();
1940    }
1941    Ok(logdet)
1942}
1943
1944fn validate_dense_hvp_symmetry(matrix: &Array2<f64>) -> Result<(), String> {
1945    let n = matrix.nrows();
1946    let mut norm_sq = 0.0_f64;
1947    for &value in matrix.iter() {
1948        norm_sq += value * value;
1949    }
1950
1951    let mut skew_sq = 0.0_f64;
1952    for i in 0..n {
1953        for j in (i + 1)..n {
1954            let skew = matrix[[i, j]] - matrix[[j, i]];
1955            skew_sq += 2.0 * skew * skew;
1956        }
1957    }
1958
1959    let rel_skew = skew_sq.sqrt() / norm_sq.sqrt().max(1.0);
1960    if !rel_skew.is_finite() || rel_skew > EVIDENCE_HVP_SYMMETRY_REL_TOL {
1961        return Err(format!(
1962            "evidence HVP logdet requires symmetric operator, relative skew norm is {rel_skew:.3e}"
1963        ));
1964    }
1965    Ok(())
1966}
1967
1968fn validate_hvp_randomized_symmetry(hvp: EvidenceHvpLogDet<'_>) -> Result<(), String> {
1969    let inv_norm = 1.0 / (hvp.dim as f64).sqrt();
1970    for probe in 0..EVIDENCE_HVP_SYMMETRY_PROBES.max(1) {
1971        let mut x = vec![0.0_f64; hvp.dim];
1972        let mut y = vec![0.0_f64; hvp.dim];
1973        rademacher_unit_probe_into_slice(&mut x, (2 * probe) as u64, inv_norm);
1974        rademacher_unit_probe_into_slice(&mut y, (2 * probe + 1) as u64, inv_norm);
1975
1976        let hx = (hvp.apply)(&x);
1977        let hy = (hvp.apply)(&y);
1978        if hx.len() != hvp.dim || hx.iter().any(|v| !v.is_finite()) {
1979            return Err(format!(
1980                "evidence HVP symmetry check expected finite vector of length {}, got {}",
1981                hvp.dim,
1982                hx.len()
1983            ));
1984        }
1985        if hy.len() != hvp.dim || hy.iter().any(|v| !v.is_finite()) {
1986            return Err(format!(
1987                "evidence HVP symmetry check expected finite vector of length {}, got {}",
1988                hvp.dim,
1989                hy.len()
1990            ));
1991        }
1992
1993        let lhs = dot_slice(&x, &hy);
1994        let rhs = dot_slice(&hx, &y);
1995        let scale = (norm2_slice(&hx) * norm2_slice(&y))
1996            .max(norm2_slice(&hy) * norm2_slice(&x))
1997            .max(lhs.abs())
1998            .max(rhs.abs())
1999            .max(1.0);
2000        let rel = (lhs - rhs).abs() / scale;
2001        if !rel.is_finite() || rel > EVIDENCE_HVP_SYMMETRY_REL_TOL {
2002            return Err(format!(
2003                "evidence HVP logdet requires symmetric operator, randomized symmetry probe {probe} has relative bilinear mismatch {rel:.3e}"
2004            ));
2005        }
2006    }
2007    Ok(())
2008}
2009
2010fn stochastic_hvp_log_det(hvp: EvidenceHvpLogDet<'_>) -> Result<f64, String> {
2011    validate_hvp_randomized_symmetry(hvp)?;
2012    let probes = EVIDENCE_LOGDET_SLQ_PROBES.max(1);
2013    let steps = EVIDENCE_LOGDET_LANCZOS_STEPS.min(hvp.dim).max(1);
2014    let inv_norm = 1.0 / (hvp.dim as f64).sqrt();
2015    let mut estimate = 0.0_f64;
2016    for probe in 0..probes {
2017        let mut q0 = vec![0.0_f64; hvp.dim];
2018        rademacher_unit_probe_into_slice(&mut q0, probe as u64, inv_norm);
2019        let quad = lanczos_log_quadrature_hvp(hvp, q0, steps)?;
2020        estimate += hvp.dim as f64 * quad;
2021    }
2022    Ok(estimate / probes as f64)
2023}
2024
2025fn lanczos_log_quadrature_hvp(
2026    hvp: EvidenceHvpLogDet<'_>,
2027    q: Vec<f64>,
2028    max_steps: usize,
2029) -> Result<f64, String> {
2030    let n = hvp.dim;
2031    let eigen = symmetric_lanczos_eigenpairs(
2032        n,
2033        &q,
2034        SymmetricLanczosOptions {
2035            max_steps,
2036            residual_tol: 1e-12,
2037            local_reorthogonalize: false,
2038            full_reorthogonalize: false,
2039        },
2040        |q, out| {
2041            let applied = (hvp.apply)(q);
2042            if applied.len() != n || applied.iter().any(|v| !v.is_finite()) {
2043                return Err(format!(
2044                    "evidence HVP SLQ expected finite vector of length {n}, got {}",
2045                    applied.len()
2046                ));
2047            }
2048            out.copy_from_slice(&applied);
2049            Ok(())
2050        },
2051    )
2052    .map_err(|e| format!("evidence HVP SLQ Lanczos failed: {e}"))?;
2053    symmetric_lanczos_log_quadrature(&eigen, "evidence HVP SLQ expected SPD Hessian")
2054}
2055
2056#[inline]
2057fn dot_slice(a: &[f64], b: &[f64]) -> f64 {
2058    assert_eq!(a.len(), b.len());
2059    let mut s = 0.0_f64;
2060    for i in 0..a.len() {
2061        s += a[i] * b[i];
2062    }
2063    s
2064}
2065
2066#[inline]
2067fn norm2_slice(a: &[f64]) -> f64 {
2068    dot_slice(a, a).sqrt()
2069}
2070
2071fn rademacher_unit_probe_into_slice(z: &mut [f64], probe: u64, scale: f64) {
2072    let mut state = 0x6A09E667F3BCC909_u64 ^ probe.wrapping_mul(0xD1B54A32D192ED03);
2073    let mut bits = 0_u64;
2074    let mut remaining_bits = 0_u32;
2075    for value in z.iter_mut() {
2076        if remaining_bits == 0 {
2077            bits = splitmix64(&mut state);
2078            remaining_bits = 64;
2079        }
2080        *value = if bits & 1 == 0 { scale } else { -scale };
2081        bits >>= 1;
2082        remaining_bits -= 1;
2083    }
2084}
2085
2086#[inline]
2087const fn splitmix64(state: &mut u64) -> u64 {
2088    gam_linalg::utils::splitmix64(state)
2089}
2090
2091/// Sum of per-row arrow log-determinants plus the Schur log-det.
2092///
2093/// `log|H| = Σ_i log|H_uu_i| + log|A|` using the undamped Cholesky
2094/// factors of `H_uu_i` and the cached Schur Cholesky factor.
2095///
2096/// Returns `None` if `cache.schur_factor` is absent (InexactPCG path) or
2097/// if a damped/incoherent cache is supplied. [`evidence_hessian_log_det`]
2098/// routes such matrix-free cases to an explicit HVP fallback.
2099pub fn arrow_log_det_from_cache(cache: &ArrowFactorCache) -> Option<f64> {
2100    if cache.ridge_t != 0.0 || cache.ridge_beta != 0.0 {
2101        // Per proposal §6.4 / §6.5 — evidence must use the undamped
2102        // operator. The cache's Schur factor here was assembled under
2103        // ridge damping, which is a different operator. Reject loudly.
2104        return None;
2105    }
2106    if let Some(log_det) = cache.joint_hessian_log_det {
2107        return log_det.is_finite().then_some(log_det);
2108    }
2109    // A `k == 0` cache has no shared β block, so the dense Direct path forms no
2110    // reduced Schur complement and `schur_factor` is legitimately `None` (the
2111    // joint Hessian is block-diagonal in the latent rows). Its log-det is the
2112    // per-row sum with no Schur term. Only reject when `k > 0` and the factor
2113    // is absent — the InexactPCG case that never built the dense `K×K` factor.
2114    // (#1132 euclidean K=4: a β-profiled atom reaches here with `k == 0`.)
2115    let schur = match cache.schur_factor.as_ref() {
2116        Some(schur) => Some(schur),
2117        None if cache.k == 0 => None,
2118        None => return None,
2119    };
2120
2121    let mut acc = 0.0_f64;
2122    // Per-row arrow blocks: log|H_uu_i| = 2 Σ log diag(L_i).
2123    for l in cache.undamped_factors_iter() {
2124        acc += 2.0 * log_det_from_chol_lower(l);
2125    }
2126    // Schur block: log|A| = 2 Σ log diag(L_schur). Empty for the `k == 0` case.
2127    if let Some(schur) = schur {
2128        acc += 2.0 * log_det_from_chol_lower(schur.view());
2129    }
2130    // #1038 cross-row IBP: when the cache carries an exact rank-`R` Woodbury,
2131    // the per-row + Schur factors above are of the NO-SELF base `H₀'`, so the
2132    // exact `log det H_full = log det H₀' + log det(I_R + D Uᵀ H₀'⁻¹ U)`. The
2133    // correction is zero (no-op) for every non-IBP cache.
2134    let woodbury_correction = cache.cross_row_woodbury_log_det();
2135    if !woodbury_correction.is_finite() {
2136        // A non-PD capacitance (negative determinant) is a value↔gradient
2137        // desync the evidence must reject loudly, not paper over.
2138        return None;
2139    }
2140    acc += woodbury_correction;
2141    Some(acc)
2142}
2143
2144/// Twice-the-diagonal-log sum for a lower-triangular Cholesky factor.
2145fn log_det_from_chol_lower(l: ArrayView2<'_, f64>) -> f64 {
2146    let n = l.nrows();
2147    let mut acc = 0.0_f64;
2148    for i in 0..n {
2149        let d = l[[i, i]];
2150        if d > 0.0 {
2151            acc += d.ln();
2152        } else {
2153            // SAFETY: a valid lower-triangular Cholesky factor has a strictly
2154            // positive diagonal by construction. A non-positive diagonal means
2155            // the caller passed a corrupted / non-SPD factor — surface it loudly
2156            // rather than papering over with a corrupting NaN that silently
2157            // poisons the evidence log-det (callers do not check is_nan).
2158            panic!(
2159                "log_det_from_chol_lower: non-positive Cholesky diagonal {d} at index {i}; \
2160                 caller passed a corrupted or non-SPD factor"
2161            );
2162        }
2163    }
2164    acc
2165}
2166
2167// ---------------------------------------------------------------------------
2168// IFT cascade: ∂u*/∂β → ∂β*/∂ρ → ∂u*/∂ρ
2169// ---------------------------------------------------------------------------
2170
2171/// Tier-1 IFT sensitivity `∂u_i*/∂β = -H_uu_i⁻¹ H_uβ_i`.
2172///
2173/// Concatenated row-major to a single `(N·d) × K` dense matrix. Each
2174/// row block is solved with the **undamped** Cholesky factor. Proposal
2175/// §2.2 / §7.
2176pub fn ift_du_dbeta(cache: &ArrowFactorCache) -> Array2<f64> {
2177    let n = cache.undamped_factor_count();
2178    let total_len = cache.delta_t_len();
2179    let k = cache.k;
2180    if !cache.htbeta_available() {
2181        return Array2::<f64>::from_elem((total_len, k), f64::NAN);
2182    }
2183    let mut out = Array2::<f64>::zeros((total_len, k));
2184    let mut beta_basis = Array1::<f64>::zeros(k);
2185    // Allocate scratch at max_d; per-row slice is ..di.
2186    let mut rhs = Array1::<f64>::zeros(cache.d);
2187    for i in 0..n {
2188        let di = cache.row_dims[i];
2189        let row_base = cache.row_offsets[i];
2190        let factor = cache.undamped_factor(i);
2191        // Solve H_uu_i Y = H_uβ_i column by column.
2192        for col in 0..k {
2193            beta_basis.fill(0.0);
2194            beta_basis[col] = 1.0;
2195            let mut rhs_i = rhs.slice_mut(ndarray::s![..di]).to_owned();
2196            // The Tier-2 IFT assembler is built only when the family's
2197            // capability surface promises cached `H_tβ` row products.
2198            if !cache.apply_htbeta_row(i, beta_basis.view(), &mut rhs_i) {
2199                // SAFETY: reaching `false` means a family declared the cache
2200                // available but failed to populate it — contract violation.
2201                return Array2::<f64>::from_elem((total_len, k), f64::NAN);
2202            }
2203            let y = cholesky_solve_vector(factor, &rhs_i);
2204            for c in 0..di {
2205                out[[row_base + c, col]] = -y[c];
2206            }
2207        }
2208    }
2209    out
2210}
2211
2212/// Coupling components of a symmetric coefficient Hessian: the connected
2213/// components of the graph whose vertices are coefficient indices `0..p` and
2214/// whose edges are the structurally nonzero off-diagonal entries of `H` (#779).
2215///
2216/// Returns a length-`p` vector of component labels in `0..num_components`,
2217/// where two indices share a label iff they are connected through a chain of
2218/// nonzero `H[i,j]` couplings. This is the exact structural partition the
2219/// cone-of-influence sensitivity reuse is keyed on: a smoothing-parameter move
2220/// whose stationarity-gradient derivative `∂g/∂ρ` is supported only inside one
2221/// component can change `β = -H⁻¹ ∂g/∂ρ` only inside that same component, so
2222/// the sensitivity of every *other* component is provably unchanged and may be
2223/// reused unrecomputed (lazy/local propagation).
2224///
2225/// The nonzero test is exact (`!= 0.0`), matching the structural-coupling gate
2226/// used elsewhere for the joint inner Hessian: a tolerance would risk dropping a
2227/// genuine (small) coupling edge and silently biasing the propagated sensitivity
2228/// — the failure mode #779/#740 explicitly guard against. A block-diagonal `H`
2229/// yields the all-singletons partition (one component per block-decoupled
2230/// coordinate); a fully coupled `H` yields a single component (no shortcut, the
2231/// full joint solve is required — and is what the non-coned path performs).
2232pub fn coupling_components(hessian: ArrayView2<'_, f64>) -> Vec<usize> {
2233    let p = hessian.nrows();
2234    if p == 0 || hessian.ncols() != p {
2235        return Vec::new();
2236    }
2237    // Union-find with path compression and union by size.
2238    let mut parent: Vec<usize> = (0..p).collect();
2239    let mut size: Vec<usize> = vec![1; p];
2240
2241    fn find(parent: &mut [usize], mut x: usize) -> usize {
2242        while parent[x] != x {
2243            parent[x] = parent[parent[x]];
2244            x = parent[x];
2245        }
2246        x
2247    }
2248
2249    for i in 0..p {
2250        for j in (i + 1)..p {
2251            // Symmetric structure: an edge exists if either triangle is nonzero,
2252            // so a numerically one-sided fill still couples the two indices.
2253            if hessian[[i, j]] != 0.0 || hessian[[j, i]] != 0.0 {
2254                let (ri, rj) = (find(&mut parent, i), find(&mut parent, j));
2255                if ri != rj {
2256                    let (small, large) = if size[ri] < size[rj] {
2257                        (ri, rj)
2258                    } else {
2259                        (rj, ri)
2260                    };
2261                    parent[small] = large;
2262                    size[large] += size[small];
2263                }
2264            }
2265        }
2266    }
2267
2268    // Relabel roots to a dense `0..num_components` range, preserving
2269    // first-seen order so labels are deterministic.
2270    let mut label_of_root: Vec<Option<usize>> = vec![None; p];
2271    let mut next_label = 0usize;
2272    let mut labels = vec![0usize; p];
2273    for idx in 0..p {
2274        let root = find(&mut parent, idx);
2275        let label = match label_of_root[root] {
2276            Some(l) => l,
2277            None => {
2278                let l = next_label;
2279                label_of_root[root] = Some(l);
2280                next_label += 1;
2281                l
2282            }
2283        };
2284        labels[idx] = label;
2285    }
2286    labels
2287}
2288
2289/// The cone of influence of a single stationarity-gradient derivative column
2290/// whose support (the coefficient indices where `∂g/∂ρ_k` is nonzero) lies in
2291/// `support`: the set of coefficient indices in the same coupling component(s)
2292/// as that support, given precomputed `labels` from [`coupling_components`].
2293///
2294/// `β_k = -H⁻¹ ∂g/∂ρ_k` is exactly zero outside this cone, so a confined solve
2295/// (or reuse of a cached zero) is exact, not an approximation. An empty support
2296/// (a structurally inactive `ρ_k`, e.g. a rank-0 or out-of-range penalty block)
2297/// yields an empty cone: the sensitivity is identically zero and no solve is
2298/// needed at all.
2299pub fn cone_of_influence(labels: &[usize], support: &[usize]) -> Vec<usize> {
2300    if support.is_empty() {
2301        return Vec::new();
2302    }
2303    let mut in_cone_labels: Vec<usize> = support
2304        .iter()
2305        .filter_map(|&idx| labels.get(idx).copied())
2306        .collect();
2307    in_cone_labels.sort_unstable();
2308    in_cone_labels.dedup();
2309    if in_cone_labels.is_empty() {
2310        return Vec::new();
2311    }
2312    (0..labels.len())
2313        .filter(|idx| in_cone_labels.binary_search(&labels[*idx]).is_ok())
2314        .collect()
2315}
2316
2317/// Tier-2 IFT sensitivity `∂β*/∂ρ = -A⁻¹ ∂g_red/∂ρ` (proposal §2.4 /
2318/// §7).
2319///
2320/// `dg_red_drho` is the `K × R` matrix whose `a`-th column is `q_a =
2321/// ∂g_red/∂ρ_a`. Returns the `K × R` matrix `β_ρ`.
2322///
2323/// Returns `None` if the Schur factor is unavailable (PCG mode) or was
2324/// built from a damped operator, or if any solved entry is non-finite;
2325/// callers must not silently substitute an approximation. The solve is
2326/// the one sensitivity operator (#935) — this site holds no private H⁻¹
2327/// convention of its own.
2328pub fn ift_dbeta_drho(
2329    cache: &ArrowFactorCache,
2330    dg_red_drho: ArrayView2<'_, f64>,
2331) -> Option<Array2<f64>> {
2332    if cache.ridge_t != 0.0 || cache.ridge_beta != 0.0 {
2333        return None;
2334    }
2335    let schur = cache.schur_factor.as_ref()?;
2336    if dg_red_drho.nrows() != cache.k || schur.nrows() != cache.k {
2337        return None;
2338    }
2339    crate::sensitivity::FitSensitivity::from_lower_triangular(schur)
2340        .mode_response(dg_red_drho)
2341}
2342
2343
2344// ---------------------------------------------------------------------------
2345// ∂V/∂ρ — analytic optimized-evidence gradient via IFT mode response
2346// ---------------------------------------------------------------------------
2347
2348/// IFT terms needed to differentiate the optimized Laplace evidence through
2349/// the fitted mode `(β*(ρ), u*(ρ))`.
2350///
2351/// For each hyperparameter `ρ_a`, the correction added to the direct trace is
2352///
2353/// ```text
2354/// F_β · β_a + F_u · u_a
2355/// + 0.5 (∂_β log|H| · β_a + ∂_u log|H| · u_a).
2356/// ```
2357///
2358/// At an exact KKT point the value-gradient pieces are zero, but they are
2359/// explicit here so the exported gradient matches the optimized objective
2360/// whenever callers carry a certified nonzero residual correction.
2361#[derive(Clone)]
2362pub struct EvidenceIftGradientTerms<'a> {
2363    pub dbeta_drho: ArrayView2<'a, f64>,
2364    pub du_drho: ArrayView2<'a, f64>,
2365    pub value_beta: ArrayView1<'a, f64>,
2366    pub value_u: ArrayView1<'a, f64>,
2367    pub logdet_h_beta: ArrayView1<'a, f64>,
2368    pub logdet_h_u: ArrayView1<'a, f64>,
2369}
2370
2371/// Contract the IFT mode-response columns into the optimized-evidence
2372/// gradient correction.
2373pub fn evidence_ift_gradient_correction(terms: EvidenceIftGradientTerms<'_>) -> Array1<f64> {
2374    let k = terms.dbeta_drho.nrows();
2375    let nd = terms.du_drho.nrows();
2376    let r = terms.dbeta_drho.ncols();
2377    if terms.du_drho.ncols() != r
2378        || terms.value_beta.len() != k
2379        || terms.logdet_h_beta.len() != k
2380        || terms.value_u.len() != nd
2381        || terms.logdet_h_u.len() != nd
2382    {
2383        return Array1::<f64>::from_elem(r, f64::NAN);
2384    }
2385
2386    let mut out = Array1::<f64>::zeros(r);
2387    for a in 0..r {
2388        let mut acc = 0.0_f64;
2389        for j in 0..k {
2390            let mode = terms.dbeta_drho[[j, a]];
2391            acc += terms.value_beta[j] * mode;
2392            acc += 0.5 * terms.logdet_h_beta[j] * mode;
2393        }
2394        for j in 0..nd {
2395            let mode = terms.du_drho[[j, a]];
2396            acc += terms.value_u[j] * mode;
2397            acc += 0.5 * terms.logdet_h_u[j] * mode;
2398        }
2399        out[a] = acc;
2400    }
2401    out
2402}
2403
2404/// Per-`ρ` optimized-evidence gradient (proposal §3.7 / §3.8 split):
2405///
2406/// ```text
2407/// ∂V/∂ρ_a =
2408///       F_{ρ_a}                                  (value part)
2409///   + 0.5 tr(H⁻¹ H_{ρ_a})                        (direct Hessian)
2410///   + F_x · x_{ρ_a}
2411///   + 0.5 (∂_x log|H|) · x_{ρ_a}                 (IFT mode response)
2412///   - 0.5 tr(S_pen⁺ S_{pen,ρ_a})                 (penalty pseudo-logdet)
2413/// ```
2414/// where `x = (β, u)`.
2415///
2416/// The `tr(H⁻¹ H_{ρ_a})` trace is computed via the arrow structure
2417/// (proposal §3.5 / §3.10):
2418///
2419/// ```text
2420/// tr(H⁻¹ H_{ρ_a}) = Σ_i tr(H_uu_i⁻¹ ∂_{ρ_a} H_uu_i) + tr(A⁻¹ ∂_{ρ_a} A).
2421/// ```
2422///
2423/// `value_rho[a] = F_{ρ_a}` (envelope theorem, proposal §3.2).
2424/// `huu_drho[i][a]` is `∂H_uu_i/∂ρ_a` as a `d × d` matrix.
2425/// `hbb_drho[a]` is `∂H_ββ/∂ρ_a` as a `K × K` matrix.
2426/// `htbeta_drho[i][a]` is `∂H_uβ_i/∂ρ_a` as a `d × K` matrix.
2427/// `pen_logdet_drho[a]` is `∂_{ρ_a} log|S_pen|+`.
2428/// `ift_terms` carries `∂β*/∂ρ`, `∂u*/∂ρ`, and the already-contracted
2429/// mode derivatives of `F` and `log|H|`.
2430///
2431/// Returns the per-`ρ` gradient. Returns a NaN-filled vector when the
2432/// cache has no undamped Schur factor (PCG mode).
2433pub fn evidence_grad_rho(
2434    cache: &ArrowFactorCache,
2435    value_rho: ArrayView1<'_, f64>,
2436    huu_drho: &[Vec<Array2<f64>>],
2437    htbeta_drho: &[Vec<Array2<f64>>],
2438    hbb_drho: &[Array2<f64>],
2439    pen_logdet_drho: ArrayView1<'_, f64>,
2440    ift_terms: EvidenceIftGradientTerms<'_>,
2441) -> Array1<f64> {
2442    let r = value_rho.len();
2443    let n = cache.undamped_factor_count();
2444    let k = cache.k;
2445    let mut out = Array1::<f64>::zeros(r);
2446    if !cache.htbeta_available()
2447        || pen_logdet_drho.len() != r
2448        || huu_drho.len() != n
2449        || htbeta_drho.len() != n
2450        || hbb_drho.len() != r
2451        || huu_drho.iter().any(|row| row.len() != r)
2452        || htbeta_drho.iter().any(|row| row.len() != r)
2453        || hbb_drho.iter().any(|m| m.nrows() != k || m.ncols() != k)
2454        || huu_drho.iter().enumerate().any(|(i, row)| {
2455            let di = cache.row_dims[i];
2456            row.iter().any(|m| m.nrows() != di || m.ncols() != di)
2457        })
2458        || htbeta_drho.iter().enumerate().any(|(i, row)| {
2459            let di = cache.row_dims[i];
2460            row.iter().any(|m| m.nrows() != di || m.ncols() != k)
2461        })
2462    {
2463        out.fill(f64::NAN);
2464        return out;
2465    }
2466    let ift_correction = evidence_ift_gradient_correction(ift_terms);
2467    if ift_correction.len() != r || ift_correction.iter().any(|v| v.is_nan()) {
2468        out.fill(f64::NAN);
2469        return out;
2470    }
2471
2472    let schur = match cache.schur_factor.as_ref() {
2473        Some(s) => s,
2474        None => {
2475            for a in 0..r {
2476                out[a] = f64::NAN;
2477            }
2478            return out;
2479        }
2480    };
2481
2482    // Precompute Y_i = H_uu_i⁻¹ H_uβ_i (di × K). Used by both the Schur
2483    // derivative formula (§3.5) and the row trace `tr(H_uu_i⁻¹ ∂H_uu_i)`.
2484    let mut y_blocks: Vec<Array2<f64>> = Vec::with_capacity(n);
2485    let mut beta_basis = Array1::<f64>::zeros(k);
2486    // Scratch sized to max_d; per-row slice is ..di.
2487    let mut rhs = Array1::<f64>::zeros(cache.d);
2488    for i in 0..n {
2489        let di = cache.row_dims[i];
2490        let factor = cache.undamped_factor(i);
2491        let mut yi = Array2::<f64>::zeros((di, k));
2492        for col in 0..k {
2493            beta_basis.fill(0.0);
2494            beta_basis[col] = 1.0;
2495            let mut rhs_i = rhs.slice_mut(ndarray::s![..di]).to_owned();
2496            // Same H_tβ cache contract as the IFT du/dβ and du/dρ paths.
2497            if !cache.apply_htbeta_row(i, beta_basis.view(), &mut rhs_i) {
2498                // SAFETY: `false` means the family declared the cache
2499                // available but did not populate it — contract violation.
2500                out.fill(f64::NAN);
2501                return out;
2502            }
2503            let v = cholesky_solve_vector(factor, &rhs_i);
2504            for c in 0..di {
2505                yi[[c, col]] = v[c];
2506            }
2507        }
2508        y_blocks.push(yi);
2509    }
2510
2511    // Outer-hoisted scratch reused across all (a, i) iterations.
2512    // Sized to max_d for trace_rhs and da_tmp; per-row slices used below.
2513    let mut trace_rhs = Array1::<f64>::zeros(cache.d);
2514    let mut da_tmp = Array2::<f64>::zeros((cache.d, k));
2515    let mut col_scratch = Array1::<f64>::zeros(k);
2516    for a in 0..r {
2517        // Part 1: F_{ρ_a} envelope contribution.
2518        let mut grad = value_rho[a];
2519
2520        // Part 2a: Σ_i tr(H_uu_i⁻¹ ∂H_uu_i).
2521        // tr(H_uu_i⁻¹ M_i) = tr(L_iᵀ⁻¹ L_i⁻¹ M_i). Compute as the sum
2522        // over columns: solve L_i Lᵀ x = e_c for the c-th column of
2523        // M_i, then take its c-th component. Equivalently and more
2524        // cheaply, build (H_uu_i⁻¹ M_i) by solving column-by-column
2525        // and take its diagonal sum.
2526        let mut row_trace_acc = 0.0_f64;
2527        for i in 0..n {
2528            let di = cache.row_dims[i];
2529            let m_i = &huu_drho[i][a];
2530            assert_eq!(m_i.shape(), &[di, di]);
2531            for col in 0..di {
2532                let mut tr_rhs_i = trace_rhs.slice_mut(ndarray::s![..di]).to_owned();
2533                for r0 in 0..di {
2534                    tr_rhs_i[r0] = m_i[[r0, col]];
2535                }
2536                let v = cholesky_solve_vector(cache.undamped_factor(i), &tr_rhs_i);
2537                row_trace_acc += v[col];
2538            }
2539        }
2540
2541        // Part 2b: tr(A⁻¹ ∂A) where (proposal §3.5)
2542        //     ∂A = ∂H_ββ
2543        //          - Σ_i (∂H_uβ_i)ᵀ Y_i
2544        //          - Σ_i Y_iᵀ (∂H_uβ_i)
2545        //          + Σ_i Y_iᵀ (∂H_uu_i) Y_i.
2546        // We accumulate ∂A as a dense `K × K` matrix, then evaluate
2547        // tr(A⁻¹ ∂A) by `Σ_j (A⁻¹ ∂A)[j, j]` via column solves of the
2548        // Schur Cholesky.
2549        let mut da = hbb_drho[a].clone();
2550        assert_eq!(da.shape(), &[k, k]);
2551        for i in 0..n {
2552            let di = cache.row_dims[i];
2553            let dhtb = &htbeta_drho[i][a]; // di × K
2554            let yi = &y_blocks[i]; // di × K
2555            // - (∂H_uβ_i)ᵀ Y_i
2556            for r0 in 0..k {
2557                for c0 in 0..k {
2558                    let mut acc = 0.0;
2559                    for cc in 0..di {
2560                        acc += dhtb[[cc, r0]] * yi[[cc, c0]];
2561                    }
2562                    da[[r0, c0]] -= acc;
2563                }
2564            }
2565            // - Y_iᵀ (∂H_uβ_i)
2566            for r0 in 0..k {
2567                for c0 in 0..k {
2568                    let mut acc = 0.0;
2569                    for cc in 0..di {
2570                        acc += yi[[cc, r0]] * dhtb[[cc, c0]];
2571                    }
2572                    da[[r0, c0]] -= acc;
2573                }
2574            }
2575            // + Y_iᵀ (∂H_uu_i) Y_i
2576            let dhuu = &huu_drho[i][a];
2577            // tmp = (∂H_uu_i) Y_i  (di × K) — use a slice of the hoisted buffer.
2578            let mut da_tmp_i = da_tmp.slice_mut(ndarray::s![..di, ..]).to_owned();
2579            for r0 in 0..di {
2580                for c0 in 0..k {
2581                    let mut acc = 0.0;
2582                    for cc in 0..di {
2583                        acc += dhuu[[r0, cc]] * yi[[cc, c0]];
2584                    }
2585                    da_tmp_i[[r0, c0]] = acc;
2586                }
2587            }
2588            // da += Y_iᵀ tmp
2589            for r0 in 0..k {
2590                for c0 in 0..k {
2591                    let mut acc = 0.0;
2592                    for cc in 0..di {
2593                        acc += yi[[cc, r0]] * da_tmp_i[[cc, c0]];
2594                    }
2595                    da[[r0, c0]] += acc;
2596                }
2597            }
2598        }
2599
2600        // tr(A⁻¹ ∂A) via column solves.
2601        let mut schur_trace_acc = 0.0_f64;
2602        for j in 0..k {
2603            for r0 in 0..k {
2604                col_scratch[r0] = da[[r0, j]];
2605            }
2606            let v = cholesky_solve_vector(schur, &col_scratch);
2607            schur_trace_acc += v[j];
2608        }
2609
2610        grad += 0.5 * (row_trace_acc + schur_trace_acc);
2611        grad += ift_correction[a];
2612
2613        // Part 3: -0.5 ∂_{ρ_a} log|S_pen|+.
2614        grad -= 0.5 * pen_logdet_drho[a];
2615
2616        out[a] = grad;
2617    }
2618    out
2619}
2620
2621// ---------------------------------------------------------------------------
2622// Topology selection
2623// ---------------------------------------------------------------------------
2624
2625/// Enumerate the candidate topologies, rank by normalized negative log
2626/// evidence, and return the winner. Failed/excluded candidates (proposal
2627/// §6.11) are appended at the end of `ranking` and are never the winner.
2628///
2629/// The caller fits each topology separately (proposal §4.2) and supplies
2630/// the resulting `TopologyCandidate` records. This function is purely
2631/// the discrete comparator + tie breaker.
2632///
2633/// # Tie-breaking
2634///
2635/// Per proposal §4.6: if normalized `|score_a - score_b| <= tie_tolerance`,
2636/// prefer the simpler topology by `TopologyKind::complexity_rank` (flat <
2637/// periodic < sphere < torus). The `tie` flag in the result records whether
2638/// such a tie occurred at the top of the ranking.
2639///
2640/// # Panics
2641///
2642/// Panics if `candidates` is empty after filtering out non-finite
2643/// scores. Proposal §6.11 explicitly forbids silent fallback to a
2644/// default topology; callers must handle the empty-candidate case
2645/// before invocation.
2646pub fn select_topology(
2647    candidates: &[TopologyCandidate],
2648    options: TopologySelectOptions,
2649) -> SelectedTopology {
2650    // Split valid and excluded.
2651    let mut valid: Vec<TopologyCandidate> = candidates
2652        .iter()
2653        .filter(|c| {
2654            c.converged
2655                && c.exclusion_reason.is_none()
2656                && c.negative_log_evidence.is_finite()
2657                && topology_selection_score(c, options.score_scale).is_finite()
2658        })
2659        .cloned()
2660        .collect();
2661    let mut excluded: Vec<TopologyCandidate> = candidates
2662        .iter()
2663        .filter(|c| {
2664            !(c.converged && c.exclusion_reason.is_none() && c.negative_log_evidence.is_finite())
2665                || !topology_selection_score(c, options.score_scale).is_finite()
2666        })
2667        .cloned()
2668        .collect();
2669
2670    assert!(
2671        !valid.is_empty(),
2672        "select_topology: no finite valid candidates; proposal §6.11 forbids silent fallback"
2673    );
2674
2675    // Sort by normalized negative log evidence (ascending = best first),
2676    // breaking ties by complexity_rank (smaller wins). The shared selector is
2677    // the single lower-is-better ordering contract used by topology ranking,
2678    // seed screening, and REML model comparison (#782).
2679    valid = rank_priority_candidates(
2680        valid
2681            .into_iter()
2682            .enumerate()
2683            .map(|(idx, row)| {
2684                let score = topology_selection_score(&row, options.score_scale);
2685                let tie_break = usize::from(row.kind.complexity_rank());
2686                PriorityCandidate::new(row, idx, score, tie_break)
2687            })
2688            .collect(),
2689    )
2690    .into_iter()
2691    .map(|row| row.item)
2692    .collect();
2693
2694    // Detect numerical tie at the top.
2695    let tie = if valid.len() >= 2 {
2696        let top = topology_selection_score(&valid[0], options.score_scale);
2697        let next = topology_selection_score(&valid[1], options.score_scale);
2698        (next - top).abs() <= options.tie_tolerance
2699    } else {
2700        false
2701    };
2702
2703    // If tied, prefer simpler topology among the tied prefix.
2704    if tie {
2705        let top_score = topology_selection_score(&valid[0], options.score_scale);
2706        // Find the tied prefix range.
2707        let tied_end = valid
2708            .iter()
2709            .position(|c| {
2710                (topology_selection_score(c, options.score_scale) - top_score).abs()
2711                    > options.tie_tolerance
2712            })
2713            .unwrap_or(valid.len());
2714        // Sort the tied prefix by complexity_rank ascending.
2715        valid[..tied_end].sort_by_key(|c| c.kind.complexity_rank());
2716    }
2717
2718    let winner = valid[0].kind;
2719    valid.append(&mut excluded);
2720    SelectedTopology {
2721        winner,
2722        ranking: valid,
2723        tie,
2724    }
2725}
2726
2727fn topology_selection_score(candidate: &TopologyCandidate, scale: TopologyScoreScale) -> f64 {
2728    match scale {
2729        TopologyScoreScale::PerObservation => {
2730            if candidate.n_obs == 0 {
2731                f64::NAN
2732            } else {
2733                candidate.negative_log_evidence / candidate.n_obs as f64
2734            }
2735        }
2736        TopologyScoreScale::PerEffectiveDim => {
2737            if !(candidate.effective_dim.is_finite() && candidate.effective_dim > 0.0) {
2738                f64::NAN
2739            } else {
2740                candidate.negative_log_evidence / candidate.effective_dim
2741            }
2742        }
2743    }
2744}
2745
2746// ---------------------------------------------------------------------------
2747// Cache verification helpers
2748// ---------------------------------------------------------------------------
2749
2750
2751/// Verifies the `ArrowSchurSystem` dimensions match the cache. Used as
2752/// a debug-time precondition; never silently masks shape errors
2753/// (proposal §6.9 — sign and shape errors must be loud).
2754pub fn cache_matches_system(cache: &ArrowFactorCache, sys: &ArrowSchurSystem) -> bool {
2755    cache.d == sys.d
2756        && cache.k == sys.k
2757        && cache.n_rows() == sys.rows.len()
2758        && cache.undamped_factor_count() == sys.rows.len()
2759        && cache.manifold_mode_fingerprint == sys.manifold_mode_fingerprint
2760        && cache.row_hessian_fingerprint == sys.current_row_hessian_fingerprint()
2761}
2762
2763// ---------------------------------------------------------------------------
2764// #1026 hybrid curved + linear-tail dictionary split-selection
2765// ---------------------------------------------------------------------------
2766//
2767// COMMON-EVIDENCE NOTE (#1202): the candidates BOTH fit the same data — the
2768// atom's leave-this-atom-out response residual `y_resp` (the response with every
2769// other atom's contribution removed). The curved candidate predicts the atom's
2770// actual mass-scaled contribution `a_k·γ_k`, the linear candidate the best
2771// mass-weighted straight line fit to `y_resp`. Because the curved family's
2772// `Θ = 0` member reproduces the linear prediction exactly, linear IS the nested
2773// `Θ = 0` sub-model on common data, so the "match-or-beat" statements below are a
2774// genuine data-level comparison: the curved candidate wins only when fitting the
2775// response residual better than its own straight projection pays for its extra
2776// parameters. See `crate::terms::sae::hybrid_split` for the residual assembly.
2777//
2778// The per-slot adjudication uses the SAME rank-aware Laplace evidence criterion
2779// the union/mixture rungs use (`−V = NLE`, lower wins), comparing the data-fit +
2780// complexity cost of the curved contribution against that of the straight line.
2781//
2782// ## The turning floor (Θ → 0) and the curved ceiling (Θ large)
2783//
2784// Per slot, the curved candidate fits the response residual with its actual
2785// mass-scaled contribution `a_k·γ_k` (data-fit `½·curved_rss`) and pays a larger
2786// free-parameter price `P_curved > P_linear`; the linear candidate fits the same
2787// residual with its best straight line (data-fit `½·linear_rss ≥ ½·curved_rss`
2788// whenever the curve beats its own straight projection) at a smaller price,
2789// charged with its genuine weighted Gram logdet `p·(log w_sum + log s_tt)`
2790// (#1203). Hence:
2791//
2792//   * Θ → 0 (the residual is straight): the curve and the line fit it equally, so
2793//     the cheaper LINEAR candidate wins — the turning floor / nested dominance. A
2794//     curved parameterization "buys nothing" on an already-straight residual.
2795//   * Θ large (a genuinely turning residual): the line's data-fit residual
2796//     exceeds the curved atom's extra parameter price, so CURVED wins. (Whether
2797//     curved wins also depends on the coordinate spread `s_tt` and amplitude, via
2798//     the honest logdet — a tightly-spread, mildly-curved residual can still
2799//     prefer the cheaper line.)
2800//
2801// The crossover is governed by the documented shatter law: a linear SAE shatters
2802// a feature of total turning Θ into `N(ε) ≈ Θ/(2√(2ε))` rank-1 directions at
2803// relative reconstruction error ε, so the curved advantage scales as `Θ/√ε`. We
2804// use the fitted turning Θ (`sae::chart_canonicalization::d1_atom_fitted_turning`)
2805// as the decision FEATURE: it both (a) sharpens the evidence comparison into a
2806// falsifiable per-atom prediction and (b) provides the exact-zero dominance
2807// guard — when an atom's fitted turning is identically zero, the curved fit has
2808// no curvature to price and the linear special case is selected by construction,
2809// independent of finite-sample evidence noise.
2810
2811/// Which atom parameterization a hybrid-dictionary slot selects: a CURVED atom
2812/// (a `latent_dim ≥ 1` curved basis whose decoded image may turn) or its LINEAR
2813/// special case (the euclidean-d=1-linear atom — one straight decoder direction,
2814/// `γ(t) = t·b`, fitted turning `Θ = 0`).
2815#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
2816pub enum HybridAtomParam {
2817    /// The curved atom (`latent_dim ≥ 1`), priced at its full coefficient count.
2818    Curved { latent_dim: usize },
2819    /// The linear special case: one decoder direction, zero turning.
2820    Linear,
2821}
2822
2823impl HybridAtomParam {
2824    /// Stable display name for logs and tests.
2825    pub const fn as_str(self) -> &'static str {
2826        match self {
2827            HybridAtomParam::Curved { .. } => "curved",
2828            HybridAtomParam::Linear => "linear",
2829        }
2830    }
2831
2832    /// `true` iff this is the linear special case (the linear tail).
2833    pub const fn is_linear(self) -> bool {
2834        matches!(self, HybridAtomParam::Linear)
2835    }
2836}
2837
2838/// One fitted candidate parameterization for a single hybrid-dictionary atom
2839/// slot, scored on the COMMON rank-aware Laplace scale (`−V = NLE`, lower wins,
2840/// identical to the union/mixture rungs). The curved and linear candidates for
2841/// the SAME slot are fit on the same rows AND the same data (the atom's response
2842/// residual, #1202), so their NLEs are directly comparable; the structural
2843/// difference is the curved candidate's larger free-parameter price and whatever
2844/// data-fit it buys with its curvature.
2845#[derive(Debug, Clone, Copy)]
2846pub struct HybridAtomCandidate {
2847    pub param: HybridAtomParam,
2848    /// Rank-aware Laplace negative-log-evidence on the common scale (lower wins).
2849    pub negative_log_evidence: f64,
2850    /// Free-parameter count this candidate is charged for (the complexity price).
2851    pub num_parameters: usize,
2852    /// The candidate's fitted total turning `Θ = ∫κ ds` of its decoded curve, if
2853    /// the basis admits an analytic second jet. `Some(0.0)` for a linear atom (a
2854    /// straight image has no turning); `None` when the turning is honestly
2855    /// unavailable (no second jet / degenerate curve) — never fabricated.
2856    pub fitted_turning: Option<f64>,
2857}
2858
2859impl HybridAtomCandidate {
2860    /// A linear special-case candidate: exact zero turning by construction.
2861    pub fn linear(negative_log_evidence: f64, num_parameters: usize) -> Self {
2862        Self {
2863            param: HybridAtomParam::Linear,
2864            negative_log_evidence,
2865            num_parameters,
2866            fitted_turning: Some(0.0),
2867        }
2868    }
2869
2870    /// A curved candidate of the given latent dimension, with its fitted turning.
2871    pub fn curved(
2872        latent_dim: usize,
2873        negative_log_evidence: f64,
2874        num_parameters: usize,
2875        fitted_turning: Option<f64>,
2876    ) -> Self {
2877        Self {
2878            param: HybridAtomParam::Curved { latent_dim },
2879            negative_log_evidence,
2880            num_parameters,
2881            fitted_turning,
2882        }
2883    }
2884}
2885
2886/// The evidence-selected parameterization for one hybrid-dictionary atom slot:
2887/// the winning candidate, plus the curved/linear NLEs that decided it (for the
2888/// EV-vs-Θ diagnostic and the tie-break audit trail).
2889#[derive(Debug, Clone, Copy)]
2890pub struct HybridAtomChoice {
2891    pub param: HybridAtomParam,
2892    /// The winning candidate's NLE.
2893    pub negative_log_evidence: f64,
2894    /// The winning candidate's free-parameter price.
2895    pub num_parameters: usize,
2896    /// The curved candidate's fitted turning `Θ` (the decision feature). `None`
2897    /// when no curved candidate offered an analytic turning.
2898    pub curved_turning: Option<f64>,
2899    /// `NLE_linear − NLE_curved`: the evidence margin the curved fit won (or lost,
2900    /// if negative) over the linear special case at this slot. Positive ⇒ curved
2901    /// bought more evidence than its parameter price; ≤ 0 ⇒ the dominance floor
2902    /// keeps the linear tail.
2903    pub curved_evidence_margin: f64,
2904}
2905
2906/// Below this fitted turning the curved candidate is treated as straight: its
2907/// curvature is numerically indistinguishable from zero, so the dominance floor
2908/// (the linear special case is cheaper at equal likelihood) is enforced by
2909/// construction rather than left to finite-sample evidence noise. This is the
2910/// exact-zero guard from the `Θ → 0 ⇒ N(ε) → 0` limit of the shatter law, not a
2911/// tunable knob: it is the curvature scale below which `‖γ' ∧ γ''‖` is at the
2912/// floor of the Simpson quadrature for a genuinely straight image.
2913pub const HYBRID_LINEAR_TURNING_FLOOR: f64 = 1e-9;
2914
2915/// Adjudicate the curved-vs-linear parameterization for ONE hybrid-dictionary
2916/// atom slot by the common rank-aware Laplace evidence criterion.
2917///
2918/// Selection rule (all on the single `NLE = −V` scale, lower wins):
2919///
2920///  1. **Dominance floor (Θ → 0).** If the curved candidate's fitted turning is
2921///     `Some(Θ)` with `Θ ≤ HYBRID_LINEAR_TURNING_FLOOR` and a linear candidate
2922///     exists, select LINEAR. A straight curved fit recovers no likelihood the
2923///     linear special case does not, and the linear atom is strictly cheaper, so
2924///     it cannot lose — we enforce that exactly instead of trusting evidence
2925///     noise at the floor.
2926///  2. **Evidence comparison.** Otherwise select the candidate with the smaller
2927///     `NLE`. The curved candidate wins only when its extra curvature lowers the
2928///     NLE by MORE than its extra parameter price — the `Θ/√ε` crossover, decided
2929///     here by the evidence numbers themselves, not by fiat. This is a
2930///     common-data comparison (both candidates fit the atom's response residual,
2931///     see `crate::terms::sae::hybrid_split`) in which linear is the curved
2932///     family's nested `Θ = 0` sub-model (#1202): the curved candidate cannot be
2933///     charged its extra parameters to fit the residual no better than its own
2934///     straight projection, and a tightly-spread, mildly-curved residual can
2935///     still prefer the cheaper line.
2936///  3. **Tie-break.** Exact NLE ties go to the cheaper (fewer-parameter)
2937///     candidate — i.e. linear — preserving the strict-generalization guarantee
2938///     that the hybrid never pays for curvature it does not need.
2939///
2940/// `candidates` must contain at most one linear and at most one curved candidate
2941/// for the slot; returns `None` only if `candidates` is empty.
2942pub fn select_hybrid_atom(candidates: &[HybridAtomCandidate]) -> Option<HybridAtomChoice> {
2943    if candidates.is_empty() {
2944        return None;
2945    }
2946    let linear = candidates.iter().find(|c| c.param.is_linear());
2947    let curved = candidates.iter().find(|c| !c.param.is_linear());
2948    let curved_turning = curved.and_then(|c| c.fitted_turning);
2949    let curved_evidence_margin = match (linear, curved) {
2950        (Some(l), Some(c)) => l.negative_log_evidence - c.negative_log_evidence,
2951        _ => 0.0,
2952    };
2953
2954    // (1) Exact-zero dominance floor: a straight curved fit yields to the linear
2955    // special case by construction.
2956    if let (Some(l), Some(turning)) = (linear, curved_turning)
2957        && turning <= HYBRID_LINEAR_TURNING_FLOOR
2958    {
2959        return Some(HybridAtomChoice {
2960            param: l.param,
2961            negative_log_evidence: l.negative_log_evidence,
2962            num_parameters: l.num_parameters,
2963            curved_turning,
2964            curved_evidence_margin,
2965        });
2966    }
2967
2968    // (2)+(3) Evidence argmin with the cheaper candidate winning exact ties.
2969    let mut best = candidates[0];
2970    for cand in &candidates[1..] {
2971        let better_evidence = cand.negative_log_evidence < best.negative_log_evidence;
2972        let tied = cand.negative_log_evidence == best.negative_log_evidence;
2973        let cheaper_on_tie = tied && cand.num_parameters < best.num_parameters;
2974        if better_evidence || cheaper_on_tie {
2975            best = *cand;
2976        }
2977    }
2978    Some(HybridAtomChoice {
2979        param: best.param,
2980        negative_log_evidence: best.negative_log_evidence,
2981        num_parameters: best.num_parameters,
2982        curved_turning,
2983        curved_evidence_margin,
2984    })
2985}
2986
2987/// The evidence-selected split for a whole hybrid dictionary: the per-atom
2988/// curved-vs-linear choices and the dictionary-level aggregates the EV-vs-Θ
2989/// frontier reports against.
2990#[derive(Debug, Clone)]
2991pub struct HybridSplitSelection {
2992    /// One adjudicated choice per atom slot, in slot order.
2993    pub atoms: Vec<HybridAtomChoice>,
2994    /// `Σ NLE` across the selected per-atom parameterizations — the dictionary's
2995    /// summed rank-aware Laplace negative-log-evidence (lower wins). Because each
2996    /// slot picks the argmin over {curved contribution, best straight line to the
2997    /// response residual}, this is ≤ the sum of the per-slot LINEAR-candidate
2998    /// NLEs. The linear baseline is the best straight line fit to each atom's
2999    /// leave-this-atom-out RESPONSE residual (#1202), the curved family's nested
3000    /// `Θ = 0` member on common data — so this is a genuine data-level
3001    /// match-or-beat dominance, not a post-hoc curve-simplification one.
3002    pub total_negative_log_evidence: f64,
3003    /// `Σ P` across the selected parameterizations — the dictionary's total
3004    /// free-parameter price (the matched-active-budget accounting).
3005    pub total_parameters: usize,
3006    /// Count of slots that selected the curved parameterization.
3007    pub curved_atom_count: usize,
3008}
3009
3010impl HybridSplitSelection {
3011    /// Count of slots that selected the linear special case (the linear tail).
3012    pub fn linear_atom_count(&self) -> usize {
3013        self.atoms.len() - self.curved_atom_count
3014    }
3015
3016    /// `true` iff every slot selected linear — the pure-linear limit, reached
3017    /// when every feature is straight (all `Θ → 0`).
3018    pub fn is_pure_linear(&self) -> bool {
3019        self.curved_atom_count == 0 && !self.atoms.is_empty()
3020    }
3021
3022    /// `true` iff every slot selected curved — the pure-curved limit, reached
3023    /// when every feature turns enough to pay for curvature.
3024    pub fn is_pure_curved(&self) -> bool {
3025        self.curved_atom_count == self.atoms.len() && !self.atoms.is_empty()
3026    }
3027}
3028
3029/// Adjudicate the curved-vs-linear split across a whole hybrid dictionary by the
3030/// common evidence criterion. `slots[i]` holds the curved/linear candidates for
3031/// atom slot `i` (each scored on the same rows, on the common Laplace scale).
3032///
3033/// The result reduces EXACTLY to pure-linear when every slot's curved candidate
3034/// has `Θ → 0` (the turning floor fires everywhere) and to pure-curved when
3035/// every slot's curved candidate wins the evidence comparison. (Common-data
3036/// criterion, #1202 — both candidates fit the atom's response residual, with
3037/// linear nested as the curved family's `Θ = 0` sub-model; see the module header
3038/// above and `crate::terms::sae::hybrid_split`.)
3039///
3040/// Returns an error only if some slot has no candidates to adjudicate (an empty
3041/// dictionary slot is a caller bug, not a silent skip).
3042pub fn select_hybrid_split(
3043    slots: &[Vec<HybridAtomCandidate>],
3044) -> Result<HybridSplitSelection, String> {
3045    let mut atoms = Vec::with_capacity(slots.len());
3046    let mut total_nle = 0.0_f64;
3047    let mut total_parameters = 0usize;
3048    let mut curved_atom_count = 0usize;
3049    for (i, slot) in slots.iter().enumerate() {
3050        let choice = select_hybrid_atom(slot)
3051            .ok_or_else(|| format!("hybrid split slot {i} has no candidate parameterizations"))?;
3052        if !choice.negative_log_evidence.is_finite() {
3053            return Err(format!(
3054                "hybrid split slot {i} selected a non-finite evidence ({})",
3055                choice.negative_log_evidence
3056            ));
3057        }
3058        if !choice.param.is_linear() {
3059            curved_atom_count += 1;
3060        }
3061        total_nle += choice.negative_log_evidence;
3062        total_parameters += choice.num_parameters;
3063        atoms.push(choice);
3064    }
3065    Ok(HybridSplitSelection {
3066        atoms,
3067        total_negative_log_evidence: total_nle,
3068        total_parameters,
3069        curved_atom_count,
3070    })
3071}
3072
3073// ---------------------------------------------------------------------------
3074// Tests
3075//
3076// These are type-level / structural tests: per the task contract we do
3077// not compile or run them in this session. They document the expected
3078// shapes and degenerate-case behavior so a future maintainer running
3079// `cargo test` sees the contract written down.
3080// ---------------------------------------------------------------------------
3081
3082#[cfg(test)]
3083mod tests {
3084    use super::*;
3085    use crate::arrow_schur::ArrowFactorSlab;
3086
3087    // Dense `H⁻¹` apply via explicit inverse (test-only reference solver).
3088    fn dense_inverse(h: &Array2<f64>) -> Array2<f64> {
3089        let p = h.nrows();
3090        let mut aug = Array2::<f64>::zeros((p, 2 * p));
3091        for i in 0..p {
3092            for j in 0..p {
3093                aug[[i, j]] = h[[i, j]];
3094            }
3095            aug[[i, p + i]] = 1.0;
3096        }
3097        for col in 0..p {
3098            let mut pivot = col;
3099            for row in (col + 1)..p {
3100                if aug[[row, col]].abs() > aug[[pivot, col]].abs() {
3101                    pivot = row;
3102                }
3103            }
3104            if pivot != col {
3105                for j in 0..(2 * p) {
3106                    aug.swap([col, j], [pivot, j]);
3107                }
3108            }
3109            let d = aug[[col, col]];
3110            for j in 0..(2 * p) {
3111                aug[[col, j]] /= d;
3112            }
3113            for row in 0..p {
3114                if row == col {
3115                    continue;
3116                }
3117                let f = aug[[row, col]];
3118                if f != 0.0 {
3119                    for j in 0..(2 * p) {
3120                        aug[[row, j]] -= f * aug[[col, j]];
3121                    }
3122                }
3123            }
3124        }
3125        let mut inv = Array2::<f64>::zeros((p, p));
3126        for i in 0..p {
3127            for j in 0..p {
3128                inv[[i, j]] = aug[[i, p + j]];
3129            }
3130        }
3131        inv
3132    }
3133
3134    #[test]
3135    fn coupling_components_block_diagonal_is_all_singletons_by_block() {
3136        // Two decoupled 2x2 blocks: {0,1} and {2,3}.
3137        let mut h = Array2::<f64>::eye(4);
3138        h[[0, 1]] = 0.3;
3139        h[[1, 0]] = 0.3;
3140        h[[2, 3]] = 0.7;
3141        h[[3, 2]] = 0.7;
3142        let labels = coupling_components(h.view());
3143        assert_eq!(labels[0], labels[1]);
3144        assert_eq!(labels[2], labels[3]);
3145        assert_ne!(labels[0], labels[2]);
3146        // Exactly two components.
3147        let mut uniq = labels.clone();
3148        uniq.sort_unstable();
3149        uniq.dedup();
3150        assert_eq!(uniq.len(), 2);
3151    }
3152
3153    #[test]
3154    fn coupling_components_fully_coupled_is_one_component() {
3155        let mut h = Array2::<f64>::eye(3);
3156        for i in 0..3 {
3157            for j in 0..3 {
3158                if i != j {
3159                    h[[i, j]] = 0.1;
3160                }
3161            }
3162        }
3163        let labels = coupling_components(h.view());
3164        assert!(labels.iter().all(|&l| l == labels[0]));
3165    }
3166
3167    #[test]
3168    fn coupling_components_transitive_chain_merges() {
3169        // 0-1 and 1-2 coupled (but no direct 0-2 edge) must form one component.
3170        let mut h = Array2::<f64>::eye(3);
3171        h[[0, 1]] = 0.5;
3172        h[[1, 0]] = 0.5;
3173        h[[1, 2]] = 0.5;
3174        h[[2, 1]] = 0.5;
3175        let labels = coupling_components(h.view());
3176        assert_eq!(labels[0], labels[1]);
3177        assert_eq!(labels[1], labels[2]);
3178    }
3179
3180    #[test]
3181    fn compare_reml_fits_delta_and_bayes_factor_never_contradict_winner_gh1465() {
3182        // Regression for #1465: the ranking `delta` / `bayes_factor` must be
3183        // measured on the SAME scale that orders the table (the Occam-penalised
3184        // conditional AIC `ranking_score`), so every row's delta is >= 0 and its
3185        // Bayes factor >= 1 — the table must never claim a non-winner beats the
3186        // declared winner. The scenario is exactly the case the comparison
3187        // exists to handle: AIC and raw REML DISAGREE. `m1` is the AIC winner
3188        // but does NOT carry the minimum raw REML (`m2` does) — the noise
3189        // extra-term case from the issue.
3190        //
3191        // `ranking_score` = -2*log_lik + 2*edf; with log_lik = 0 it is `2*edf`,
3192        // so the AIC order is m1 < m2 < m3 while the raw-REML order has m2 lowest.
3193        let cand = |name: &str, score: f64, edf: f64| RemlCandidate {
3194            index: 0,
3195            name: name.to_string(),
3196            score,
3197            edf: Some(edf),
3198            log_lik: Some(0.0),
3199            family: Some("gaussian".to_string()),
3200            n_obs: Some(100),
3201        };
3202        // raw REML : m2 (41.605) < m1 (53.748) < m3 (120.011)
3203        // AIC=2*edf: m1 (100)    < m2 (102)    < m3 (130)
3204        let candidates = vec![
3205            cand("m1", 53.748, 50.0),
3206            cand("m2", 41.605, 51.0),
3207            cand("m3", 120.011, 65.0),
3208        ];
3209        let cmp = compare_reml_fits(candidates).expect("comparison");
3210
3211        assert_eq!(cmp.winner, "m1", "AIC winner");
3212        // No ranking row may contradict the declared winner.
3213        for row in &cmp.ranking {
3214            assert!(
3215                row.delta >= 0.0,
3216                "ranking delta for {} must be >= 0, got {}",
3217                row.name,
3218                row.delta
3219            );
3220            assert!(
3221                row.bayes_factor >= 1.0 - 1e-12,
3222                "ranking bayes_factor for {} must be >= 1, got {}",
3223                row.name,
3224                row.bayes_factor
3225            );
3226        }
3227        let winner_row = cmp.ranking.iter().find(|r| r.name == "m1").unwrap();
3228        assert!(winner_row.delta.abs() < 1e-12, "winner delta == 0");
3229        assert!(
3230            (winner_row.bayes_factor - 1.0).abs() < 1e-9,
3231            "winner bayes_factor == 1"
3232        );
3233
3234        // The raw-REML score table is referenced to the genuine minimum raw REML
3235        // (m2), so its best-over-model Bayes factors are also coherent (>= 1).
3236        for row in &cmp.score_table {
3237            assert!(
3238                row.delta_reml >= 0.0,
3239                "score-table delta_reml for {} must be >= 0, got {}",
3240                row.name,
3241                row.delta_reml
3242            );
3243            assert!(
3244                row.bayes_factor_best_over_model >= 1.0 - 1e-12,
3245                "score-table bayes_factor for {} must be >= 1, got {}",
3246                row.name,
3247                row.bayes_factor_best_over_model
3248            );
3249        }
3250        // m2 carries the minimum raw REML, so its raw delta is exactly 0.
3251        let m2 = cmp.score_table.iter().find(|r| r.name == "m2").unwrap();
3252        assert!(
3253            m2.delta_reml.abs() < 1e-12,
3254            "the minimum-raw-REML row has delta_reml 0"
3255        );
3256    }
3257
3258    #[test]
3259    fn cone_of_influence_empty_support_is_empty() {
3260        let labels = vec![0usize, 0, 1, 1];
3261        assert!(cone_of_influence(&labels, &[]).is_empty());
3262    }
3263
3264    #[test]
3265    fn cone_of_influence_returns_full_component() {
3266        let labels = vec![0usize, 0, 1, 1];
3267        // Support in component 0 -> cone is {0,1}.
3268        assert_eq!(cone_of_influence(&labels, &[0]), vec![0, 1]);
3269        // Support spanning both -> cone is everything.
3270        assert_eq!(cone_of_influence(&labels, &[1, 2]), vec![0, 1, 2, 3]);
3271    }
3272
3273    #[test]
3274    fn coned_matches_full_solve_on_fully_coupled_hessian() {
3275        // Fully coupled SPD H: cone is the whole space, result must equal the
3276        // unconfined sensitivity-operator mode response bit-for-bit.
3277        let h = Array2::from_shape_vec((3, 3), vec![4.0, 1.0, 0.5, 1.0, 3.0, 0.8, 0.5, 0.8, 2.5])
3278            .unwrap();
3279        let inv = dense_inverse(&h);
3280        // Two ρ-columns, each supported on a single coefficient.
3281        let mut dg = Array2::<f64>::zeros((3, 2));
3282        dg[[0, 0]] = 1.3;
3283        dg[[2, 1]] = -0.7;
3284        let supports = vec![0..1usize, 2..3usize];
3285
3286        let eye: Array2<f64> = Array2::eye(3);
3287        let op = crate::sensitivity::FitSensitivity::from_projected(&eye, &inv);
3288        let full = op.mode_response(dg.view()).unwrap();
3289        let coned = op
3290            .mode_response_coned(h.view(), dg.view(), &supports)
3291            .unwrap();
3292        for i in 0..3 {
3293            for a in 0..2 {
3294                assert!(
3295                    (full[[i, a]] - coned[[i, a]]).abs() < 1e-12,
3296                    "fully-coupled mismatch at ({i},{a}): {} vs {}",
3297                    full[[i, a]],
3298                    coned[[i, a]]
3299                );
3300            }
3301        }
3302    }
3303
3304    #[test]
3305    fn coned_confines_to_component_on_decoupled_hessian() {
3306        // Block-decoupled H: blocks {0,1} and {2,3}. A column supported only in
3307        // block {0,1} must produce sensitivity zero in block {2,3}, and match
3308        // the exact solution within its own block.
3309        let mut h = Array2::<f64>::zeros((4, 4));
3310        // Block A.
3311        h[[0, 0]] = 4.0;
3312        h[[1, 1]] = 3.0;
3313        h[[0, 1]] = 1.0;
3314        h[[1, 0]] = 1.0;
3315        // Block B.
3316        h[[2, 2]] = 2.0;
3317        h[[3, 3]] = 5.0;
3318        h[[2, 3]] = 0.6;
3319        h[[3, 2]] = 0.6;
3320        let inv = dense_inverse(&h);
3321
3322        let mut dg = Array2::<f64>::zeros((4, 1));
3323        dg[[0, 0]] = 0.9;
3324        dg[[1, 0]] = -0.4;
3325        let support_range = 0..2usize;
3326        let supports = std::slice::from_ref(&support_range);
3327
3328        let eye: Array2<f64> = Array2::eye(4);
3329        let coned = crate::sensitivity::FitSensitivity::from_projected(&eye, &inv)
3330            .mode_response_coned(h.view(), dg.view(), supports)
3331            .unwrap();
3332        // Exact reference: -H⁻¹ q. Off-block entries are exactly zero already
3333        // (decoupled inverse), and the cone must preserve the in-block ones.
3334        let q = dg.column(0).to_owned();
3335        let exact = inv.dot(&q).mapv(|v| -v);
3336        for i in 0..4 {
3337            assert!(
3338                (coned[[i, 0]] - exact[[i]]).abs() < 1e-12,
3339                "decoupled mismatch at {i}: {} vs {}",
3340                coned[[i, 0]],
3341                exact[[i]]
3342            );
3343        }
3344        // Block B is outside the cone -> exactly zero.
3345        assert_eq!(coned[[2, 0]], 0.0);
3346        assert_eq!(coned[[3, 0]], 0.0);
3347    }
3348
3349    #[test]
3350    fn coned_skips_inactive_column_with_empty_support() {
3351        let h = Array2::<f64>::eye(2);
3352        let dg = Array2::<f64>::zeros((2, 1));
3353        // Inactive ρ: empty support, must be skipped without solving.
3354        let empty_support = 0..0usize;
3355        let supports = std::slice::from_ref(&empty_support);
3356        // A NaN inverse: an empty-support column must be skipped WITHOUT
3357        // solving, so the operator's finite-check never sees the NaN and the
3358        // result is `Some(zeros)`. Were the inactive column ever solved, the
3359        // NaN would propagate and `mode_response_coned` would return `None`.
3360        let eye: Array2<f64> = Array2::eye(2);
3361        let nan_inv = Array2::<f64>::from_elem((2, 2), f64::NAN);
3362        let coned = crate::sensitivity::FitSensitivity::from_projected(&eye, &nan_inv)
3363            .mode_response_coned(h.view(), dg.view(), supports)
3364            .unwrap();
3365        assert_eq!(coned[[0, 0]], 0.0);
3366        assert_eq!(coned[[1, 0]], 0.0);
3367    }
3368
3369    fn make_minimal_cache() -> ArrowFactorCache {
3370        // d = 1, k = 1, n = 1, H_uu_1 = [[2.0]] => L = [[sqrt(2)]],
3371        // H_uβ_1 = [[0.5]], A = 2 - 0.5 * 0.5 / 2 = 1.875.
3372        let l_huu = Array2::from_shape_vec((1, 1), vec![std::f64::consts::SQRT_2]).unwrap();
3373        let l_schur = Array2::from_shape_vec((1, 1), vec![(1.875_f64).sqrt()]).unwrap();
3374        let htbeta = Array2::from_shape_vec((1, 1), vec![0.5]).unwrap();
3375        ArrowFactorCache {
3376            htt_factors: ArrowFactorSlab::from_blocks(vec![l_huu]),
3377            htt_factors_undamped: crate::arrow_schur::ArrowUndampedFactors::SameAsDamped,
3378            schur_factor: Some(l_schur),
3379            joint_hessian_log_det: None,
3380            solver_mode: crate::arrow_schur::ArrowSolverMode::Direct,
3381            ridge_t: 0.0,
3382            ridge_beta: 0.0,
3383            htbeta: crate::arrow_schur::ArrowHtbetaCache::Dense {
3384                blocks: std::sync::Arc::from(vec![htbeta]),
3385                estimated_bytes: std::mem::size_of::<f64>(),
3386            },
3387            d: 1,
3388            row_dims: std::sync::Arc::from(vec![1usize]),
3389            row_offsets: std::sync::Arc::from(vec![0usize, 1usize]),
3390            k: 1,
3391            manifold_mode_fingerprint: 0,
3392            row_hessian_fingerprint: 0,
3393            pcg_diagnostics: crate::arrow_schur::PcgDiagnostics::default(),
3394            gauge_deflated_directions: 0,
3395            deflated_row_directions: std::sync::Arc::from(Vec::new()),
3396            deflation_row_spectra: std::sync::Arc::from(Vec::new()),
3397            cross_row_woodbury: None,
3398        }
3399    }
3400
3401    #[test]
3402    fn laplace_evidence_returns_finite_for_minimal_cache() {
3403        let cache = make_minimal_cache();
3404        // log|H| = log(2) + log(1.875). With dim(H)=2 and rank(S)=1,
3405        // V includes the rank-aware TK nullspace normalizer.
3406        let v = laplace_evidence(
3407            EvidenceLogDetSource::FactoredArrow {
3408                cache: &cache,
3409                fallback_hvp: None,
3410            },
3411            0.0,
3412            0.0,
3413            2.0,
3414            1.0,
3415        );
3416        assert!(v.is_finite());
3417        let expected =
3418            0.5 * (2.0_f64.ln() + 1.875_f64.ln()) - 0.5 * (2.0 * std::f64::consts::PI).ln();
3419        assert!((v - expected).abs() < 1e-12);
3420    }
3421
3422    /// #1132 bug 2: a β-profiled atom (no shared `β` block, `k == 0`) reaches
3423    /// `arrow_log_det_from_cache` in the dense Direct path with
3424    /// `schur_factor = None` — there is no reduced Schur complement to form. The
3425    /// joint Hessian is then block-diagonal in the latent rows, so its log-det
3426    /// is exactly the per-row sum with NO Schur term. Before the fix this
3427    /// returned `None` (the `schur_factor.as_ref()?` bail), starving the REML
3428    /// Laplace normaliser and erroring "arrow_log_det_from_cache returned None
3429    /// at ridge=0 Direct mode". Now it returns `Some(Σ_i log|H_tt^(i)|)`.
3430    fn k0_direct_cache_no_schur(latent_diag: f64) -> ArrowFactorCache {
3431        let l_huu = Array2::from_shape_vec((1, 1), vec![latent_diag.sqrt()]).unwrap();
3432        ArrowFactorCache {
3433            htt_factors: ArrowFactorSlab::from_blocks(vec![l_huu]),
3434            htt_factors_undamped: crate::arrow_schur::ArrowUndampedFactors::SameAsDamped,
3435            schur_factor: None,
3436            joint_hessian_log_det: None,
3437            solver_mode: crate::arrow_schur::ArrowSolverMode::Direct,
3438            ridge_t: 0.0,
3439            ridge_beta: 0.0,
3440            htbeta: crate::arrow_schur::ArrowHtbetaCache::Disabled { estimated_bytes: 0 },
3441            d: 1,
3442            row_dims: std::sync::Arc::from(vec![1usize]),
3443            row_offsets: std::sync::Arc::from(vec![0usize, 1usize]),
3444            k: 0,
3445            manifold_mode_fingerprint: 0,
3446            row_hessian_fingerprint: 0,
3447            pcg_diagnostics: crate::arrow_schur::PcgDiagnostics::default(),
3448            gauge_deflated_directions: 0,
3449            deflated_row_directions: std::sync::Arc::from(Vec::new()),
3450            deflation_row_spectra: std::sync::Arc::from(Vec::new()),
3451            cross_row_woodbury: None,
3452        }
3453    }
3454
3455    #[test]
3456    fn arrow_log_det_some_for_k0_direct_cache_without_schur() {
3457        let cache = k0_direct_cache_no_schur(3.0);
3458        let log_det = arrow_log_det_from_cache(&cache)
3459            .expect("k==0 Direct cache must yield Some(per-row sum), not None (#1132)");
3460        // Single latent block H_tt = [[3.0]]; no Schur term for k == 0.
3461        assert!(
3462            (log_det - 3.0_f64.ln()).abs() < 1e-12,
3463            "log_det = {log_det}"
3464        );
3465        // The cache's own computation must agree bit-for-bit.
3466        let cached = cache
3467            .compute_undamped_arrow_log_det()
3468            .expect("compute_undamped_arrow_log_det must be Some for k==0");
3469        assert!((cached - 3.0_f64.ln()).abs() < 1e-12, "cached = {cached}");
3470    }
3471
3472    #[test]
3473    fn arrow_log_det_none_for_kpos_cache_without_schur() {
3474        // k > 0 but no dense Schur factor is the genuine InexactPCG case and
3475        // must still reject (the guard must not over-broaden to all `None`).
3476        let mut cache = k0_direct_cache_no_schur(3.0);
3477        cache.k = 1;
3478        cache.solver_mode = crate::arrow_schur::ArrowSolverMode::InexactPCG;
3479        assert!(arrow_log_det_from_cache(&cache).is_none());
3480        assert!(cache.compute_undamped_arrow_log_det().is_none());
3481    }
3482
3483    #[test]
3484    fn laplace_evidence_nan_when_ridge_is_nonzero() {
3485        let mut cache = make_minimal_cache();
3486        cache.ridge_t = 1e-3;
3487        assert!(
3488            laplace_evidence(
3489                EvidenceLogDetSource::FactoredArrow {
3490                    cache: &cache,
3491                    fallback_hvp: None,
3492                },
3493                0.0,
3494                0.0,
3495                2.0,
3496                1.0,
3497            )
3498            .is_nan()
3499        );
3500    }
3501
3502    #[test]
3503    fn laplace_evidence_uses_hvp_fallback_without_schur_factor() {
3504        let mut cache = make_minimal_cache();
3505        cache.schur_factor = None;
3506        let hvp = |x: &[f64]| -> Vec<f64> { vec![2.0 * x[0], 1.875 * x[1]] };
3507        let v = laplace_evidence(
3508            EvidenceLogDetSource::FactoredArrow {
3509                cache: &cache,
3510                fallback_hvp: Some(EvidenceHvpLogDet {
3511                    dim: 2,
3512                    apply: &hvp,
3513                }),
3514            },
3515            0.0,
3516            0.0,
3517            2.0,
3518            1.0,
3519        );
3520        let expected =
3521            0.5 * (2.0_f64.ln() + 1.875_f64.ln()) - 0.5 * (2.0 * std::f64::consts::PI).ln();
3522        assert!((v - expected).abs() < 1e-12);
3523    }
3524
3525    #[test]
3526    fn ift_du_dbeta_has_expected_shape() {
3527        let cache = make_minimal_cache();
3528        let du_db = ift_du_dbeta(&cache);
3529        assert_eq!(du_db.shape(), &[1, 1]);
3530        // ∂u/∂β = -H_uu⁻¹ H_uβ = -0.5 / 2 = -0.25.
3531        assert!((du_db[[0, 0]] - (-0.25)).abs() < 1e-12);
3532    }
3533
3534    #[test]
3535    fn ift_dbeta_drho_returns_some_for_direct_cache() {
3536        let cache = make_minimal_cache();
3537        let q = Array2::from_shape_vec((1, 1), vec![1.0]).unwrap();
3538        let out = ift_dbeta_drho(&cache, q.view()).unwrap();
3539        assert_eq!(out.shape(), &[1, 1]);
3540        // ∂β/∂ρ = -A⁻¹ · 1 = -1/1.875.
3541        assert!((out[[0, 0]] + 1.0 / 1.875).abs() < 1e-12);
3542    }
3543
3544    #[test]
3545    fn topology_select_picks_lowest_negative_log_evidence() {
3546        let candidates = vec![
3547            TopologyCandidate {
3548                kind: TopologyKind::Flat,
3549                negative_log_evidence: 10.0,
3550                effective_dim: 4.0,
3551                n_obs: 100,
3552                converged: true,
3553                exclusion_reason: None,
3554            },
3555            TopologyCandidate {
3556                kind: TopologyKind::Sphere,
3557                negative_log_evidence: 8.0,
3558                effective_dim: 5.0,
3559                n_obs: 100,
3560                converged: true,
3561                exclusion_reason: None,
3562            },
3563            TopologyCandidate {
3564                kind: TopologyKind::Torus,
3565                negative_log_evidence: f64::NAN,
3566                effective_dim: 6.0,
3567                n_obs: 100,
3568                converged: false,
3569                exclusion_reason: Some("torus periods missing".to_string()),
3570            },
3571        ];
3572        let sel = select_topology(&candidates, TopologySelectOptions::default());
3573        assert_eq!(sel.winner, TopologyKind::Sphere);
3574        assert!(!sel.tie);
3575    }
3576
3577    #[test]
3578    fn topology_select_tie_breaks_to_simpler() {
3579        let candidates = vec![
3580            TopologyCandidate {
3581                kind: TopologyKind::Sphere,
3582                negative_log_evidence: 5.0,
3583                effective_dim: 5.0,
3584                n_obs: 100,
3585                converged: true,
3586                exclusion_reason: None,
3587            },
3588            TopologyCandidate {
3589                kind: TopologyKind::Flat,
3590                negative_log_evidence: 5.0 + 1e-6,
3591                effective_dim: 4.0,
3592                n_obs: 100,
3593                converged: true,
3594                exclusion_reason: None,
3595            },
3596        ];
3597        let sel = select_topology(&candidates, TopologySelectOptions::default());
3598        assert_eq!(sel.winner, TopologyKind::Flat);
3599        assert!(sel.tie);
3600    }
3601
3602    fn gaussian_logpdf(y: f64, mean: f64, sd: f64) -> f64 {
3603        let z = (y - mean) / sd;
3604        -0.5 * (2.0 * std::f64::consts::PI).ln() - sd.ln() - 0.5 * z * z
3605    }
3606
3607    #[test]
3608    fn stacking_single_candidate_gets_full_weight() {
3609        let log_density = Array2::from_shape_vec((3, 1), vec![-1.0, -2.0, -0.5]).unwrap();
3610        let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
3611        assert!((out.weights[0] - 1.0).abs() < 1e-12);
3612        assert_eq!(out.weights.len(), 1);
3613    }
3614
3615    #[test]
3616    fn stacking_dominant_candidate_attracts_nearly_all_weight() {
3617        let mut log_density = Array2::<f64>::zeros((50, 2));
3618        for i in 0..50 {
3619            log_density[[i, 0]] = -0.1;
3620            log_density[[i, 1]] = -5.0;
3621        }
3622        let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
3623        assert!(out.weights[0] > 0.99, "w0 = {}", out.weights[0]);
3624        assert!(out.weights[1] < 0.01, "w1 = {}", out.weights[1]);
3625    }
3626
3627    #[test]
3628    fn stacking_complementary_candidates_share_weight() {
3629        // Each candidate is the better predictor on its own half of the data;
3630        // stacking keeps both, unlike winner-take-all.
3631        let n = 40;
3632        let mut log_density = Array2::<f64>::zeros((n, 2));
3633        for i in 0..n {
3634            if i < n / 2 {
3635                log_density[[i, 0]] = gaussian_logpdf(0.0, 0.0, 0.5);
3636                log_density[[i, 1]] = gaussian_logpdf(0.0, 1.5, 0.5);
3637            } else {
3638                log_density[[i, 0]] = gaussian_logpdf(0.0, 1.5, 0.5);
3639                log_density[[i, 1]] = gaussian_logpdf(0.0, 0.0, 0.5);
3640            }
3641        }
3642        let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
3643        assert!(
3644            out.weights[0] > 0.2 && out.weights[0] < 0.8,
3645            "w0 = {}",
3646            out.weights[0]
3647        );
3648        assert!((out.weights.sum() - 1.0).abs() < 1e-9);
3649    }
3650
3651    #[test]
3652    fn stacking_weights_stay_on_the_simplex() {
3653        let log_density = Array2::from_shape_vec(
3654            (3, 3),
3655            vec![-1.0, -2.0, -3.0, -2.5, -1.0, -2.0, -3.0, -2.0, -1.0],
3656        )
3657        .unwrap();
3658        let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
3659        assert!((out.weights.sum() - 1.0).abs() < 1e-9);
3660        assert!(out.weights.iter().all(|&w| w >= -1e-12));
3661    }
3662
3663    #[test]
3664    fn stacking_mean_log_score_is_monotone_under_more_iterations() {
3665        // The EM ascent is monotone in the held-out mean log-score, so allowing
3666        // more iterations never lowers it.
3667        let log_density =
3668            Array2::from_shape_vec((4, 2), vec![-0.2, -3.0, -3.0, -0.2, -0.5, -1.5, -1.5, -0.5])
3669                .unwrap();
3670        let mut prev = f64::NEG_INFINITY;
3671        for max_iter in [1usize, 2, 4, 8, 32] {
3672            let out = solve_stacking_weights(
3673                log_density.view(),
3674                StackingConfig {
3675                    max_iter,
3676                    weight_tol: 0.0,
3677                },
3678            )
3679            .unwrap();
3680            assert!(
3681                out.mean_log_score >= prev - 1e-12,
3682                "log-score decreased at max_iter={max_iter}: {prev} -> {}",
3683                out.mean_log_score
3684            );
3685            prev = out.mean_log_score;
3686        }
3687    }
3688
3689    #[test]
3690    fn stacking_dead_candidate_column_is_rejected_and_zero_weighted() {
3691        let log_density = Array2::from_shape_vec(
3692            (3, 2),
3693            vec![
3694                -1.0,
3695                f64::NEG_INFINITY,
3696                -2.0,
3697                f64::NAN,
3698                -0.5,
3699                f64::NEG_INFINITY,
3700            ],
3701        )
3702        .unwrap();
3703        let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
3704        assert_eq!(out.weights[1], 0.0);
3705        assert!((out.weights[0] - 1.0).abs() < 1e-12);
3706    }
3707
3708    #[test]
3709    fn stacking_rows_with_no_finite_density_are_dropped() {
3710        let log_density = Array2::from_shape_vec(
3711            (3, 2),
3712            vec![-1.0, -2.0, f64::NAN, f64::NEG_INFINITY, -2.0, -1.0],
3713        )
3714        .unwrap();
3715        let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
3716        assert!((out.weights.sum() - 1.0).abs() < 1e-9);
3717        assert!(out.mean_log_score.is_finite());
3718    }
3719
3720    #[test]
3721    fn stacking_all_dead_table_errors() {
3722        let log_density = Array2::from_elem((2, 2), f64::NEG_INFINITY);
3723        assert!(solve_stacking_weights(log_density.view(), StackingConfig::default()).is_err());
3724    }
3725
3726    #[test]
3727    fn stacked_mean_is_weighted_combination() {
3728        let weights = Array1::from_vec(vec![0.25, 0.75]);
3729        let means = vec![
3730            Array1::from_vec(vec![1.0, 2.0, 3.0]),
3731            Array1::from_vec(vec![5.0, 6.0, 7.0]),
3732        ];
3733        let out = stacked_predictive_mean(&weights, &means).unwrap();
3734        assert!((out[0] - (0.25 * 1.0 + 0.75 * 5.0)).abs() < 1e-12);
3735        assert!((out[2] - (0.25 * 3.0 + 0.75 * 7.0)).abs() < 1e-12);
3736    }
3737
3738    #[test]
3739    fn stacked_mean_rejects_shape_mismatch() {
3740        let weights = Array1::from_vec(vec![0.5, 0.5]);
3741        let means = vec![
3742            Array1::from_vec(vec![1.0, 2.0]),
3743            Array1::from_vec(vec![3.0]),
3744        ];
3745        assert!(stacked_predictive_mean(&weights, &means).is_err());
3746    }
3747
3748    // -----------------------------------------------------------------------
3749    // #1026 hybrid curved + linear-tail split-selection
3750    // -----------------------------------------------------------------------
3751
3752    /// Build the two candidate parameterizations for one atom slot the way the
3753    /// fit would: the linear special case (one decoder direction, `Θ = 0`,
3754    /// `P_linear` params) and the curved candidate (`latent_dim` ≥ 1, more
3755    /// params, fitted turning `theta`). The curved candidate's likelihood is the
3756    /// linear likelihood MINUS `curved_loglik_gain` of NLE (curvature it captures
3757    /// the secant cannot), so the nesting invariant `curved_loglik ≥ linear` is
3758    /// honored: a straight feature has zero gain, a turning feature a positive
3759    /// gain that grows with Θ. The rank-aware Laplace normalizer charges the
3760    /// extra `½(P_curved − P_linear)·log(2π)` for the curved parameters, so the
3761    /// evidence comparison is the real `Θ/√ε` crossover.
3762    fn hybrid_slot(
3763        linear_nle: f64,
3764        p_linear: usize,
3765        latent_dim: usize,
3766        p_curved: usize,
3767        theta: f64,
3768        curved_loglik_gain: f64,
3769    ) -> Vec<HybridAtomCandidate> {
3770        let param_price =
3771            0.5 * (p_curved as f64 - p_linear as f64) * (2.0 * std::f64::consts::PI).ln();
3772        let curved_nle = linear_nle - curved_loglik_gain + param_price;
3773        vec![
3774            HybridAtomCandidate::linear(linear_nle, p_linear),
3775            HybridAtomCandidate::curved(latent_dim, curved_nle, p_curved, Some(theta)),
3776        ]
3777    }
3778
3779    #[test]
3780    fn hybrid_dominance_floor_selects_linear_when_turning_is_zero() {
3781        // A perfectly straight curved fit (Θ = 0) gains no likelihood over its
3782        // linear sub-model but pays more parameters → linear must win, by
3783        // construction, even if finite-sample evidence noise nudged the curved
3784        // NLE slightly below linear.
3785        let slot = hybrid_slot(100.0, 2, 1, 5, 0.0, 0.0);
3786        let choice = select_hybrid_atom(&slot).unwrap();
3787        assert!(choice.param.is_linear());
3788        assert_eq!(choice.param, HybridAtomParam::Linear);
3789        // The exact-zero guard fires regardless of the evidence margin sign.
3790        assert!(choice.curved_turning.unwrap() <= HYBRID_LINEAR_TURNING_FLOOR);
3791    }
3792
3793    #[test]
3794    fn hybrid_selects_curved_when_turning_pays_for_itself() {
3795        // A genuinely turning feature (Θ = 2π, a full loop): the curved fit
3796        // captures enough curvature that, even charged the extra-parameter price,
3797        // its NLE drops below the linear secant's → curved wins.
3798        let slot = hybrid_slot(100.0, 2, 1, 5, 2.0 * std::f64::consts::PI, 30.0);
3799        let choice = select_hybrid_atom(&slot).unwrap();
3800        assert_eq!(choice.param, HybridAtomParam::Curved { latent_dim: 1 });
3801        // The curved fit won a strictly positive evidence margin.
3802        assert!(choice.curved_evidence_margin > 0.0);
3803    }
3804
3805    #[test]
3806    fn hybrid_keeps_linear_when_curvature_doesnt_pay_its_price() {
3807        // A barely-curved feature (small Θ): the curved fit recovers only a sliver
3808        // of likelihood, not enough to cover the extra-parameter price → the
3809        // dominance floor keeps the linear tail.
3810        let slot = hybrid_slot(100.0, 2, 1, 5, 0.05, 0.1);
3811        let choice = select_hybrid_atom(&slot).unwrap();
3812        assert!(choice.param.is_linear());
3813        assert!(choice.curved_evidence_margin <= 0.0);
3814    }
3815
3816    #[test]
3817    fn hybrid_tie_breaks_to_the_cheaper_linear_atom() {
3818        // Exact NLE tie (above the turning floor so the evidence path decides):
3819        // the cheaper linear atom wins, preserving strict generalization — the
3820        // hybrid never pays for curvature it does not need.
3821        let theta = 0.5; // above the floor → evidence path, not the exact guard
3822        let nle = 42.0;
3823        let slot = vec![
3824            HybridAtomCandidate::linear(nle, 2),
3825            HybridAtomCandidate::curved(1, nle, 5, Some(theta)),
3826        ];
3827        let choice = select_hybrid_atom(&slot).unwrap();
3828        assert!(choice.param.is_linear());
3829        assert_eq!(choice.num_parameters, 2);
3830    }
3831
3832    #[test]
3833    fn hybrid_split_reduces_to_pure_linear_when_all_features_are_straight() {
3834        // Every slot's curved candidate has Θ → 0 (flat features everywhere): the
3835        // dominance floor fires at every slot → the hybrid recovers the pure-
3836        // linear dictionary exactly. This is the `all Θ → 0` limit (3).
3837        let slots: Vec<Vec<HybridAtomCandidate>> = (0..6)
3838            .map(|i| hybrid_slot(50.0 + i as f64, 2, 1, 5, 0.0, 0.0))
3839            .collect();
3840        let split = select_hybrid_split(&slots).unwrap();
3841        assert!(split.is_pure_linear());
3842        assert_eq!(split.curved_atom_count, 0);
3843        assert_eq!(split.linear_atom_count(), 6);
3844        // Summed NLE equals the pure-linear baseline (every slot chose linear).
3845        let pure_linear: f64 = (0..6).map(|i| 50.0 + i as f64).sum();
3846        assert!((split.total_negative_log_evidence - pure_linear).abs() < 1e-12);
3847    }
3848
3849    #[test]
3850    fn hybrid_split_reduces_to_pure_curved_when_every_feature_curves() {
3851        // Every slot's feature turns enough (Θ = 2π, large likelihood gain) that
3852        // curved beats linear everywhere → the pure-curved limit (3).
3853        let slots: Vec<Vec<HybridAtomCandidate>> = (0..5)
3854            .map(|i| hybrid_slot(80.0 + i as f64, 2, 1, 5, 2.0 * std::f64::consts::PI, 40.0))
3855            .collect();
3856        let split = select_hybrid_split(&slots).unwrap();
3857        assert!(split.is_pure_curved());
3858        assert_eq!(split.curved_atom_count, 5);
3859        assert_eq!(split.linear_atom_count(), 0);
3860    }
3861
3862    #[test]
3863    fn hybrid_split_on_mixed_dictionary_picks_curved_for_circles_linear_for_directions() {
3864        // Mixed synthetic: slots 0..3 are CIRCLE features (high turning Θ = 2π,
3865        // the curved fit captures the loop), slots 3..7 are LINEAR DIRECTIONS
3866        // (straight, Θ = 0). The evidence split must select curved for the
3867        // circles and linear for the directions — and the hybrid's summed
3868        // evidence must be ≤ the summed per-slot LINEAR-candidate NLE (each
3869        // slot's best straight line fit to its response residual). This is a
3870        // data-level match-or-beat dominance (#1202: linear is the curved
3871        // family's nested Θ = 0 sub-model on common data), and holds because each
3872        // slot picks the argmin of its two common-data candidates.
3873        let mut slots: Vec<Vec<HybridAtomCandidate>> = Vec::new();
3874        let mut pure_linear_baseline = 0.0_f64;
3875        // Three circle features: a curved atom replaces ~10-30 linear secants, so
3876        // the curved fit buys a large likelihood gain that dwarfs its param price.
3877        for i in 0..3 {
3878            let linear_nle = 120.0 + 3.0 * i as f64;
3879            pure_linear_baseline += linear_nle;
3880            slots.push(hybrid_slot(
3881                linear_nle,
3882                2,
3883                1,
3884                5,
3885                2.0 * std::f64::consts::PI,
3886                35.0,
3887            ));
3888        }
3889        // Four straight linear directions: zero turning, the linear special case
3890        // is optimal — a curved atom buys nothing and only costs parameters.
3891        for i in 0..4 {
3892            let linear_nle = 90.0 + 2.0 * i as f64;
3893            pure_linear_baseline += linear_nle;
3894            slots.push(hybrid_slot(linear_nle, 2, 1, 5, 0.0, 0.0));
3895        }
3896
3897        let split = select_hybrid_split(&slots).unwrap();
3898
3899        // The first three (circles) chose curved; the last four (directions) chose
3900        // linear.
3901        for (idx, choice) in split.atoms.iter().enumerate() {
3902            if idx < 3 {
3903                assert_eq!(
3904                    choice.param,
3905                    HybridAtomParam::Curved { latent_dim: 1 },
3906                    "circle slot {idx} should select curved"
3907                );
3908            } else {
3909                assert!(
3910                    choice.param.is_linear(),
3911                    "direction slot {idx} should select linear"
3912                );
3913            }
3914        }
3915        assert_eq!(split.curved_atom_count, 3);
3916        assert_eq!(split.linear_atom_count(), 4);
3917
3918        // The hybrid's summed negative-log-evidence is ≤ the summed per-slot
3919        // LINEAR-candidate NLE (each slot's best straight line fit to its response
3920        // residual): the per-slot argmin can only lower the sum. This is a
3921        // data-level match-or-beat dominance (#1202): linear is the curved
3922        // family's nested Θ = 0 sub-model on common data.
3923        assert!(
3924            split.total_negative_log_evidence <= pure_linear_baseline + 1e-9,
3925            "hybrid NLE {} must be <= summed linear-candidate NLE {}",
3926            split.total_negative_log_evidence,
3927            pure_linear_baseline
3928        );
3929        // And strictly better, because the curved circle slots paid off.
3930        assert!(split.total_negative_log_evidence < pure_linear_baseline);
3931    }
3932
3933    #[test]
3934    fn hybrid_split_rejects_empty_slot() {
3935        let slots = vec![hybrid_slot(10.0, 2, 1, 5, 0.0, 0.0), Vec::new()];
3936        assert!(select_hybrid_split(&slots).is_err());
3937    }
3938
3939    // ── #1362: compare_models must Occam-penalise a pure-noise smooth ────────
3940    //
3941    // These tests pin the ranking contract directly on `compare_reml_fits` with
3942    // controlled (score, edf, log_lik) inputs taken from the actual #1362
3943    // reproduction (Rust `reml_score` of `y ~ s(x)` vs `y ~ s(x) + s(z)` at
3944    // n=700). They do not need a fitted GAM or a Python wheel.
3945
3946    fn cand(name: &str, score: f64, edf: f64, log_lik: f64) -> RemlCandidate {
3947        RemlCandidate {
3948            index: 0,
3949            name: name.to_string(),
3950            score,
3951            edf: Some(edf),
3952            log_lik: Some(log_lik),
3953            family: None,
3954            n_obs: None,
3955        }
3956    }
3957
3958    #[test]
3959    fn ranking_score_is_conditional_aic_when_loglik_and_edf_present() {
3960        // AIC = -2ℓ + 2·edf.
3961        let c = cand("m", /*score (ignored)*/ 999.0, 6.748, -32.0866);
3962        let expected = -2.0 * -32.0866 + 2.0 * 6.748;
3963        assert!((c.ranking_score() - expected).abs() < 1e-9);
3964    }
3965
3966    #[test]
3967    fn ranking_score_falls_back_to_evidence_without_loglik() {
3968        let c = RemlCandidate {
3969            index: 0,
3970            name: "m".to_string(),
3971            score: 151.28,
3972            edf: Some(6.0),
3973            log_lik: None,
3974            family: None,
3975            n_obs: None,
3976        };
3977        assert_eq!(c.ranking_score(), 151.28);
3978    }
3979
3980    #[test]
3981    fn compare_models_rejects_pure_noise_smooth_despite_lower_evidence() {
3982        // Seed-3000 numbers from the #1362 Rust reproduction:
3983        //   small (y ~ s(x)):      reml=180.526, edf=6.748,  loglik=-32.0866
3984        //   big   (y ~ s(x)+s(z)): reml=177.404, edf=14.250, loglik=-32.1212
3985        // The big (noise-augmented) model has the LOWER (apparently better) raw
3986        // REML evidence, yet it spends ~7.5 extra EDF fitting noise without
3987        // improving the likelihood. The winner must be the SMALL model.
3988        let small = cand("small", 180.526, 6.748, -32.0866);
3989        let big = cand("big", 177.404, 14.250, -32.1212);
3990
3991        // Sanity: raw evidence (the broken headline) prefers big.
3992        assert!(big.score < small.score);
3993
3994        let cmp = compare_reml_fits(vec![small, big]).expect("compare");
3995        assert_eq!(
3996            cmp.winner, "small",
3997            "compare_models must Occam-penalise the pure-noise smooth and pick the smaller model"
3998        );
3999        // The score table still reports the raw evidence headline unchanged, so
4000        // Model.evidence / bayes_factor_vs stay consistent with the table.
4001        let small_row = cmp
4002            .score_table
4003            .iter()
4004            .find(|r| r.name == "small")
4005            .expect("small row");
4006        let big_row = cmp
4007            .score_table
4008            .iter()
4009            .find(|r| r.name == "big")
4010            .expect("big row");
4011        assert!((small_row.reml_score - 180.526).abs() < 1e-9);
4012        assert!((big_row.reml_score - 177.404).abs() < 1e-9);
4013    }
4014
4015    #[test]
4016    fn compare_models_keeps_power_for_a_relevant_smooth() {
4017        // Seed-3000 relevant-z numbers from the same reproduction:
4018        //   small: reml=1025.067, edf≈6.75,  loglik≈-368.99 (aic≈751.5)
4019        //   big:   reml=199.509,  edf≈14.25, loglik≈-33.16  (aic≈94.8)
4020        // A genuinely relevant smooth lowers BOTH the evidence and the AIC, so
4021        // the bigger model must still win — a fix cannot just always pick small.
4022        let small = cand("small", 1025.067, 6.75, -368.985);
4023        let big = cand("big", 199.509, 14.25, -33.165);
4024        let cmp = compare_reml_fits(vec![small, big]).expect("compare");
4025        assert_eq!(
4026            cmp.winner, "big",
4027            "compare_models must retain power: the relevant smooth's model must win"
4028        );
4029    }
4030
4031    #[test]
4032    fn compare_models_rejects_mismatched_observation_counts() {
4033        // Two same-family fits on different-sized data are not comparable by
4034        // AIC / evidence; the comparison must fail loud, mirroring the family
4035        // guard, rather than declare a sample-size-driven winner.
4036        let with_n = |name: &str, n: usize| RemlCandidate {
4037            index: 0,
4038            name: name.to_string(),
4039            score: 100.0,
4040            edf: Some(5.0),
4041            log_lik: Some(-40.0),
4042            family: Some("gaussian".to_string()),
4043            n_obs: Some(n),
4044        };
4045        let err = compare_reml_fits(vec![with_n("big", 500), with_n("small", 100)])
4046            .expect_err("cross-n comparison must be rejected");
4047        assert!(
4048            err.contains("number of observations") && err.contains("500") && err.contains("100"),
4049            "n-guard error should name the incomparable counts, got: {err}"
4050        );
4051
4052        // Same n is comparable.
4053        compare_reml_fits(vec![with_n("a", 250), with_n("b", 250)])
4054            .expect("same-n comparison must succeed");
4055
4056        // A missing count (`None`) is unconstrained: it must not block a
4057        // comparison against a fit that does carry one (legacy / scan payloads).
4058        let without_n = RemlCandidate {
4059            index: 0,
4060            name: "legacy".to_string(),
4061            score: 90.0,
4062            edf: Some(4.0),
4063            log_lik: Some(-35.0),
4064            family: Some("gaussian".to_string()),
4065            n_obs: None,
4066        };
4067        compare_reml_fits(vec![with_n("counted", 500), without_n])
4068            .expect("an unconstrained (None) count must not trip the guard");
4069    }
4070}