Skip to main content

gam_sae/inference/
checkpoint_dynamics.rs

1//! Cross-checkpoint training-dynamics inference for SAE atoms (issue #1102).
2//!
3//! OLMo ships intermediate-training checkpoints. Each checkpoint `c` fits a
4//! dictionary whose atom `a` is a decoder curve `g^{(c)}_a: t ↦ ℝ^ambient`
5//! sampled on a shared latent grid `t`. The question this module answers, per
6//! atom, is *did the atom change across training, and where*, with
7//! debiased point estimates, standard errors, and anytime-valid evidence.
8//!
9//! It is pure assembly of three landed instruments — none is reimplemented:
10//!
11//! * [`crate::inference::riesz`] — the per-step contrast
12//!   `θ = g^{(c+1)}(t₀) − g^{(c)}(t₀)` is the linear
13//!   [`SmoothFunctional::Contrast`] of a stacked two-checkpoint coefficient
14//!   vector; [`debias_with_dense_hessian`] returns the penalty-debiased
15//!   estimate and a plug-in SE via the Riesz representer.
16//! * [`crate::inference::layer_transport`] — the checkpoint axis is reused as
17//!   the "layer" axis: [`fit_transport_map`] aligns the atom's latent chart
18//!   across consecutive checkpoints (topology compatibility, isometry defect,
19//!   winding degree), packaged as a [`LayerTransportReport`].
20//! * [`gam_terms::inference::structure_evidence`] — each consecutive-step contrast
21//!   feeds one anytime-valid e-value (the studentized displacement mapped to a
22//!   two-sided p-value and run through the frozen κ = ½ p→e calibrator) into a
23//!   per-step [`StructureLedger`] claim under the null "the atom did not change
24//!   at this checkpoint step". A genuine e-value (`E_{H0}[E] ≤ 1`), unlike the
25//!   divergent in-sample `exp(½ z²)` likelihood ratio; optional-stopping-safe.
26//!
27//! # Honest accounting of the Riesz inputs
28//!
29//! A *bare* decoder grid carries the fitted curve VALUES but no
30//! observation-level scores and no penalized Hessian — those cannot be
31//! fabricated from grid samples. So the smooth this module debiases is the
32//! one the grid actually defines: a ridge-penalized least-squares fit of the
33//! grid VALUES on the latent-grid identity (interpolation) basis, where each
34//! grid node is one observation with response equal to the decoder value at
35//! that node. This is a genuine fit with a genuine penalized Hessian
36//! `XᵀX + λS = I + λS` and genuine per-node scores `s_i = (β_i − y_i)·eᵢ`,
37//! so every quantity handed to [`debias_with_dense_hessian`] is real, not a
38//! placeholder. The contrast functional is then evaluated against the
39//! identity design row at the latent-grid mode index. The ambient dimension is
40//! handled component-wise and the per-component contrasts are aggregated into a
41//! single scalar `θ` by the L2 norm of the component contrast vector (the size
42//! of the decoder displacement at `t₀`); its SE is propagated by the
43//! delta method through that norm.
44
45use crate::inference::layer_transport::{ChartTopology, LayerTransportReport, fit_layer_transport};
46use crate::inference::riesz::{
47    RieszDebiasReport, RieszInput, SmoothFunctional, debias_with_dense_hessian,
48};
49use gam_terms::inference::structure_evidence::{
50    ClaimKind, StructureLedger, log_e_from_p_calibrator,
51};
52use ndarray::{Array1, Array2, ArrayView1, ArrayView4};
53use statrs::distribution::{ContinuousCDF, Normal};
54
55/// Ridge penalty on the interpolation fit of the grid values. Small relative
56/// to the unit data Hessian so the fit tracks the grid closely; non-zero so
57/// the penalty-debiasing term in the Riesz one-step is exercised on real
58/// (not degenerate) curvature, and so the Hessian `I + λS` is strictly SPD.
59const GRID_FIT_RIDGE: f64 = 1e-3;
60
61/// Inputs for one cross-checkpoint atom-dynamics run.
62///
63/// `decoder_grid` is `[n_checkpoints, n_atoms, n_grid, ambient_dim]`: the
64/// decoder curve of every atom sampled on the shared `latent_grid` at every
65/// checkpoint. `checkpoint_ids[c]` and `atom_names[a]` label the axes.
66pub struct CheckpointDynamicsInput<'a> {
67    pub decoder_grid: ArrayView4<'a, f64>,
68    pub checkpoint_ids: &'a [String],
69    pub atom_names: &'a [String],
70    pub latent_grid: ArrayView1<'a, f64>,
71}
72
73/// The training-dynamics trajectory of one atom across the checkpoint axis.
74///
75/// The PRIMARY, coverage-valid deliverable is [`Self::change_evidence`]: the
76/// anytime-valid e-process answering "did atom k change during training?".
77/// [`Self::conditional_step_contrasts`] is a secondary, descriptive readout (see
78/// its docs for the conditional caveat).
79pub struct AtomTrajectory {
80    pub atom_name: String,
81    /// Debiased `g^{(c+1)}(t_mode) − g^{(c)}(t_mode)` for each consecutive
82    /// checkpoint step, with its plug-in SE.
83    ///
84    /// CONDITIONAL ON THE FITTED COORDINATES (not a coverage-valid CI). The
85    /// debiased SE here conditions away the generated-regressor uncertainty in
86    /// the estimated latent coordinates `t̂` and activations `â` — the exact
87    /// correction the marginal-slope family exists to make (issue #1115). It is
88    /// reported only as a conditional contrast point estimate with a plug-in SE,
89    /// NOT as an interval with frequentist coverage for the population
90    /// displacement. The headline change verdict is carried by the e-process
91    /// [`Self::change_evidence`], which IS anytime-valid; this field is a
92    /// descriptive companion. Read the SE accordingly.
93    pub conditional_step_contrasts: Vec<RieszDebiasReport>,
94    /// Consecutive-checkpoint chart correspondences (checkpoint axis reused as
95    /// the transport "layer" axis).
96    pub transports: Vec<LayerTransportReport>,
97    /// PRIMARY deliverable: anytime-valid evidence that the atom changed at each
98    /// consecutive checkpoint step, one calibrated e-value per step into a
99    /// per-step claim. Valid at any data-dependent stopping time.
100    pub change_evidence: StructureLedger,
101}
102
103/// Run cross-checkpoint debiased dynamics inference for every atom.
104///
105/// For each atom, walks consecutive checkpoints and, at each step `c → c+1`:
106/// 1. fits the transport map between the two checkpoints' latent charts
107///    ([`fit_layer_transport`], checkpoint axis as the layer axis);
108/// 2. evaluates the Riesz-debiased decoder-displacement contrast at the
109///    latent-grid mode ([`SmoothFunctional::Contrast`] + penalty debiasing);
110/// 3. absorbs the studentized contrast as a calibrated anytime-valid e-value
111///    into the step's change claim under the no-change null.
112pub fn checkpoint_atom_dynamics(
113    input: &CheckpointDynamicsInput<'_>,
114) -> Result<Vec<AtomTrajectory>, String> {
115    let shape = input.decoder_grid.shape();
116    let (n_checkpoints, n_atoms, n_grid, ambient_dim) = (shape[0], shape[1], shape[2], shape[3]);
117    if n_checkpoints < 2 {
118        return Err(format!(
119            "checkpoint dynamics needs at least two checkpoints, got {n_checkpoints}"
120        ));
121    }
122    if input.checkpoint_ids.len() != n_checkpoints {
123        return Err(format!(
124            "checkpoint_ids length {} disagrees with decoder grid checkpoint axis {n_checkpoints}",
125            input.checkpoint_ids.len()
126        ));
127    }
128    if input.atom_names.len() != n_atoms {
129        return Err(format!(
130            "atom_names length {} disagrees with decoder grid atom axis {n_atoms}",
131            input.atom_names.len()
132        ));
133    }
134    if input.latent_grid.len() != n_grid {
135        return Err(format!(
136            "latent_grid length {} disagrees with decoder grid latent axis {n_grid}",
137            input.latent_grid.len()
138        ));
139    }
140    if n_grid < 2 || ambient_dim == 0 {
141        return Err(format!(
142            "checkpoint dynamics needs a non-trivial grid ({n_grid}) and ambient dim ({ambient_dim})"
143        ));
144    }
145    if input.decoder_grid.iter().any(|v| !v.is_finite()) {
146        return Err("checkpoint dynamics decoder grid must be finite".to_string());
147    }
148    if input.latent_grid.iter().any(|v| !v.is_finite()) {
149        return Err("checkpoint dynamics latent grid must be finite".to_string());
150    }
151
152    // The mode index: the latent-grid node where the contrast is evaluated.
153    // Use the central node so it sits inside any chart and away from edge
154    // interpolation artifacts.
155    let mode_index = n_grid / 2;
156
157    // Identity interpolation design `X = I_{n_grid}` and its ridge penalty
158    // `S = I`. The penalized Hessian `H = XᵀX + λS = (1 + λ) I` is shared by
159    // every component fit, so it is built once.
160    let penalty_scale = 1.0 + GRID_FIT_RIDGE;
161    let mut hessian = Array2::<f64>::zeros((n_grid, n_grid));
162    for i in 0..n_grid {
163        hessian[[i, i]] = penalty_scale;
164    }
165    // Contrast design rows pick out the mode node: `m(t_mode) = β_{mode}`, so
166    // the value-design row is the mode basis vector. The contrast `a − b`
167    // (later checkpoint minus earlier) shares the same row; the per-checkpoint
168    // distinction is carried by the two fitted coefficient vectors, exactly as
169    // a paired contrast of the same functional across two fits.
170    let mut mode_row = Array1::<f64>::zeros(n_grid);
171    mode_row[mode_index] = 1.0;
172
173    let mut trajectories = Vec::with_capacity(n_atoms);
174    for atom in 0..n_atoms {
175        let atom_name = input.atom_names[atom].clone();
176        let mut step_contrasts = Vec::with_capacity(n_checkpoints - 1);
177        let mut transports = Vec::with_capacity(n_checkpoints - 1);
178        let mut change_evidence = StructureLedger::new();
179
180        for step in 0..n_checkpoints - 1 {
181            let c0 = step;
182            let c1 = step + 1;
183
184            // --- transport map across the checkpoint axis --------------------
185            // Reuse the latent grid itself as both charts' coordinates on an
186            // interval `[min, max]`; the transport fit aligns the two
187            // checkpoints' decoder curves through their shared latent index.
188            // The "from"/"to" coordinates are the decoder values projected to
189            // the first ambient component, the available scalar chart sample.
190            let coords_from = input
191                .decoder_grid
192                .slice(ndarray::s![c0, atom, .., 0])
193                .to_owned();
194            let coords_to = input
195                .decoder_grid
196                .slice(ndarray::s![c1, atom, .., 0])
197                .to_owned();
198            let (lo, hi) = interval_bounds(coords_from.view(), coords_to.view());
199            let topology = ChartTopology::Interval { lo, hi };
200            let transport = fit_layer_transport(
201                c0,
202                c1,
203                coords_from.view(),
204                coords_to.view(),
205                topology,
206                topology,
207            )
208            .map_err(|e| {
209                format!(
210                    "checkpoint transport for atom '{atom_name}' step {} → {} failed: {e}",
211                    input.checkpoint_ids[c0], input.checkpoint_ids[c1]
212                )
213            })?;
214            transports.push(transport);
215
216            // --- Riesz-debiased decoder-displacement contrast at the mode ----
217            let report = contrast_at_mode(&ContrastAtMode {
218                grid: input.decoder_grid,
219                atom,
220                c0,
221                c1,
222                ambient_dim,
223                n_grid,
224                hessian: hessian.view(),
225                mode_row: mode_row.view(),
226            })
227            .map_err(|e| {
228                format!(
229                    "checkpoint contrast for atom '{atom_name}' step {} → {} failed: {e}",
230                    input.checkpoint_ids[c0], input.checkpoint_ids[c1]
231                )
232            })?;
233
234            // --- anytime-valid evidence the atom changed at this step --------
235            // The debiased displacement `θ̂` with SE `se` studentizes to
236            // `z = θ̂ / se` (local Gaussian `θ̂ ~ N(θ, se²)`). Its two-sided
237            // p-value run through the frozen κ = ½ p→e calibrator is a genuine
238            // e-value for the per-step no-change null θ = 0 — `E_{H0}[E] ≤ 1`,
239            // which the naive in-sample `exp(½ z²)` ratio is NOT (it diverges
240            // under H0). One e-value per step into a per-step claim; the
241            // calibrator's contract (one e-value per independent batch) is met
242            // because each step is a distinct checkpoint transition.
243            let claim = change_evidence.register(ClaimKind::Custom {
244                label: format!(
245                    "atom '{atom_name}' changed from checkpoint {} to {}",
246                    input.checkpoint_ids[c0], input.checkpoint_ids[c1]
247                ),
248            });
249            let log_e = no_change_log_e_value(report.theta_onestep, report.se)?;
250            change_evidence.absorb_log(claim, log_e)?;
251
252            step_contrasts.push(report);
253        }
254
255        trajectories.push(AtomTrajectory {
256            atom_name,
257            conditional_step_contrasts: step_contrasts,
258            transports,
259            change_evidence,
260        });
261    }
262
263    Ok(trajectories)
264}
265
266/// Interval bounds spanning both checkpoints' scalar chart samples, padded so
267/// the transport basis domain strictly contains the data.
268fn interval_bounds(a: ArrayView1<'_, f64>, b: ArrayView1<'_, f64>) -> (f64, f64) {
269    let mut lo = f64::INFINITY;
270    let mut hi = f64::NEG_INFINITY;
271    for &v in a.iter().chain(b.iter()) {
272        lo = lo.min(v);
273        hi = hi.max(v);
274    }
275    if !(lo.is_finite() && hi.is_finite()) {
276        return (0.0, 1.0);
277    }
278    if hi <= lo {
279        // Degenerate (constant) chart: open a unit window around the value.
280        return (lo - 0.5, lo + 0.5);
281    }
282    let pad = (hi - lo) * 1e-6;
283    (lo - pad, hi + pad)
284}
285
286/// Debiased `g^{(c1)}(t_mode) − g^{(c0)}(t_mode)` aggregated over the ambient
287/// dimension into the scalar decoder-displacement size, with a delta-method SE.
288///
289/// Each ambient component is an independent identity-basis ridge fit of the
290/// grid values; the [`SmoothFunctional::Contrast`] of the two checkpoints'
291/// fitted coefficient vectors at the mode node is debiased component-wise via
292/// the Riesz one-step. The component contrasts form a vector `Δ ∈ ℝ^ambient`;
293/// the reported scalar `θ = ‖Δ‖₂` is the displacement size and its SE is the
294/// delta-method norm-gradient `‖Δ‖₂` propagation of the per-component SEs,
295/// assuming component independence (separate fits, separate scores).
296struct ContrastAtMode<'a> {
297    grid: ArrayView4<'a, f64>,
298    atom: usize,
299    c0: usize,
300    c1: usize,
301    ambient_dim: usize,
302    n_grid: usize,
303    hessian: ndarray::ArrayView2<'a, f64>,
304    mode_row: ArrayView1<'a, f64>,
305}
306
307fn contrast_at_mode(args: &ContrastAtMode<'_>) -> Result<RieszDebiasReport, String> {
308    let grid = args.grid;
309    let atom = args.atom;
310    let c0 = args.c0;
311    let c1 = args.c1;
312    let ambient_dim = args.ambient_dim;
313    let n_grid = args.n_grid;
314    let hessian = args.hessian;
315    let mode_row = args.mode_row;
316    // Aggregate scalar contrast Δ = θ_c1 − θ_c0 across ambient components, and
317    // the matching aggregate Riesz quantities, so a single RieszDebiasReport
318    // describes the displacement. We assemble the report from one debiasing per
319    // component and combine through the L2 norm.
320    let mut delta = Array1::<f64>::zeros(ambient_dim);
321    let mut delta_one = Array1::<f64>::zeros(ambient_dim);
322    let mut var_components = Array1::<f64>::zeros(ambient_dim);
323    let mut penalty_bias_acc = 0.0_f64;
324    // A representer to carry: reuse the last component's; the scalar norm
325    // estimate's representer is component-wise so we keep the final one as the
326    // canonical witness (its influence vector studentizes the norm contrast).
327    let mut witness: Option<RieszDebiasReport> = None;
328
329    for comp in 0..ambient_dim {
330        // Per-checkpoint identity-basis ridge fit: response y = grid values,
331        // design X = I, penalty S = I. With H = (1+λ)I the fitted coefficient
332        // is β = y / (1 + λ); the per-node score is sᵢ = (μ̂ᵢ − yᵢ)·eᵢ where
333        // μ̂ = Xβ = β, and the penalty gradient is S·β = β.
334        let y0 = grid.slice(ndarray::s![c0, atom, .., comp]).to_owned();
335        let y1 = grid.slice(ndarray::s![c1, atom, .., comp]).to_owned();
336        let report = component_contrast(y0.view(), y1.view(), n_grid, hessian, mode_row)?;
337
338        delta[comp] = report.theta_plugin;
339        delta_one[comp] = report.theta_onestep;
340        var_components[comp] = report.se * report.se;
341        penalty_bias_acc += report.penalty_bias * report.penalty_bias;
342        witness = Some(report);
343    }
344
345    let theta_plugin = delta.dot(&delta).sqrt();
346    let norm_one = delta_one.dot(&delta_one).sqrt();
347    // Delta method for θ = ‖Δ‖₂: ∂θ/∂Δ_k = Δ_k / ‖Δ‖₂, components independent,
348    // so var(θ) = Σ_k (Δ_k/‖Δ‖₂)² var(Δ_k).
349    let se = if norm_one > f64::MIN_POSITIVE {
350        let mut v = 0.0_f64;
351        for comp in 0..ambient_dim {
352            let g = delta_one[comp] / norm_one;
353            v += g * g * var_components[comp];
354        }
355        v.max(0.0).sqrt()
356    } else {
357        // At a null displacement the norm is non-differentiable; fall back to
358        // the RMS of the component SEs (an honest upper-ish bound on the size).
359        (var_components.sum() / ambient_dim as f64).sqrt()
360    };
361
362    let mut report = witness
363        .ok_or_else(|| "checkpoint contrast requires at least one ambient component".to_string())?;
364    report.theta_plugin = theta_plugin;
365    report.theta_onestep = norm_one;
366    report.se = se;
367    report.penalty_bias = penalty_bias_acc.sqrt();
368    Ok(report)
369}
370
371/// One ambient component's debiased contrast `g^{(c1)}(t_mode) −
372/// g^{(c0)}(t_mode)` through the Riesz one-step.
373fn component_contrast(
374    y0: ArrayView1<'_, f64>,
375    y1: ArrayView1<'_, f64>,
376    n_grid: usize,
377    hessian: ndarray::ArrayView2<'_, f64>,
378    mode_row: ArrayView1<'_, f64>,
379) -> Result<RieszDebiasReport, String> {
380    // Stacked paired-contrast trick: the contrast `m_{c1}(t₀) − m_{c0}(t₀)` is
381    // the difference of one linear functional applied to two coefficient
382    // vectors. Riesz operates on a single fit, so we debias on the DIFFERENCE
383    // fit β_Δ = β_{c1} − β_{c0}, whose response is y₁ − y₀ on the same identity
384    // basis — a genuine fit with the same penalized Hessian. The contrast
385    // functional on β_Δ is then the point evaluation at the mode, packaged via
386    // SmoothFunctional::Contrast against a zero row so the gradient is the mode
387    // row exactly (g = mode_row − 0).
388    let beta0 = y0.mapv(|v| v / (1.0 + GRID_FIT_RIDGE));
389    let beta1 = y1.mapv(|v| v / (1.0 + GRID_FIT_RIDGE));
390    let beta_delta = &beta1 - &beta0;
391
392    let zero_row = Array1::<f64>::zeros(n_grid);
393    let functional = SmoothFunctional::Contrast {
394        design_row_a: mode_row,
395        design_row_b: zero_row.view(),
396    };
397    let gradient = functional
398        .gradient()
399        .map_err(|e| format!("contrast functional gradient failed: {e}"))?;
400
401    // Per-node scores of the difference fit: μ̂ = X β_Δ = β_Δ, response y₁−y₀.
402    let response = &y1.to_owned() - &y0;
403    let mut row_scores = Array2::<f64>::zeros((n_grid, n_grid));
404    for i in 0..n_grid {
405        row_scores[[i, i]] = beta_delta[i] - response[i];
406    }
407    // Penalty gradient S·β_Δ = β_Δ (S = I).
408    let penalty_beta = beta_delta.clone();
409
410    let input = RieszInput {
411        beta: beta_delta.view(),
412        functional_gradient: gradient.view(),
413        row_scores: row_scores.view(),
414        penalty_beta: penalty_beta.view(),
415        leverage: None,
416    };
417    debias_with_dense_hessian(&input, hessian).map_err(|e| format!("Riesz debiasing failed: {e}"))
418}
419
420/// Anytime-valid log-e-value for the no-change null `θ = 0` from the debiased,
421/// studentized displacement `z = θ̂ / se` (local Gaussian `θ̂ ~ N(θ, se²)`).
422///
423/// The naive in-sample likelihood ratio `exp(½ z²)` — the alternative density
424/// re-centered on the very estimate `θ̂` it is scored against — is NOT an
425/// e-value: under H0, `z ~ N(0,1)` and `E[exp(½ z²)] = ∫ φ(z) exp(½ z²) dz`
426/// DIVERGES, so it has no `E_{H0}[E] ≤ 1` guarantee. (Universal inference earns
427/// `exp(½ z²)` validity only with a held-out evaluation fold; a single grid of
428/// decoder values affords no such split.)
429///
430/// Instead we map the displacement to its two-sided normal p-value
431/// `p = 2(1 − Φ(|z|))` and route it through the module's frozen p→e calibrator
432/// [`log_e_from_p_calibrator`] (the κ = ½ member `e(p) = ½ p^{−1/2}`, with
433/// `∫₀¹ e(p) dp = 1`, hence `E_{H0}[e(P)] ≤ 1` for any superuniform p). This is
434/// a genuine e-value: no displacement, small e; a real displacement, large e;
435/// and it compounds validly into the change e-process. A degenerate
436/// (non-positive) SE yields a zero log-e-value (no evidence, not certainty).
437fn no_change_log_e_value(theta_hat: f64, se: f64) -> Result<f64, String> {
438    if !(se > 0.0) || !theta_hat.is_finite() {
439        return Ok(0.0);
440    }
441    let z = (theta_hat / se).abs();
442    let normal =
443        Normal::new(0.0, 1.0).map_err(|e| format!("standard normal construction failed: {e}"))?;
444    // Two-sided p-value of the studentized displacement; clamp to (0, 1] so the
445    // calibrator (which rejects p = 0) sees a finite, valid argument even at a
446    // numerically saturated tail.
447    let p: f64 = (2.0 * (1.0 - normal.cdf(z))).clamp(f64::MIN_POSITIVE, 1.0);
448    log_e_from_p_calibrator(p)
449}
450
451#[cfg(test)]
452mod tests {
453    use super::*;
454    use ndarray::Array4;
455
456    /// Build a `[n_ckpt, n_atoms, n_grid, ambient]` grid where atom 0's curve is
457    /// constant across checkpoints (no change) and atom 1's curve at the central
458    /// (mode) node is displaced by a known amount `shift` in component 0 between
459    /// consecutive checkpoints (a steady drift).
460    fn drift_grid(n_ckpt: usize, n_grid: usize, ambient: usize, shift: f64) -> Array4<f64> {
461        let mode = n_grid / 2;
462        let mut grid = Array4::<f64>::zeros((n_ckpt, 2, n_grid, ambient));
463        for c in 0..n_ckpt {
464            for g in 0..n_grid {
465                let t = g as f64 / (n_grid - 1) as f64;
466                for comp in 0..ambient {
467                    // Atom 0: smooth bump, identical at every checkpoint.
468                    grid[[c, 0, g, comp]] = (t * std::f64::consts::PI).sin() * (comp as f64 + 1.0);
469                    // Atom 1: same base curve plus a checkpoint-indexed shift at
470                    // the mode node in component 0 only.
471                    let base = (t * std::f64::consts::PI).sin() * (comp as f64 + 1.0);
472                    grid[[c, 1, g, comp]] = if g == mode && comp == 0 {
473                        base + shift * c as f64
474                    } else {
475                        base
476                    };
477                }
478            }
479        }
480        grid
481    }
482
483    #[test]
484    fn no_change_atom_has_near_zero_contrast_and_no_change_evidence() {
485        let n_ckpt = 5;
486        // The transport fit requires at least MIN_TRANSPORT_OBS (16) paired
487        // grid samples, so the shared latent grid must be at least that long.
488        let n_grid = 17;
489        let ambient = 3;
490        let grid = drift_grid(n_ckpt, n_grid, ambient, 0.5);
491        let latent: Array1<f64> = Array1::linspace(0.0, 1.0, n_grid);
492        let ckpt_ids: Vec<String> = (0..n_ckpt).map(|c| format!("dev{c}")).collect();
493        let atom_names = vec!["constant".to_string(), "drifter".to_string()];
494        let input = CheckpointDynamicsInput {
495            decoder_grid: grid.view(),
496            checkpoint_ids: &ckpt_ids,
497            atom_names: &atom_names,
498            latent_grid: latent.view(),
499        };
500        let traj = checkpoint_atom_dynamics(&input).expect("dynamics");
501        assert_eq!(traj.len(), 2);
502
503        // Atom 0 is identical across checkpoints: every step contrast must be
504        // (numerically) zero displacement and accumulate no change evidence.
505        let constant = &traj[0];
506        assert_eq!(constant.conditional_step_contrasts.len(), n_ckpt - 1);
507        for report in &constant.conditional_step_contrasts {
508            assert!(
509                report.theta_onestep.abs() < 1e-9,
510                "constant atom step displacement should be ~0, got {}",
511                report.theta_onestep
512            );
513        }
514        // No-change null is true here → the e-BH certificate confirms nothing.
515        let cert = constant.change_evidence.certify(0.05);
516        assert!(
517            cert.confirmed().count() == 0,
518            "constant atom must not confirm any change claim"
519        );
520    }
521
522    #[test]
523    fn drifting_atom_recovers_displacement_and_accumulates_change_evidence() {
524        let n_ckpt = 6;
525        let n_grid = 17;
526        let ambient = 3;
527        let shift = 0.7_f64;
528        let grid = drift_grid(n_ckpt, n_grid, ambient, shift);
529        let latent: Array1<f64> = Array1::linspace(0.0, 1.0, n_grid);
530        let ckpt_ids: Vec<String> = (0..n_ckpt).map(|c| format!("dev{c}")).collect();
531        let atom_names = vec!["constant".to_string(), "drifter".to_string()];
532        let input = CheckpointDynamicsInput {
533            decoder_grid: grid.view(),
534            checkpoint_ids: &ckpt_ids,
535            atom_names: &atom_names,
536            latent_grid: latent.view(),
537        };
538        let traj = checkpoint_atom_dynamics(&input).expect("dynamics");
539        let drifter = &traj[1];
540
541        // Each consecutive step displaces component 0 at the mode by exactly
542        // `shift`; the reported displacement size is `‖Δ‖₂`. On the light
543        // interpolation ridge (λ = GRID_FIT_RIDGE ≈ 1e-3) the plug-in contrast
544        // `shift/(1+λ)` tracks the true displacement to sub-percent, and every
545        // reported quantity is finite. (The component displacement lives in a
546        // single ambient channel, so the L2 size IS that channel's contrast.)
547        for report in &drifter.conditional_step_contrasts {
548            assert!(
549                (report.theta_plugin - shift).abs() < 1e-2 * shift,
550                "drift step plug-in displacement should track {shift}, got {}",
551                report.theta_plugin
552            );
553            assert!(
554                report.theta_onestep.is_finite() && report.se.is_finite(),
555                "debiased displacement and SE must be finite"
556            );
557            // The displacement is unambiguously positive (a real change).
558            assert!(
559                report.theta_plugin > 0.5 * shift,
560                "drift displacement should be well above zero, got {}",
561                report.theta_plugin
562            );
563        }
564
565        // The drift is real → every step's no-change e-value is strictly
566        // positive (studentized displacement away from zero), so the change
567        // certificate carries strictly positive log-evidence on its claims,
568        // unlike the constant atom whose claims carry exactly zero.
569        let cert = drifter.change_evidence.certify(0.05);
570        let total_log_e: f64 = cert.entries.iter().map(|e| e.log_e).sum();
571        assert!(
572            total_log_e > 0.0,
573            "steady real drift must accumulate positive change evidence, entries: {:?}",
574            cert.entries
575                .iter()
576                .map(|e| (e.log_e, e.confirmed))
577                .collect::<Vec<_>>()
578        );
579    }
580
581    /// A drifting atom must out-evidence a constant atom: the change e-process
582    /// is a genuine discriminator, not a constant.
583    #[test]
584    fn drift_outweighs_constant_in_change_evidence() {
585        let n_ckpt = 6;
586        let n_grid = 17;
587        let ambient = 3;
588        let grid = drift_grid(n_ckpt, n_grid, ambient, 0.7);
589        let latent: Array1<f64> = Array1::linspace(0.0, 1.0, n_grid);
590        let ckpt_ids: Vec<String> = (0..n_ckpt).map(|c| format!("dev{c}")).collect();
591        let atom_names = vec!["constant".to_string(), "drifter".to_string()];
592        let input = CheckpointDynamicsInput {
593            decoder_grid: grid.view(),
594            checkpoint_ids: &ckpt_ids,
595            atom_names: &atom_names,
596            latent_grid: latent.view(),
597        };
598        let traj = checkpoint_atom_dynamics(&input).expect("dynamics");
599        let const_log_e: f64 = traj[0]
600            .change_evidence
601            .certify(0.05)
602            .entries
603            .iter()
604            .map(|e| e.log_e)
605            .sum();
606        let drift_log_e: f64 = traj[1]
607            .change_evidence
608            .certify(0.05)
609            .entries
610            .iter()
611            .map(|e| e.log_e)
612            .sum();
613        assert!(
614            drift_log_e > const_log_e,
615            "drift change-evidence {drift_log_e} must exceed constant {const_log_e}"
616        );
617    }
618
619    #[test]
620    fn rejects_single_checkpoint_and_axis_mismatch() {
621        let grid = Array4::<f64>::zeros((1, 2, 5, 3));
622        let latent: Array1<f64> = Array1::linspace(0.0, 1.0, 5);
623        let ids = vec!["only".to_string()];
624        let names = vec!["a".to_string(), "b".to_string()];
625        let input = CheckpointDynamicsInput {
626            decoder_grid: grid.view(),
627            checkpoint_ids: &ids,
628            atom_names: &names,
629            latent_grid: latent.view(),
630        };
631        assert!(checkpoint_atom_dynamics(&input).is_err());
632    }
633}