Skip to main content

gam_sae/inference/
steering.rs

1//! `steer_delta` — the **steering primitive with output dosimetry**: the
2//! actionable LLM payload of the SAE-manifold machine.
3//!
4//! # What this computes
5//!
6//! Given a fitted [`SaeManifoldTerm`] and the per-row output-Fisher
7//! [`RowMetric`], a *steering move* is "drive atom `k`'s latent coordinate from
8//! `t_from` to `t_to`". The atom's decoder curve `g_k(t) = Φ_k(t) B_k` maps that
9//! latent move to an **activation-space delta** — the actual vector you add to
10//! the residual stream / reconstruction to realize the move *on the manifold*:
11//!
12//! ```text
13//! δ = a · ( g_k(t_to) − g_k(t_from) )          (the on-manifold move)
14//! ```
15//!
16//! where `a` is the atom's amplitude (how loudly the atom is expressed). This is
17//! the thing a downstream consumer adds to a hidden state.
18//!
19//! # Dosimetry — how big is this push, in nats?
20//!
21//! The headline number is the **predicted output effect**: how much behavioral
22//! change (in nats of KL on the model's output distribution) the move induces.
23//! For a locally-quadratic output readout the KL of a parameter move `Δ` is
24//! `½ Δᵀ F Δ` with `F` the output-Fisher information — exactly the inner product
25//! [`RowMetric`] carries. The dose is the Fisher quadratic form of the move,
26//! **integrated along the decoder curve** rather than read only at the endpoints:
27//!
28//! ```text
29//! predicted_nats = ½ ∫_{t_from}^{t_to} a² · g_k'(t)ᵀ M_n g_k'(t) dt
30//! ```
31//!
32//! evaluated in small steps via the per-row pullback / fisher-mass methods. The
33//! path integral is the honest dose: it follows the curved surface, so a long arc
34//! that doubles back is not under-counted the way a straight endpoint chord would
35//! be.
36//!
37//! # Validity radius — where local linearization stops being trusted
38//!
39//! A consumer must know *how far* the move can be trusted as a linear push. The
40//! **validity radius** is the latent step size at which the path-integrated dose
41//! diverges from the straight endpoint quadratic form
42//! `½ a² δ̂ᵀ M δ̂` (the local-linear prediction) by more than
43//! [`VALIDITY_DIVERGENCE_FRACTION`]. Beyond it the surface has curved enough that
44//! the endpoint chord no longer represents the move. We **report** it; we do not
45//! silently clip to it.
46//!
47//! # Off-manifold guard
48//!
49//! `δ` is, by construction, a chord of the decoder curve, so it should lie in the
50//! atom's local tangent/frame at `t_from` (up to second-order curvature). The
51//! **off-manifold norm** projects `δ` onto the span of the local decoder tangents
52//! `∂g_k/∂t` at `t_from` and reports the residual norm — a self-check that the
53//! steering move stays on the learned surface. It is `≈ 0` for small steps and
54//! grows with arc curvature; a large value means the requested move left the
55//! manifold and the dose number is not to be trusted.
56//!
57//! # Read-only / no loss contact
58//!
59//! This module is a **pure read** over the fitted term and the metric. It calls
60//! only `g_k(t)` evaluation ([`SaeManifoldAtom`]'s decoder + installed
61//! [`SaeBasisEvaluator`]) and the criterion-facing
62//! [`RowMetric::fisher_mass`] / [`RowMetric::pullback`]. It never mutates the
63//! model, never touches a likelihood / criterion / penalty, and the solver floor
64//! `δ` of [`RowMetric`] never enters any number it reports (the fisher-mass /
65//! pullback face is `δ`-free, #747).
66
67use ndarray::{Array1, Array2, ArrayView1};
68
69use gam_problem::{MetricProvenance, RowMetric};
70use crate::manifold::SaeManifoldTerm;
71
72/// Number of sub-steps the latent path `[t_from, t_to]` is integrated over for
73/// the dosimetry path integral. The decoder curve is smooth, so a modest
74/// midpoint-rule grid resolves the arc; fixed (no clock / no adaptivity) so the
75/// reported dose is deterministic.
76const STEER_PATH_STEPS: usize = 64;
77
78/// The fraction by which the path-integrated dose may diverge from the straight
79/// endpoint quadratic form before the move is declared past its validity radius.
80/// At `0.1` we trust the linearization while the curved-path dose stays within
81/// 10% of the chord dose.
82const VALIDITY_DIVERGENCE_FRACTION: f64 = 0.1;
83
84/// Active-mass floor below which a row is "off" for an atom and excluded from the
85/// representative-row / amplitude selection. Mirrors `atom_lens::ACTIVE_MASS_FLOOR`.
86const ACTIVE_MASS_FLOOR: f64 = 1e-6;
87
88/// The actionable output of a steering query over one atom.
89#[derive(Clone, Debug, PartialEq)]
90pub struct SteerPlan {
91    /// Which atom was steered (index into [`SaeManifoldTerm::atoms`]).
92    pub atom: usize,
93    /// The atom's name (mirrors [`crate::manifold::SaeManifoldAtom::name`]).
94    pub atom_name: String,
95    /// The source latent coordinate `t_from` (length = atom's `latent_dim`).
96    pub t_from: Vec<f64>,
97    /// The target latent coordinate `t_to` (length = atom's `latent_dim`).
98    pub t_to: Vec<f64>,
99    /// The amplitude `a` the on-manifold move was scaled by (the atom's mean
100    /// active assignment mass; `1.0` if the atom is active on no row).
101    pub amplitude: f64,
102    /// The row whose per-row output-Fisher metric the dose was measured through
103    /// (the atom's most-active row; `0` if active nowhere).
104    pub measured_row: usize,
105    /// **The activation-space delta**: `δ = a · (g_k(t_to) − g_k(t_from))`, a
106    /// length-`p` vector in the reconstruction/output space — the actual move to
107    /// add to a hidden state.
108    pub delta: Array1<f64>,
109    /// **DOSIMETRY**: predicted output effect of the move in **nats** of KL,
110    /// integrated along the decoder curve through the output-Fisher metric.
111    /// `None` when the metric carries no behavioral information (Euclidean
112    /// provenance) — the dose is *not available*, not zero.
113    pub predicted_nats: Option<f64>,
114    /// **VALIDITY RADIUS**: the latent step size (Euclidean norm of the move from
115    /// `t_from`) at which the path-integrated dose first diverges from the
116    /// straight endpoint quadratic form by more than
117    /// [`VALIDITY_DIVERGENCE_FRACTION`]. Equals the full move length when the
118    /// linearization is trusted all the way to `t_to`. `None` under a no-behavior
119    /// metric (there is no dose to validate).
120    pub validity_radius: Option<f64>,
121    /// **OFF-MANIFOLD GUARD**: the norm of `δ`'s component outside the span of
122    /// the atom's local decoder tangents `∂g_k/∂t` at `t_from`. `≈ 0` by
123    /// construction (the move is a chord of the curve); a large value flags a
124    /// move that left the learned surface.
125    pub off_manifold_norm: f64,
126    /// The provenance of the metric the dose was read through, echoed so a
127    /// consumer can certify *why* `predicted_nats` is `None` when it is.
128    pub metric_provenance: MetricProvenance,
129}
130
131/// Build a [`SteerPlan`] for driving atom `atom_k` from `t_from` to `t_to`.
132///
133/// `model` is the fitted term (read only); `metric` is the per-row output-Fisher
134/// inner product the dose is measured through (typically `model.row_metric()`'s
135/// own metric, or any metric whose row/output dims match the term). `t_from` and
136/// `t_to` are latent coordinates of length `atom.latent_dim`.
137///
138/// Errors when the atom index is out of range, the coordinate lengths do not
139/// match the atom's latent dimension, the atom has no installed
140/// [`crate::manifold::SaeBasisEvaluator`] (arbitrary-`t` evaluation
141/// requires one), or the metric dimensions do not match the term. Under a
142/// Euclidean (no-behavior) metric the geometry is still produced but
143/// `predicted_nats` / `validity_radius` degrade to `None`.
144pub fn steer_delta(
145    model: &SaeManifoldTerm,
146    metric: &RowMetric,
147    atom_k: usize,
148    t_from: &[f64],
149    t_to: &[f64],
150) -> Result<SteerPlan, String> {
151    let k = model.k_atoms();
152    if atom_k >= k {
153        return Err(format!(
154            "steer_delta: atom index {atom_k} out of range (term has {k} atoms)"
155        ));
156    }
157    let atom = &model.atoms[atom_k];
158    let d = atom.latent_dim;
159    let p = atom.output_dim();
160    if t_from.len() != d || t_to.len() != d {
161        return Err(format!(
162            "steer_delta: t_from/t_to must have length latent_dim={d}; got {} and {}",
163            t_from.len(),
164            t_to.len()
165        ));
166    }
167    let evaluator = atom.basis_evaluator.as_ref().ok_or_else(|| {
168        format!(
169            "steer_delta: atom {atom_k} ('{}') has no installed basis evaluator; \
170             arbitrary-t decoder evaluation requires one",
171            atom.name
172        )
173    })?;
174
175    // --- amplitude & the row the dose is measured through -------------------
176    // The amplitude is the atom's mean active assignment mass (how loudly the
177    // atom is expressed), mirroring the presence weighting in `atom_lens`. The
178    // measured row is the atom's single most-active row: the per-row metric
179    // there is the most representative behavioral inner product for "this atom
180    // is on". Both fall back gracefully when the atom is active nowhere.
181    let assignments = model.assignment.assignments();
182    let n = model.n_obs();
183    let mut mass_sum = 0.0_f64;
184    let mut active_count = 0.0_f64;
185    let mut best_row = 0usize;
186    let mut best_mass = f64::NEG_INFINITY;
187    for row in 0..n {
188        let mass = assignments[[row, atom_k]];
189        if mass > best_mass {
190            best_mass = mass;
191            best_row = row;
192        }
193        if mass > ACTIVE_MASS_FLOOR {
194            mass_sum += mass;
195            active_count += 1.0;
196        }
197    }
198    let amplitude = if active_count > 0.0 {
199        mass_sum / active_count
200    } else {
201        1.0
202    };
203
204    // --- the on-manifold activation-space delta -----------------------------
205    let g_from = decode_at(evaluator.as_ref(), &atom.decoder_coefficients, t_from, p)?;
206    let g_to = decode_at(evaluator.as_ref(), &atom.decoder_coefficients, t_to, p)?;
207    let mut delta = Array1::<f64>::zeros(p);
208    for i in 0..p {
209        delta[i] = amplitude * (g_to[i] - g_from[i]);
210    }
211
212    // Whether the metric can/does match this term and carries behavior.
213    let provenance = metric.provenance();
214    let behavior_available =
215        metric_carries_behavior(provenance) && metric.n_rows() == n && metric.p_out() == p;
216
217    // --- off-manifold guard -------------------------------------------------
218    // Project δ onto the span of the local decoder tangents ∂g_k/∂t and report
219    // the residual norm. The tangents are evaluated at the move's MIDPOINT, not
220    // at t_from: the chord of a curve is symmetric about its midpoint, so its
221    // component transverse to the midpoint tangent is the true second-order
222    // sagitta (`O(‖Δt‖²)`), whereas the endpoint tangent differs from the chord
223    // direction already at first order. Measuring against the midpoint frame is
224    // therefore the honest "did the move stay on the surface" self-check: it is
225    // `≈ 0` for an on-manifold move and grows only with genuine arc curvature.
226    let mut t_mid = vec![0.0_f64; d];
227    for a in 0..d {
228        t_mid[a] = 0.5 * (t_from[a] + t_to[a]);
229    }
230    let tangents =
231        decode_tangents_at(evaluator.as_ref(), &atom.decoder_coefficients, &t_mid, p, d)?;
232    let off_manifold_norm = off_manifold_residual_norm(&tangents, delta.view());
233
234    // --- dosimetry: path-integrated Fisher dose -----------------------------
235    let (predicted_nats, validity_radius) = if !behavior_available {
236        (None, None)
237    } else {
238        let ctx = SteerContext {
239            evaluator: evaluator.as_ref(),
240            decoder: &atom.decoder_coefficients,
241            metric,
242            row: best_row,
243            p,
244            d,
245            amplitude,
246        };
247        let dose = path_integrated_dose(&ctx, t_from, t_to)?;
248        let radius = validity_radius(&ctx, t_from, t_to)?;
249        (Some(dose), Some(radius))
250    };
251
252    Ok(SteerPlan {
253        atom: atom_k,
254        atom_name: atom.name.clone(),
255        t_from: t_from.to_vec(),
256        t_to: t_to.to_vec(),
257        amplitude,
258        measured_row: best_row,
259        delta,
260        predicted_nats,
261        validity_radius,
262        off_manifold_norm,
263        metric_provenance: provenance,
264    })
265}
266
267/// The model's predicted output-mean response to an applied activation push
268/// `δ`, under the LOCAL-LINEAR reading of its fitted surface: the projection
269/// of `δ` onto the span of atom `atom_k`'s decoder tangents `∂g_k/∂t` at the
270/// operating point `t_at`. A dictionary "predicts" exactly the component of a
271/// push it can carry along its learned surface; the transverse component is
272/// off-manifold and predicted to die (this is the same local model the
273/// off-manifold guard and the dosimetry chord trust, used in the same radius).
274///
275/// This is `μ(δ)` for the design loop of
276/// [`gam_terms::inference::structure_evidence`]: two structural hypotheses about
277/// the same activations (e.g. "one curved atom" vs "two flat atoms") are two
278/// fitted terms whose tangent spans differ, so they predict DIFFERENT
279/// responses to the same probe — and that disagreement, in the output-Fisher
280/// metric, is what `select_probe_by_expected_evidence` maximizes.
281pub fn predicted_response(
282    model: &SaeManifoldTerm,
283    atom_k: usize,
284    t_at: &[f64],
285    delta: ArrayView1<'_, f64>,
286) -> Result<Array1<f64>, String> {
287    let k = model.k_atoms();
288    if atom_k >= k {
289        return Err(format!(
290            "predicted_response: atom index {atom_k} out of range (term has {k} atoms)"
291        ));
292    }
293    let atom = &model.atoms[atom_k];
294    let d = atom.latent_dim;
295    let p = atom.output_dim();
296    if t_at.len() != d {
297        return Err(format!(
298            "predicted_response: t_at must have length latent_dim={d}; got {}",
299            t_at.len()
300        ));
301    }
302    if delta.len() != p {
303        return Err(format!(
304            "predicted_response: delta must have length output_dim={p}; got {}",
305            delta.len()
306        ));
307    }
308    let evaluator = atom.basis_evaluator.as_ref().ok_or_else(|| {
309        format!(
310            "predicted_response: atom {atom_k} ('{}') has no installed basis evaluator",
311            atom.name
312        )
313    })?;
314    let tangents = decode_tangents_at(evaluator.as_ref(), &atom.decoder_coefficients, t_at, p, d)?;
315    Ok(project_onto_tangent_span(&tangents, delta))
316}
317
318/// Does this provenance carry behavioral (output-Fisher) information? Euclidean
319/// is the isotropic activation-only path and carries none; the factored
320/// provenances do. (Mirrors `atom_lens::metric_carries_behavior`.)
321fn metric_carries_behavior(p: MetricProvenance) -> bool {
322    match p {
323        MetricProvenance::Euclidean => false,
324        MetricProvenance::OutputFisher { .. }
325        | MetricProvenance::OutputFisherDownstream { .. }
326        | MetricProvenance::WhitenedStructured { .. } => true,
327    }
328}
329
330/// Evaluate the decoder output `g_k(t) = Φ_k(t) B_k ∈ ℝ^p` at an arbitrary
331/// latent coordinate `t` (length `d`) via the atom's installed evaluator.
332fn decode_at(
333    evaluator: &dyn crate::manifold::SaeBasisEvaluator,
334    decoder: &Array2<f64>,
335    t: &[f64],
336    p: usize,
337) -> Result<Array1<f64>, String> {
338    let d = t.len();
339    let coords = Array2::from_shape_vec((1, d), t.to_vec())
340        .map_err(|e| format!("steer_delta::decode_at: coord shape: {e}"))?;
341    let (phi, _jet) = evaluator.evaluate(coords.view())?;
342    let m = decoder.nrows();
343    if phi.ncols() != m {
344        return Err(format!(
345            "steer_delta::decode_at: evaluator returned {} basis cols but decoder has {m} rows",
346            phi.ncols()
347        ));
348    }
349    let mut g = Array1::<f64>::zeros(p);
350    for basis_col in 0..m {
351        let phi_v = phi[[0, basis_col]];
352        if phi_v == 0.0 {
353            continue;
354        }
355        for out_col in 0..p {
356            g[out_col] += phi_v * decoder[[basis_col, out_col]];
357        }
358    }
359    Ok(g)
360}
361
362/// Evaluate the decoder tangents `∂g_k/∂t_a = Φ_k'(t) B_k ∈ ℝ^p`, one per latent
363/// axis `a ∈ 0..d`, at an arbitrary latent coordinate `t`. Returned as a
364/// `(p × d)` matrix whose column `a` is the tangent along axis `a`.
365fn decode_tangents_at(
366    evaluator: &dyn crate::manifold::SaeBasisEvaluator,
367    decoder: &Array2<f64>,
368    t: &[f64],
369    p: usize,
370    d: usize,
371) -> Result<Array2<f64>, String> {
372    let coords = Array2::from_shape_vec((1, d), t.to_vec())
373        .map_err(|e| format!("steer_delta::decode_tangents_at: coord shape: {e}"))?;
374    let (_phi, jet) = evaluator.evaluate(coords.view())?;
375    let m = decoder.nrows();
376    if jet.dim() != (1, m, d) {
377        return Err(format!(
378            "steer_delta::decode_tangents_at: evaluator jet {:?} != (1, {m}, {d})",
379            jet.dim()
380        ));
381    }
382    let mut tang = Array2::<f64>::zeros((p, d));
383    for axis in 0..d {
384        for basis_col in 0..m {
385            let dphi = jet[[0, basis_col, axis]];
386            if dphi == 0.0 {
387                continue;
388            }
389            for out_col in 0..p {
390                tang[[out_col, axis]] += dphi * decoder[[basis_col, out_col]];
391            }
392        }
393    }
394    Ok(tang)
395}
396
397/// Least-squares projection of `δ` onto the span of the local tangents
398/// (columns of `tangents`, shape `p × d`): `δ̂ = T (TᵀT)⁻¹ Tᵀ δ` via a small
399/// `d × d` Gram solve (with a tiny diagonal jitter to absorb a rank-deficient
400/// tangent frame; the jitter only shrinks the projection, never inflates it).
401fn project_onto_tangent_span(tangents: &Array2<f64>, delta: ArrayView1<'_, f64>) -> Array1<f64> {
402    let p = tangents.nrows();
403    let d = tangents.ncols();
404    if d == 0 {
405        return Array1::<f64>::zeros(p);
406    }
407    // Gram = TᵀT (d × d) and rhs = Tᵀδ (d).
408    let mut gram = Array2::<f64>::zeros((d, d));
409    let mut rhs = Array1::<f64>::zeros(d);
410    for a in 0..d {
411        let mut r = 0.0_f64;
412        for i in 0..p {
413            r += tangents[[i, a]] * delta[i];
414        }
415        rhs[a] = r;
416        for b in a..d {
417            let mut acc = 0.0_f64;
418            for i in 0..p {
419                acc += tangents[[i, a]] * tangents[[i, b]];
420            }
421            gram[[a, b]] = acc;
422            gram[[b, a]] = acc;
423        }
424    }
425    let trace: f64 = (0..d).map(|a| gram[[a, a]]).sum();
426    let jitter = if trace > 0.0 { 1e-12 * trace } else { 1e-12 };
427    for a in 0..d {
428        gram[[a, a]] += jitter;
429    }
430    let coeffs = solve_spd_small(&gram, &rhs);
431    let mut proj = Array1::<f64>::zeros(p);
432    for i in 0..p {
433        for a in 0..d {
434            proj[i] += tangents[[i, a]] * coeffs[a];
435        }
436    }
437    proj
438}
439
440/// Norm of `δ`'s component orthogonal to the span of the local tangents:
441/// `‖δ − δ̂‖` with `δ̂` the [`project_onto_tangent_span`] projection.
442fn off_manifold_residual_norm(tangents: &Array2<f64>, delta: ArrayView1<'_, f64>) -> f64 {
443    let proj = project_onto_tangent_span(tangents, delta);
444    let mut res_sq = 0.0_f64;
445    for i in 0..delta.len() {
446        let r = delta[i] - proj[i];
447        res_sq += r * r;
448    }
449    res_sq.max(0.0).sqrt()
450}
451
452/// Tiny symmetric-positive-definite solve via Cholesky for the `d × d` tangent
453/// Gram (`d` is the atom's latent dim, typically 1–3). Falls back to the bare rhs
454/// if the factorization fails (a fully degenerate frame), which only inflates the
455/// reported off-manifold residual — never deflates it.
456fn solve_spd_small(gram: &Array2<f64>, rhs: &Array1<f64>) -> Array1<f64> {
457    let d = gram.nrows();
458    // Cholesky L LᵀT = gram.
459    let mut l = Array2::<f64>::zeros((d, d));
460    for i in 0..d {
461        for j in 0..=i {
462            let mut sum = gram[[i, j]];
463            for k in 0..j {
464                sum -= l[[i, k]] * l[[j, k]];
465            }
466            if i == j {
467                if sum <= 0.0 {
468                    return Array1::<f64>::zeros(d);
469                }
470                l[[i, j]] = sum.sqrt();
471            } else {
472                l[[i, j]] = sum / l[[j, j]];
473            }
474        }
475    }
476    // Forward solve L y = rhs.
477    let mut y = Array1::<f64>::zeros(d);
478    for i in 0..d {
479        let mut sum = rhs[i];
480        for k in 0..i {
481            sum -= l[[i, k]] * y[k];
482        }
483        y[i] = sum / l[[i, i]];
484    }
485    // Back solve Lᵀ x = y.
486    let mut x = Array1::<f64>::zeros(d);
487    for i in (0..d).rev() {
488        let mut sum = y[i];
489        for k in (i + 1)..d {
490            sum -= l[[k, i]] * x[k];
491        }
492        x[i] = sum / l[[i, i]];
493    }
494    x
495}
496
497/// The fixed geometry of one steering query, bundled so the dose integrator and
498/// its helpers take a single context rather than a long argument list.
499struct SteerContext<'a> {
500    evaluator: &'a dyn crate::manifold::SaeBasisEvaluator,
501    decoder: &'a Array2<f64>,
502    metric: &'a RowMetric,
503    /// The row whose per-row metric the dose is measured through.
504    row: usize,
505    /// Output dimension `p`.
506    p: usize,
507    /// Latent dimension `d`.
508    d: usize,
509    /// Amplitude `a` the move is scaled by.
510    amplitude: f64,
511}
512
513/// Path-integrated Fisher dose
514/// `½ a² ∫ g_k'(t)ᵀ M g_k'(t) dt` along the straight latent segment
515/// `t(τ) = t_from + τ (t_to − t_from)`, `τ ∈ [0, 1]`, by the midpoint rule over
516/// [`STEER_PATH_STEPS`] sub-steps.
517///
518/// The local quadratic `g'(t)ᵀ M g'(t)` is the [`RowMetric::pullback`] of the
519/// per-axis decoder tangents contracted with the latent velocity `Δt`, so this
520/// uses only the criterion-facing pullback (no loss / no solver floor).
521fn path_integrated_dose(
522    ctx: &SteerContext<'_>,
523    t_from: &[f64],
524    t_to: &[f64],
525) -> Result<f64, String> {
526    let d = ctx.d;
527    let p = ctx.p;
528    let steps = STEER_PATH_STEPS;
529    let dtau = 1.0 / steps as f64;
530    // Latent velocity Δt (constant along the straight segment).
531    let mut dt = vec![0.0_f64; d];
532    for a in 0..d {
533        dt[a] = t_to[a] - t_from[a];
534    }
535    let mut acc = 0.0_f64;
536    let amp2 = ctx.amplitude * ctx.amplitude;
537    for s in 0..steps {
538        // Midpoint of sub-step s in τ, mapped to a latent coordinate.
539        let tau_mid = (s as f64 + 0.5) * dtau;
540        let mut t_mid = vec![0.0_f64; d];
541        for a in 0..d {
542            t_mid[a] = t_from[a] + tau_mid * dt[a];
543        }
544        // Decoder tangents at the midpoint: ∂g/∂t_a, columns of a (p × d) matrix.
545        let tang = decode_tangents_at(ctx.evaluator, ctx.decoder, &t_mid, p, d)?;
546        // The pulled-back metric at this point is g_{ab} = (∂g/∂t)ᵀ M (∂g/∂t),
547        // the d × d local inner product of latent motion *in output-Fisher
548        // units*. We form it through the criterion-facing `RowMetric::pullback`
549        // (which never materializes the p × p M and never sees the solver δ),
550        // then contract the latent velocity Δt twice: the squared output-Fisher
551        // speed along the path is Δtᵀ g Δt. The decoder Jacobian is passed flat
552        // row-major (J[i, a] = j_row[i * d + a]) as `pullback` expects.
553        let mut j_row = vec![0.0_f64; p * d];
554        for i in 0..p {
555            for a in 0..d {
556                j_row[i * d + a] = tang[[i, a]];
557            }
558        }
559        let g_ab = ctx.metric.pullback(ctx.row, &j_row, d);
560        let mut speed_sq = 0.0_f64;
561        for a in 0..d {
562            for b in 0..d {
563                speed_sq += dt[a] * g_ab[[a, b]] * dt[b];
564            }
565        }
566        acc += 0.5 * amp2 * speed_sq * dtau;
567    }
568    Ok(acc)
569}
570
571/// The validity radius: the latent step length (Euclidean distance from
572/// `t_from`) at which **local linearization stops being trusted**.
573///
574/// Linearizing the steering move means predicting the output effect of a prefix
575/// step `τ·Δt` from the initial tangent alone: the first-order output move is
576/// `δ_lin(τ) = a · (∂g/∂t|_{t_from}) · (τ Δt)`, whose output-Fisher KL is the
577/// quadratic form `½ ‖δ_lin(τ)‖²_M = τ² · ½ a² ‖∂g/∂t·Δt‖²_M`. The **true**
578/// effect of that prefix is the chord quadratic form of the *actual* curved
579/// output move `½ a² ‖g(t_from + τΔt) − g(t_from)‖²_M`.
580///
581/// The radius is the chord length `τ* · ‖Δt‖` at the first prefix `τ*` where the
582/// true chord KL diverges from the linear prediction by more than
583/// [`VALIDITY_DIVERGENCE_FRACTION`] (relative to the linear prediction). This is
584/// pure surface curvature: on a flat decoder the two agree for every `τ` and the
585/// radius is the whole move. If the metric kills the tangent (no linear effect to
586/// validate), the move is trusted to its full length.
587fn validity_radius(ctx: &SteerContext<'_>, t_from: &[f64], t_to: &[f64]) -> Result<f64, String> {
588    let d = ctx.d;
589    let p = ctx.p;
590    let full_len: f64 = t_from
591        .iter()
592        .zip(t_to.iter())
593        .map(|(&a, &b)| (b - a) * (b - a))
594        .sum::<f64>()
595        .sqrt();
596    if full_len == 0.0 {
597        return Ok(0.0);
598    }
599    let mut dt = vec![0.0_f64; d];
600    for a in 0..d {
601        dt[a] = t_to[a] - t_from[a];
602    }
603    let amp = ctx.amplitude;
604
605    // Initial-tangent linear output move per unit τ: v0 = (∂g/∂t|_{t_from}) Δt.
606    let tang0 = decode_tangents_at(ctx.evaluator, ctx.decoder, t_from, p, d)?;
607    let mut v0 = Array1::<f64>::zeros(p);
608    for i in 0..p {
609        let mut acc = 0.0_f64;
610        for a in 0..d {
611            acc += tang0[[i, a]] * dt[a];
612        }
613        v0[i] = acc;
614    }
615    // ½ a² ‖v0‖²_M — the per-τ² linear KL coefficient.
616    let lin_coeff = 0.5 * amp * amp * ctx.metric.fisher_mass(ctx.row, v0.view());
617    // No linear effect to validate against ⇒ trust the full move.
618    if !(lin_coeff > 0.0) {
619        return Ok(full_len);
620    }
621
622    let g_from = decode_at(ctx.evaluator, ctx.decoder, t_from, p)?;
623    let steps = STEER_PATH_STEPS;
624    for s in 0..steps {
625        let tau = (s as f64 + 1.0) / steps as f64;
626        let mut t_mid = vec![0.0_f64; d];
627        for a in 0..d {
628            t_mid[a] = t_from[a] + tau * dt[a];
629        }
630        let g_tau = decode_at(ctx.evaluator, ctx.decoder, &t_mid, p)?;
631        let mut chord = Array1::<f64>::zeros(p);
632        for i in 0..p {
633            chord[i] = amp * (g_tau[i] - g_from[i]);
634        }
635        // True chord KL of the prefix, and the linear prediction τ²·lin_coeff.
636        let chord_kl = 0.5 * ctx.metric.fisher_mass(ctx.row, chord.view());
637        let lin_kl = tau * tau * lin_coeff;
638        let rel = (chord_kl - lin_kl).abs() / lin_kl;
639        if rel > VALIDITY_DIVERGENCE_FRACTION {
640            return Ok(tau * full_len);
641        }
642    }
643    Ok(full_len)
644}