use ndarray::{Array1, Array2, ArrayView1};
use crate::inference::row_metric::{MetricProvenance, RowMetric};
use crate::terms::sae_manifold::SaeManifoldTerm;
const STEER_PATH_STEPS: usize = 64;
const VALIDITY_DIVERGENCE_FRACTION: f64 = 0.1;
const ACTIVE_MASS_FLOOR: f64 = 1e-6;
#[derive(Clone, Debug, PartialEq)]
pub struct SteerPlan {
pub atom: usize,
pub atom_name: String,
pub t_from: Vec<f64>,
pub t_to: Vec<f64>,
pub amplitude: f64,
pub measured_row: usize,
pub delta: Array1<f64>,
pub predicted_nats: Option<f64>,
pub validity_radius: Option<f64>,
pub off_manifold_norm: f64,
pub metric_provenance: MetricProvenance,
}
pub fn steer_delta(
model: &SaeManifoldTerm,
metric: &RowMetric,
atom_k: usize,
t_from: &[f64],
t_to: &[f64],
) -> Result<SteerPlan, String> {
let k = model.k_atoms();
if atom_k >= k {
return Err(format!(
"steer_delta: atom index {atom_k} out of range (term has {k} atoms)"
));
}
let atom = &model.atoms[atom_k];
let d = atom.latent_dim;
let p = atom.output_dim();
if t_from.len() != d || t_to.len() != d {
return Err(format!(
"steer_delta: t_from/t_to must have length latent_dim={d}; got {} and {}",
t_from.len(),
t_to.len()
));
}
let evaluator = atom.basis_evaluator.as_ref().ok_or_else(|| {
format!(
"steer_delta: atom {atom_k} ('{}') has no installed basis evaluator; \
arbitrary-t decoder evaluation requires one",
atom.name
)
})?;
let assignments = model.assignment.assignments();
let n = model.n_obs();
let mut mass_sum = 0.0_f64;
let mut active_count = 0.0_f64;
let mut best_row = 0usize;
let mut best_mass = f64::NEG_INFINITY;
for row in 0..n {
let mass = assignments[[row, atom_k]];
if mass > best_mass {
best_mass = mass;
best_row = row;
}
if mass > ACTIVE_MASS_FLOOR {
mass_sum += mass;
active_count += 1.0;
}
}
let amplitude = if active_count > 0.0 {
mass_sum / active_count
} else {
1.0
};
let g_from = decode_at(evaluator.as_ref(), &atom.decoder_coefficients, t_from, p)?;
let g_to = decode_at(evaluator.as_ref(), &atom.decoder_coefficients, t_to, p)?;
let mut delta = Array1::<f64>::zeros(p);
for i in 0..p {
delta[i] = amplitude * (g_to[i] - g_from[i]);
}
let provenance = metric.provenance();
let behavior_available =
metric_carries_behavior(provenance) && metric.n_rows() == n && metric.p_out() == p;
let mut t_mid = vec![0.0_f64; d];
for a in 0..d {
t_mid[a] = 0.5 * (t_from[a] + t_to[a]);
}
let tangents =
decode_tangents_at(evaluator.as_ref(), &atom.decoder_coefficients, &t_mid, p, d)?;
let off_manifold_norm = off_manifold_residual_norm(&tangents, delta.view());
let (predicted_nats, validity_radius) = if !behavior_available {
(None, None)
} else {
let ctx = SteerContext {
evaluator: evaluator.as_ref(),
decoder: &atom.decoder_coefficients,
metric,
row: best_row,
p,
d,
amplitude,
};
let dose = path_integrated_dose(&ctx, t_from, t_to)?;
let radius = validity_radius(&ctx, t_from, t_to)?;
(Some(dose), Some(radius))
};
Ok(SteerPlan {
atom: atom_k,
atom_name: atom.name.clone(),
t_from: t_from.to_vec(),
t_to: t_to.to_vec(),
amplitude,
measured_row: best_row,
delta,
predicted_nats,
validity_radius,
off_manifold_norm,
metric_provenance: provenance,
})
}
fn metric_carries_behavior(p: MetricProvenance) -> bool {
match p {
MetricProvenance::Euclidean => false,
MetricProvenance::OutputFisher { .. } | MetricProvenance::WhitenedStructured { .. } => true,
}
}
fn decode_at(
evaluator: &dyn crate::terms::sae_manifold::SaeBasisEvaluator,
decoder: &Array2<f64>,
t: &[f64],
p: usize,
) -> Result<Array1<f64>, String> {
let d = t.len();
let coords = Array2::from_shape_vec((1, d), t.to_vec())
.map_err(|e| format!("steer_delta::decode_at: coord shape: {e}"))?;
let (phi, _jet) = evaluator.evaluate(coords.view())?;
let m = decoder.nrows();
if phi.ncols() != m {
return Err(format!(
"steer_delta::decode_at: evaluator returned {} basis cols but decoder has {m} rows",
phi.ncols()
));
}
let mut g = Array1::<f64>::zeros(p);
for basis_col in 0..m {
let phi_v = phi[[0, basis_col]];
if phi_v == 0.0 {
continue;
}
for out_col in 0..p {
g[out_col] += phi_v * decoder[[basis_col, out_col]];
}
}
Ok(g)
}
fn decode_tangents_at(
evaluator: &dyn crate::terms::sae_manifold::SaeBasisEvaluator,
decoder: &Array2<f64>,
t: &[f64],
p: usize,
d: usize,
) -> Result<Array2<f64>, String> {
let coords = Array2::from_shape_vec((1, d), t.to_vec())
.map_err(|e| format!("steer_delta::decode_tangents_at: coord shape: {e}"))?;
let (_phi, jet) = evaluator.evaluate(coords.view())?;
let m = decoder.nrows();
if jet.dim() != (1, m, d) {
return Err(format!(
"steer_delta::decode_tangents_at: evaluator jet {:?} != (1, {m}, {d})",
jet.dim()
));
}
let mut tang = Array2::<f64>::zeros((p, d));
for axis in 0..d {
for basis_col in 0..m {
let dphi = jet[[0, basis_col, axis]];
if dphi == 0.0 {
continue;
}
for out_col in 0..p {
tang[[out_col, axis]] += dphi * decoder[[basis_col, out_col]];
}
}
}
Ok(tang)
}
fn off_manifold_residual_norm(tangents: &Array2<f64>, delta: ArrayView1<'_, f64>) -> f64 {
let p = tangents.nrows();
let d = tangents.ncols();
if d == 0 {
return delta.iter().map(|&v| v * v).sum::<f64>().sqrt();
}
let mut gram = Array2::<f64>::zeros((d, d));
let mut rhs = Array1::<f64>::zeros(d);
for a in 0..d {
let mut r = 0.0_f64;
for i in 0..p {
r += tangents[[i, a]] * delta[i];
}
rhs[a] = r;
for b in a..d {
let mut acc = 0.0_f64;
for i in 0..p {
acc += tangents[[i, a]] * tangents[[i, b]];
}
gram[[a, b]] = acc;
gram[[b, a]] = acc;
}
}
let trace: f64 = (0..d).map(|a| gram[[a, a]]).sum();
let jitter = if trace > 0.0 { 1e-12 * trace } else { 1e-12 };
for a in 0..d {
gram[[a, a]] += jitter;
}
let coeffs = solve_spd_small(&gram, &rhs);
let mut res_sq = 0.0_f64;
for i in 0..p {
let mut proj = 0.0_f64;
for a in 0..d {
proj += tangents[[i, a]] * coeffs[a];
}
let r = delta[i] - proj;
res_sq += r * r;
}
res_sq.max(0.0).sqrt()
}
fn solve_spd_small(gram: &Array2<f64>, rhs: &Array1<f64>) -> Array1<f64> {
let d = gram.nrows();
let mut l = Array2::<f64>::zeros((d, d));
for i in 0..d {
for j in 0..=i {
let mut sum = gram[[i, j]];
for k in 0..j {
sum -= l[[i, k]] * l[[j, k]];
}
if i == j {
if sum <= 0.0 {
return Array1::<f64>::zeros(d);
}
l[[i, j]] = sum.sqrt();
} else {
l[[i, j]] = sum / l[[j, j]];
}
}
}
let mut y = Array1::<f64>::zeros(d);
for i in 0..d {
let mut sum = rhs[i];
for k in 0..i {
sum -= l[[i, k]] * y[k];
}
y[i] = sum / l[[i, i]];
}
let mut x = Array1::<f64>::zeros(d);
for i in (0..d).rev() {
let mut sum = y[i];
for k in (i + 1)..d {
sum -= l[[k, i]] * x[k];
}
x[i] = sum / l[[i, i]];
}
x
}
struct SteerContext<'a> {
evaluator: &'a dyn crate::terms::sae_manifold::SaeBasisEvaluator,
decoder: &'a Array2<f64>,
metric: &'a RowMetric,
row: usize,
p: usize,
d: usize,
amplitude: f64,
}
fn path_integrated_dose(
ctx: &SteerContext<'_>,
t_from: &[f64],
t_to: &[f64],
) -> Result<f64, String> {
let d = ctx.d;
let p = ctx.p;
let steps = STEER_PATH_STEPS;
let dtau = 1.0 / steps as f64;
let mut dt = vec![0.0_f64; d];
for a in 0..d {
dt[a] = t_to[a] - t_from[a];
}
let mut acc = 0.0_f64;
let amp2 = ctx.amplitude * ctx.amplitude;
for s in 0..steps {
let tau_mid = (s as f64 + 0.5) * dtau;
let mut t_mid = vec![0.0_f64; d];
for a in 0..d {
t_mid[a] = t_from[a] + tau_mid * dt[a];
}
let tang = decode_tangents_at(ctx.evaluator, ctx.decoder, &t_mid, p, d)?;
let mut j_row = vec![0.0_f64; p * d];
for i in 0..p {
for a in 0..d {
j_row[i * d + a] = tang[[i, a]];
}
}
let g_ab = ctx.metric.pullback(ctx.row, &j_row, d);
let mut speed_sq = 0.0_f64;
for a in 0..d {
for b in 0..d {
speed_sq += dt[a] * g_ab[[a, b]] * dt[b];
}
}
acc += 0.5 * amp2 * speed_sq * dtau;
}
Ok(acc)
}
fn validity_radius(ctx: &SteerContext<'_>, t_from: &[f64], t_to: &[f64]) -> Result<f64, String> {
let d = ctx.d;
let p = ctx.p;
let full_len: f64 = t_from
.iter()
.zip(t_to.iter())
.map(|(&a, &b)| (b - a) * (b - a))
.sum::<f64>()
.sqrt();
if full_len == 0.0 {
return Ok(0.0);
}
let mut dt = vec![0.0_f64; d];
for a in 0..d {
dt[a] = t_to[a] - t_from[a];
}
let amp = ctx.amplitude;
let tang0 = decode_tangents_at(ctx.evaluator, ctx.decoder, t_from, p, d)?;
let mut v0 = Array1::<f64>::zeros(p);
for i in 0..p {
let mut acc = 0.0_f64;
for a in 0..d {
acc += tang0[[i, a]] * dt[a];
}
v0[i] = acc;
}
let lin_coeff = 0.5 * amp * amp * ctx.metric.fisher_mass(ctx.row, v0.view());
if !(lin_coeff > 0.0) {
return Ok(full_len);
}
let g_from = decode_at(ctx.evaluator, ctx.decoder, t_from, p)?;
let steps = STEER_PATH_STEPS;
for s in 0..steps {
let tau = (s as f64 + 1.0) / steps as f64;
let mut t_mid = vec![0.0_f64; d];
for a in 0..d {
t_mid[a] = t_from[a] + tau * dt[a];
}
let g_tau = decode_at(ctx.evaluator, ctx.decoder, &t_mid, p)?;
let mut chord = Array1::<f64>::zeros(p);
for i in 0..p {
chord[i] = amp * (g_tau[i] - g_from[i]);
}
let chord_kl = 0.5 * ctx.metric.fisher_mass(ctx.row, chord.view());
let lin_kl = tau * tau * lin_coeff;
let rel = (chord_kl - lin_kl).abs() / lin_kl;
if rel > VALIDITY_DIVERGENCE_FRACTION {
return Ok(tau * full_len);
}
}
Ok(full_len)
}