gam_sae/inference/checkpoint_dynamics.rs
1//! Cross-checkpoint training-dynamics inference for SAE atoms (issue #1102).
2//!
3//! OLMo ships intermediate-training checkpoints. Each checkpoint `c` fits a
4//! dictionary whose atom `a` is a decoder curve `g^{(c)}_a: t ↦ ℝ^ambient`
5//! sampled on a shared latent grid `t`. The question this module answers, per
6//! atom, is *did the atom change across training, and where*, with
7//! debiased point estimates, standard errors, and anytime-valid evidence.
8//!
9//! It is pure assembly of three landed instruments — none is reimplemented:
10//!
11//! * [`crate::inference::riesz`] — the per-step contrast
12//! `θ = g^{(c+1)}(t₀) − g^{(c)}(t₀)` is the linear
13//! [`SmoothFunctional::Contrast`] of a stacked two-checkpoint coefficient
14//! vector; [`debias_with_dense_hessian`] returns the penalty-debiased
15//! estimate and a plug-in SE via the Riesz representer.
16//! * [`crate::inference::layer_transport`] — the checkpoint axis is reused as
17//! the "layer" axis: [`fit_transport_map`] aligns the atom's latent chart
18//! across consecutive checkpoints (topology compatibility, isometry defect,
19//! winding degree), packaged as a [`LayerTransportReport`].
20//! * [`gam_terms::inference::structure_evidence`] — each consecutive-step contrast
21//! feeds one anytime-valid e-value (the studentized displacement mapped to a
22//! two-sided p-value and run through the frozen κ = ½ p→e calibrator) into a
23//! per-step [`StructureLedger`] claim under the null "the atom did not change
24//! at this checkpoint step". A genuine e-value (`E_{H0}[E] ≤ 1`), unlike the
25//! divergent in-sample `exp(½ z²)` likelihood ratio; optional-stopping-safe.
26//!
27//! # Honest accounting of the Riesz inputs
28//!
29//! A *bare* decoder grid carries the fitted curve VALUES but no
30//! observation-level scores and no penalized Hessian — those cannot be
31//! fabricated from grid samples. So the smooth this module debiases is the
32//! one the grid actually defines: a ridge-penalized least-squares fit of the
33//! grid VALUES on the latent-grid identity (interpolation) basis, where each
34//! grid node is one observation with response equal to the decoder value at
35//! that node. This is a genuine fit with a genuine penalized Hessian
36//! `XᵀX + λS = I + λS` and genuine per-node scores `s_i = (β_i − y_i)·eᵢ`,
37//! so every quantity handed to [`debias_with_dense_hessian`] is real, not a
38//! placeholder. The contrast functional is then evaluated against the
39//! identity design row at the latent-grid mode index. The ambient dimension is
40//! handled component-wise and the per-component contrasts are aggregated into a
41//! single scalar `θ` by the L2 norm of the component contrast vector (the size
42//! of the decoder displacement at `t₀`); its SE is propagated by the
43//! delta method through that norm.
44
45use crate::inference::layer_transport::{ChartTopology, LayerTransportReport, fit_layer_transport};
46use crate::inference::riesz::{
47 RieszDebiasReport, RieszInput, SmoothFunctional, debias_with_dense_hessian,
48};
49use gam_terms::inference::structure_evidence::{
50 ClaimKind, StructureLedger, log_e_from_p_calibrator,
51};
52use ndarray::{Array1, Array2, ArrayView1, ArrayView4};
53use statrs::distribution::{ContinuousCDF, Normal};
54
55/// Ridge penalty on the interpolation fit of the grid values. Small relative
56/// to the unit data Hessian so the fit tracks the grid closely; non-zero so
57/// the penalty-debiasing term in the Riesz one-step is exercised on real
58/// (not degenerate) curvature, and so the Hessian `I + λS` is strictly SPD.
59const GRID_FIT_RIDGE: f64 = 1e-3;
60
61/// Inputs for one cross-checkpoint atom-dynamics run.
62///
63/// `decoder_grid` is `[n_checkpoints, n_atoms, n_grid, ambient_dim]`: the
64/// decoder curve of every atom sampled on the shared `latent_grid` at every
65/// checkpoint. `checkpoint_ids[c]` and `atom_names[a]` label the axes.
66pub struct CheckpointDynamicsInput<'a> {
67 pub decoder_grid: ArrayView4<'a, f64>,
68 pub checkpoint_ids: &'a [String],
69 pub atom_names: &'a [String],
70 pub latent_grid: ArrayView1<'a, f64>,
71}
72
73/// The training-dynamics trajectory of one atom across the checkpoint axis.
74///
75/// The PRIMARY, coverage-valid deliverable is [`Self::change_evidence`]: the
76/// anytime-valid e-process answering "did atom k change during training?".
77/// [`Self::conditional_step_contrasts`] is a secondary, descriptive readout (see
78/// its docs for the conditional caveat).
79pub struct AtomTrajectory {
80 pub atom_name: String,
81 /// Debiased `g^{(c+1)}(t_mode) − g^{(c)}(t_mode)` for each consecutive
82 /// checkpoint step, with its plug-in SE.
83 ///
84 /// CONDITIONAL ON THE FITTED COORDINATES (not a coverage-valid CI). The
85 /// debiased SE here conditions away the generated-regressor uncertainty in
86 /// the estimated latent coordinates `t̂` and activations `â` — the exact
87 /// correction the marginal-slope family exists to make (issue #1115). It is
88 /// reported only as a conditional contrast point estimate with a plug-in SE,
89 /// NOT as an interval with frequentist coverage for the population
90 /// displacement. The headline change verdict is carried by the e-process
91 /// [`Self::change_evidence`], which IS anytime-valid; this field is a
92 /// descriptive companion. Read the SE accordingly.
93 pub conditional_step_contrasts: Vec<RieszDebiasReport>,
94 /// Consecutive-checkpoint chart correspondences (checkpoint axis reused as
95 /// the transport "layer" axis).
96 pub transports: Vec<LayerTransportReport>,
97 /// PRIMARY deliverable: anytime-valid evidence that the atom changed at each
98 /// consecutive checkpoint step, one calibrated e-value per step into a
99 /// per-step claim. Valid at any data-dependent stopping time.
100 pub change_evidence: StructureLedger,
101}
102
103/// Run cross-checkpoint debiased dynamics inference for every atom.
104///
105/// For each atom, walks consecutive checkpoints and, at each step `c → c+1`:
106/// 1. fits the transport map between the two checkpoints' latent charts
107/// ([`fit_layer_transport`], checkpoint axis as the layer axis);
108/// 2. evaluates the Riesz-debiased decoder-displacement contrast at the
109/// latent-grid mode ([`SmoothFunctional::Contrast`] + penalty debiasing);
110/// 3. absorbs the studentized contrast as a calibrated anytime-valid e-value
111/// into the step's change claim under the no-change null.
112pub fn checkpoint_atom_dynamics(
113 input: &CheckpointDynamicsInput<'_>,
114) -> Result<Vec<AtomTrajectory>, String> {
115 let shape = input.decoder_grid.shape();
116 let (n_checkpoints, n_atoms, n_grid, ambient_dim) = (shape[0], shape[1], shape[2], shape[3]);
117 if n_checkpoints < 2 {
118 return Err(format!(
119 "checkpoint dynamics needs at least two checkpoints, got {n_checkpoints}"
120 ));
121 }
122 if input.checkpoint_ids.len() != n_checkpoints {
123 return Err(format!(
124 "checkpoint_ids length {} disagrees with decoder grid checkpoint axis {n_checkpoints}",
125 input.checkpoint_ids.len()
126 ));
127 }
128 if input.atom_names.len() != n_atoms {
129 return Err(format!(
130 "atom_names length {} disagrees with decoder grid atom axis {n_atoms}",
131 input.atom_names.len()
132 ));
133 }
134 if input.latent_grid.len() != n_grid {
135 return Err(format!(
136 "latent_grid length {} disagrees with decoder grid latent axis {n_grid}",
137 input.latent_grid.len()
138 ));
139 }
140 if n_grid < 2 || ambient_dim == 0 {
141 return Err(format!(
142 "checkpoint dynamics needs a non-trivial grid ({n_grid}) and ambient dim ({ambient_dim})"
143 ));
144 }
145 if input.decoder_grid.iter().any(|v| !v.is_finite()) {
146 return Err("checkpoint dynamics decoder grid must be finite".to_string());
147 }
148 if input.latent_grid.iter().any(|v| !v.is_finite()) {
149 return Err("checkpoint dynamics latent grid must be finite".to_string());
150 }
151
152 // The mode index: the latent-grid node where the contrast is evaluated.
153 // Use the central node so it sits inside any chart and away from edge
154 // interpolation artifacts.
155 let mode_index = n_grid / 2;
156
157 // Identity interpolation design `X = I_{n_grid}` and its ridge penalty
158 // `S = I`. The penalized Hessian `H = XᵀX + λS = (1 + λ) I` is shared by
159 // every component fit, so it is built once.
160 let penalty_scale = 1.0 + GRID_FIT_RIDGE;
161 let mut hessian = Array2::<f64>::zeros((n_grid, n_grid));
162 for i in 0..n_grid {
163 hessian[[i, i]] = penalty_scale;
164 }
165 // Contrast design rows pick out the mode node: `m(t_mode) = β_{mode}`, so
166 // the value-design row is the mode basis vector. The contrast `a − b`
167 // (later checkpoint minus earlier) shares the same row; the per-checkpoint
168 // distinction is carried by the two fitted coefficient vectors, exactly as
169 // a paired contrast of the same functional across two fits.
170 let mut mode_row = Array1::<f64>::zeros(n_grid);
171 mode_row[mode_index] = 1.0;
172
173 let mut trajectories = Vec::with_capacity(n_atoms);
174 for atom in 0..n_atoms {
175 let atom_name = input.atom_names[atom].clone();
176 let mut step_contrasts = Vec::with_capacity(n_checkpoints - 1);
177 let mut transports = Vec::with_capacity(n_checkpoints - 1);
178 let mut change_evidence = StructureLedger::new();
179
180 for step in 0..n_checkpoints - 1 {
181 let c0 = step;
182 let c1 = step + 1;
183
184 // --- transport map across the checkpoint axis --------------------
185 // Reuse the latent grid itself as both charts' coordinates on an
186 // interval `[min, max]`; the transport fit aligns the two
187 // checkpoints' decoder curves through their shared latent index.
188 // The "from"/"to" coordinates are the decoder values projected to
189 // the first ambient component, the available scalar chart sample.
190 let coords_from = input
191 .decoder_grid
192 .slice(ndarray::s![c0, atom, .., 0])
193 .to_owned();
194 let coords_to = input
195 .decoder_grid
196 .slice(ndarray::s![c1, atom, .., 0])
197 .to_owned();
198 let (lo, hi) = interval_bounds(coords_from.view(), coords_to.view());
199 let topology = ChartTopology::Interval { lo, hi };
200 let transport = fit_layer_transport(
201 c0,
202 c1,
203 coords_from.view(),
204 coords_to.view(),
205 topology,
206 topology,
207 )
208 .map_err(|e| {
209 format!(
210 "checkpoint transport for atom '{atom_name}' step {} → {} failed: {e}",
211 input.checkpoint_ids[c0], input.checkpoint_ids[c1]
212 )
213 })?;
214 transports.push(transport);
215
216 // --- Riesz-debiased decoder-displacement contrast at the mode ----
217 let report = contrast_at_mode(&ContrastAtMode {
218 grid: input.decoder_grid,
219 atom,
220 c0,
221 c1,
222 ambient_dim,
223 n_grid,
224 hessian: hessian.view(),
225 mode_row: mode_row.view(),
226 })
227 .map_err(|e| {
228 format!(
229 "checkpoint contrast for atom '{atom_name}' step {} → {} failed: {e}",
230 input.checkpoint_ids[c0], input.checkpoint_ids[c1]
231 )
232 })?;
233
234 // --- anytime-valid evidence the atom changed at this step --------
235 // The debiased displacement `θ̂` with SE `se` studentizes to
236 // `z = θ̂ / se` (local Gaussian `θ̂ ~ N(θ, se²)`). Its two-sided
237 // p-value run through the frozen κ = ½ p→e calibrator is a genuine
238 // e-value for the per-step no-change null θ = 0 — `E_{H0}[E] ≤ 1`,
239 // which the naive in-sample `exp(½ z²)` ratio is NOT (it diverges
240 // under H0). One e-value per step into a per-step claim; the
241 // calibrator's contract (one e-value per independent batch) is met
242 // because each step is a distinct checkpoint transition.
243 let claim = change_evidence.register(ClaimKind::Custom {
244 label: format!(
245 "atom '{atom_name}' changed from checkpoint {} to {}",
246 input.checkpoint_ids[c0], input.checkpoint_ids[c1]
247 ),
248 });
249 let log_e = no_change_log_e_value(report.theta_onestep, report.se)?;
250 change_evidence.absorb_log(claim, log_e)?;
251
252 step_contrasts.push(report);
253 }
254
255 trajectories.push(AtomTrajectory {
256 atom_name,
257 conditional_step_contrasts: step_contrasts,
258 transports,
259 change_evidence,
260 });
261 }
262
263 Ok(trajectories)
264}
265
266/// Interval bounds spanning both checkpoints' scalar chart samples, padded so
267/// the transport basis domain strictly contains the data.
268fn interval_bounds(a: ArrayView1<'_, f64>, b: ArrayView1<'_, f64>) -> (f64, f64) {
269 let mut lo = f64::INFINITY;
270 let mut hi = f64::NEG_INFINITY;
271 for &v in a.iter().chain(b.iter()) {
272 lo = lo.min(v);
273 hi = hi.max(v);
274 }
275 if !(lo.is_finite() && hi.is_finite()) {
276 return (0.0, 1.0);
277 }
278 if hi <= lo {
279 // Degenerate (constant) chart: open a unit window around the value.
280 return (lo - 0.5, lo + 0.5);
281 }
282 let pad = (hi - lo) * 1e-6;
283 (lo - pad, hi + pad)
284}
285
286/// Debiased `g^{(c1)}(t_mode) − g^{(c0)}(t_mode)` aggregated over the ambient
287/// dimension into the scalar decoder-displacement size, with a delta-method SE.
288///
289/// Each ambient component is an independent identity-basis ridge fit of the
290/// grid values; the [`SmoothFunctional::Contrast`] of the two checkpoints'
291/// fitted coefficient vectors at the mode node is debiased component-wise via
292/// the Riesz one-step. The component contrasts form a vector `Δ ∈ ℝ^ambient`;
293/// the reported scalar `θ = ‖Δ‖₂` is the displacement size and its SE is the
294/// delta-method norm-gradient `‖Δ‖₂` propagation of the per-component SEs,
295/// assuming component independence (separate fits, separate scores).
296struct ContrastAtMode<'a> {
297 grid: ArrayView4<'a, f64>,
298 atom: usize,
299 c0: usize,
300 c1: usize,
301 ambient_dim: usize,
302 n_grid: usize,
303 hessian: ndarray::ArrayView2<'a, f64>,
304 mode_row: ArrayView1<'a, f64>,
305}
306
307fn contrast_at_mode(args: &ContrastAtMode<'_>) -> Result<RieszDebiasReport, String> {
308 let grid = args.grid;
309 let atom = args.atom;
310 let c0 = args.c0;
311 let c1 = args.c1;
312 let ambient_dim = args.ambient_dim;
313 let n_grid = args.n_grid;
314 let hessian = args.hessian;
315 let mode_row = args.mode_row;
316 // Aggregate scalar contrast Δ = θ_c1 − θ_c0 across ambient components, and
317 // the matching aggregate Riesz quantities, so a single RieszDebiasReport
318 // describes the displacement. We assemble the report from one debiasing per
319 // component and combine through the L2 norm.
320 let mut delta = Array1::<f64>::zeros(ambient_dim);
321 let mut delta_one = Array1::<f64>::zeros(ambient_dim);
322 let mut var_components = Array1::<f64>::zeros(ambient_dim);
323 let mut penalty_bias_acc = 0.0_f64;
324 // A representer to carry: reuse the last component's; the scalar norm
325 // estimate's representer is component-wise so we keep the final one as the
326 // canonical witness (its influence vector studentizes the norm contrast).
327 let mut witness: Option<RieszDebiasReport> = None;
328
329 for comp in 0..ambient_dim {
330 // Per-checkpoint identity-basis ridge fit: response y = grid values,
331 // design X = I, penalty S = I. With H = (1+λ)I the fitted coefficient
332 // is β = y / (1 + λ); the per-node score is sᵢ = (μ̂ᵢ − yᵢ)·eᵢ where
333 // μ̂ = Xβ = β, and the penalty gradient is S·β = β.
334 let y0 = grid.slice(ndarray::s![c0, atom, .., comp]).to_owned();
335 let y1 = grid.slice(ndarray::s![c1, atom, .., comp]).to_owned();
336 let report = component_contrast(y0.view(), y1.view(), n_grid, hessian, mode_row)?;
337
338 delta[comp] = report.theta_plugin;
339 delta_one[comp] = report.theta_onestep;
340 var_components[comp] = report.se * report.se;
341 penalty_bias_acc += report.penalty_bias * report.penalty_bias;
342 witness = Some(report);
343 }
344
345 let theta_plugin = delta.dot(&delta).sqrt();
346 let norm_one = delta_one.dot(&delta_one).sqrt();
347 // Delta method for θ = ‖Δ‖₂: ∂θ/∂Δ_k = Δ_k / ‖Δ‖₂, components independent,
348 // so var(θ) = Σ_k (Δ_k/‖Δ‖₂)² var(Δ_k).
349 let se = if norm_one > f64::MIN_POSITIVE {
350 let mut v = 0.0_f64;
351 for comp in 0..ambient_dim {
352 let g = delta_one[comp] / norm_one;
353 v += g * g * var_components[comp];
354 }
355 v.max(0.0).sqrt()
356 } else {
357 // At a null displacement the norm is non-differentiable; fall back to
358 // the RMS of the component SEs (an honest upper-ish bound on the size).
359 (var_components.sum() / ambient_dim as f64).sqrt()
360 };
361
362 let mut report = witness
363 .ok_or_else(|| "checkpoint contrast requires at least one ambient component".to_string())?;
364 report.theta_plugin = theta_plugin;
365 report.theta_onestep = norm_one;
366 report.se = se;
367 report.penalty_bias = penalty_bias_acc.sqrt();
368 Ok(report)
369}
370
371/// One ambient component's debiased contrast `g^{(c1)}(t_mode) −
372/// g^{(c0)}(t_mode)` through the Riesz one-step.
373fn component_contrast(
374 y0: ArrayView1<'_, f64>,
375 y1: ArrayView1<'_, f64>,
376 n_grid: usize,
377 hessian: ndarray::ArrayView2<'_, f64>,
378 mode_row: ArrayView1<'_, f64>,
379) -> Result<RieszDebiasReport, String> {
380 // Stacked paired-contrast trick: the contrast `m_{c1}(t₀) − m_{c0}(t₀)` is
381 // the difference of one linear functional applied to two coefficient
382 // vectors. Riesz operates on a single fit, so we debias on the DIFFERENCE
383 // fit β_Δ = β_{c1} − β_{c0}, whose response is y₁ − y₀ on the same identity
384 // basis — a genuine fit with the same penalized Hessian. The contrast
385 // functional on β_Δ is then the point evaluation at the mode, packaged via
386 // SmoothFunctional::Contrast against a zero row so the gradient is the mode
387 // row exactly (g = mode_row − 0).
388 let beta0 = y0.mapv(|v| v / (1.0 + GRID_FIT_RIDGE));
389 let beta1 = y1.mapv(|v| v / (1.0 + GRID_FIT_RIDGE));
390 let beta_delta = &beta1 - &beta0;
391
392 let zero_row = Array1::<f64>::zeros(n_grid);
393 let functional = SmoothFunctional::Contrast {
394 design_row_a: mode_row,
395 design_row_b: zero_row.view(),
396 };
397 let gradient = functional
398 .gradient()
399 .map_err(|e| format!("contrast functional gradient failed: {e}"))?;
400
401 // Per-node scores of the difference fit: μ̂ = X β_Δ = β_Δ, response y₁−y₀.
402 let response = &y1.to_owned() - &y0;
403 let mut row_scores = Array2::<f64>::zeros((n_grid, n_grid));
404 for i in 0..n_grid {
405 row_scores[[i, i]] = beta_delta[i] - response[i];
406 }
407 // Penalty gradient S·β_Δ = β_Δ (S = I).
408 let penalty_beta = beta_delta.clone();
409
410 let input = RieszInput {
411 beta: beta_delta.view(),
412 functional_gradient: gradient.view(),
413 row_scores: row_scores.view(),
414 penalty_beta: penalty_beta.view(),
415 leverage: None,
416 };
417 debias_with_dense_hessian(&input, hessian).map_err(|e| format!("Riesz debiasing failed: {e}"))
418}
419
420/// Anytime-valid log-e-value for the no-change null `θ = 0` from the debiased,
421/// studentized displacement `z = θ̂ / se` (local Gaussian `θ̂ ~ N(θ, se²)`).
422///
423/// The naive in-sample likelihood ratio `exp(½ z²)` — the alternative density
424/// re-centered on the very estimate `θ̂` it is scored against — is NOT an
425/// e-value: under H0, `z ~ N(0,1)` and `E[exp(½ z²)] = ∫ φ(z) exp(½ z²) dz`
426/// DIVERGES, so it has no `E_{H0}[E] ≤ 1` guarantee. (Universal inference earns
427/// `exp(½ z²)` validity only with a held-out evaluation fold; a single grid of
428/// decoder values affords no such split.)
429///
430/// Instead we map the displacement to its two-sided normal p-value
431/// `p = 2(1 − Φ(|z|))` and route it through the module's frozen p→e calibrator
432/// [`log_e_from_p_calibrator`] (the κ = ½ member `e(p) = ½ p^{−1/2}`, with
433/// `∫₀¹ e(p) dp = 1`, hence `E_{H0}[e(P)] ≤ 1` for any superuniform p). This is
434/// a genuine e-value: no displacement, small e; a real displacement, large e;
435/// and it compounds validly into the change e-process. A degenerate
436/// (non-positive) SE yields a zero log-e-value (no evidence, not certainty).
437fn no_change_log_e_value(theta_hat: f64, se: f64) -> Result<f64, String> {
438 if !(se > 0.0) || !theta_hat.is_finite() {
439 return Ok(0.0);
440 }
441 let z = (theta_hat / se).abs();
442 let normal =
443 Normal::new(0.0, 1.0).map_err(|e| format!("standard normal construction failed: {e}"))?;
444 // Two-sided p-value of the studentized displacement; clamp to (0, 1] so the
445 // calibrator (which rejects p = 0) sees a finite, valid argument even at a
446 // numerically saturated tail.
447 let p: f64 = (2.0 * (1.0 - normal.cdf(z))).clamp(f64::MIN_POSITIVE, 1.0);
448 log_e_from_p_calibrator(p)
449}
450
451#[cfg(test)]
452mod tests {
453 use super::*;
454 use ndarray::Array4;
455
456 /// Build a `[n_ckpt, n_atoms, n_grid, ambient]` grid where atom 0's curve is
457 /// constant across checkpoints (no change) and atom 1's curve at the central
458 /// (mode) node is displaced by a known amount `shift` in component 0 between
459 /// consecutive checkpoints (a steady drift).
460 fn drift_grid(n_ckpt: usize, n_grid: usize, ambient: usize, shift: f64) -> Array4<f64> {
461 let mode = n_grid / 2;
462 let mut grid = Array4::<f64>::zeros((n_ckpt, 2, n_grid, ambient));
463 for c in 0..n_ckpt {
464 for g in 0..n_grid {
465 let t = g as f64 / (n_grid - 1) as f64;
466 for comp in 0..ambient {
467 // Atom 0: smooth bump, identical at every checkpoint.
468 grid[[c, 0, g, comp]] = (t * std::f64::consts::PI).sin() * (comp as f64 + 1.0);
469 // Atom 1: same base curve plus a checkpoint-indexed shift at
470 // the mode node in component 0 only.
471 let base = (t * std::f64::consts::PI).sin() * (comp as f64 + 1.0);
472 grid[[c, 1, g, comp]] = if g == mode && comp == 0 {
473 base + shift * c as f64
474 } else {
475 base
476 };
477 }
478 }
479 }
480 grid
481 }
482
483 #[test]
484 fn no_change_atom_has_near_zero_contrast_and_no_change_evidence() {
485 let n_ckpt = 5;
486 // The transport fit requires at least MIN_TRANSPORT_OBS (16) paired
487 // grid samples, so the shared latent grid must be at least that long.
488 let n_grid = 17;
489 let ambient = 3;
490 let grid = drift_grid(n_ckpt, n_grid, ambient, 0.5);
491 let latent: Array1<f64> = Array1::linspace(0.0, 1.0, n_grid);
492 let ckpt_ids: Vec<String> = (0..n_ckpt).map(|c| format!("dev{c}")).collect();
493 let atom_names = vec!["constant".to_string(), "drifter".to_string()];
494 let input = CheckpointDynamicsInput {
495 decoder_grid: grid.view(),
496 checkpoint_ids: &ckpt_ids,
497 atom_names: &atom_names,
498 latent_grid: latent.view(),
499 };
500 let traj = checkpoint_atom_dynamics(&input).expect("dynamics");
501 assert_eq!(traj.len(), 2);
502
503 // Atom 0 is identical across checkpoints: every step contrast must be
504 // (numerically) zero displacement and accumulate no change evidence.
505 let constant = &traj[0];
506 assert_eq!(constant.conditional_step_contrasts.len(), n_ckpt - 1);
507 for report in &constant.conditional_step_contrasts {
508 assert!(
509 report.theta_onestep.abs() < 1e-9,
510 "constant atom step displacement should be ~0, got {}",
511 report.theta_onestep
512 );
513 }
514 // No-change null is true here → the e-BH certificate confirms nothing.
515 let cert = constant.change_evidence.certify(0.05);
516 assert!(
517 cert.confirmed().count() == 0,
518 "constant atom must not confirm any change claim"
519 );
520 }
521
522 #[test]
523 fn drifting_atom_recovers_displacement_and_accumulates_change_evidence() {
524 let n_ckpt = 6;
525 let n_grid = 17;
526 let ambient = 3;
527 let shift = 0.7_f64;
528 let grid = drift_grid(n_ckpt, n_grid, ambient, shift);
529 let latent: Array1<f64> = Array1::linspace(0.0, 1.0, n_grid);
530 let ckpt_ids: Vec<String> = (0..n_ckpt).map(|c| format!("dev{c}")).collect();
531 let atom_names = vec!["constant".to_string(), "drifter".to_string()];
532 let input = CheckpointDynamicsInput {
533 decoder_grid: grid.view(),
534 checkpoint_ids: &ckpt_ids,
535 atom_names: &atom_names,
536 latent_grid: latent.view(),
537 };
538 let traj = checkpoint_atom_dynamics(&input).expect("dynamics");
539 let drifter = &traj[1];
540
541 // Each consecutive step displaces component 0 at the mode by exactly
542 // `shift`; the reported displacement size is `‖Δ‖₂`. On the light
543 // interpolation ridge (λ = GRID_FIT_RIDGE ≈ 1e-3) the plug-in contrast
544 // `shift/(1+λ)` tracks the true displacement to sub-percent, and every
545 // reported quantity is finite. (The component displacement lives in a
546 // single ambient channel, so the L2 size IS that channel's contrast.)
547 for report in &drifter.conditional_step_contrasts {
548 assert!(
549 (report.theta_plugin - shift).abs() < 1e-2 * shift,
550 "drift step plug-in displacement should track {shift}, got {}",
551 report.theta_plugin
552 );
553 assert!(
554 report.theta_onestep.is_finite() && report.se.is_finite(),
555 "debiased displacement and SE must be finite"
556 );
557 // The displacement is unambiguously positive (a real change).
558 assert!(
559 report.theta_plugin > 0.5 * shift,
560 "drift displacement should be well above zero, got {}",
561 report.theta_plugin
562 );
563 }
564
565 // The drift is real → every step's no-change e-value is strictly
566 // positive (studentized displacement away from zero), so the change
567 // certificate carries strictly positive log-evidence on its claims,
568 // unlike the constant atom whose claims carry exactly zero.
569 let cert = drifter.change_evidence.certify(0.05);
570 let total_log_e: f64 = cert.entries.iter().map(|e| e.log_e).sum();
571 assert!(
572 total_log_e > 0.0,
573 "steady real drift must accumulate positive change evidence, entries: {:?}",
574 cert.entries
575 .iter()
576 .map(|e| (e.log_e, e.confirmed))
577 .collect::<Vec<_>>()
578 );
579 }
580
581 /// A drifting atom must out-evidence a constant atom: the change e-process
582 /// is a genuine discriminator, not a constant.
583 #[test]
584 fn drift_outweighs_constant_in_change_evidence() {
585 let n_ckpt = 6;
586 let n_grid = 17;
587 let ambient = 3;
588 let grid = drift_grid(n_ckpt, n_grid, ambient, 0.7);
589 let latent: Array1<f64> = Array1::linspace(0.0, 1.0, n_grid);
590 let ckpt_ids: Vec<String> = (0..n_ckpt).map(|c| format!("dev{c}")).collect();
591 let atom_names = vec!["constant".to_string(), "drifter".to_string()];
592 let input = CheckpointDynamicsInput {
593 decoder_grid: grid.view(),
594 checkpoint_ids: &ckpt_ids,
595 atom_names: &atom_names,
596 latent_grid: latent.view(),
597 };
598 let traj = checkpoint_atom_dynamics(&input).expect("dynamics");
599 let const_log_e: f64 = traj[0]
600 .change_evidence
601 .certify(0.05)
602 .entries
603 .iter()
604 .map(|e| e.log_e)
605 .sum();
606 let drift_log_e: f64 = traj[1]
607 .change_evidence
608 .certify(0.05)
609 .entries
610 .iter()
611 .map(|e| e.log_e)
612 .sum();
613 assert!(
614 drift_log_e > const_log_e,
615 "drift change-evidence {drift_log_e} must exceed constant {const_log_e}"
616 );
617 }
618
619 #[test]
620 fn rejects_single_checkpoint_and_axis_mismatch() {
621 let grid = Array4::<f64>::zeros((1, 2, 5, 3));
622 let latent: Array1<f64> = Array1::linspace(0.0, 1.0, 5);
623 let ids = vec!["only".to_string()];
624 let names = vec!["a".to_string(), "b".to_string()];
625 let input = CheckpointDynamicsInput {
626 decoder_grid: grid.view(),
627 checkpoint_ids: &ids,
628 atom_names: &names,
629 latent_grid: latent.view(),
630 };
631 assert!(checkpoint_atom_dynamics(&input).is_err());
632 }
633}