Skip to main content

gam_sae/inference/
steering.rs

1//! `steer_delta` — the **steering primitive with output dosimetry**: the
2//! actionable LLM payload of the SAE-manifold machine.
3//!
4//! # What this computes
5//!
6//! Given a fitted [`SaeManifoldTerm`] and the per-row output-Fisher
7//! [`RowMetric`], a *steering move* is "drive atom `k`'s latent coordinate from
8//! `t_from` to `t_to`". The atom's decoder curve `g_k(t) = Φ_k(t) B_k` maps that
9//! latent move to an **activation-space delta** — the actual vector you add to
10//! the residual stream / reconstruction to realize the move *on the manifold*:
11//!
12//! ```text
13//! δ = a · ( g_k(t_to) − g_k(t_from) )          (the on-manifold move)
14//! ```
15//!
16//! where `a` is the atom's amplitude (how loudly the atom is expressed). This is
17//! the thing a downstream consumer adds to a hidden state.
18//!
19//! # Dosimetry — how big is this push, in nats?
20//!
21//! The headline number is the **predicted output effect**: how much behavioral
22//! change (in nats of KL on the model's output distribution) the move induces.
23//! For a locally-quadratic output readout the KL of a parameter move `Δ` is
24//! `½ Δᵀ F Δ` with `F` the output-Fisher information — exactly the inner product
25//! [`RowMetric`] carries. The dose is the Fisher quadratic form of the move,
26//! **integrated along the decoder curve** rather than read only at the endpoints:
27//!
28//! ```text
29//! predicted_nats = ½ ∫_{t_from}^{t_to} a² · g_k'(t)ᵀ M_n g_k'(t) dt
30//! ```
31//!
32//! evaluated in small steps via the per-row pullback / fisher-mass methods. The
33//! path integral is the honest dose: it follows the curved surface, so a long arc
34//! that doubles back is not under-counted the way a straight endpoint chord would
35//! be.
36//!
37//! # Validity radius — where local linearization stops being trusted
38//!
39//! A consumer must know *how far* the move can be trusted as a linear push. The
40//! **validity radius** is the latent step size at which the path-integrated dose
41//! diverges from the straight endpoint quadratic form
42//! `½ a² δ̂ᵀ M δ̂` (the local-linear prediction) by more than
43//! [`VALIDITY_DIVERGENCE_FRACTION`]. Beyond it the surface has curved enough that
44//! the endpoint chord no longer represents the move. We **report** it; we do not
45//! silently clip to it.
46//!
47//! # Off-manifold guard
48//!
49//! `δ` is, by construction, a chord of the decoder curve, so it should lie in the
50//! atom's local tangent/frame at `t_from` (up to second-order curvature). The
51//! **off-manifold norm** projects `δ` onto the span of the local decoder tangents
52//! `∂g_k/∂t` at `t_from` and reports the residual norm — a self-check that the
53//! steering move stays on the learned surface. It is `≈ 0` for small steps and
54//! grows with arc curvature; a large value means the requested move left the
55//! manifold and the dose number is not to be trusted.
56//!
57//! # Read-only / no loss contact
58//!
59//! This module is a **pure read** over the fitted term and the metric. It calls
60//! only `g_k(t)` evaluation ([`SaeManifoldAtom`]'s decoder + installed
61//! [`SaeBasisEvaluator`]) and the criterion-facing
62//! [`RowMetric::fisher_mass`] / [`RowMetric::pullback`]. It never mutates the
63//! model, never touches a likelihood / criterion / penalty, and the solver floor
64//! `δ` of [`RowMetric`] never enters any number it reports (the fisher-mass /
65//! pullback face is `δ`-free, #747).
66
67use ndarray::{Array1, Array2, ArrayView1};
68
69use crate::encode::EncodeAtlas;
70use crate::manifold::SaeManifoldTerm;
71use gam_problem::{MetricProvenance, RowMetric};
72use gam_terms::inference::structure_evidence::log_e_from_p_calibrator;
73
74/// Number of sub-steps the latent path `[t_from, t_to]` is integrated over for
75/// the dosimetry path integral. The decoder curve is smooth, so a modest
76/// midpoint-rule grid resolves the arc; fixed (no clock / no adaptivity) so the
77/// reported dose is deterministic.
78const STEER_PATH_STEPS: usize = 64;
79
80/// The fraction by which the path-integrated dose may diverge from the straight
81/// endpoint quadratic form before the move is declared past its validity radius.
82/// At `0.1` we trust the linearization while the curved-path dose stays within
83/// 10% of the chord dose.
84const VALIDITY_DIVERGENCE_FRACTION: f64 = 0.1;
85
86/// Active-mass floor below which a row is "off" for an atom and excluded from the
87/// representative-row / amplitude selection. Mirrors `atom_lens::ACTIVE_MASS_FLOOR`.
88const ACTIVE_MASS_FLOOR: f64 = 1e-6;
89
90/// The actionable output of a steering query over one atom.
91#[derive(Clone, Debug, PartialEq)]
92pub struct SteerPlan {
93    /// Which atom was steered (index into [`SaeManifoldTerm::atoms`]).
94    pub atom: usize,
95    /// The atom's name (mirrors [`crate::manifold::SaeManifoldAtom::name`]).
96    pub atom_name: String,
97    /// The source latent coordinate `t_from` (length = atom's `latent_dim`).
98    pub t_from: Vec<f64>,
99    /// The target latent coordinate `t_to` (length = atom's `latent_dim`).
100    pub t_to: Vec<f64>,
101    /// The amplitude `a` the on-manifold move was scaled by (the atom's mean
102    /// active assignment mass; `1.0` if the atom is active on no row).
103    pub amplitude: f64,
104    /// The row whose per-row output-Fisher metric the dose was measured through
105    /// (the atom's most-active row; `0` if active nowhere).
106    pub measured_row: usize,
107    /// **The activation-space delta**: `δ = a · (g_k(t_to) − g_k(t_from))`, a
108    /// length-`p` vector in the reconstruction/output space — the actual move to
109    /// add to a hidden state.
110    pub delta: Array1<f64>,
111    /// **DOSIMETRY**: predicted output effect of the move in **nats** of KL,
112    /// integrated along the decoder curve through the output-Fisher metric.
113    /// `None` when the metric carries no behavioral information (Euclidean
114    /// provenance) — the dose is *not available*, not zero.
115    pub predicted_nats: Option<f64>,
116    /// **VALIDITY RADIUS**: the latent step size (Euclidean norm of the move from
117    /// `t_from`) at which the path-integrated dose first diverges from the
118    /// straight endpoint quadratic form by more than
119    /// [`VALIDITY_DIVERGENCE_FRACTION`]. Equals the full move length when the
120    /// linearization is trusted all the way to `t_to`. `None` under a no-behavior
121    /// metric (there is no dose to validate).
122    pub validity_radius: Option<f64>,
123    /// **OFF-MANIFOLD GUARD**: the norm of `δ`'s component outside the span of
124    /// the atom's local decoder tangents `∂g_k/∂t` at `t_from`. `≈ 0` by
125    /// construction (the move is a chord of the curve); a large value flags a
126    /// move that left the learned surface.
127    pub off_manifold_norm: f64,
128    /// The provenance of the metric the dose was read through, echoed so a
129    /// consumer can certify *why* `predicted_nats` is `None` when it is.
130    pub metric_provenance: MetricProvenance,
131}
132
133/// Result of writing one certified chart coordinate into an activation row.
134///
135/// The edited row is always `x + δ`, where `δ` is the delta returned by
136/// [`steer_delta`] for the atom's current encoded coordinate and the requested
137/// target coordinate. Because only the on-manifold atom chord is added, every
138/// component of `x` outside this atom's chart residual is preserved exactly; this
139/// is the locality guarantee missing from whole-residual linear-steering
140/// baselines.
141#[derive(Clone, Debug)]
142pub struct CoordinateSetResult {
143    /// The edited activation/reconstruction row.
144    pub edited: Array1<f64>,
145    /// Certified coordinate read from the input row before the write.
146    pub t_from_certified: Array1<f64>,
147    /// Certificate attached to `t_from_certified`.
148    pub encode_certificate: crate::encode::RowCertificate,
149    /// Steering plan whose `delta` was added to the row.
150    pub steer: SteerPlan,
151}
152
153/// Write atom `atom_k`'s chart coordinate in row `x` to `t_to` by delta
154/// steering, preserving the row's off-atom/off-subspace residual exactly.
155///
156/// `amplitude` is the assignment/intensity with which the row expresses this
157/// atom; callers that have already separated existence/intensity/position should
158/// pass the intensity and only swap the position coordinate. The certified read
159/// uses [`EncodeAtlas::certified_encode_row`]; the write uses [`steer_delta`].
160pub fn set_coordinate(
161    model: &SaeManifoldTerm,
162    metric: &RowMetric,
163    atlas: &EncodeAtlas,
164    x: ArrayView1<'_, f64>,
165    atom_k: usize,
166    amplitude: f64,
167    t_to: &[f64],
168) -> Result<CoordinateSetResult, String> {
169    let atom = model.atoms.get(atom_k).ok_or_else(|| {
170        format!(
171            "set_coordinate: atom index {atom_k} out of range (term has {} atoms)",
172            model.k_atoms()
173        )
174    })?;
175    if x.len() != atom.output_dim() {
176        return Err(format!(
177            "set_coordinate: input row has length {} but atom {atom_k} output_dim is {}",
178            x.len(),
179            atom.output_dim()
180        ));
181    }
182    let (t_from, cert) = atlas.certified_encode_row(atom, atom_k, x, amplitude)?;
183    let steer = steer_delta_with_amplitude(
184        model,
185        metric,
186        atom_k,
187        t_from.as_slice().unwrap_or(&[]),
188        t_to,
189        amplitude,
190    )?;
191    let mut edited = x.to_owned();
192    if edited.len() != steer.delta.len() {
193        return Err(format!(
194            "set_coordinate: steering delta length {} does not match row length {}",
195            steer.delta.len(),
196            edited.len()
197        ));
198    }
199    for i in 0..edited.len() {
200        edited[i] += steer.delta[i];
201    }
202    Ok(CoordinateSetResult {
203        edited,
204        t_from_certified: t_from,
205        encode_certificate: cert,
206        steer,
207    })
208}
209
210/// Result of a coordinate interchange: donor position read from `x_source`, then
211/// written into `x_target` while preserving the target residual and intensity.
212#[derive(Clone, Debug)]
213pub struct InterchangeResult {
214    /// Target row after the donor coordinate has been delta-written into it.
215    pub edited_target: Array1<f64>,
216    /// Donor/source coordinate that was transplanted.
217    pub donor_t: Array1<f64>,
218    /// Target coordinate before the transplant.
219    pub target_t_before: Array1<f64>,
220    /// Target behavior coordinate after the transplant, re-read from the edit.
221    pub target_t_after: Array1<f64>,
222    /// Steering dose in nats, when a behavioral metric is available.
223    pub predicted_nats: Option<f64>,
224    /// Norm of the steering delta outside the local atom tangent frame.
225    pub off_manifold_norm: f64,
226    /// Reported steering validity radius.
227    pub validity_radius: Option<f64>,
228    /// Calibrated log e-value for counterfactual consistency: larger means the
229    /// post-edit target coordinate landed closer to the donor coordinate.
230    pub counterfactual_consistency_log_e: f64,
231    /// Underlying coordinate-write plan.
232    pub set_result: CoordinateSetResult,
233}
234
235/// Interchange atom `atom_k`'s chart coordinate from `x_source` into `x_target`.
236///
237/// The source coordinate is certified with `source_amplitude`; the target write
238/// is performed with `target_amplitude`, so swapping a position coordinate cannot
239/// silently smuggle donor intensity into the target. The returned consistency
240/// e-value is computed by re-encoding the edited target and calibrating the
241/// coordinate landing error into the existing structure-evidence e-currency.
242pub fn interchange(
243    model: &SaeManifoldTerm,
244    metric: &RowMetric,
245    atlas: &EncodeAtlas,
246    x_target: ArrayView1<'_, f64>,
247    target_amplitude: f64,
248    x_source: ArrayView1<'_, f64>,
249    source_amplitude: f64,
250    atom_k: usize,
251) -> Result<InterchangeResult, String> {
252    let atom = model.atoms.get(atom_k).ok_or_else(|| {
253        format!(
254            "interchange: atom index {atom_k} out of range (term has {} atoms)",
255            model.k_atoms()
256        )
257    })?;
258    let (donor_t, _donor_cert) =
259        atlas.certified_encode_row(atom, atom_k, x_source, source_amplitude)?;
260    let set = set_coordinate(
261        model,
262        metric,
263        atlas,
264        x_target,
265        atom_k,
266        target_amplitude,
267        donor_t.as_slice().unwrap_or(&[]),
268    )?;
269    let (target_t_after, _after_cert) =
270        atlas.certified_encode_row(atom, atom_k, set.edited.view(), target_amplitude)?;
271    let landing_error = l2_distance(donor_t.view(), target_t_after.view())?;
272    let scale = set
273        .steer
274        .validity_radius
275        .unwrap_or_else(|| {
276            l2_distance(set.t_from_certified.view(), donor_t.view())
277                .unwrap_or(1.0)
278                .max(1e-12)
279        })
280        .max(1e-12);
281    // Convert closeness into a superuniform-shaped p-value and then into the
282    // repository's standard e-value currency. Exact hits approach machine-small
283    // p-values; errors at/above the validity radius produce e-values near or
284    // below one, so shuffled-chart negative controls do not accumulate evidence.
285    let z = (scale / landing_error.max(1e-12)).min(1.0e6);
286    let p_value = (-0.5 * z * z).exp().clamp(f64::MIN_POSITIVE, 1.0);
287    let log_e = log_e_from_p_calibrator(p_value)?;
288    Ok(InterchangeResult {
289        edited_target: set.edited.clone(),
290        donor_t,
291        target_t_before: set.t_from_certified.clone(),
292        target_t_after,
293        predicted_nats: set.steer.predicted_nats,
294        off_manifold_norm: set.steer.off_manifold_norm,
295        validity_radius: set.steer.validity_radius,
296        counterfactual_consistency_log_e: log_e,
297        set_result: set,
298    })
299}
300
301fn l2_distance(a: ArrayView1<'_, f64>, b: ArrayView1<'_, f64>) -> Result<f64, String> {
302    if a.len() != b.len() {
303        return Err(format!(
304            "coordinate distance length mismatch: {} vs {}",
305            a.len(),
306            b.len()
307        ));
308    }
309    let mut ss = 0.0;
310    for i in 0..a.len() {
311        let r = a[i] - b[i];
312        ss += r * r;
313    }
314    Ok(ss.sqrt())
315}
316
317/// Build a [`SteerPlan`] for driving atom `atom_k` from `t_from` to `t_to`.
318///
319/// `model` is the fitted term (read only); `metric` is the per-row output-Fisher
320/// inner product the dose is measured through (typically `model.row_metric()`'s
321/// own metric, or any metric whose row/output dims match the term). `t_from` and
322/// `t_to` are latent coordinates of length `atom.latent_dim`.
323///
324/// Errors when the atom index is out of range, the coordinate lengths do not
325/// match the atom's latent dimension, the atom has no installed
326/// [`crate::manifold::SaeBasisEvaluator`] (arbitrary-`t` evaluation
327/// requires one), or the metric dimensions do not match the term. Under a
328/// Euclidean (no-behavior) metric the geometry is still produced but
329/// `predicted_nats` / `validity_radius` degrade to `None`.
330pub fn steer_delta(
331    model: &SaeManifoldTerm,
332    metric: &RowMetric,
333    atom_k: usize,
334    t_from: &[f64],
335    t_to: &[f64],
336) -> Result<SteerPlan, String> {
337    steer_delta_impl(model, metric, atom_k, t_from, t_to, None)
338}
339
340fn steer_delta_with_amplitude(
341    model: &SaeManifoldTerm,
342    metric: &RowMetric,
343    atom_k: usize,
344    t_from: &[f64],
345    t_to: &[f64],
346    amplitude: f64,
347) -> Result<SteerPlan, String> {
348    if !(amplitude.is_finite() && amplitude > 0.0) {
349        return Err(format!(
350            "steer_delta_with_amplitude: amplitude must be finite and positive, got {amplitude}"
351        ));
352    }
353    steer_delta_impl(model, metric, atom_k, t_from, t_to, Some(amplitude))
354}
355
356fn steer_delta_impl(
357    model: &SaeManifoldTerm,
358    metric: &RowMetric,
359    atom_k: usize,
360    t_from: &[f64],
361    t_to: &[f64],
362    amplitude_override: Option<f64>,
363) -> Result<SteerPlan, String> {
364    let k = model.k_atoms();
365    if atom_k >= k {
366        return Err(format!(
367            "steer_delta: atom index {atom_k} out of range (term has {k} atoms)"
368        ));
369    }
370    let atom = &model.atoms[atom_k];
371    let d = atom.latent_dim;
372    let p = atom.output_dim();
373    if t_from.len() != d || t_to.len() != d {
374        return Err(format!(
375            "steer_delta: t_from/t_to must have length latent_dim={d}; got {} and {}",
376            t_from.len(),
377            t_to.len()
378        ));
379    }
380    let evaluator = atom.basis_evaluator.as_ref().ok_or_else(|| {
381        format!(
382            "steer_delta: atom {atom_k} ('{}') has no installed basis evaluator; \
383             arbitrary-t decoder evaluation requires one",
384            atom.name
385        )
386    })?;
387
388    // --- amplitude & the row the dose is measured through -------------------
389    // The amplitude is the atom's mean active assignment mass (how loudly the
390    // atom is expressed), mirroring the presence weighting in `atom_lens`. The
391    // measured row is the atom's single most-active row: the per-row metric
392    // there is the most representative behavioral inner product for "this atom
393    // is on". Both fall back gracefully when the atom is active nowhere.
394    let assignments = model.assignment.assignments();
395    let n = model.n_obs();
396    let mut mass_sum = 0.0_f64;
397    let mut active_count = 0.0_f64;
398    let mut best_row = 0usize;
399    let mut best_mass = f64::NEG_INFINITY;
400    for row in 0..n {
401        let mass = assignments[[row, atom_k]];
402        if mass > best_mass {
403            best_mass = mass;
404            best_row = row;
405        }
406        if mass > ACTIVE_MASS_FLOOR {
407            mass_sum += mass;
408            active_count += 1.0;
409        }
410    }
411    let amplitude = amplitude_override.unwrap_or_else(|| {
412        if active_count > 0.0 {
413            mass_sum / active_count
414        } else {
415            1.0
416        }
417    });
418
419    // --- the on-manifold activation-space delta -----------------------------
420    let g_from = decode_at(evaluator.as_ref(), &atom.decoder_coefficients, t_from, p)?;
421    let g_to = decode_at(evaluator.as_ref(), &atom.decoder_coefficients, t_to, p)?;
422    let mut delta = Array1::<f64>::zeros(p);
423    for i in 0..p {
424        delta[i] = amplitude * (g_to[i] - g_from[i]);
425    }
426
427    // Whether the metric can/does match this term and carries behavior.
428    let provenance = metric.provenance();
429    let behavior_available =
430        metric_carries_behavior(provenance) && metric.n_rows() == n && metric.p_out() == p;
431
432    // --- off-manifold guard -------------------------------------------------
433    // Project δ onto the span of the local decoder tangents ∂g_k/∂t and report
434    // the residual norm. The tangents are evaluated at the move's MIDPOINT, not
435    // at t_from: the chord of a curve is symmetric about its midpoint, so its
436    // component transverse to the midpoint tangent is the true second-order
437    // sagitta (`O(‖Δt‖²)`), whereas the endpoint tangent differs from the chord
438    // direction already at first order. Measuring against the midpoint frame is
439    // therefore the honest "did the move stay on the surface" self-check: it is
440    // `≈ 0` for an on-manifold move and grows only with genuine arc curvature.
441    let mut t_mid = vec![0.0_f64; d];
442    for a in 0..d {
443        t_mid[a] = 0.5 * (t_from[a] + t_to[a]);
444    }
445    let tangents =
446        decode_tangents_at(evaluator.as_ref(), &atom.decoder_coefficients, &t_mid, p, d)?;
447    let off_manifold_norm = off_manifold_residual_norm(&tangents, delta.view());
448
449    // --- dosimetry: path-integrated Fisher dose -----------------------------
450    let (predicted_nats, validity_radius) = if !behavior_available {
451        (None, None)
452    } else {
453        let ctx = SteerContext {
454            evaluator: evaluator.as_ref(),
455            decoder: &atom.decoder_coefficients,
456            metric,
457            row: best_row,
458            p,
459            d,
460            amplitude,
461        };
462        let dose = path_integrated_dose(&ctx, t_from, t_to)?;
463        let radius = validity_radius(&ctx, t_from, t_to)?;
464        (Some(dose), Some(radius))
465    };
466
467    Ok(SteerPlan {
468        atom: atom_k,
469        atom_name: atom.name.clone(),
470        t_from: t_from.to_vec(),
471        t_to: t_to.to_vec(),
472        amplitude,
473        measured_row: best_row,
474        delta,
475        predicted_nats,
476        validity_radius,
477        off_manifold_norm,
478        metric_provenance: provenance,
479    })
480}
481
482/// The model's predicted output-mean response to an applied activation push
483/// `δ`, under the LOCAL-LINEAR reading of its fitted surface: the projection
484/// of `δ` onto the span of atom `atom_k`'s decoder tangents `∂g_k/∂t` at the
485/// operating point `t_at`. A dictionary "predicts" exactly the component of a
486/// push it can carry along its learned surface; the transverse component is
487/// off-manifold and predicted to die (this is the same local model the
488/// off-manifold guard and the dosimetry chord trust, used in the same radius).
489///
490/// This is `μ(δ)` for the design loop of
491/// [`gam_terms::inference::structure_evidence`]: two structural hypotheses about
492/// the same activations (e.g. "one curved atom" vs "two flat atoms") are two
493/// fitted terms whose tangent spans differ, so they predict DIFFERENT
494/// responses to the same probe — and that disagreement, in the output-Fisher
495/// metric, is what `select_probe_by_expected_evidence` maximizes.
496pub fn predicted_response(
497    model: &SaeManifoldTerm,
498    atom_k: usize,
499    t_at: &[f64],
500    delta: ArrayView1<'_, f64>,
501) -> Result<Array1<f64>, String> {
502    let k = model.k_atoms();
503    if atom_k >= k {
504        return Err(format!(
505            "predicted_response: atom index {atom_k} out of range (term has {k} atoms)"
506        ));
507    }
508    let atom = &model.atoms[atom_k];
509    let d = atom.latent_dim;
510    let p = atom.output_dim();
511    if t_at.len() != d {
512        return Err(format!(
513            "predicted_response: t_at must have length latent_dim={d}; got {}",
514            t_at.len()
515        ));
516    }
517    if delta.len() != p {
518        return Err(format!(
519            "predicted_response: delta must have length output_dim={p}; got {}",
520            delta.len()
521        ));
522    }
523    let evaluator = atom.basis_evaluator.as_ref().ok_or_else(|| {
524        format!(
525            "predicted_response: atom {atom_k} ('{}') has no installed basis evaluator",
526            atom.name
527        )
528    })?;
529    let tangents = decode_tangents_at(evaluator.as_ref(), &atom.decoder_coefficients, t_at, p, d)?;
530    Ok(project_onto_tangent_span(&tangents, delta))
531}
532
533/// Does this provenance carry behavioral (output-Fisher) information? Euclidean
534/// is the isotropic activation-only path and carries none; the factored
535/// provenances do. (Mirrors `atom_lens::metric_carries_behavior`.)
536fn metric_carries_behavior(p: MetricProvenance) -> bool {
537    match p {
538        MetricProvenance::Euclidean => false,
539        MetricProvenance::OutputFisher { .. }
540        | MetricProvenance::OutputFisherDownstream { .. }
541        | MetricProvenance::BehavioralFisher { .. }
542        | MetricProvenance::WhitenedStructured { .. } => true,
543    }
544}
545
546/// Evaluate the decoder output `g_k(t) = Φ_k(t) B_k ∈ ℝ^p` at an arbitrary
547/// latent coordinate `t` (length `d`) via the atom's installed evaluator.
548fn decode_at(
549    evaluator: &dyn crate::manifold::SaeBasisEvaluator,
550    decoder: &Array2<f64>,
551    t: &[f64],
552    p: usize,
553) -> Result<Array1<f64>, String> {
554    let d = t.len();
555    let coords = Array2::from_shape_vec((1, d), t.to_vec())
556        .map_err(|e| format!("steer_delta::decode_at: coord shape: {e}"))?;
557    let (phi, _jet) = evaluator.evaluate(coords.view())?;
558    let m = decoder.nrows();
559    if phi.ncols() != m {
560        return Err(format!(
561            "steer_delta::decode_at: evaluator returned {} basis cols but decoder has {m} rows",
562            phi.ncols()
563        ));
564    }
565    let mut g = Array1::<f64>::zeros(p);
566    for basis_col in 0..m {
567        let phi_v = phi[[0, basis_col]];
568        if phi_v == 0.0 {
569            continue;
570        }
571        for out_col in 0..p {
572            g[out_col] += phi_v * decoder[[basis_col, out_col]];
573        }
574    }
575    Ok(g)
576}
577
578/// Evaluate the decoder tangents `∂g_k/∂t_a = Φ_k'(t) B_k ∈ ℝ^p`, one per latent
579/// axis `a ∈ 0..d`, at an arbitrary latent coordinate `t`. Returned as a
580/// `(p × d)` matrix whose column `a` is the tangent along axis `a`.
581fn decode_tangents_at(
582    evaluator: &dyn crate::manifold::SaeBasisEvaluator,
583    decoder: &Array2<f64>,
584    t: &[f64],
585    p: usize,
586    d: usize,
587) -> Result<Array2<f64>, String> {
588    let coords = Array2::from_shape_vec((1, d), t.to_vec())
589        .map_err(|e| format!("steer_delta::decode_tangents_at: coord shape: {e}"))?;
590    let (_phi, jet) = evaluator.evaluate(coords.view())?;
591    let m = decoder.nrows();
592    if jet.dim() != (1, m, d) {
593        return Err(format!(
594            "steer_delta::decode_tangents_at: evaluator jet {:?} != (1, {m}, {d})",
595            jet.dim()
596        ));
597    }
598    let mut tang = Array2::<f64>::zeros((p, d));
599    for axis in 0..d {
600        for basis_col in 0..m {
601            let dphi = jet[[0, basis_col, axis]];
602            if dphi == 0.0 {
603                continue;
604            }
605            for out_col in 0..p {
606                tang[[out_col, axis]] += dphi * decoder[[basis_col, out_col]];
607            }
608        }
609    }
610    Ok(tang)
611}
612
613/// Least-squares projection of `δ` onto the span of the local tangents
614/// (columns of `tangents`, shape `p × d`): `δ̂ = T (TᵀT)⁻¹ Tᵀ δ` via a small
615/// `d × d` Gram solve (with a tiny diagonal jitter to absorb a rank-deficient
616/// tangent frame; the jitter only shrinks the projection, never inflates it).
617fn project_onto_tangent_span(tangents: &Array2<f64>, delta: ArrayView1<'_, f64>) -> Array1<f64> {
618    let p = tangents.nrows();
619    let d = tangents.ncols();
620    if d == 0 {
621        return Array1::<f64>::zeros(p);
622    }
623    // Gram = TᵀT (d × d) and rhs = Tᵀδ (d).
624    let mut gram = Array2::<f64>::zeros((d, d));
625    let mut rhs = Array1::<f64>::zeros(d);
626    for a in 0..d {
627        let mut r = 0.0_f64;
628        for i in 0..p {
629            r += tangents[[i, a]] * delta[i];
630        }
631        rhs[a] = r;
632        for b in a..d {
633            let mut acc = 0.0_f64;
634            for i in 0..p {
635                acc += tangents[[i, a]] * tangents[[i, b]];
636            }
637            gram[[a, b]] = acc;
638            gram[[b, a]] = acc;
639        }
640    }
641    let trace: f64 = (0..d).map(|a| gram[[a, a]]).sum();
642    let jitter = if trace > 0.0 { 1e-12 * trace } else { 1e-12 };
643    for a in 0..d {
644        gram[[a, a]] += jitter;
645    }
646    let coeffs = solve_spd_small(&gram, &rhs);
647    let mut proj = Array1::<f64>::zeros(p);
648    for i in 0..p {
649        for a in 0..d {
650            proj[i] += tangents[[i, a]] * coeffs[a];
651        }
652    }
653    proj
654}
655
656/// Norm of `δ`'s component orthogonal to the span of the local tangents:
657/// `‖δ − δ̂‖` with `δ̂` the [`project_onto_tangent_span`] projection.
658fn off_manifold_residual_norm(tangents: &Array2<f64>, delta: ArrayView1<'_, f64>) -> f64 {
659    let proj = project_onto_tangent_span(tangents, delta);
660    let mut res_sq = 0.0_f64;
661    for i in 0..delta.len() {
662        let r = delta[i] - proj[i];
663        res_sq += r * r;
664    }
665    res_sq.max(0.0).sqrt()
666}
667
668/// Tiny symmetric-positive-definite solve via Cholesky for the `d × d` tangent
669/// Gram (`d` is the atom's latent dim, typically 1–3). Falls back to the bare rhs
670/// if the factorization fails (a fully degenerate frame), which only inflates the
671/// reported off-manifold residual — never deflates it.
672fn solve_spd_small(gram: &Array2<f64>, rhs: &Array1<f64>) -> Array1<f64> {
673    let d = gram.nrows();
674    // Cholesky L LᵀT = gram.
675    let mut l = Array2::<f64>::zeros((d, d));
676    for i in 0..d {
677        for j in 0..=i {
678            let mut sum = gram[[i, j]];
679            for k in 0..j {
680                sum -= l[[i, k]] * l[[j, k]];
681            }
682            if i == j {
683                if sum <= 0.0 {
684                    return Array1::<f64>::zeros(d);
685                }
686                l[[i, j]] = sum.sqrt();
687            } else {
688                l[[i, j]] = sum / l[[j, j]];
689            }
690        }
691    }
692    // Forward solve L y = rhs.
693    let mut y = Array1::<f64>::zeros(d);
694    for i in 0..d {
695        let mut sum = rhs[i];
696        for k in 0..i {
697            sum -= l[[i, k]] * y[k];
698        }
699        y[i] = sum / l[[i, i]];
700    }
701    // Back solve Lᵀ x = y.
702    let mut x = Array1::<f64>::zeros(d);
703    for i in (0..d).rev() {
704        let mut sum = y[i];
705        for k in (i + 1)..d {
706            sum -= l[[k, i]] * x[k];
707        }
708        x[i] = sum / l[[i, i]];
709    }
710    x
711}
712
713/// The fixed geometry of one steering query, bundled so the dose integrator and
714/// its helpers take a single context rather than a long argument list.
715struct SteerContext<'a> {
716    evaluator: &'a dyn crate::manifold::SaeBasisEvaluator,
717    decoder: &'a Array2<f64>,
718    metric: &'a RowMetric,
719    /// The row whose per-row metric the dose is measured through.
720    row: usize,
721    /// Output dimension `p`.
722    p: usize,
723    /// Latent dimension `d`.
724    d: usize,
725    /// Amplitude `a` the move is scaled by.
726    amplitude: f64,
727}
728
729/// Path-integrated Fisher dose
730/// `½ a² ∫ g_k'(t)ᵀ M g_k'(t) dt` along the straight latent segment
731/// `t(τ) = t_from + τ (t_to − t_from)`, `τ ∈ [0, 1]`, by the midpoint rule over
732/// [`STEER_PATH_STEPS`] sub-steps.
733///
734/// The local quadratic `g'(t)ᵀ M g'(t)` is the [`RowMetric::pullback`] of the
735/// per-axis decoder tangents contracted with the latent velocity `Δt`, so this
736/// uses only the criterion-facing pullback (no loss / no solver floor).
737fn path_integrated_dose(
738    ctx: &SteerContext<'_>,
739    t_from: &[f64],
740    t_to: &[f64],
741) -> Result<f64, String> {
742    let d = ctx.d;
743    let p = ctx.p;
744    let steps = STEER_PATH_STEPS;
745    let dtau = 1.0 / steps as f64;
746    // Latent velocity Δt (constant along the straight segment).
747    let mut dt = vec![0.0_f64; d];
748    for a in 0..d {
749        dt[a] = t_to[a] - t_from[a];
750    }
751    let mut acc = 0.0_f64;
752    let amp2 = ctx.amplitude * ctx.amplitude;
753    for s in 0..steps {
754        // Midpoint of sub-step s in τ, mapped to a latent coordinate.
755        let tau_mid = (s as f64 + 0.5) * dtau;
756        let mut t_mid = vec![0.0_f64; d];
757        for a in 0..d {
758            t_mid[a] = t_from[a] + tau_mid * dt[a];
759        }
760        // Decoder tangents at the midpoint: ∂g/∂t_a, columns of a (p × d) matrix.
761        let tang = decode_tangents_at(ctx.evaluator, ctx.decoder, &t_mid, p, d)?;
762        // The pulled-back metric at this point is g_{ab} = (∂g/∂t)ᵀ M (∂g/∂t),
763        // the d × d local inner product of latent motion *in output-Fisher
764        // units*. We form it through the criterion-facing `RowMetric::pullback`
765        // (which never materializes the p × p M and never sees the solver δ),
766        // then contract the latent velocity Δt twice: the squared output-Fisher
767        // speed along the path is Δtᵀ g Δt. The decoder Jacobian is passed flat
768        // row-major (J[i, a] = j_row[i * d + a]) as `pullback` expects.
769        let mut j_row = vec![0.0_f64; p * d];
770        for i in 0..p {
771            for a in 0..d {
772                j_row[i * d + a] = tang[[i, a]];
773            }
774        }
775        let g_ab = ctx.metric.pullback(ctx.row, &j_row, d);
776        let mut speed_sq = 0.0_f64;
777        for a in 0..d {
778            for b in 0..d {
779                speed_sq += dt[a] * g_ab[[a, b]] * dt[b];
780            }
781        }
782        acc += 0.5 * amp2 * speed_sq * dtau;
783    }
784    Ok(acc)
785}
786
787/// The validity radius: the latent step length (Euclidean distance from
788/// `t_from`) at which **local linearization stops being trusted**.
789///
790/// Linearizing the steering move means predicting the output effect of a prefix
791/// step `τ·Δt` from the initial tangent alone: the first-order output move is
792/// `δ_lin(τ) = a · (∂g/∂t|_{t_from}) · (τ Δt)`, whose output-Fisher KL is the
793/// quadratic form `½ ‖δ_lin(τ)‖²_M = τ² · ½ a² ‖∂g/∂t·Δt‖²_M`. The **true**
794/// effect of that prefix is the chord quadratic form of the *actual* curved
795/// output move `½ a² ‖g(t_from + τΔt) − g(t_from)‖²_M`.
796///
797/// The radius is the chord length `τ* · ‖Δt‖` at the first prefix `τ*` where the
798/// true chord KL diverges from the linear prediction by more than
799/// [`VALIDITY_DIVERGENCE_FRACTION`] (relative to the linear prediction). This is
800/// pure surface curvature: on a flat decoder the two agree for every `τ` and the
801/// radius is the whole move. If the metric kills the tangent (no linear effect to
802/// validate), the move is trusted to its full length.
803fn validity_radius(ctx: &SteerContext<'_>, t_from: &[f64], t_to: &[f64]) -> Result<f64, String> {
804    let d = ctx.d;
805    let p = ctx.p;
806    let full_len: f64 = t_from
807        .iter()
808        .zip(t_to.iter())
809        .map(|(&a, &b)| (b - a) * (b - a))
810        .sum::<f64>()
811        .sqrt();
812    if full_len == 0.0 {
813        return Ok(0.0);
814    }
815    let mut dt = vec![0.0_f64; d];
816    for a in 0..d {
817        dt[a] = t_to[a] - t_from[a];
818    }
819    let amp = ctx.amplitude;
820
821    // Initial-tangent linear output move per unit τ: v0 = (∂g/∂t|_{t_from}) Δt.
822    let tang0 = decode_tangents_at(ctx.evaluator, ctx.decoder, t_from, p, d)?;
823    let mut v0 = Array1::<f64>::zeros(p);
824    for i in 0..p {
825        let mut acc = 0.0_f64;
826        for a in 0..d {
827            acc += tang0[[i, a]] * dt[a];
828        }
829        v0[i] = acc;
830    }
831    // ½ a² ‖v0‖²_M — the per-τ² linear KL coefficient.
832    let lin_coeff = 0.5 * amp * amp * ctx.metric.fisher_mass(ctx.row, v0.view());
833    // No linear effect to validate against ⇒ trust the full move.
834    if !(lin_coeff > 0.0) {
835        return Ok(full_len);
836    }
837
838    let g_from = decode_at(ctx.evaluator, ctx.decoder, t_from, p)?;
839    let steps = STEER_PATH_STEPS;
840    for s in 0..steps {
841        let tau = (s as f64 + 1.0) / steps as f64;
842        let mut t_mid = vec![0.0_f64; d];
843        for a in 0..d {
844            t_mid[a] = t_from[a] + tau * dt[a];
845        }
846        let g_tau = decode_at(ctx.evaluator, ctx.decoder, &t_mid, p)?;
847        let mut chord = Array1::<f64>::zeros(p);
848        for i in 0..p {
849            chord[i] = amp * (g_tau[i] - g_from[i]);
850        }
851        // True chord KL of the prefix, and the linear prediction τ²·lin_coeff.
852        let chord_kl = 0.5 * ctx.metric.fisher_mass(ctx.row, chord.view());
853        let lin_kl = tau * tau * lin_coeff;
854        let rel = (chord_kl - lin_kl).abs() / lin_kl;
855        if rel > VALIDITY_DIVERGENCE_FRACTION {
856            return Ok(tau * full_len);
857        }
858    }
859    Ok(full_len)
860}