Skip to main content

gam_sae/inference/
checkpoint_dynamics.rs

1//! Cross-checkpoint training-dynamics inference for SAE atoms (issue #1102).
2//!
3//! OLMo ships intermediate-training checkpoints. Each checkpoint `c` fits a
4//! dictionary whose atom `a` is a decoder curve `g^{(c)}_a: t ↦ ℝ^ambient`
5//! sampled on a shared latent grid `t`. The question this module answers, per
6//! atom, is *did the atom change across training, and where*, with
7//! debiased point estimates, standard errors, and anytime-valid evidence.
8//!
9//! It is pure assembly of three landed instruments — none is reimplemented:
10//!
11//! * [`crate::inference::riesz`] — the per-step contrast
12//!   `θ = g^{(c+1)}(t₀) − g^{(c)}(t₀)` is the linear
13//!   [`SmoothFunctional::Contrast`] of a stacked two-checkpoint coefficient
14//!   vector; [`debias_with_dense_hessian`] returns the penalty-debiased
15//!   estimate and a plug-in SE via the Riesz representer.
16//! * [`crate::inference::layer_transport`] — the checkpoint axis is reused as
17//!   the "layer" axis: [`fit_transport_map`] aligns the atom's latent chart
18//!   across consecutive checkpoints (topology compatibility, isometry defect,
19//!   winding degree), packaged as a [`LayerTransportReport`].
20//! * [`gam_terms::inference::structure_evidence`] — each consecutive-step contrast
21//!   feeds one anytime-valid e-value (the studentized displacement mapped to a
22//!   two-sided p-value and run through the frozen κ = ½ p→e calibrator) into a
23//!   per-step [`StructureLedger`] claim under the null "the atom did not change
24//!   at this checkpoint step". A genuine e-value (`E_{H0}[E] ≤ 1`), unlike the
25//!   divergent in-sample `exp(½ z²)` likelihood ratio; optional-stopping-safe.
26//!
27//! # Honest accounting of the Riesz inputs
28//!
29//! A *bare* decoder grid carries the fitted curve VALUES but no
30//! observation-level scores and no penalized Hessian — those cannot be
31//! fabricated from grid samples. So the smooth this module debiases is the
32//! one the grid actually defines: a ridge-penalized least-squares fit of the
33//! grid VALUES on the latent-grid identity (interpolation) basis, where each
34//! grid node is one observation with response equal to the decoder value at
35//! that node. This is a genuine fit with a genuine penalized Hessian
36//! `XᵀX + λS = I + λS` and genuine per-node scores `s_i = (β_i − y_i)·eᵢ`,
37//! so every quantity handed to [`debias_with_dense_hessian`] is real, not a
38//! placeholder. The contrast functional is then evaluated against the
39//! identity design row at the latent-grid mode index. The ambient dimension is
40//! handled component-wise and the per-component contrasts are aggregated into a
41//! single scalar `θ` by the L2 norm of the component contrast vector (the size
42//! of the decoder displacement at `t₀`); its SE is propagated by the
43//! delta method through that norm.
44
45use crate::inference::layer_transport::{ChartTopology, LayerTransportReport, fit_layer_transport};
46use crate::inference::riesz::{
47    RieszDebiasReport, RieszInput, SmoothFunctional, debias_with_dense_hessian,
48};
49use gam_terms::inference::structure_evidence::{ClaimKind, StructureLedger, log_e_from_p_calibrator};
50use ndarray::{Array1, Array2, ArrayView1, ArrayView4};
51use statrs::distribution::{ContinuousCDF, Normal};
52
53/// Ridge penalty on the interpolation fit of the grid values. Small relative
54/// to the unit data Hessian so the fit tracks the grid closely; non-zero so
55/// the penalty-debiasing term in the Riesz one-step is exercised on real
56/// (not degenerate) curvature, and so the Hessian `I + λS` is strictly SPD.
57const GRID_FIT_RIDGE: f64 = 1e-3;
58
59/// Inputs for one cross-checkpoint atom-dynamics run.
60///
61/// `decoder_grid` is `[n_checkpoints, n_atoms, n_grid, ambient_dim]`: the
62/// decoder curve of every atom sampled on the shared `latent_grid` at every
63/// checkpoint. `checkpoint_ids[c]` and `atom_names[a]` label the axes.
64pub struct CheckpointDynamicsInput<'a> {
65    pub decoder_grid: ArrayView4<'a, f64>,
66    pub checkpoint_ids: &'a [String],
67    pub atom_names: &'a [String],
68    pub latent_grid: ArrayView1<'a, f64>,
69}
70
71/// The training-dynamics trajectory of one atom across the checkpoint axis.
72///
73/// The PRIMARY, coverage-valid deliverable is [`Self::change_evidence`]: the
74/// anytime-valid e-process answering "did atom k change during training?".
75/// [`Self::conditional_step_contrasts`] is a secondary, descriptive readout (see
76/// its docs for the conditional caveat).
77pub struct AtomTrajectory {
78    pub atom_name: String,
79    /// Debiased `g^{(c+1)}(t_mode) − g^{(c)}(t_mode)` for each consecutive
80    /// checkpoint step, with its plug-in SE.
81    ///
82    /// CONDITIONAL ON THE FITTED COORDINATES (not a coverage-valid CI). The
83    /// debiased SE here conditions away the generated-regressor uncertainty in
84    /// the estimated latent coordinates `t̂` and activations `â` — the exact
85    /// correction the marginal-slope family exists to make (issue #1115). It is
86    /// reported only as a conditional contrast point estimate with a plug-in SE,
87    /// NOT as an interval with frequentist coverage for the population
88    /// displacement. The headline change verdict is carried by the e-process
89    /// [`Self::change_evidence`], which IS anytime-valid; this field is a
90    /// descriptive companion. Read the SE accordingly.
91    pub conditional_step_contrasts: Vec<RieszDebiasReport>,
92    /// Consecutive-checkpoint chart correspondences (checkpoint axis reused as
93    /// the transport "layer" axis).
94    pub transports: Vec<LayerTransportReport>,
95    /// PRIMARY deliverable: anytime-valid evidence that the atom changed at each
96    /// consecutive checkpoint step, one calibrated e-value per step into a
97    /// per-step claim. Valid at any data-dependent stopping time.
98    pub change_evidence: StructureLedger,
99}
100
101/// Run cross-checkpoint debiased dynamics inference for every atom.
102///
103/// For each atom, walks consecutive checkpoints and, at each step `c → c+1`:
104/// 1. fits the transport map between the two checkpoints' latent charts
105///    ([`fit_layer_transport`], checkpoint axis as the layer axis);
106/// 2. evaluates the Riesz-debiased decoder-displacement contrast at the
107///    latent-grid mode ([`SmoothFunctional::Contrast`] + penalty debiasing);
108/// 3. absorbs the studentized contrast as a calibrated anytime-valid e-value
109///    into the step's change claim under the no-change null.
110pub fn checkpoint_atom_dynamics(
111    input: &CheckpointDynamicsInput<'_>,
112) -> Result<Vec<AtomTrajectory>, String> {
113    let shape = input.decoder_grid.shape();
114    let (n_checkpoints, n_atoms, n_grid, ambient_dim) = (shape[0], shape[1], shape[2], shape[3]);
115    if n_checkpoints < 2 {
116        return Err(format!(
117            "checkpoint dynamics needs at least two checkpoints, got {n_checkpoints}"
118        ));
119    }
120    if input.checkpoint_ids.len() != n_checkpoints {
121        return Err(format!(
122            "checkpoint_ids length {} disagrees with decoder grid checkpoint axis {n_checkpoints}",
123            input.checkpoint_ids.len()
124        ));
125    }
126    if input.atom_names.len() != n_atoms {
127        return Err(format!(
128            "atom_names length {} disagrees with decoder grid atom axis {n_atoms}",
129            input.atom_names.len()
130        ));
131    }
132    if input.latent_grid.len() != n_grid {
133        return Err(format!(
134            "latent_grid length {} disagrees with decoder grid latent axis {n_grid}",
135            input.latent_grid.len()
136        ));
137    }
138    if n_grid < 2 || ambient_dim == 0 {
139        return Err(format!(
140            "checkpoint dynamics needs a non-trivial grid ({n_grid}) and ambient dim ({ambient_dim})"
141        ));
142    }
143    if input.decoder_grid.iter().any(|v| !v.is_finite()) {
144        return Err("checkpoint dynamics decoder grid must be finite".to_string());
145    }
146    if input.latent_grid.iter().any(|v| !v.is_finite()) {
147        return Err("checkpoint dynamics latent grid must be finite".to_string());
148    }
149
150    // The mode index: the latent-grid node where the contrast is evaluated.
151    // Use the central node so it sits inside any chart and away from edge
152    // interpolation artifacts.
153    let mode_index = n_grid / 2;
154
155    // Identity interpolation design `X = I_{n_grid}` and its ridge penalty
156    // `S = I`. The penalized Hessian `H = XᵀX + λS = (1 + λ) I` is shared by
157    // every component fit, so it is built once.
158    let penalty_scale = 1.0 + GRID_FIT_RIDGE;
159    let mut hessian = Array2::<f64>::zeros((n_grid, n_grid));
160    for i in 0..n_grid {
161        hessian[[i, i]] = penalty_scale;
162    }
163    // Contrast design rows pick out the mode node: `m(t_mode) = β_{mode}`, so
164    // the value-design row is the mode basis vector. The contrast `a − b`
165    // (later checkpoint minus earlier) shares the same row; the per-checkpoint
166    // distinction is carried by the two fitted coefficient vectors, exactly as
167    // a paired contrast of the same functional across two fits.
168    let mut mode_row = Array1::<f64>::zeros(n_grid);
169    mode_row[mode_index] = 1.0;
170
171    let mut trajectories = Vec::with_capacity(n_atoms);
172    for atom in 0..n_atoms {
173        let atom_name = input.atom_names[atom].clone();
174        let mut step_contrasts = Vec::with_capacity(n_checkpoints - 1);
175        let mut transports = Vec::with_capacity(n_checkpoints - 1);
176        let mut change_evidence = StructureLedger::new();
177
178        for step in 0..n_checkpoints - 1 {
179            let c0 = step;
180            let c1 = step + 1;
181
182            // --- transport map across the checkpoint axis --------------------
183            // Reuse the latent grid itself as both charts' coordinates on an
184            // interval `[min, max]`; the transport fit aligns the two
185            // checkpoints' decoder curves through their shared latent index.
186            // The "from"/"to" coordinates are the decoder values projected to
187            // the first ambient component, the available scalar chart sample.
188            let coords_from = input
189                .decoder_grid
190                .slice(ndarray::s![c0, atom, .., 0])
191                .to_owned();
192            let coords_to = input
193                .decoder_grid
194                .slice(ndarray::s![c1, atom, .., 0])
195                .to_owned();
196            let (lo, hi) = interval_bounds(coords_from.view(), coords_to.view());
197            let topology = ChartTopology::Interval { lo, hi };
198            let transport = fit_layer_transport(
199                c0,
200                c1,
201                coords_from.view(),
202                coords_to.view(),
203                topology,
204                topology,
205            )
206            .map_err(|e| {
207                format!(
208                    "checkpoint transport for atom '{atom_name}' step {} → {} failed: {e}",
209                    input.checkpoint_ids[c0], input.checkpoint_ids[c1]
210                )
211            })?;
212            transports.push(transport);
213
214            // --- Riesz-debiased decoder-displacement contrast at the mode ----
215            let report = contrast_at_mode(&ContrastAtMode {
216                grid: input.decoder_grid,
217                atom,
218                c0,
219                c1,
220                ambient_dim,
221                n_grid,
222                hessian: hessian.view(),
223                mode_row: mode_row.view(),
224            })
225            .map_err(|e| {
226                format!(
227                    "checkpoint contrast for atom '{atom_name}' step {} → {} failed: {e}",
228                    input.checkpoint_ids[c0], input.checkpoint_ids[c1]
229                )
230            })?;
231
232            // --- anytime-valid evidence the atom changed at this step --------
233            // The debiased displacement `θ̂` with SE `se` studentizes to
234            // `z = θ̂ / se` (local Gaussian `θ̂ ~ N(θ, se²)`). Its two-sided
235            // p-value run through the frozen κ = ½ p→e calibrator is a genuine
236            // e-value for the per-step no-change null θ = 0 — `E_{H0}[E] ≤ 1`,
237            // which the naive in-sample `exp(½ z²)` ratio is NOT (it diverges
238            // under H0). One e-value per step into a per-step claim; the
239            // calibrator's contract (one e-value per independent batch) is met
240            // because each step is a distinct checkpoint transition.
241            let claim = change_evidence.register(ClaimKind::Custom {
242                label: format!(
243                    "atom '{atom_name}' changed from checkpoint {} to {}",
244                    input.checkpoint_ids[c0], input.checkpoint_ids[c1]
245                ),
246            });
247            let log_e = no_change_log_e_value(report.theta_onestep, report.se)?;
248            change_evidence.absorb_log(claim, log_e)?;
249
250            step_contrasts.push(report);
251        }
252
253        trajectories.push(AtomTrajectory {
254            atom_name,
255            conditional_step_contrasts: step_contrasts,
256            transports,
257            change_evidence,
258        });
259    }
260
261    Ok(trajectories)
262}
263
264/// Interval bounds spanning both checkpoints' scalar chart samples, padded so
265/// the transport basis domain strictly contains the data.
266fn interval_bounds(a: ArrayView1<'_, f64>, b: ArrayView1<'_, f64>) -> (f64, f64) {
267    let mut lo = f64::INFINITY;
268    let mut hi = f64::NEG_INFINITY;
269    for &v in a.iter().chain(b.iter()) {
270        lo = lo.min(v);
271        hi = hi.max(v);
272    }
273    if !(lo.is_finite() && hi.is_finite()) {
274        return (0.0, 1.0);
275    }
276    if hi <= lo {
277        // Degenerate (constant) chart: open a unit window around the value.
278        return (lo - 0.5, lo + 0.5);
279    }
280    let pad = (hi - lo) * 1e-6;
281    (lo - pad, hi + pad)
282}
283
284/// Debiased `g^{(c1)}(t_mode) − g^{(c0)}(t_mode)` aggregated over the ambient
285/// dimension into the scalar decoder-displacement size, with a delta-method SE.
286///
287/// Each ambient component is an independent identity-basis ridge fit of the
288/// grid values; the [`SmoothFunctional::Contrast`] of the two checkpoints'
289/// fitted coefficient vectors at the mode node is debiased component-wise via
290/// the Riesz one-step. The component contrasts form a vector `Δ ∈ ℝ^ambient`;
291/// the reported scalar `θ = ‖Δ‖₂` is the displacement size and its SE is the
292/// delta-method norm-gradient `‖Δ‖₂` propagation of the per-component SEs,
293/// assuming component independence (separate fits, separate scores).
294struct ContrastAtMode<'a> {
295    grid: ArrayView4<'a, f64>,
296    atom: usize,
297    c0: usize,
298    c1: usize,
299    ambient_dim: usize,
300    n_grid: usize,
301    hessian: ndarray::ArrayView2<'a, f64>,
302    mode_row: ArrayView1<'a, f64>,
303}
304
305fn contrast_at_mode(args: &ContrastAtMode<'_>) -> Result<RieszDebiasReport, String> {
306    let grid = args.grid;
307    let atom = args.atom;
308    let c0 = args.c0;
309    let c1 = args.c1;
310    let ambient_dim = args.ambient_dim;
311    let n_grid = args.n_grid;
312    let hessian = args.hessian;
313    let mode_row = args.mode_row;
314    // Aggregate scalar contrast Δ = θ_c1 − θ_c0 across ambient components, and
315    // the matching aggregate Riesz quantities, so a single RieszDebiasReport
316    // describes the displacement. We assemble the report from one debiasing per
317    // component and combine through the L2 norm.
318    let mut delta = Array1::<f64>::zeros(ambient_dim);
319    let mut delta_one = Array1::<f64>::zeros(ambient_dim);
320    let mut var_components = Array1::<f64>::zeros(ambient_dim);
321    let mut penalty_bias_acc = 0.0_f64;
322    // A representer to carry: reuse the last component's; the scalar norm
323    // estimate's representer is component-wise so we keep the final one as the
324    // canonical witness (its influence vector studentizes the norm contrast).
325    let mut witness: Option<RieszDebiasReport> = None;
326
327    for comp in 0..ambient_dim {
328        // Per-checkpoint identity-basis ridge fit: response y = grid values,
329        // design X = I, penalty S = I. With H = (1+λ)I the fitted coefficient
330        // is β = y / (1 + λ); the per-node score is sᵢ = (μ̂ᵢ − yᵢ)·eᵢ where
331        // μ̂ = Xβ = β, and the penalty gradient is S·β = β.
332        let y0 = grid.slice(ndarray::s![c0, atom, .., comp]).to_owned();
333        let y1 = grid.slice(ndarray::s![c1, atom, .., comp]).to_owned();
334        let report = component_contrast(y0.view(), y1.view(), n_grid, hessian, mode_row)?;
335
336        delta[comp] = report.theta_plugin;
337        delta_one[comp] = report.theta_onestep;
338        var_components[comp] = report.se * report.se;
339        penalty_bias_acc += report.penalty_bias * report.penalty_bias;
340        witness = Some(report);
341    }
342
343    let theta_plugin = delta.dot(&delta).sqrt();
344    let norm_one = delta_one.dot(&delta_one).sqrt();
345    // Delta method for θ = ‖Δ‖₂: ∂θ/∂Δ_k = Δ_k / ‖Δ‖₂, components independent,
346    // so var(θ) = Σ_k (Δ_k/‖Δ‖₂)² var(Δ_k).
347    let se = if norm_one > f64::MIN_POSITIVE {
348        let mut v = 0.0_f64;
349        for comp in 0..ambient_dim {
350            let g = delta_one[comp] / norm_one;
351            v += g * g * var_components[comp];
352        }
353        v.max(0.0).sqrt()
354    } else {
355        // At a null displacement the norm is non-differentiable; fall back to
356        // the RMS of the component SEs (an honest upper-ish bound on the size).
357        (var_components.sum() / ambient_dim as f64).sqrt()
358    };
359
360    let mut report = witness
361        .ok_or_else(|| "checkpoint contrast requires at least one ambient component".to_string())?;
362    report.theta_plugin = theta_plugin;
363    report.theta_onestep = norm_one;
364    report.se = se;
365    report.penalty_bias = penalty_bias_acc.sqrt();
366    Ok(report)
367}
368
369/// One ambient component's debiased contrast `g^{(c1)}(t_mode) −
370/// g^{(c0)}(t_mode)` through the Riesz one-step.
371fn component_contrast(
372    y0: ArrayView1<'_, f64>,
373    y1: ArrayView1<'_, f64>,
374    n_grid: usize,
375    hessian: ndarray::ArrayView2<'_, f64>,
376    mode_row: ArrayView1<'_, f64>,
377) -> Result<RieszDebiasReport, String> {
378    // Stacked paired-contrast trick: the contrast `m_{c1}(t₀) − m_{c0}(t₀)` is
379    // the difference of one linear functional applied to two coefficient
380    // vectors. Riesz operates on a single fit, so we debias on the DIFFERENCE
381    // fit β_Δ = β_{c1} − β_{c0}, whose response is y₁ − y₀ on the same identity
382    // basis — a genuine fit with the same penalized Hessian. The contrast
383    // functional on β_Δ is then the point evaluation at the mode, packaged via
384    // SmoothFunctional::Contrast against a zero row so the gradient is the mode
385    // row exactly (g = mode_row − 0).
386    let beta0 = y0.mapv(|v| v / (1.0 + GRID_FIT_RIDGE));
387    let beta1 = y1.mapv(|v| v / (1.0 + GRID_FIT_RIDGE));
388    let beta_delta = &beta1 - &beta0;
389
390    let zero_row = Array1::<f64>::zeros(n_grid);
391    let functional = SmoothFunctional::Contrast {
392        design_row_a: mode_row,
393        design_row_b: zero_row.view(),
394    };
395    let gradient = functional
396        .gradient()
397        .map_err(|e| format!("contrast functional gradient failed: {e}"))?;
398
399    // Per-node scores of the difference fit: μ̂ = X β_Δ = β_Δ, response y₁−y₀.
400    let response = &y1.to_owned() - &y0;
401    let mut row_scores = Array2::<f64>::zeros((n_grid, n_grid));
402    for i in 0..n_grid {
403        row_scores[[i, i]] = beta_delta[i] - response[i];
404    }
405    // Penalty gradient S·β_Δ = β_Δ (S = I).
406    let penalty_beta = beta_delta.clone();
407
408    let input = RieszInput {
409        beta: beta_delta.view(),
410        functional_gradient: gradient.view(),
411        row_scores: row_scores.view(),
412        penalty_beta: penalty_beta.view(),
413        leverage: None,
414    };
415    debias_with_dense_hessian(&input, hessian).map_err(|e| format!("Riesz debiasing failed: {e}"))
416}
417
418/// Anytime-valid log-e-value for the no-change null `θ = 0` from the debiased,
419/// studentized displacement `z = θ̂ / se` (local Gaussian `θ̂ ~ N(θ, se²)`).
420///
421/// The naive in-sample likelihood ratio `exp(½ z²)` — the alternative density
422/// re-centered on the very estimate `θ̂` it is scored against — is NOT an
423/// e-value: under H0, `z ~ N(0,1)` and `E[exp(½ z²)] = ∫ φ(z) exp(½ z²) dz`
424/// DIVERGES, so it has no `E_{H0}[E] ≤ 1` guarantee. (Universal inference earns
425/// `exp(½ z²)` validity only with a held-out evaluation fold; a single grid of
426/// decoder values affords no such split.)
427///
428/// Instead we map the displacement to its two-sided normal p-value
429/// `p = 2(1 − Φ(|z|))` and route it through the module's frozen p→e calibrator
430/// [`log_e_from_p_calibrator`] (the κ = ½ member `e(p) = ½ p^{−1/2}`, with
431/// `∫₀¹ e(p) dp = 1`, hence `E_{H0}[e(P)] ≤ 1` for any superuniform p). This is
432/// a genuine e-value: no displacement, small e; a real displacement, large e;
433/// and it compounds validly into the change e-process. A degenerate
434/// (non-positive) SE yields a zero log-e-value (no evidence, not certainty).
435fn no_change_log_e_value(theta_hat: f64, se: f64) -> Result<f64, String> {
436    if !(se > 0.0) || !theta_hat.is_finite() {
437        return Ok(0.0);
438    }
439    let z = (theta_hat / se).abs();
440    let normal =
441        Normal::new(0.0, 1.0).map_err(|e| format!("standard normal construction failed: {e}"))?;
442    // Two-sided p-value of the studentized displacement; clamp to (0, 1] so the
443    // calibrator (which rejects p = 0) sees a finite, valid argument even at a
444    // numerically saturated tail.
445    let p: f64 = (2.0 * (1.0 - normal.cdf(z))).clamp(f64::MIN_POSITIVE, 1.0);
446    log_e_from_p_calibrator(p)
447}
448
449#[cfg(test)]
450mod tests {
451    use super::*;
452    use ndarray::Array4;
453
454    /// Build a `[n_ckpt, n_atoms, n_grid, ambient]` grid where atom 0's curve is
455    /// constant across checkpoints (no change) and atom 1's curve at the central
456    /// (mode) node is displaced by a known amount `shift` in component 0 between
457    /// consecutive checkpoints (a steady drift).
458    fn drift_grid(n_ckpt: usize, n_grid: usize, ambient: usize, shift: f64) -> Array4<f64> {
459        let mode = n_grid / 2;
460        let mut grid = Array4::<f64>::zeros((n_ckpt, 2, n_grid, ambient));
461        for c in 0..n_ckpt {
462            for g in 0..n_grid {
463                let t = g as f64 / (n_grid - 1) as f64;
464                for comp in 0..ambient {
465                    // Atom 0: smooth bump, identical at every checkpoint.
466                    grid[[c, 0, g, comp]] = (t * std::f64::consts::PI).sin() * (comp as f64 + 1.0);
467                    // Atom 1: same base curve plus a checkpoint-indexed shift at
468                    // the mode node in component 0 only.
469                    let base = (t * std::f64::consts::PI).sin() * (comp as f64 + 1.0);
470                    grid[[c, 1, g, comp]] = if g == mode && comp == 0 {
471                        base + shift * c as f64
472                    } else {
473                        base
474                    };
475                }
476            }
477        }
478        grid
479    }
480
481    #[test]
482    fn no_change_atom_has_near_zero_contrast_and_no_change_evidence() {
483        let n_ckpt = 5;
484        // The transport fit requires at least MIN_TRANSPORT_OBS (16) paired
485        // grid samples, so the shared latent grid must be at least that long.
486        let n_grid = 17;
487        let ambient = 3;
488        let grid = drift_grid(n_ckpt, n_grid, ambient, 0.5);
489        let latent: Array1<f64> = Array1::linspace(0.0, 1.0, n_grid);
490        let ckpt_ids: Vec<String> = (0..n_ckpt).map(|c| format!("dev{c}")).collect();
491        let atom_names = vec!["constant".to_string(), "drifter".to_string()];
492        let input = CheckpointDynamicsInput {
493            decoder_grid: grid.view(),
494            checkpoint_ids: &ckpt_ids,
495            atom_names: &atom_names,
496            latent_grid: latent.view(),
497        };
498        let traj = checkpoint_atom_dynamics(&input).expect("dynamics");
499        assert_eq!(traj.len(), 2);
500
501        // Atom 0 is identical across checkpoints: every step contrast must be
502        // (numerically) zero displacement and accumulate no change evidence.
503        let constant = &traj[0];
504        assert_eq!(constant.conditional_step_contrasts.len(), n_ckpt - 1);
505        for report in &constant.conditional_step_contrasts {
506            assert!(
507                report.theta_onestep.abs() < 1e-9,
508                "constant atom step displacement should be ~0, got {}",
509                report.theta_onestep
510            );
511        }
512        // No-change null is true here → the e-BH certificate confirms nothing.
513        let cert = constant.change_evidence.certify(0.05);
514        assert!(
515            cert.confirmed().count() == 0,
516            "constant atom must not confirm any change claim"
517        );
518    }
519
520    #[test]
521    fn drifting_atom_recovers_displacement_and_accumulates_change_evidence() {
522        let n_ckpt = 6;
523        let n_grid = 17;
524        let ambient = 3;
525        let shift = 0.7_f64;
526        let grid = drift_grid(n_ckpt, n_grid, ambient, shift);
527        let latent: Array1<f64> = Array1::linspace(0.0, 1.0, n_grid);
528        let ckpt_ids: Vec<String> = (0..n_ckpt).map(|c| format!("dev{c}")).collect();
529        let atom_names = vec!["constant".to_string(), "drifter".to_string()];
530        let input = CheckpointDynamicsInput {
531            decoder_grid: grid.view(),
532            checkpoint_ids: &ckpt_ids,
533            atom_names: &atom_names,
534            latent_grid: latent.view(),
535        };
536        let traj = checkpoint_atom_dynamics(&input).expect("dynamics");
537        let drifter = &traj[1];
538
539        // Each consecutive step displaces component 0 at the mode by exactly
540        // `shift`; the reported displacement size is `‖Δ‖₂`. On the light
541        // interpolation ridge (λ = GRID_FIT_RIDGE ≈ 1e-3) the plug-in contrast
542        // `shift/(1+λ)` tracks the true displacement to sub-percent, and every
543        // reported quantity is finite. (The component displacement lives in a
544        // single ambient channel, so the L2 size IS that channel's contrast.)
545        for report in &drifter.conditional_step_contrasts {
546            assert!(
547                (report.theta_plugin - shift).abs() < 1e-2 * shift,
548                "drift step plug-in displacement should track {shift}, got {}",
549                report.theta_plugin
550            );
551            assert!(
552                report.theta_onestep.is_finite() && report.se.is_finite(),
553                "debiased displacement and SE must be finite"
554            );
555            // The displacement is unambiguously positive (a real change).
556            assert!(
557                report.theta_plugin > 0.5 * shift,
558                "drift displacement should be well above zero, got {}",
559                report.theta_plugin
560            );
561        }
562
563        // The drift is real → every step's no-change e-value is strictly
564        // positive (studentized displacement away from zero), so the change
565        // certificate carries strictly positive log-evidence on its claims,
566        // unlike the constant atom whose claims carry exactly zero.
567        let cert = drifter.change_evidence.certify(0.05);
568        let total_log_e: f64 = cert.entries.iter().map(|e| e.log_e).sum();
569        assert!(
570            total_log_e > 0.0,
571            "steady real drift must accumulate positive change evidence, entries: {:?}",
572            cert.entries
573                .iter()
574                .map(|e| (e.log_e, e.confirmed))
575                .collect::<Vec<_>>()
576        );
577    }
578
579    /// A drifting atom must out-evidence a constant atom: the change e-process
580    /// is a genuine discriminator, not a constant.
581    #[test]
582    fn drift_outweighs_constant_in_change_evidence() {
583        let n_ckpt = 6;
584        let n_grid = 17;
585        let ambient = 3;
586        let grid = drift_grid(n_ckpt, n_grid, ambient, 0.7);
587        let latent: Array1<f64> = Array1::linspace(0.0, 1.0, n_grid);
588        let ckpt_ids: Vec<String> = (0..n_ckpt).map(|c| format!("dev{c}")).collect();
589        let atom_names = vec!["constant".to_string(), "drifter".to_string()];
590        let input = CheckpointDynamicsInput {
591            decoder_grid: grid.view(),
592            checkpoint_ids: &ckpt_ids,
593            atom_names: &atom_names,
594            latent_grid: latent.view(),
595        };
596        let traj = checkpoint_atom_dynamics(&input).expect("dynamics");
597        let const_log_e: f64 = traj[0]
598            .change_evidence
599            .certify(0.05)
600            .entries
601            .iter()
602            .map(|e| e.log_e)
603            .sum();
604        let drift_log_e: f64 = traj[1]
605            .change_evidence
606            .certify(0.05)
607            .entries
608            .iter()
609            .map(|e| e.log_e)
610            .sum();
611        assert!(
612            drift_log_e > const_log_e,
613            "drift change-evidence {drift_log_e} must exceed constant {const_log_e}"
614        );
615    }
616
617    #[test]
618    fn rejects_single_checkpoint_and_axis_mismatch() {
619        let grid = Array4::<f64>::zeros((1, 2, 5, 3));
620        let latent: Array1<f64> = Array1::linspace(0.0, 1.0, 5);
621        let ids = vec!["only".to_string()];
622        let names = vec!["a".to_string(), "b".to_string()];
623        let input = CheckpointDynamicsInput {
624            decoder_grid: grid.view(),
625            checkpoint_ids: &ids,
626            atom_names: &names,
627            latent_grid: latent.view(),
628        };
629        assert!(checkpoint_atom_dynamics(&input).is_err());
630    }
631}