use crate::inference::layer_transport::{ChartTopology, LayerTransportReport, fit_layer_transport};
use crate::inference::riesz::{
RieszDebiasReport, RieszInput, SmoothFunctional, debias_with_dense_hessian,
};
use crate::inference::structure_evidence::{ClaimKind, StructureLedger, log_e_from_p_calibrator};
use ndarray::{Array1, Array2, ArrayView1, ArrayView4};
use statrs::distribution::{ContinuousCDF, Normal};
const GRID_FIT_RIDGE: f64 = 1e-3;
pub struct CheckpointDynamicsInput<'a> {
pub decoder_grid: ArrayView4<'a, f64>,
pub checkpoint_ids: &'a [String],
pub atom_names: &'a [String],
pub latent_grid: ArrayView1<'a, f64>,
}
pub struct AtomTrajectory {
pub atom_name: String,
pub conditional_step_contrasts: Vec<RieszDebiasReport>,
pub transports: Vec<LayerTransportReport>,
pub change_evidence: StructureLedger,
}
pub fn checkpoint_atom_dynamics(
input: &CheckpointDynamicsInput<'_>,
) -> Result<Vec<AtomTrajectory>, String> {
let shape = input.decoder_grid.shape();
let (n_checkpoints, n_atoms, n_grid, ambient_dim) = (shape[0], shape[1], shape[2], shape[3]);
if n_checkpoints < 2 {
return Err(format!(
"checkpoint dynamics needs at least two checkpoints, got {n_checkpoints}"
));
}
if input.checkpoint_ids.len() != n_checkpoints {
return Err(format!(
"checkpoint_ids length {} disagrees with decoder grid checkpoint axis {n_checkpoints}",
input.checkpoint_ids.len()
));
}
if input.atom_names.len() != n_atoms {
return Err(format!(
"atom_names length {} disagrees with decoder grid atom axis {n_atoms}",
input.atom_names.len()
));
}
if input.latent_grid.len() != n_grid {
return Err(format!(
"latent_grid length {} disagrees with decoder grid latent axis {n_grid}",
input.latent_grid.len()
));
}
if n_grid < 2 || ambient_dim == 0 {
return Err(format!(
"checkpoint dynamics needs a non-trivial grid ({n_grid}) and ambient dim ({ambient_dim})"
));
}
if input.decoder_grid.iter().any(|v| !v.is_finite()) {
return Err("checkpoint dynamics decoder grid must be finite".to_string());
}
if input.latent_grid.iter().any(|v| !v.is_finite()) {
return Err("checkpoint dynamics latent grid must be finite".to_string());
}
let mode_index = n_grid / 2;
let penalty_scale = 1.0 + GRID_FIT_RIDGE;
let mut hessian = Array2::<f64>::zeros((n_grid, n_grid));
for i in 0..n_grid {
hessian[[i, i]] = penalty_scale;
}
let mut mode_row = Array1::<f64>::zeros(n_grid);
mode_row[mode_index] = 1.0;
let mut trajectories = Vec::with_capacity(n_atoms);
for atom in 0..n_atoms {
let atom_name = input.atom_names[atom].clone();
let mut step_contrasts = Vec::with_capacity(n_checkpoints - 1);
let mut transports = Vec::with_capacity(n_checkpoints - 1);
let mut change_evidence = StructureLedger::new();
for step in 0..n_checkpoints - 1 {
let c0 = step;
let c1 = step + 1;
let coords_from = input
.decoder_grid
.slice(ndarray::s![c0, atom, .., 0])
.to_owned();
let coords_to = input
.decoder_grid
.slice(ndarray::s![c1, atom, .., 0])
.to_owned();
let (lo, hi) = interval_bounds(coords_from.view(), coords_to.view());
let topology = ChartTopology::Interval { lo, hi };
let transport = fit_layer_transport(
c0,
c1,
coords_from.view(),
coords_to.view(),
topology,
topology,
)
.map_err(|e| {
format!(
"checkpoint transport for atom '{atom_name}' step {} → {} failed: {e}",
input.checkpoint_ids[c0], input.checkpoint_ids[c1]
)
})?;
transports.push(transport);
let report = contrast_at_mode(&ContrastAtMode {
grid: input.decoder_grid,
atom,
c0,
c1,
ambient_dim,
n_grid,
hessian: hessian.view(),
mode_row: mode_row.view(),
})
.map_err(|e| {
format!(
"checkpoint contrast for atom '{atom_name}' step {} → {} failed: {e}",
input.checkpoint_ids[c0], input.checkpoint_ids[c1]
)
})?;
let claim = change_evidence.register(ClaimKind::Custom {
label: format!(
"atom '{atom_name}' changed from checkpoint {} to {}",
input.checkpoint_ids[c0], input.checkpoint_ids[c1]
),
});
let log_e = no_change_log_e_value(report.theta_onestep, report.se)?;
change_evidence.absorb_log(claim, log_e)?;
step_contrasts.push(report);
}
trajectories.push(AtomTrajectory {
atom_name,
conditional_step_contrasts: step_contrasts,
transports,
change_evidence,
});
}
Ok(trajectories)
}
fn interval_bounds(a: ArrayView1<'_, f64>, b: ArrayView1<'_, f64>) -> (f64, f64) {
let mut lo = f64::INFINITY;
let mut hi = f64::NEG_INFINITY;
for &v in a.iter().chain(b.iter()) {
lo = lo.min(v);
hi = hi.max(v);
}
if !(lo.is_finite() && hi.is_finite()) {
return (0.0, 1.0);
}
if hi <= lo {
return (lo - 0.5, lo + 0.5);
}
let pad = (hi - lo) * 1e-6;
(lo - pad, hi + pad)
}
struct ContrastAtMode<'a> {
grid: ArrayView4<'a, f64>,
atom: usize,
c0: usize,
c1: usize,
ambient_dim: usize,
n_grid: usize,
hessian: ndarray::ArrayView2<'a, f64>,
mode_row: ArrayView1<'a, f64>,
}
fn contrast_at_mode(args: &ContrastAtMode<'_>) -> Result<RieszDebiasReport, String> {
let grid = args.grid;
let atom = args.atom;
let c0 = args.c0;
let c1 = args.c1;
let ambient_dim = args.ambient_dim;
let n_grid = args.n_grid;
let hessian = args.hessian;
let mode_row = args.mode_row;
let mut delta = Array1::<f64>::zeros(ambient_dim);
let mut delta_one = Array1::<f64>::zeros(ambient_dim);
let mut var_components = Array1::<f64>::zeros(ambient_dim);
let mut penalty_bias_acc = 0.0_f64;
let mut witness: Option<RieszDebiasReport> = None;
for comp in 0..ambient_dim {
let y0 = grid.slice(ndarray::s![c0, atom, .., comp]).to_owned();
let y1 = grid.slice(ndarray::s![c1, atom, .., comp]).to_owned();
let report = component_contrast(y0.view(), y1.view(), n_grid, hessian, mode_row)?;
delta[comp] = report.theta_plugin;
delta_one[comp] = report.theta_onestep;
var_components[comp] = report.se * report.se;
penalty_bias_acc += report.penalty_bias * report.penalty_bias;
witness = Some(report);
}
let theta_plugin = delta.dot(&delta).sqrt();
let norm_one = delta_one.dot(&delta_one).sqrt();
let se = if norm_one > f64::MIN_POSITIVE {
let mut v = 0.0_f64;
for comp in 0..ambient_dim {
let g = delta_one[comp] / norm_one;
v += g * g * var_components[comp];
}
v.max(0.0).sqrt()
} else {
(var_components.sum() / ambient_dim as f64).sqrt()
};
let mut report = witness
.ok_or_else(|| "checkpoint contrast requires at least one ambient component".to_string())?;
report.theta_plugin = theta_plugin;
report.theta_onestep = norm_one;
report.se = se;
report.penalty_bias = penalty_bias_acc.sqrt();
Ok(report)
}
fn component_contrast(
y0: ArrayView1<'_, f64>,
y1: ArrayView1<'_, f64>,
n_grid: usize,
hessian: ndarray::ArrayView2<'_, f64>,
mode_row: ArrayView1<'_, f64>,
) -> Result<RieszDebiasReport, String> {
let beta0 = y0.mapv(|v| v / (1.0 + GRID_FIT_RIDGE));
let beta1 = y1.mapv(|v| v / (1.0 + GRID_FIT_RIDGE));
let beta_delta = &beta1 - &beta0;
let zero_row = Array1::<f64>::zeros(n_grid);
let functional = SmoothFunctional::Contrast {
design_row_a: mode_row,
design_row_b: zero_row.view(),
};
let gradient = functional
.gradient()
.map_err(|e| format!("contrast functional gradient failed: {e}"))?;
let response = &y1.to_owned() - &y0;
let mut row_scores = Array2::<f64>::zeros((n_grid, n_grid));
for i in 0..n_grid {
row_scores[[i, i]] = beta_delta[i] - response[i];
}
let penalty_beta = beta_delta.clone();
let input = RieszInput {
beta: beta_delta.view(),
functional_gradient: gradient.view(),
row_scores: row_scores.view(),
penalty_beta: penalty_beta.view(),
leverage: None,
};
debias_with_dense_hessian(&input, hessian).map_err(|e| format!("Riesz debiasing failed: {e}"))
}
fn no_change_log_e_value(theta_hat: f64, se: f64) -> Result<f64, String> {
if !(se > 0.0) || !theta_hat.is_finite() {
return Ok(0.0);
}
let z = (theta_hat / se).abs();
let normal =
Normal::new(0.0, 1.0).map_err(|e| format!("standard normal construction failed: {e}"))?;
let p = (2.0 * (1.0 - normal.cdf(z))).clamp(f64::MIN_POSITIVE, 1.0);
log_e_from_p_calibrator(p)
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array4;
fn drift_grid(n_ckpt: usize, n_grid: usize, ambient: usize, shift: f64) -> Array4<f64> {
let mode = n_grid / 2;
let mut grid = Array4::<f64>::zeros((n_ckpt, 2, n_grid, ambient));
for c in 0..n_ckpt {
for g in 0..n_grid {
let t = g as f64 / (n_grid - 1) as f64;
for comp in 0..ambient {
grid[[c, 0, g, comp]] = (t * std::f64::consts::PI).sin() * (comp as f64 + 1.0);
let base = (t * std::f64::consts::PI).sin() * (comp as f64 + 1.0);
grid[[c, 1, g, comp]] = if g == mode && comp == 0 {
base + shift * c as f64
} else {
base
};
}
}
}
grid
}
#[test]
fn no_change_atom_has_near_zero_contrast_and_no_change_evidence() {
let n_ckpt = 5;
let n_grid = 17;
let ambient = 3;
let grid = drift_grid(n_ckpt, n_grid, ambient, 0.5);
let latent: Array1<f64> = Array1::linspace(0.0, 1.0, n_grid);
let ckpt_ids: Vec<String> = (0..n_ckpt).map(|c| format!("dev{c}")).collect();
let atom_names = vec!["constant".to_string(), "drifter".to_string()];
let input = CheckpointDynamicsInput {
decoder_grid: grid.view(),
checkpoint_ids: &ckpt_ids,
atom_names: &atom_names,
latent_grid: latent.view(),
};
let traj = checkpoint_atom_dynamics(&input).expect("dynamics");
assert_eq!(traj.len(), 2);
let constant = &traj[0];
assert_eq!(constant.conditional_step_contrasts.len(), n_ckpt - 1);
for report in &constant.conditional_step_contrasts {
assert!(
report.theta_onestep.abs() < 1e-9,
"constant atom step displacement should be ~0, got {}",
report.theta_onestep
);
}
let cert = constant.change_evidence.certify(0.05);
assert!(
cert.confirmed().count() == 0,
"constant atom must not confirm any change claim"
);
}
#[test]
fn drifting_atom_recovers_displacement_and_accumulates_change_evidence() {
let n_ckpt = 6;
let n_grid = 17;
let ambient = 3;
let shift = 0.7_f64;
let grid = drift_grid(n_ckpt, n_grid, ambient, shift);
let latent: Array1<f64> = Array1::linspace(0.0, 1.0, n_grid);
let ckpt_ids: Vec<String> = (0..n_ckpt).map(|c| format!("dev{c}")).collect();
let atom_names = vec!["constant".to_string(), "drifter".to_string()];
let input = CheckpointDynamicsInput {
decoder_grid: grid.view(),
checkpoint_ids: &ckpt_ids,
atom_names: &atom_names,
latent_grid: latent.view(),
};
let traj = checkpoint_atom_dynamics(&input).expect("dynamics");
let drifter = &traj[1];
for report in &drifter.conditional_step_contrasts {
assert!(
(report.theta_plugin - shift).abs() < 1e-2 * shift,
"drift step plug-in displacement should track {shift}, got {}",
report.theta_plugin
);
assert!(
report.theta_onestep.is_finite() && report.se.is_finite(),
"debiased displacement and SE must be finite"
);
assert!(
report.theta_plugin > 0.5 * shift,
"drift displacement should be well above zero, got {}",
report.theta_plugin
);
}
let cert = drifter.change_evidence.certify(0.05);
let total_log_e: f64 = cert.entries.iter().map(|e| e.log_e).sum();
assert!(
total_log_e > 0.0,
"steady real drift must accumulate positive change evidence, entries: {:?}",
cert.entries
.iter()
.map(|e| (e.log_e, e.confirmed))
.collect::<Vec<_>>()
);
}
#[test]
fn drift_outweighs_constant_in_change_evidence() {
let n_ckpt = 6;
let n_grid = 17;
let ambient = 3;
let grid = drift_grid(n_ckpt, n_grid, ambient, 0.7);
let latent: Array1<f64> = Array1::linspace(0.0, 1.0, n_grid);
let ckpt_ids: Vec<String> = (0..n_ckpt).map(|c| format!("dev{c}")).collect();
let atom_names = vec!["constant".to_string(), "drifter".to_string()];
let input = CheckpointDynamicsInput {
decoder_grid: grid.view(),
checkpoint_ids: &ckpt_ids,
atom_names: &atom_names,
latent_grid: latent.view(),
};
let traj = checkpoint_atom_dynamics(&input).expect("dynamics");
let const_log_e: f64 = traj[0]
.change_evidence
.certify(0.05)
.entries
.iter()
.map(|e| e.log_e)
.sum();
let drift_log_e: f64 = traj[1]
.change_evidence
.certify(0.05)
.entries
.iter()
.map(|e| e.log_e)
.sum();
assert!(
drift_log_e > const_log_e,
"drift change-evidence {drift_log_e} must exceed constant {const_log_e}"
);
}
#[test]
fn rejects_single_checkpoint_and_axis_mismatch() {
let grid = Array4::<f64>::zeros((1, 2, 5, 3));
let latent: Array1<f64> = Array1::linspace(0.0, 1.0, 5);
let ids = vec!["only".to_string()];
let names = vec!["a".to_string(), "b".to_string()];
let input = CheckpointDynamicsInput {
decoder_grid: grid.view(),
checkpoint_ids: &ids,
atom_names: &names,
latent_grid: latent.view(),
};
assert!(checkpoint_atom_dynamics(&input).is_err());
}
}