use ndarray::{Array1, Array2, ArrayView1};
use crate::encode::EncodeAtlas;
use crate::manifold::SaeManifoldTerm;
use gam_problem::{MetricProvenance, RowMetric};
use gam_terms::inference::structure_evidence::log_e_from_p_calibrator;
const STEER_PATH_STEPS: usize = 64;
const VALIDITY_DIVERGENCE_FRACTION: f64 = 0.1;
const ACTIVE_MASS_FLOOR: f64 = 1e-6;
#[derive(Clone, Debug, PartialEq)]
pub struct SteerPlan {
pub atom: usize,
pub atom_name: String,
pub t_from: Vec<f64>,
pub t_to: Vec<f64>,
pub amplitude: f64,
pub measured_row: usize,
pub delta: Array1<f64>,
pub predicted_nats: Option<f64>,
pub validity_radius: Option<f64>,
pub off_manifold_norm: f64,
pub metric_provenance: MetricProvenance,
}
#[derive(Clone, Debug)]
pub struct CoordinateSetResult {
pub edited: Array1<f64>,
pub t_from_certified: Array1<f64>,
pub encode_certificate: crate::encode::RowCertificate,
pub steer: SteerPlan,
}
pub fn set_coordinate(
model: &SaeManifoldTerm,
metric: &RowMetric,
atlas: &EncodeAtlas,
x: ArrayView1<'_, f64>,
atom_k: usize,
amplitude: f64,
t_to: &[f64],
) -> Result<CoordinateSetResult, String> {
let atom = model.atoms.get(atom_k).ok_or_else(|| {
format!(
"set_coordinate: atom index {atom_k} out of range (term has {} atoms)",
model.k_atoms()
)
})?;
if x.len() != atom.output_dim() {
return Err(format!(
"set_coordinate: input row has length {} but atom {atom_k} output_dim is {}",
x.len(),
atom.output_dim()
));
}
let (t_from, cert) = atlas.certified_encode_row(atom, atom_k, x, amplitude)?;
let steer = steer_delta_with_amplitude(
model,
metric,
atom_k,
t_from.as_slice().unwrap_or(&[]),
t_to,
amplitude,
)?;
let mut edited = x.to_owned();
if edited.len() != steer.delta.len() {
return Err(format!(
"set_coordinate: steering delta length {} does not match row length {}",
steer.delta.len(),
edited.len()
));
}
for i in 0..edited.len() {
edited[i] += steer.delta[i];
}
Ok(CoordinateSetResult {
edited,
t_from_certified: t_from,
encode_certificate: cert,
steer,
})
}
#[derive(Clone, Debug)]
pub struct InterchangeResult {
pub edited_target: Array1<f64>,
pub donor_t: Array1<f64>,
pub target_t_before: Array1<f64>,
pub target_t_after: Array1<f64>,
pub predicted_nats: Option<f64>,
pub off_manifold_norm: f64,
pub validity_radius: Option<f64>,
pub counterfactual_consistency_log_e: f64,
pub set_result: CoordinateSetResult,
}
pub fn interchange(
model: &SaeManifoldTerm,
metric: &RowMetric,
atlas: &EncodeAtlas,
x_target: ArrayView1<'_, f64>,
target_amplitude: f64,
x_source: ArrayView1<'_, f64>,
source_amplitude: f64,
atom_k: usize,
) -> Result<InterchangeResult, String> {
let atom = model.atoms.get(atom_k).ok_or_else(|| {
format!(
"interchange: atom index {atom_k} out of range (term has {} atoms)",
model.k_atoms()
)
})?;
let (donor_t, _donor_cert) =
atlas.certified_encode_row(atom, atom_k, x_source, source_amplitude)?;
let set = set_coordinate(
model,
metric,
atlas,
x_target,
atom_k,
target_amplitude,
donor_t.as_slice().unwrap_or(&[]),
)?;
let (target_t_after, _after_cert) =
atlas.certified_encode_row(atom, atom_k, set.edited.view(), target_amplitude)?;
let landing_error = l2_distance(donor_t.view(), target_t_after.view())?;
let scale = set
.steer
.validity_radius
.unwrap_or_else(|| {
l2_distance(set.t_from_certified.view(), donor_t.view())
.unwrap_or(1.0)
.max(1e-12)
})
.max(1e-12);
let z = (scale / landing_error.max(1e-12)).min(1.0e6);
let p_value = (-0.5 * z * z).exp().clamp(f64::MIN_POSITIVE, 1.0);
let log_e = log_e_from_p_calibrator(p_value)?;
Ok(InterchangeResult {
edited_target: set.edited.clone(),
donor_t,
target_t_before: set.t_from_certified.clone(),
target_t_after,
predicted_nats: set.steer.predicted_nats,
off_manifold_norm: set.steer.off_manifold_norm,
validity_radius: set.steer.validity_radius,
counterfactual_consistency_log_e: log_e,
set_result: set,
})
}
fn l2_distance(a: ArrayView1<'_, f64>, b: ArrayView1<'_, f64>) -> Result<f64, String> {
if a.len() != b.len() {
return Err(format!(
"coordinate distance length mismatch: {} vs {}",
a.len(),
b.len()
));
}
let mut ss = 0.0;
for i in 0..a.len() {
let r = a[i] - b[i];
ss += r * r;
}
Ok(ss.sqrt())
}
pub fn steer_delta(
model: &SaeManifoldTerm,
metric: &RowMetric,
atom_k: usize,
t_from: &[f64],
t_to: &[f64],
) -> Result<SteerPlan, String> {
steer_delta_impl(model, metric, atom_k, t_from, t_to, None)
}
fn steer_delta_with_amplitude(
model: &SaeManifoldTerm,
metric: &RowMetric,
atom_k: usize,
t_from: &[f64],
t_to: &[f64],
amplitude: f64,
) -> Result<SteerPlan, String> {
if !(amplitude.is_finite() && amplitude > 0.0) {
return Err(format!(
"steer_delta_with_amplitude: amplitude must be finite and positive, got {amplitude}"
));
}
steer_delta_impl(model, metric, atom_k, t_from, t_to, Some(amplitude))
}
fn steer_delta_impl(
model: &SaeManifoldTerm,
metric: &RowMetric,
atom_k: usize,
t_from: &[f64],
t_to: &[f64],
amplitude_override: Option<f64>,
) -> Result<SteerPlan, String> {
let k = model.k_atoms();
if atom_k >= k {
return Err(format!(
"steer_delta: atom index {atom_k} out of range (term has {k} atoms)"
));
}
let atom = &model.atoms[atom_k];
let d = atom.latent_dim;
let p = atom.output_dim();
if t_from.len() != d || t_to.len() != d {
return Err(format!(
"steer_delta: t_from/t_to must have length latent_dim={d}; got {} and {}",
t_from.len(),
t_to.len()
));
}
let evaluator = atom.basis_evaluator.as_ref().ok_or_else(|| {
format!(
"steer_delta: atom {atom_k} ('{}') has no installed basis evaluator; \
arbitrary-t decoder evaluation requires one",
atom.name
)
})?;
let assignments = model.assignment.assignments();
let n = model.n_obs();
let mut mass_sum = 0.0_f64;
let mut active_count = 0.0_f64;
let mut best_row = 0usize;
let mut best_mass = f64::NEG_INFINITY;
for row in 0..n {
let mass = assignments[[row, atom_k]];
if mass > best_mass {
best_mass = mass;
best_row = row;
}
if mass > ACTIVE_MASS_FLOOR {
mass_sum += mass;
active_count += 1.0;
}
}
let amplitude = amplitude_override.unwrap_or_else(|| {
if active_count > 0.0 {
mass_sum / active_count
} else {
1.0
}
});
let g_from = decode_at(evaluator.as_ref(), &atom.decoder_coefficients, t_from, p)?;
let g_to = decode_at(evaluator.as_ref(), &atom.decoder_coefficients, t_to, p)?;
let mut delta = Array1::<f64>::zeros(p);
for i in 0..p {
delta[i] = amplitude * (g_to[i] - g_from[i]);
}
let provenance = metric.provenance();
let behavior_available =
metric_carries_behavior(provenance) && metric.n_rows() == n && metric.p_out() == p;
let mut t_mid = vec![0.0_f64; d];
for a in 0..d {
t_mid[a] = 0.5 * (t_from[a] + t_to[a]);
}
let tangents =
decode_tangents_at(evaluator.as_ref(), &atom.decoder_coefficients, &t_mid, p, d)?;
let off_manifold_norm = off_manifold_residual_norm(&tangents, delta.view());
let (predicted_nats, validity_radius) = if !behavior_available {
(None, None)
} else {
let ctx = SteerContext {
evaluator: evaluator.as_ref(),
decoder: &atom.decoder_coefficients,
metric,
row: best_row,
p,
d,
amplitude,
};
let dose = path_integrated_dose(&ctx, t_from, t_to)?;
let radius = validity_radius(&ctx, t_from, t_to)?;
(Some(dose), Some(radius))
};
Ok(SteerPlan {
atom: atom_k,
atom_name: atom.name.clone(),
t_from: t_from.to_vec(),
t_to: t_to.to_vec(),
amplitude,
measured_row: best_row,
delta,
predicted_nats,
validity_radius,
off_manifold_norm,
metric_provenance: provenance,
})
}
pub fn predicted_response(
model: &SaeManifoldTerm,
atom_k: usize,
t_at: &[f64],
delta: ArrayView1<'_, f64>,
) -> Result<Array1<f64>, String> {
let k = model.k_atoms();
if atom_k >= k {
return Err(format!(
"predicted_response: atom index {atom_k} out of range (term has {k} atoms)"
));
}
let atom = &model.atoms[atom_k];
let d = atom.latent_dim;
let p = atom.output_dim();
if t_at.len() != d {
return Err(format!(
"predicted_response: t_at must have length latent_dim={d}; got {}",
t_at.len()
));
}
if delta.len() != p {
return Err(format!(
"predicted_response: delta must have length output_dim={p}; got {}",
delta.len()
));
}
let evaluator = atom.basis_evaluator.as_ref().ok_or_else(|| {
format!(
"predicted_response: atom {atom_k} ('{}') has no installed basis evaluator",
atom.name
)
})?;
let tangents = decode_tangents_at(evaluator.as_ref(), &atom.decoder_coefficients, t_at, p, d)?;
Ok(project_onto_tangent_span(&tangents, delta))
}
fn metric_carries_behavior(p: MetricProvenance) -> bool {
match p {
MetricProvenance::Euclidean => false,
MetricProvenance::OutputFisher { .. }
| MetricProvenance::OutputFisherDownstream { .. }
| MetricProvenance::BehavioralFisher { .. }
| MetricProvenance::WhitenedStructured { .. } => true,
}
}
fn decode_at(
evaluator: &dyn crate::manifold::SaeBasisEvaluator,
decoder: &Array2<f64>,
t: &[f64],
p: usize,
) -> Result<Array1<f64>, String> {
let d = t.len();
let coords = Array2::from_shape_vec((1, d), t.to_vec())
.map_err(|e| format!("steer_delta::decode_at: coord shape: {e}"))?;
let (phi, _jet) = evaluator.evaluate(coords.view())?;
let m = decoder.nrows();
if phi.ncols() != m {
return Err(format!(
"steer_delta::decode_at: evaluator returned {} basis cols but decoder has {m} rows",
phi.ncols()
));
}
let mut g = Array1::<f64>::zeros(p);
for basis_col in 0..m {
let phi_v = phi[[0, basis_col]];
if phi_v == 0.0 {
continue;
}
for out_col in 0..p {
g[out_col] += phi_v * decoder[[basis_col, out_col]];
}
}
Ok(g)
}
fn decode_tangents_at(
evaluator: &dyn crate::manifold::SaeBasisEvaluator,
decoder: &Array2<f64>,
t: &[f64],
p: usize,
d: usize,
) -> Result<Array2<f64>, String> {
let coords = Array2::from_shape_vec((1, d), t.to_vec())
.map_err(|e| format!("steer_delta::decode_tangents_at: coord shape: {e}"))?;
let (_phi, jet) = evaluator.evaluate(coords.view())?;
let m = decoder.nrows();
if jet.dim() != (1, m, d) {
return Err(format!(
"steer_delta::decode_tangents_at: evaluator jet {:?} != (1, {m}, {d})",
jet.dim()
));
}
let mut tang = Array2::<f64>::zeros((p, d));
for axis in 0..d {
for basis_col in 0..m {
let dphi = jet[[0, basis_col, axis]];
if dphi == 0.0 {
continue;
}
for out_col in 0..p {
tang[[out_col, axis]] += dphi * decoder[[basis_col, out_col]];
}
}
}
Ok(tang)
}
fn project_onto_tangent_span(tangents: &Array2<f64>, delta: ArrayView1<'_, f64>) -> Array1<f64> {
let p = tangents.nrows();
let d = tangents.ncols();
if d == 0 {
return Array1::<f64>::zeros(p);
}
let mut gram = Array2::<f64>::zeros((d, d));
let mut rhs = Array1::<f64>::zeros(d);
for a in 0..d {
let mut r = 0.0_f64;
for i in 0..p {
r += tangents[[i, a]] * delta[i];
}
rhs[a] = r;
for b in a..d {
let mut acc = 0.0_f64;
for i in 0..p {
acc += tangents[[i, a]] * tangents[[i, b]];
}
gram[[a, b]] = acc;
gram[[b, a]] = acc;
}
}
let trace: f64 = (0..d).map(|a| gram[[a, a]]).sum();
let jitter = if trace > 0.0 { 1e-12 * trace } else { 1e-12 };
for a in 0..d {
gram[[a, a]] += jitter;
}
let coeffs = solve_spd_small(&gram, &rhs);
let mut proj = Array1::<f64>::zeros(p);
for i in 0..p {
for a in 0..d {
proj[i] += tangents[[i, a]] * coeffs[a];
}
}
proj
}
fn off_manifold_residual_norm(tangents: &Array2<f64>, delta: ArrayView1<'_, f64>) -> f64 {
let proj = project_onto_tangent_span(tangents, delta);
let mut res_sq = 0.0_f64;
for i in 0..delta.len() {
let r = delta[i] - proj[i];
res_sq += r * r;
}
res_sq.max(0.0).sqrt()
}
fn solve_spd_small(gram: &Array2<f64>, rhs: &Array1<f64>) -> Array1<f64> {
let d = gram.nrows();
let mut l = Array2::<f64>::zeros((d, d));
for i in 0..d {
for j in 0..=i {
let mut sum = gram[[i, j]];
for k in 0..j {
sum -= l[[i, k]] * l[[j, k]];
}
if i == j {
if sum <= 0.0 {
return Array1::<f64>::zeros(d);
}
l[[i, j]] = sum.sqrt();
} else {
l[[i, j]] = sum / l[[j, j]];
}
}
}
let mut y = Array1::<f64>::zeros(d);
for i in 0..d {
let mut sum = rhs[i];
for k in 0..i {
sum -= l[[i, k]] * y[k];
}
y[i] = sum / l[[i, i]];
}
let mut x = Array1::<f64>::zeros(d);
for i in (0..d).rev() {
let mut sum = y[i];
for k in (i + 1)..d {
sum -= l[[k, i]] * x[k];
}
x[i] = sum / l[[i, i]];
}
x
}
struct SteerContext<'a> {
evaluator: &'a dyn crate::manifold::SaeBasisEvaluator,
decoder: &'a Array2<f64>,
metric: &'a RowMetric,
row: usize,
p: usize,
d: usize,
amplitude: f64,
}
fn path_integrated_dose(
ctx: &SteerContext<'_>,
t_from: &[f64],
t_to: &[f64],
) -> Result<f64, String> {
let d = ctx.d;
let p = ctx.p;
let steps = STEER_PATH_STEPS;
let dtau = 1.0 / steps as f64;
let mut dt = vec![0.0_f64; d];
for a in 0..d {
dt[a] = t_to[a] - t_from[a];
}
let mut acc = 0.0_f64;
let amp2 = ctx.amplitude * ctx.amplitude;
for s in 0..steps {
let tau_mid = (s as f64 + 0.5) * dtau;
let mut t_mid = vec![0.0_f64; d];
for a in 0..d {
t_mid[a] = t_from[a] + tau_mid * dt[a];
}
let tang = decode_tangents_at(ctx.evaluator, ctx.decoder, &t_mid, p, d)?;
let mut j_row = vec![0.0_f64; p * d];
for i in 0..p {
for a in 0..d {
j_row[i * d + a] = tang[[i, a]];
}
}
let g_ab = ctx.metric.pullback(ctx.row, &j_row, d);
let mut speed_sq = 0.0_f64;
for a in 0..d {
for b in 0..d {
speed_sq += dt[a] * g_ab[[a, b]] * dt[b];
}
}
acc += 0.5 * amp2 * speed_sq * dtau;
}
Ok(acc)
}
fn validity_radius(ctx: &SteerContext<'_>, t_from: &[f64], t_to: &[f64]) -> Result<f64, String> {
let d = ctx.d;
let p = ctx.p;
let full_len: f64 = t_from
.iter()
.zip(t_to.iter())
.map(|(&a, &b)| (b - a) * (b - a))
.sum::<f64>()
.sqrt();
if full_len == 0.0 {
return Ok(0.0);
}
let mut dt = vec![0.0_f64; d];
for a in 0..d {
dt[a] = t_to[a] - t_from[a];
}
let amp = ctx.amplitude;
let tang0 = decode_tangents_at(ctx.evaluator, ctx.decoder, t_from, p, d)?;
let mut v0 = Array1::<f64>::zeros(p);
for i in 0..p {
let mut acc = 0.0_f64;
for a in 0..d {
acc += tang0[[i, a]] * dt[a];
}
v0[i] = acc;
}
let lin_coeff = 0.5 * amp * amp * ctx.metric.fisher_mass(ctx.row, v0.view());
if !(lin_coeff > 0.0) {
return Ok(full_len);
}
let g_from = decode_at(ctx.evaluator, ctx.decoder, t_from, p)?;
let steps = STEER_PATH_STEPS;
for s in 0..steps {
let tau = (s as f64 + 1.0) / steps as f64;
let mut t_mid = vec![0.0_f64; d];
for a in 0..d {
t_mid[a] = t_from[a] + tau * dt[a];
}
let g_tau = decode_at(ctx.evaluator, ctx.decoder, &t_mid, p)?;
let mut chord = Array1::<f64>::zeros(p);
for i in 0..p {
chord[i] = amp * (g_tau[i] - g_from[i]);
}
let chord_kl = 0.5 * ctx.metric.fisher_mass(ctx.row, chord.view());
let lin_kl = tau * tau * lin_coeff;
let rel = (chord_kl - lin_kl).abs() / lin_kl;
if rel > VALIDITY_DIVERGENCE_FRACTION {
return Ok(tau * full_len);
}
}
Ok(full_len)
}