gam_sae/inference/checkpoint_dynamics.rs
1//! Cross-checkpoint training-dynamics inference for SAE atoms (issue #1102).
2//!
3//! OLMo ships intermediate-training checkpoints. Each checkpoint `c` fits a
4//! dictionary whose atom `a` is a decoder curve `g^{(c)}_a: t ↦ ℝ^ambient`
5//! sampled on a shared latent grid `t`. The question this module answers, per
6//! atom, is *did the atom change across training, and where*, with
7//! debiased point estimates, standard errors, and anytime-valid evidence.
8//!
9//! It is pure assembly of three landed instruments — none is reimplemented:
10//!
11//! * [`crate::inference::riesz`] — the per-step contrast
12//! `θ = g^{(c+1)}(t₀) − g^{(c)}(t₀)` is the linear
13//! [`SmoothFunctional::Contrast`] of a stacked two-checkpoint coefficient
14//! vector; [`debias_with_dense_hessian`] returns the penalty-debiased
15//! estimate and a plug-in SE via the Riesz representer.
16//! * [`crate::inference::layer_transport`] — the checkpoint axis is reused as
17//! the "layer" axis: [`fit_transport_map`] aligns the atom's latent chart
18//! across consecutive checkpoints (topology compatibility, isometry defect,
19//! winding degree), packaged as a [`LayerTransportReport`].
20//! * [`gam_terms::inference::structure_evidence`] — each consecutive-step contrast
21//! feeds one anytime-valid e-value (the studentized displacement mapped to a
22//! two-sided p-value and run through the frozen κ = ½ p→e calibrator) into a
23//! per-step [`StructureLedger`] claim under the null "the atom did not change
24//! at this checkpoint step". A genuine e-value (`E_{H0}[E] ≤ 1`), unlike the
25//! divergent in-sample `exp(½ z²)` likelihood ratio; optional-stopping-safe.
26//!
27//! # Honest accounting of the Riesz inputs
28//!
29//! A *bare* decoder grid carries the fitted curve VALUES but no
30//! observation-level scores and no penalized Hessian — those cannot be
31//! fabricated from grid samples. So the smooth this module debiases is the
32//! one the grid actually defines: a ridge-penalized least-squares fit of the
33//! grid VALUES on the latent-grid identity (interpolation) basis, where each
34//! grid node is one observation with response equal to the decoder value at
35//! that node. This is a genuine fit with a genuine penalized Hessian
36//! `XᵀX + λS = I + λS` and genuine per-node scores `s_i = (β_i − y_i)·eᵢ`,
37//! so every quantity handed to [`debias_with_dense_hessian`] is real, not a
38//! placeholder. The contrast functional is then evaluated against the
39//! identity design row at the latent-grid mode index. The ambient dimension is
40//! handled component-wise and the per-component contrasts are aggregated into a
41//! single scalar `θ` by the L2 norm of the component contrast vector (the size
42//! of the decoder displacement at `t₀`); its SE is propagated by the
43//! delta method through that norm.
44
45use crate::inference::layer_transport::{ChartTopology, LayerTransportReport, fit_layer_transport};
46use crate::inference::riesz::{
47 RieszDebiasReport, RieszInput, SmoothFunctional, debias_with_dense_hessian,
48};
49use gam_terms::inference::structure_evidence::{ClaimKind, StructureLedger, log_e_from_p_calibrator};
50use ndarray::{Array1, Array2, ArrayView1, ArrayView4};
51use statrs::distribution::{ContinuousCDF, Normal};
52
53/// Ridge penalty on the interpolation fit of the grid values. Small relative
54/// to the unit data Hessian so the fit tracks the grid closely; non-zero so
55/// the penalty-debiasing term in the Riesz one-step is exercised on real
56/// (not degenerate) curvature, and so the Hessian `I + λS` is strictly SPD.
57const GRID_FIT_RIDGE: f64 = 1e-3;
58
59/// Inputs for one cross-checkpoint atom-dynamics run.
60///
61/// `decoder_grid` is `[n_checkpoints, n_atoms, n_grid, ambient_dim]`: the
62/// decoder curve of every atom sampled on the shared `latent_grid` at every
63/// checkpoint. `checkpoint_ids[c]` and `atom_names[a]` label the axes.
64pub struct CheckpointDynamicsInput<'a> {
65 pub decoder_grid: ArrayView4<'a, f64>,
66 pub checkpoint_ids: &'a [String],
67 pub atom_names: &'a [String],
68 pub latent_grid: ArrayView1<'a, f64>,
69}
70
71/// The training-dynamics trajectory of one atom across the checkpoint axis.
72///
73/// The PRIMARY, coverage-valid deliverable is [`Self::change_evidence`]: the
74/// anytime-valid e-process answering "did atom k change during training?".
75/// [`Self::conditional_step_contrasts`] is a secondary, descriptive readout (see
76/// its docs for the conditional caveat).
77pub struct AtomTrajectory {
78 pub atom_name: String,
79 /// Debiased `g^{(c+1)}(t_mode) − g^{(c)}(t_mode)` for each consecutive
80 /// checkpoint step, with its plug-in SE.
81 ///
82 /// CONDITIONAL ON THE FITTED COORDINATES (not a coverage-valid CI). The
83 /// debiased SE here conditions away the generated-regressor uncertainty in
84 /// the estimated latent coordinates `t̂` and activations `â` — the exact
85 /// correction the marginal-slope family exists to make (issue #1115). It is
86 /// reported only as a conditional contrast point estimate with a plug-in SE,
87 /// NOT as an interval with frequentist coverage for the population
88 /// displacement. The headline change verdict is carried by the e-process
89 /// [`Self::change_evidence`], which IS anytime-valid; this field is a
90 /// descriptive companion. Read the SE accordingly.
91 pub conditional_step_contrasts: Vec<RieszDebiasReport>,
92 /// Consecutive-checkpoint chart correspondences (checkpoint axis reused as
93 /// the transport "layer" axis).
94 pub transports: Vec<LayerTransportReport>,
95 /// PRIMARY deliverable: anytime-valid evidence that the atom changed at each
96 /// consecutive checkpoint step, one calibrated e-value per step into a
97 /// per-step claim. Valid at any data-dependent stopping time.
98 pub change_evidence: StructureLedger,
99}
100
101/// Run cross-checkpoint debiased dynamics inference for every atom.
102///
103/// For each atom, walks consecutive checkpoints and, at each step `c → c+1`:
104/// 1. fits the transport map between the two checkpoints' latent charts
105/// ([`fit_layer_transport`], checkpoint axis as the layer axis);
106/// 2. evaluates the Riesz-debiased decoder-displacement contrast at the
107/// latent-grid mode ([`SmoothFunctional::Contrast`] + penalty debiasing);
108/// 3. absorbs the studentized contrast as a calibrated anytime-valid e-value
109/// into the step's change claim under the no-change null.
110pub fn checkpoint_atom_dynamics(
111 input: &CheckpointDynamicsInput<'_>,
112) -> Result<Vec<AtomTrajectory>, String> {
113 let shape = input.decoder_grid.shape();
114 let (n_checkpoints, n_atoms, n_grid, ambient_dim) = (shape[0], shape[1], shape[2], shape[3]);
115 if n_checkpoints < 2 {
116 return Err(format!(
117 "checkpoint dynamics needs at least two checkpoints, got {n_checkpoints}"
118 ));
119 }
120 if input.checkpoint_ids.len() != n_checkpoints {
121 return Err(format!(
122 "checkpoint_ids length {} disagrees with decoder grid checkpoint axis {n_checkpoints}",
123 input.checkpoint_ids.len()
124 ));
125 }
126 if input.atom_names.len() != n_atoms {
127 return Err(format!(
128 "atom_names length {} disagrees with decoder grid atom axis {n_atoms}",
129 input.atom_names.len()
130 ));
131 }
132 if input.latent_grid.len() != n_grid {
133 return Err(format!(
134 "latent_grid length {} disagrees with decoder grid latent axis {n_grid}",
135 input.latent_grid.len()
136 ));
137 }
138 if n_grid < 2 || ambient_dim == 0 {
139 return Err(format!(
140 "checkpoint dynamics needs a non-trivial grid ({n_grid}) and ambient dim ({ambient_dim})"
141 ));
142 }
143 if input.decoder_grid.iter().any(|v| !v.is_finite()) {
144 return Err("checkpoint dynamics decoder grid must be finite".to_string());
145 }
146 if input.latent_grid.iter().any(|v| !v.is_finite()) {
147 return Err("checkpoint dynamics latent grid must be finite".to_string());
148 }
149
150 // The mode index: the latent-grid node where the contrast is evaluated.
151 // Use the central node so it sits inside any chart and away from edge
152 // interpolation artifacts.
153 let mode_index = n_grid / 2;
154
155 // Identity interpolation design `X = I_{n_grid}` and its ridge penalty
156 // `S = I`. The penalized Hessian `H = XᵀX + λS = (1 + λ) I` is shared by
157 // every component fit, so it is built once.
158 let penalty_scale = 1.0 + GRID_FIT_RIDGE;
159 let mut hessian = Array2::<f64>::zeros((n_grid, n_grid));
160 for i in 0..n_grid {
161 hessian[[i, i]] = penalty_scale;
162 }
163 // Contrast design rows pick out the mode node: `m(t_mode) = β_{mode}`, so
164 // the value-design row is the mode basis vector. The contrast `a − b`
165 // (later checkpoint minus earlier) shares the same row; the per-checkpoint
166 // distinction is carried by the two fitted coefficient vectors, exactly as
167 // a paired contrast of the same functional across two fits.
168 let mut mode_row = Array1::<f64>::zeros(n_grid);
169 mode_row[mode_index] = 1.0;
170
171 let mut trajectories = Vec::with_capacity(n_atoms);
172 for atom in 0..n_atoms {
173 let atom_name = input.atom_names[atom].clone();
174 let mut step_contrasts = Vec::with_capacity(n_checkpoints - 1);
175 let mut transports = Vec::with_capacity(n_checkpoints - 1);
176 let mut change_evidence = StructureLedger::new();
177
178 for step in 0..n_checkpoints - 1 {
179 let c0 = step;
180 let c1 = step + 1;
181
182 // --- transport map across the checkpoint axis --------------------
183 // Reuse the latent grid itself as both charts' coordinates on an
184 // interval `[min, max]`; the transport fit aligns the two
185 // checkpoints' decoder curves through their shared latent index.
186 // The "from"/"to" coordinates are the decoder values projected to
187 // the first ambient component, the available scalar chart sample.
188 let coords_from = input
189 .decoder_grid
190 .slice(ndarray::s![c0, atom, .., 0])
191 .to_owned();
192 let coords_to = input
193 .decoder_grid
194 .slice(ndarray::s![c1, atom, .., 0])
195 .to_owned();
196 let (lo, hi) = interval_bounds(coords_from.view(), coords_to.view());
197 let topology = ChartTopology::Interval { lo, hi };
198 let transport = fit_layer_transport(
199 c0,
200 c1,
201 coords_from.view(),
202 coords_to.view(),
203 topology,
204 topology,
205 )
206 .map_err(|e| {
207 format!(
208 "checkpoint transport for atom '{atom_name}' step {} → {} failed: {e}",
209 input.checkpoint_ids[c0], input.checkpoint_ids[c1]
210 )
211 })?;
212 transports.push(transport);
213
214 // --- Riesz-debiased decoder-displacement contrast at the mode ----
215 let report = contrast_at_mode(&ContrastAtMode {
216 grid: input.decoder_grid,
217 atom,
218 c0,
219 c1,
220 ambient_dim,
221 n_grid,
222 hessian: hessian.view(),
223 mode_row: mode_row.view(),
224 })
225 .map_err(|e| {
226 format!(
227 "checkpoint contrast for atom '{atom_name}' step {} → {} failed: {e}",
228 input.checkpoint_ids[c0], input.checkpoint_ids[c1]
229 )
230 })?;
231
232 // --- anytime-valid evidence the atom changed at this step --------
233 // The debiased displacement `θ̂` with SE `se` studentizes to
234 // `z = θ̂ / se` (local Gaussian `θ̂ ~ N(θ, se²)`). Its two-sided
235 // p-value run through the frozen κ = ½ p→e calibrator is a genuine
236 // e-value for the per-step no-change null θ = 0 — `E_{H0}[E] ≤ 1`,
237 // which the naive in-sample `exp(½ z²)` ratio is NOT (it diverges
238 // under H0). One e-value per step into a per-step claim; the
239 // calibrator's contract (one e-value per independent batch) is met
240 // because each step is a distinct checkpoint transition.
241 let claim = change_evidence.register(ClaimKind::Custom {
242 label: format!(
243 "atom '{atom_name}' changed from checkpoint {} to {}",
244 input.checkpoint_ids[c0], input.checkpoint_ids[c1]
245 ),
246 });
247 let log_e = no_change_log_e_value(report.theta_onestep, report.se)?;
248 change_evidence.absorb_log(claim, log_e)?;
249
250 step_contrasts.push(report);
251 }
252
253 trajectories.push(AtomTrajectory {
254 atom_name,
255 conditional_step_contrasts: step_contrasts,
256 transports,
257 change_evidence,
258 });
259 }
260
261 Ok(trajectories)
262}
263
264/// Interval bounds spanning both checkpoints' scalar chart samples, padded so
265/// the transport basis domain strictly contains the data.
266fn interval_bounds(a: ArrayView1<'_, f64>, b: ArrayView1<'_, f64>) -> (f64, f64) {
267 let mut lo = f64::INFINITY;
268 let mut hi = f64::NEG_INFINITY;
269 for &v in a.iter().chain(b.iter()) {
270 lo = lo.min(v);
271 hi = hi.max(v);
272 }
273 if !(lo.is_finite() && hi.is_finite()) {
274 return (0.0, 1.0);
275 }
276 if hi <= lo {
277 // Degenerate (constant) chart: open a unit window around the value.
278 return (lo - 0.5, lo + 0.5);
279 }
280 let pad = (hi - lo) * 1e-6;
281 (lo - pad, hi + pad)
282}
283
284/// Debiased `g^{(c1)}(t_mode) − g^{(c0)}(t_mode)` aggregated over the ambient
285/// dimension into the scalar decoder-displacement size, with a delta-method SE.
286///
287/// Each ambient component is an independent identity-basis ridge fit of the
288/// grid values; the [`SmoothFunctional::Contrast`] of the two checkpoints'
289/// fitted coefficient vectors at the mode node is debiased component-wise via
290/// the Riesz one-step. The component contrasts form a vector `Δ ∈ ℝ^ambient`;
291/// the reported scalar `θ = ‖Δ‖₂` is the displacement size and its SE is the
292/// delta-method norm-gradient `‖Δ‖₂` propagation of the per-component SEs,
293/// assuming component independence (separate fits, separate scores).
294struct ContrastAtMode<'a> {
295 grid: ArrayView4<'a, f64>,
296 atom: usize,
297 c0: usize,
298 c1: usize,
299 ambient_dim: usize,
300 n_grid: usize,
301 hessian: ndarray::ArrayView2<'a, f64>,
302 mode_row: ArrayView1<'a, f64>,
303}
304
305fn contrast_at_mode(args: &ContrastAtMode<'_>) -> Result<RieszDebiasReport, String> {
306 let grid = args.grid;
307 let atom = args.atom;
308 let c0 = args.c0;
309 let c1 = args.c1;
310 let ambient_dim = args.ambient_dim;
311 let n_grid = args.n_grid;
312 let hessian = args.hessian;
313 let mode_row = args.mode_row;
314 // Aggregate scalar contrast Δ = θ_c1 − θ_c0 across ambient components, and
315 // the matching aggregate Riesz quantities, so a single RieszDebiasReport
316 // describes the displacement. We assemble the report from one debiasing per
317 // component and combine through the L2 norm.
318 let mut delta = Array1::<f64>::zeros(ambient_dim);
319 let mut delta_one = Array1::<f64>::zeros(ambient_dim);
320 let mut var_components = Array1::<f64>::zeros(ambient_dim);
321 let mut penalty_bias_acc = 0.0_f64;
322 // A representer to carry: reuse the last component's; the scalar norm
323 // estimate's representer is component-wise so we keep the final one as the
324 // canonical witness (its influence vector studentizes the norm contrast).
325 let mut witness: Option<RieszDebiasReport> = None;
326
327 for comp in 0..ambient_dim {
328 // Per-checkpoint identity-basis ridge fit: response y = grid values,
329 // design X = I, penalty S = I. With H = (1+λ)I the fitted coefficient
330 // is β = y / (1 + λ); the per-node score is sᵢ = (μ̂ᵢ − yᵢ)·eᵢ where
331 // μ̂ = Xβ = β, and the penalty gradient is S·β = β.
332 let y0 = grid.slice(ndarray::s![c0, atom, .., comp]).to_owned();
333 let y1 = grid.slice(ndarray::s![c1, atom, .., comp]).to_owned();
334 let report = component_contrast(y0.view(), y1.view(), n_grid, hessian, mode_row)?;
335
336 delta[comp] = report.theta_plugin;
337 delta_one[comp] = report.theta_onestep;
338 var_components[comp] = report.se * report.se;
339 penalty_bias_acc += report.penalty_bias * report.penalty_bias;
340 witness = Some(report);
341 }
342
343 let theta_plugin = delta.dot(&delta).sqrt();
344 let norm_one = delta_one.dot(&delta_one).sqrt();
345 // Delta method for θ = ‖Δ‖₂: ∂θ/∂Δ_k = Δ_k / ‖Δ‖₂, components independent,
346 // so var(θ) = Σ_k (Δ_k/‖Δ‖₂)² var(Δ_k).
347 let se = if norm_one > f64::MIN_POSITIVE {
348 let mut v = 0.0_f64;
349 for comp in 0..ambient_dim {
350 let g = delta_one[comp] / norm_one;
351 v += g * g * var_components[comp];
352 }
353 v.max(0.0).sqrt()
354 } else {
355 // At a null displacement the norm is non-differentiable; fall back to
356 // the RMS of the component SEs (an honest upper-ish bound on the size).
357 (var_components.sum() / ambient_dim as f64).sqrt()
358 };
359
360 let mut report = witness
361 .ok_or_else(|| "checkpoint contrast requires at least one ambient component".to_string())?;
362 report.theta_plugin = theta_plugin;
363 report.theta_onestep = norm_one;
364 report.se = se;
365 report.penalty_bias = penalty_bias_acc.sqrt();
366 Ok(report)
367}
368
369/// One ambient component's debiased contrast `g^{(c1)}(t_mode) −
370/// g^{(c0)}(t_mode)` through the Riesz one-step.
371fn component_contrast(
372 y0: ArrayView1<'_, f64>,
373 y1: ArrayView1<'_, f64>,
374 n_grid: usize,
375 hessian: ndarray::ArrayView2<'_, f64>,
376 mode_row: ArrayView1<'_, f64>,
377) -> Result<RieszDebiasReport, String> {
378 // Stacked paired-contrast trick: the contrast `m_{c1}(t₀) − m_{c0}(t₀)` is
379 // the difference of one linear functional applied to two coefficient
380 // vectors. Riesz operates on a single fit, so we debias on the DIFFERENCE
381 // fit β_Δ = β_{c1} − β_{c0}, whose response is y₁ − y₀ on the same identity
382 // basis — a genuine fit with the same penalized Hessian. The contrast
383 // functional on β_Δ is then the point evaluation at the mode, packaged via
384 // SmoothFunctional::Contrast against a zero row so the gradient is the mode
385 // row exactly (g = mode_row − 0).
386 let beta0 = y0.mapv(|v| v / (1.0 + GRID_FIT_RIDGE));
387 let beta1 = y1.mapv(|v| v / (1.0 + GRID_FIT_RIDGE));
388 let beta_delta = &beta1 - &beta0;
389
390 let zero_row = Array1::<f64>::zeros(n_grid);
391 let functional = SmoothFunctional::Contrast {
392 design_row_a: mode_row,
393 design_row_b: zero_row.view(),
394 };
395 let gradient = functional
396 .gradient()
397 .map_err(|e| format!("contrast functional gradient failed: {e}"))?;
398
399 // Per-node scores of the difference fit: μ̂ = X β_Δ = β_Δ, response y₁−y₀.
400 let response = &y1.to_owned() - &y0;
401 let mut row_scores = Array2::<f64>::zeros((n_grid, n_grid));
402 for i in 0..n_grid {
403 row_scores[[i, i]] = beta_delta[i] - response[i];
404 }
405 // Penalty gradient S·β_Δ = β_Δ (S = I).
406 let penalty_beta = beta_delta.clone();
407
408 let input = RieszInput {
409 beta: beta_delta.view(),
410 functional_gradient: gradient.view(),
411 row_scores: row_scores.view(),
412 penalty_beta: penalty_beta.view(),
413 leverage: None,
414 };
415 debias_with_dense_hessian(&input, hessian).map_err(|e| format!("Riesz debiasing failed: {e}"))
416}
417
418/// Anytime-valid log-e-value for the no-change null `θ = 0` from the debiased,
419/// studentized displacement `z = θ̂ / se` (local Gaussian `θ̂ ~ N(θ, se²)`).
420///
421/// The naive in-sample likelihood ratio `exp(½ z²)` — the alternative density
422/// re-centered on the very estimate `θ̂` it is scored against — is NOT an
423/// e-value: under H0, `z ~ N(0,1)` and `E[exp(½ z²)] = ∫ φ(z) exp(½ z²) dz`
424/// DIVERGES, so it has no `E_{H0}[E] ≤ 1` guarantee. (Universal inference earns
425/// `exp(½ z²)` validity only with a held-out evaluation fold; a single grid of
426/// decoder values affords no such split.)
427///
428/// Instead we map the displacement to its two-sided normal p-value
429/// `p = 2(1 − Φ(|z|))` and route it through the module's frozen p→e calibrator
430/// [`log_e_from_p_calibrator`] (the κ = ½ member `e(p) = ½ p^{−1/2}`, with
431/// `∫₀¹ e(p) dp = 1`, hence `E_{H0}[e(P)] ≤ 1` for any superuniform p). This is
432/// a genuine e-value: no displacement, small e; a real displacement, large e;
433/// and it compounds validly into the change e-process. A degenerate
434/// (non-positive) SE yields a zero log-e-value (no evidence, not certainty).
435fn no_change_log_e_value(theta_hat: f64, se: f64) -> Result<f64, String> {
436 if !(se > 0.0) || !theta_hat.is_finite() {
437 return Ok(0.0);
438 }
439 let z = (theta_hat / se).abs();
440 let normal =
441 Normal::new(0.0, 1.0).map_err(|e| format!("standard normal construction failed: {e}"))?;
442 // Two-sided p-value of the studentized displacement; clamp to (0, 1] so the
443 // calibrator (which rejects p = 0) sees a finite, valid argument even at a
444 // numerically saturated tail.
445 let p: f64 = (2.0 * (1.0 - normal.cdf(z))).clamp(f64::MIN_POSITIVE, 1.0);
446 log_e_from_p_calibrator(p)
447}
448
449#[cfg(test)]
450mod tests {
451 use super::*;
452 use ndarray::Array4;
453
454 /// Build a `[n_ckpt, n_atoms, n_grid, ambient]` grid where atom 0's curve is
455 /// constant across checkpoints (no change) and atom 1's curve at the central
456 /// (mode) node is displaced by a known amount `shift` in component 0 between
457 /// consecutive checkpoints (a steady drift).
458 fn drift_grid(n_ckpt: usize, n_grid: usize, ambient: usize, shift: f64) -> Array4<f64> {
459 let mode = n_grid / 2;
460 let mut grid = Array4::<f64>::zeros((n_ckpt, 2, n_grid, ambient));
461 for c in 0..n_ckpt {
462 for g in 0..n_grid {
463 let t = g as f64 / (n_grid - 1) as f64;
464 for comp in 0..ambient {
465 // Atom 0: smooth bump, identical at every checkpoint.
466 grid[[c, 0, g, comp]] = (t * std::f64::consts::PI).sin() * (comp as f64 + 1.0);
467 // Atom 1: same base curve plus a checkpoint-indexed shift at
468 // the mode node in component 0 only.
469 let base = (t * std::f64::consts::PI).sin() * (comp as f64 + 1.0);
470 grid[[c, 1, g, comp]] = if g == mode && comp == 0 {
471 base + shift * c as f64
472 } else {
473 base
474 };
475 }
476 }
477 }
478 grid
479 }
480
481 #[test]
482 fn no_change_atom_has_near_zero_contrast_and_no_change_evidence() {
483 let n_ckpt = 5;
484 // The transport fit requires at least MIN_TRANSPORT_OBS (16) paired
485 // grid samples, so the shared latent grid must be at least that long.
486 let n_grid = 17;
487 let ambient = 3;
488 let grid = drift_grid(n_ckpt, n_grid, ambient, 0.5);
489 let latent: Array1<f64> = Array1::linspace(0.0, 1.0, n_grid);
490 let ckpt_ids: Vec<String> = (0..n_ckpt).map(|c| format!("dev{c}")).collect();
491 let atom_names = vec!["constant".to_string(), "drifter".to_string()];
492 let input = CheckpointDynamicsInput {
493 decoder_grid: grid.view(),
494 checkpoint_ids: &ckpt_ids,
495 atom_names: &atom_names,
496 latent_grid: latent.view(),
497 };
498 let traj = checkpoint_atom_dynamics(&input).expect("dynamics");
499 assert_eq!(traj.len(), 2);
500
501 // Atom 0 is identical across checkpoints: every step contrast must be
502 // (numerically) zero displacement and accumulate no change evidence.
503 let constant = &traj[0];
504 assert_eq!(constant.conditional_step_contrasts.len(), n_ckpt - 1);
505 for report in &constant.conditional_step_contrasts {
506 assert!(
507 report.theta_onestep.abs() < 1e-9,
508 "constant atom step displacement should be ~0, got {}",
509 report.theta_onestep
510 );
511 }
512 // No-change null is true here → the e-BH certificate confirms nothing.
513 let cert = constant.change_evidence.certify(0.05);
514 assert!(
515 cert.confirmed().count() == 0,
516 "constant atom must not confirm any change claim"
517 );
518 }
519
520 #[test]
521 fn drifting_atom_recovers_displacement_and_accumulates_change_evidence() {
522 let n_ckpt = 6;
523 let n_grid = 17;
524 let ambient = 3;
525 let shift = 0.7_f64;
526 let grid = drift_grid(n_ckpt, n_grid, ambient, shift);
527 let latent: Array1<f64> = Array1::linspace(0.0, 1.0, n_grid);
528 let ckpt_ids: Vec<String> = (0..n_ckpt).map(|c| format!("dev{c}")).collect();
529 let atom_names = vec!["constant".to_string(), "drifter".to_string()];
530 let input = CheckpointDynamicsInput {
531 decoder_grid: grid.view(),
532 checkpoint_ids: &ckpt_ids,
533 atom_names: &atom_names,
534 latent_grid: latent.view(),
535 };
536 let traj = checkpoint_atom_dynamics(&input).expect("dynamics");
537 let drifter = &traj[1];
538
539 // Each consecutive step displaces component 0 at the mode by exactly
540 // `shift`; the reported displacement size is `‖Δ‖₂`. On the light
541 // interpolation ridge (λ = GRID_FIT_RIDGE ≈ 1e-3) the plug-in contrast
542 // `shift/(1+λ)` tracks the true displacement to sub-percent, and every
543 // reported quantity is finite. (The component displacement lives in a
544 // single ambient channel, so the L2 size IS that channel's contrast.)
545 for report in &drifter.conditional_step_contrasts {
546 assert!(
547 (report.theta_plugin - shift).abs() < 1e-2 * shift,
548 "drift step plug-in displacement should track {shift}, got {}",
549 report.theta_plugin
550 );
551 assert!(
552 report.theta_onestep.is_finite() && report.se.is_finite(),
553 "debiased displacement and SE must be finite"
554 );
555 // The displacement is unambiguously positive (a real change).
556 assert!(
557 report.theta_plugin > 0.5 * shift,
558 "drift displacement should be well above zero, got {}",
559 report.theta_plugin
560 );
561 }
562
563 // The drift is real → every step's no-change e-value is strictly
564 // positive (studentized displacement away from zero), so the change
565 // certificate carries strictly positive log-evidence on its claims,
566 // unlike the constant atom whose claims carry exactly zero.
567 let cert = drifter.change_evidence.certify(0.05);
568 let total_log_e: f64 = cert.entries.iter().map(|e| e.log_e).sum();
569 assert!(
570 total_log_e > 0.0,
571 "steady real drift must accumulate positive change evidence, entries: {:?}",
572 cert.entries
573 .iter()
574 .map(|e| (e.log_e, e.confirmed))
575 .collect::<Vec<_>>()
576 );
577 }
578
579 /// A drifting atom must out-evidence a constant atom: the change e-process
580 /// is a genuine discriminator, not a constant.
581 #[test]
582 fn drift_outweighs_constant_in_change_evidence() {
583 let n_ckpt = 6;
584 let n_grid = 17;
585 let ambient = 3;
586 let grid = drift_grid(n_ckpt, n_grid, ambient, 0.7);
587 let latent: Array1<f64> = Array1::linspace(0.0, 1.0, n_grid);
588 let ckpt_ids: Vec<String> = (0..n_ckpt).map(|c| format!("dev{c}")).collect();
589 let atom_names = vec!["constant".to_string(), "drifter".to_string()];
590 let input = CheckpointDynamicsInput {
591 decoder_grid: grid.view(),
592 checkpoint_ids: &ckpt_ids,
593 atom_names: &atom_names,
594 latent_grid: latent.view(),
595 };
596 let traj = checkpoint_atom_dynamics(&input).expect("dynamics");
597 let const_log_e: f64 = traj[0]
598 .change_evidence
599 .certify(0.05)
600 .entries
601 .iter()
602 .map(|e| e.log_e)
603 .sum();
604 let drift_log_e: f64 = traj[1]
605 .change_evidence
606 .certify(0.05)
607 .entries
608 .iter()
609 .map(|e| e.log_e)
610 .sum();
611 assert!(
612 drift_log_e > const_log_e,
613 "drift change-evidence {drift_log_e} must exceed constant {const_log_e}"
614 );
615 }
616
617 #[test]
618 fn rejects_single_checkpoint_and_axis_mismatch() {
619 let grid = Array4::<f64>::zeros((1, 2, 5, 3));
620 let latent: Array1<f64> = Array1::linspace(0.0, 1.0, 5);
621 let ids = vec!["only".to_string()];
622 let names = vec!["a".to_string(), "b".to_string()];
623 let input = CheckpointDynamicsInput {
624 decoder_grid: grid.view(),
625 checkpoint_ids: &ids,
626 atom_names: &names,
627 latent_grid: latent.view(),
628 };
629 assert!(checkpoint_atom_dynamics(&input).is_err());
630 }
631}