gam_sae/inference/steering.rs
1//! `steer_delta` — the **steering primitive with output dosimetry**: the
2//! actionable LLM payload of the SAE-manifold machine.
3//!
4//! # What this computes
5//!
6//! Given a fitted [`SaeManifoldTerm`] and the per-row output-Fisher
7//! [`RowMetric`], a *steering move* is "drive atom `k`'s latent coordinate from
8//! `t_from` to `t_to`". The atom's decoder curve `g_k(t) = Φ_k(t) B_k` maps that
9//! latent move to an **activation-space delta** — the actual vector you add to
10//! the residual stream / reconstruction to realize the move *on the manifold*:
11//!
12//! ```text
13//! δ = a · ( g_k(t_to) − g_k(t_from) ) (the on-manifold move)
14//! ```
15//!
16//! where `a` is the atom's amplitude (how loudly the atom is expressed). This is
17//! the thing a downstream consumer adds to a hidden state.
18//!
19//! # Dosimetry — how big is this push, in nats?
20//!
21//! The headline number is the **predicted output effect**: how much behavioral
22//! change (in nats of KL on the model's output distribution) the move induces.
23//! For a locally-quadratic output readout the KL of a parameter move `Δ` is
24//! `½ Δᵀ F Δ` with `F` the output-Fisher information — exactly the inner product
25//! [`RowMetric`] carries. The dose is the Fisher quadratic form of the move,
26//! **integrated along the decoder curve** rather than read only at the endpoints:
27//!
28//! ```text
29//! predicted_nats = ½ ∫_{t_from}^{t_to} a² · g_k'(t)ᵀ M_n g_k'(t) dt
30//! ```
31//!
32//! evaluated in small steps via the per-row pullback / fisher-mass methods. The
33//! path integral is the honest dose: it follows the curved surface, so a long arc
34//! that doubles back is not under-counted the way a straight endpoint chord would
35//! be.
36//!
37//! # Validity radius — where local linearization stops being trusted
38//!
39//! A consumer must know *how far* the move can be trusted as a linear push. The
40//! **validity radius** is the latent step size at which the path-integrated dose
41//! diverges from the straight endpoint quadratic form
42//! `½ a² δ̂ᵀ M δ̂` (the local-linear prediction) by more than
43//! [`VALIDITY_DIVERGENCE_FRACTION`]. Beyond it the surface has curved enough that
44//! the endpoint chord no longer represents the move. We **report** it; we do not
45//! silently clip to it.
46//!
47//! # Off-manifold guard
48//!
49//! `δ` is, by construction, a chord of the decoder curve, so it should lie in the
50//! atom's local tangent/frame at `t_from` (up to second-order curvature). The
51//! **off-manifold norm** projects `δ` onto the span of the local decoder tangents
52//! `∂g_k/∂t` at `t_from` and reports the residual norm — a self-check that the
53//! steering move stays on the learned surface. It is `≈ 0` for small steps and
54//! grows with arc curvature; a large value means the requested move left the
55//! manifold and the dose number is not to be trusted.
56//!
57//! # Read-only / no loss contact
58//!
59//! This module is a **pure read** over the fitted term and the metric. It calls
60//! only `g_k(t)` evaluation ([`SaeManifoldAtom`]'s decoder + installed
61//! [`SaeBasisEvaluator`]) and the criterion-facing
62//! [`RowMetric::fisher_mass`] / [`RowMetric::pullback`]. It never mutates the
63//! model, never touches a likelihood / criterion / penalty, and the solver floor
64//! `δ` of [`RowMetric`] never enters any number it reports (the fisher-mass /
65//! pullback face is `δ`-free, #747).
66
67use ndarray::{Array1, Array2, ArrayView1};
68
69use gam_problem::{MetricProvenance, RowMetric};
70use crate::manifold::SaeManifoldTerm;
71
72/// Number of sub-steps the latent path `[t_from, t_to]` is integrated over for
73/// the dosimetry path integral. The decoder curve is smooth, so a modest
74/// midpoint-rule grid resolves the arc; fixed (no clock / no adaptivity) so the
75/// reported dose is deterministic.
76const STEER_PATH_STEPS: usize = 64;
77
78/// The fraction by which the path-integrated dose may diverge from the straight
79/// endpoint quadratic form before the move is declared past its validity radius.
80/// At `0.1` we trust the linearization while the curved-path dose stays within
81/// 10% of the chord dose.
82const VALIDITY_DIVERGENCE_FRACTION: f64 = 0.1;
83
84/// Active-mass floor below which a row is "off" for an atom and excluded from the
85/// representative-row / amplitude selection. Mirrors `atom_lens::ACTIVE_MASS_FLOOR`.
86const ACTIVE_MASS_FLOOR: f64 = 1e-6;
87
88/// The actionable output of a steering query over one atom.
89#[derive(Clone, Debug, PartialEq)]
90pub struct SteerPlan {
91 /// Which atom was steered (index into [`SaeManifoldTerm::atoms`]).
92 pub atom: usize,
93 /// The atom's name (mirrors [`crate::manifold::SaeManifoldAtom::name`]).
94 pub atom_name: String,
95 /// The source latent coordinate `t_from` (length = atom's `latent_dim`).
96 pub t_from: Vec<f64>,
97 /// The target latent coordinate `t_to` (length = atom's `latent_dim`).
98 pub t_to: Vec<f64>,
99 /// The amplitude `a` the on-manifold move was scaled by (the atom's mean
100 /// active assignment mass; `1.0` if the atom is active on no row).
101 pub amplitude: f64,
102 /// The row whose per-row output-Fisher metric the dose was measured through
103 /// (the atom's most-active row; `0` if active nowhere).
104 pub measured_row: usize,
105 /// **The activation-space delta**: `δ = a · (g_k(t_to) − g_k(t_from))`, a
106 /// length-`p` vector in the reconstruction/output space — the actual move to
107 /// add to a hidden state.
108 pub delta: Array1<f64>,
109 /// **DOSIMETRY**: predicted output effect of the move in **nats** of KL,
110 /// integrated along the decoder curve through the output-Fisher metric.
111 /// `None` when the metric carries no behavioral information (Euclidean
112 /// provenance) — the dose is *not available*, not zero.
113 pub predicted_nats: Option<f64>,
114 /// **VALIDITY RADIUS**: the latent step size (Euclidean norm of the move from
115 /// `t_from`) at which the path-integrated dose first diverges from the
116 /// straight endpoint quadratic form by more than
117 /// [`VALIDITY_DIVERGENCE_FRACTION`]. Equals the full move length when the
118 /// linearization is trusted all the way to `t_to`. `None` under a no-behavior
119 /// metric (there is no dose to validate).
120 pub validity_radius: Option<f64>,
121 /// **OFF-MANIFOLD GUARD**: the norm of `δ`'s component outside the span of
122 /// the atom's local decoder tangents `∂g_k/∂t` at `t_from`. `≈ 0` by
123 /// construction (the move is a chord of the curve); a large value flags a
124 /// move that left the learned surface.
125 pub off_manifold_norm: f64,
126 /// The provenance of the metric the dose was read through, echoed so a
127 /// consumer can certify *why* `predicted_nats` is `None` when it is.
128 pub metric_provenance: MetricProvenance,
129}
130
131/// Build a [`SteerPlan`] for driving atom `atom_k` from `t_from` to `t_to`.
132///
133/// `model` is the fitted term (read only); `metric` is the per-row output-Fisher
134/// inner product the dose is measured through (typically `model.row_metric()`'s
135/// own metric, or any metric whose row/output dims match the term). `t_from` and
136/// `t_to` are latent coordinates of length `atom.latent_dim`.
137///
138/// Errors when the atom index is out of range, the coordinate lengths do not
139/// match the atom's latent dimension, the atom has no installed
140/// [`crate::manifold::SaeBasisEvaluator`] (arbitrary-`t` evaluation
141/// requires one), or the metric dimensions do not match the term. Under a
142/// Euclidean (no-behavior) metric the geometry is still produced but
143/// `predicted_nats` / `validity_radius` degrade to `None`.
144pub fn steer_delta(
145 model: &SaeManifoldTerm,
146 metric: &RowMetric,
147 atom_k: usize,
148 t_from: &[f64],
149 t_to: &[f64],
150) -> Result<SteerPlan, String> {
151 let k = model.k_atoms();
152 if atom_k >= k {
153 return Err(format!(
154 "steer_delta: atom index {atom_k} out of range (term has {k} atoms)"
155 ));
156 }
157 let atom = &model.atoms[atom_k];
158 let d = atom.latent_dim;
159 let p = atom.output_dim();
160 if t_from.len() != d || t_to.len() != d {
161 return Err(format!(
162 "steer_delta: t_from/t_to must have length latent_dim={d}; got {} and {}",
163 t_from.len(),
164 t_to.len()
165 ));
166 }
167 let evaluator = atom.basis_evaluator.as_ref().ok_or_else(|| {
168 format!(
169 "steer_delta: atom {atom_k} ('{}') has no installed basis evaluator; \
170 arbitrary-t decoder evaluation requires one",
171 atom.name
172 )
173 })?;
174
175 // --- amplitude & the row the dose is measured through -------------------
176 // The amplitude is the atom's mean active assignment mass (how loudly the
177 // atom is expressed), mirroring the presence weighting in `atom_lens`. The
178 // measured row is the atom's single most-active row: the per-row metric
179 // there is the most representative behavioral inner product for "this atom
180 // is on". Both fall back gracefully when the atom is active nowhere.
181 let assignments = model.assignment.assignments();
182 let n = model.n_obs();
183 let mut mass_sum = 0.0_f64;
184 let mut active_count = 0.0_f64;
185 let mut best_row = 0usize;
186 let mut best_mass = f64::NEG_INFINITY;
187 for row in 0..n {
188 let mass = assignments[[row, atom_k]];
189 if mass > best_mass {
190 best_mass = mass;
191 best_row = row;
192 }
193 if mass > ACTIVE_MASS_FLOOR {
194 mass_sum += mass;
195 active_count += 1.0;
196 }
197 }
198 let amplitude = if active_count > 0.0 {
199 mass_sum / active_count
200 } else {
201 1.0
202 };
203
204 // --- the on-manifold activation-space delta -----------------------------
205 let g_from = decode_at(evaluator.as_ref(), &atom.decoder_coefficients, t_from, p)?;
206 let g_to = decode_at(evaluator.as_ref(), &atom.decoder_coefficients, t_to, p)?;
207 let mut delta = Array1::<f64>::zeros(p);
208 for i in 0..p {
209 delta[i] = amplitude * (g_to[i] - g_from[i]);
210 }
211
212 // Whether the metric can/does match this term and carries behavior.
213 let provenance = metric.provenance();
214 let behavior_available =
215 metric_carries_behavior(provenance) && metric.n_rows() == n && metric.p_out() == p;
216
217 // --- off-manifold guard -------------------------------------------------
218 // Project δ onto the span of the local decoder tangents ∂g_k/∂t and report
219 // the residual norm. The tangents are evaluated at the move's MIDPOINT, not
220 // at t_from: the chord of a curve is symmetric about its midpoint, so its
221 // component transverse to the midpoint tangent is the true second-order
222 // sagitta (`O(‖Δt‖²)`), whereas the endpoint tangent differs from the chord
223 // direction already at first order. Measuring against the midpoint frame is
224 // therefore the honest "did the move stay on the surface" self-check: it is
225 // `≈ 0` for an on-manifold move and grows only with genuine arc curvature.
226 let mut t_mid = vec![0.0_f64; d];
227 for a in 0..d {
228 t_mid[a] = 0.5 * (t_from[a] + t_to[a]);
229 }
230 let tangents =
231 decode_tangents_at(evaluator.as_ref(), &atom.decoder_coefficients, &t_mid, p, d)?;
232 let off_manifold_norm = off_manifold_residual_norm(&tangents, delta.view());
233
234 // --- dosimetry: path-integrated Fisher dose -----------------------------
235 let (predicted_nats, validity_radius) = if !behavior_available {
236 (None, None)
237 } else {
238 let ctx = SteerContext {
239 evaluator: evaluator.as_ref(),
240 decoder: &atom.decoder_coefficients,
241 metric,
242 row: best_row,
243 p,
244 d,
245 amplitude,
246 };
247 let dose = path_integrated_dose(&ctx, t_from, t_to)?;
248 let radius = validity_radius(&ctx, t_from, t_to)?;
249 (Some(dose), Some(radius))
250 };
251
252 Ok(SteerPlan {
253 atom: atom_k,
254 atom_name: atom.name.clone(),
255 t_from: t_from.to_vec(),
256 t_to: t_to.to_vec(),
257 amplitude,
258 measured_row: best_row,
259 delta,
260 predicted_nats,
261 validity_radius,
262 off_manifold_norm,
263 metric_provenance: provenance,
264 })
265}
266
267/// The model's predicted output-mean response to an applied activation push
268/// `δ`, under the LOCAL-LINEAR reading of its fitted surface: the projection
269/// of `δ` onto the span of atom `atom_k`'s decoder tangents `∂g_k/∂t` at the
270/// operating point `t_at`. A dictionary "predicts" exactly the component of a
271/// push it can carry along its learned surface; the transverse component is
272/// off-manifold and predicted to die (this is the same local model the
273/// off-manifold guard and the dosimetry chord trust, used in the same radius).
274///
275/// This is `μ(δ)` for the design loop of
276/// [`gam_terms::inference::structure_evidence`]: two structural hypotheses about
277/// the same activations (e.g. "one curved atom" vs "two flat atoms") are two
278/// fitted terms whose tangent spans differ, so they predict DIFFERENT
279/// responses to the same probe — and that disagreement, in the output-Fisher
280/// metric, is what `select_probe_by_expected_evidence` maximizes.
281pub fn predicted_response(
282 model: &SaeManifoldTerm,
283 atom_k: usize,
284 t_at: &[f64],
285 delta: ArrayView1<'_, f64>,
286) -> Result<Array1<f64>, String> {
287 let k = model.k_atoms();
288 if atom_k >= k {
289 return Err(format!(
290 "predicted_response: atom index {atom_k} out of range (term has {k} atoms)"
291 ));
292 }
293 let atom = &model.atoms[atom_k];
294 let d = atom.latent_dim;
295 let p = atom.output_dim();
296 if t_at.len() != d {
297 return Err(format!(
298 "predicted_response: t_at must have length latent_dim={d}; got {}",
299 t_at.len()
300 ));
301 }
302 if delta.len() != p {
303 return Err(format!(
304 "predicted_response: delta must have length output_dim={p}; got {}",
305 delta.len()
306 ));
307 }
308 let evaluator = atom.basis_evaluator.as_ref().ok_or_else(|| {
309 format!(
310 "predicted_response: atom {atom_k} ('{}') has no installed basis evaluator",
311 atom.name
312 )
313 })?;
314 let tangents = decode_tangents_at(evaluator.as_ref(), &atom.decoder_coefficients, t_at, p, d)?;
315 Ok(project_onto_tangent_span(&tangents, delta))
316}
317
318/// Does this provenance carry behavioral (output-Fisher) information? Euclidean
319/// is the isotropic activation-only path and carries none; the factored
320/// provenances do. (Mirrors `atom_lens::metric_carries_behavior`.)
321fn metric_carries_behavior(p: MetricProvenance) -> bool {
322 match p {
323 MetricProvenance::Euclidean => false,
324 MetricProvenance::OutputFisher { .. }
325 | MetricProvenance::OutputFisherDownstream { .. }
326 | MetricProvenance::WhitenedStructured { .. } => true,
327 }
328}
329
330/// Evaluate the decoder output `g_k(t) = Φ_k(t) B_k ∈ ℝ^p` at an arbitrary
331/// latent coordinate `t` (length `d`) via the atom's installed evaluator.
332fn decode_at(
333 evaluator: &dyn crate::manifold::SaeBasisEvaluator,
334 decoder: &Array2<f64>,
335 t: &[f64],
336 p: usize,
337) -> Result<Array1<f64>, String> {
338 let d = t.len();
339 let coords = Array2::from_shape_vec((1, d), t.to_vec())
340 .map_err(|e| format!("steer_delta::decode_at: coord shape: {e}"))?;
341 let (phi, _jet) = evaluator.evaluate(coords.view())?;
342 let m = decoder.nrows();
343 if phi.ncols() != m {
344 return Err(format!(
345 "steer_delta::decode_at: evaluator returned {} basis cols but decoder has {m} rows",
346 phi.ncols()
347 ));
348 }
349 let mut g = Array1::<f64>::zeros(p);
350 for basis_col in 0..m {
351 let phi_v = phi[[0, basis_col]];
352 if phi_v == 0.0 {
353 continue;
354 }
355 for out_col in 0..p {
356 g[out_col] += phi_v * decoder[[basis_col, out_col]];
357 }
358 }
359 Ok(g)
360}
361
362/// Evaluate the decoder tangents `∂g_k/∂t_a = Φ_k'(t) B_k ∈ ℝ^p`, one per latent
363/// axis `a ∈ 0..d`, at an arbitrary latent coordinate `t`. Returned as a
364/// `(p × d)` matrix whose column `a` is the tangent along axis `a`.
365fn decode_tangents_at(
366 evaluator: &dyn crate::manifold::SaeBasisEvaluator,
367 decoder: &Array2<f64>,
368 t: &[f64],
369 p: usize,
370 d: usize,
371) -> Result<Array2<f64>, String> {
372 let coords = Array2::from_shape_vec((1, d), t.to_vec())
373 .map_err(|e| format!("steer_delta::decode_tangents_at: coord shape: {e}"))?;
374 let (_phi, jet) = evaluator.evaluate(coords.view())?;
375 let m = decoder.nrows();
376 if jet.dim() != (1, m, d) {
377 return Err(format!(
378 "steer_delta::decode_tangents_at: evaluator jet {:?} != (1, {m}, {d})",
379 jet.dim()
380 ));
381 }
382 let mut tang = Array2::<f64>::zeros((p, d));
383 for axis in 0..d {
384 for basis_col in 0..m {
385 let dphi = jet[[0, basis_col, axis]];
386 if dphi == 0.0 {
387 continue;
388 }
389 for out_col in 0..p {
390 tang[[out_col, axis]] += dphi * decoder[[basis_col, out_col]];
391 }
392 }
393 }
394 Ok(tang)
395}
396
397/// Least-squares projection of `δ` onto the span of the local tangents
398/// (columns of `tangents`, shape `p × d`): `δ̂ = T (TᵀT)⁻¹ Tᵀ δ` via a small
399/// `d × d` Gram solve (with a tiny diagonal jitter to absorb a rank-deficient
400/// tangent frame; the jitter only shrinks the projection, never inflates it).
401fn project_onto_tangent_span(tangents: &Array2<f64>, delta: ArrayView1<'_, f64>) -> Array1<f64> {
402 let p = tangents.nrows();
403 let d = tangents.ncols();
404 if d == 0 {
405 return Array1::<f64>::zeros(p);
406 }
407 // Gram = TᵀT (d × d) and rhs = Tᵀδ (d).
408 let mut gram = Array2::<f64>::zeros((d, d));
409 let mut rhs = Array1::<f64>::zeros(d);
410 for a in 0..d {
411 let mut r = 0.0_f64;
412 for i in 0..p {
413 r += tangents[[i, a]] * delta[i];
414 }
415 rhs[a] = r;
416 for b in a..d {
417 let mut acc = 0.0_f64;
418 for i in 0..p {
419 acc += tangents[[i, a]] * tangents[[i, b]];
420 }
421 gram[[a, b]] = acc;
422 gram[[b, a]] = acc;
423 }
424 }
425 let trace: f64 = (0..d).map(|a| gram[[a, a]]).sum();
426 let jitter = if trace > 0.0 { 1e-12 * trace } else { 1e-12 };
427 for a in 0..d {
428 gram[[a, a]] += jitter;
429 }
430 let coeffs = solve_spd_small(&gram, &rhs);
431 let mut proj = Array1::<f64>::zeros(p);
432 for i in 0..p {
433 for a in 0..d {
434 proj[i] += tangents[[i, a]] * coeffs[a];
435 }
436 }
437 proj
438}
439
440/// Norm of `δ`'s component orthogonal to the span of the local tangents:
441/// `‖δ − δ̂‖` with `δ̂` the [`project_onto_tangent_span`] projection.
442fn off_manifold_residual_norm(tangents: &Array2<f64>, delta: ArrayView1<'_, f64>) -> f64 {
443 let proj = project_onto_tangent_span(tangents, delta);
444 let mut res_sq = 0.0_f64;
445 for i in 0..delta.len() {
446 let r = delta[i] - proj[i];
447 res_sq += r * r;
448 }
449 res_sq.max(0.0).sqrt()
450}
451
452/// Tiny symmetric-positive-definite solve via Cholesky for the `d × d` tangent
453/// Gram (`d` is the atom's latent dim, typically 1–3). Falls back to the bare rhs
454/// if the factorization fails (a fully degenerate frame), which only inflates the
455/// reported off-manifold residual — never deflates it.
456fn solve_spd_small(gram: &Array2<f64>, rhs: &Array1<f64>) -> Array1<f64> {
457 let d = gram.nrows();
458 // Cholesky L LᵀT = gram.
459 let mut l = Array2::<f64>::zeros((d, d));
460 for i in 0..d {
461 for j in 0..=i {
462 let mut sum = gram[[i, j]];
463 for k in 0..j {
464 sum -= l[[i, k]] * l[[j, k]];
465 }
466 if i == j {
467 if sum <= 0.0 {
468 return Array1::<f64>::zeros(d);
469 }
470 l[[i, j]] = sum.sqrt();
471 } else {
472 l[[i, j]] = sum / l[[j, j]];
473 }
474 }
475 }
476 // Forward solve L y = rhs.
477 let mut y = Array1::<f64>::zeros(d);
478 for i in 0..d {
479 let mut sum = rhs[i];
480 for k in 0..i {
481 sum -= l[[i, k]] * y[k];
482 }
483 y[i] = sum / l[[i, i]];
484 }
485 // Back solve Lᵀ x = y.
486 let mut x = Array1::<f64>::zeros(d);
487 for i in (0..d).rev() {
488 let mut sum = y[i];
489 for k in (i + 1)..d {
490 sum -= l[[k, i]] * x[k];
491 }
492 x[i] = sum / l[[i, i]];
493 }
494 x
495}
496
497/// The fixed geometry of one steering query, bundled so the dose integrator and
498/// its helpers take a single context rather than a long argument list.
499struct SteerContext<'a> {
500 evaluator: &'a dyn crate::manifold::SaeBasisEvaluator,
501 decoder: &'a Array2<f64>,
502 metric: &'a RowMetric,
503 /// The row whose per-row metric the dose is measured through.
504 row: usize,
505 /// Output dimension `p`.
506 p: usize,
507 /// Latent dimension `d`.
508 d: usize,
509 /// Amplitude `a` the move is scaled by.
510 amplitude: f64,
511}
512
513/// Path-integrated Fisher dose
514/// `½ a² ∫ g_k'(t)ᵀ M g_k'(t) dt` along the straight latent segment
515/// `t(τ) = t_from + τ (t_to − t_from)`, `τ ∈ [0, 1]`, by the midpoint rule over
516/// [`STEER_PATH_STEPS`] sub-steps.
517///
518/// The local quadratic `g'(t)ᵀ M g'(t)` is the [`RowMetric::pullback`] of the
519/// per-axis decoder tangents contracted with the latent velocity `Δt`, so this
520/// uses only the criterion-facing pullback (no loss / no solver floor).
521fn path_integrated_dose(
522 ctx: &SteerContext<'_>,
523 t_from: &[f64],
524 t_to: &[f64],
525) -> Result<f64, String> {
526 let d = ctx.d;
527 let p = ctx.p;
528 let steps = STEER_PATH_STEPS;
529 let dtau = 1.0 / steps as f64;
530 // Latent velocity Δt (constant along the straight segment).
531 let mut dt = vec![0.0_f64; d];
532 for a in 0..d {
533 dt[a] = t_to[a] - t_from[a];
534 }
535 let mut acc = 0.0_f64;
536 let amp2 = ctx.amplitude * ctx.amplitude;
537 for s in 0..steps {
538 // Midpoint of sub-step s in τ, mapped to a latent coordinate.
539 let tau_mid = (s as f64 + 0.5) * dtau;
540 let mut t_mid = vec![0.0_f64; d];
541 for a in 0..d {
542 t_mid[a] = t_from[a] + tau_mid * dt[a];
543 }
544 // Decoder tangents at the midpoint: ∂g/∂t_a, columns of a (p × d) matrix.
545 let tang = decode_tangents_at(ctx.evaluator, ctx.decoder, &t_mid, p, d)?;
546 // The pulled-back metric at this point is g_{ab} = (∂g/∂t)ᵀ M (∂g/∂t),
547 // the d × d local inner product of latent motion *in output-Fisher
548 // units*. We form it through the criterion-facing `RowMetric::pullback`
549 // (which never materializes the p × p M and never sees the solver δ),
550 // then contract the latent velocity Δt twice: the squared output-Fisher
551 // speed along the path is Δtᵀ g Δt. The decoder Jacobian is passed flat
552 // row-major (J[i, a] = j_row[i * d + a]) as `pullback` expects.
553 let mut j_row = vec![0.0_f64; p * d];
554 for i in 0..p {
555 for a in 0..d {
556 j_row[i * d + a] = tang[[i, a]];
557 }
558 }
559 let g_ab = ctx.metric.pullback(ctx.row, &j_row, d);
560 let mut speed_sq = 0.0_f64;
561 for a in 0..d {
562 for b in 0..d {
563 speed_sq += dt[a] * g_ab[[a, b]] * dt[b];
564 }
565 }
566 acc += 0.5 * amp2 * speed_sq * dtau;
567 }
568 Ok(acc)
569}
570
571/// The validity radius: the latent step length (Euclidean distance from
572/// `t_from`) at which **local linearization stops being trusted**.
573///
574/// Linearizing the steering move means predicting the output effect of a prefix
575/// step `τ·Δt` from the initial tangent alone: the first-order output move is
576/// `δ_lin(τ) = a · (∂g/∂t|_{t_from}) · (τ Δt)`, whose output-Fisher KL is the
577/// quadratic form `½ ‖δ_lin(τ)‖²_M = τ² · ½ a² ‖∂g/∂t·Δt‖²_M`. The **true**
578/// effect of that prefix is the chord quadratic form of the *actual* curved
579/// output move `½ a² ‖g(t_from + τΔt) − g(t_from)‖²_M`.
580///
581/// The radius is the chord length `τ* · ‖Δt‖` at the first prefix `τ*` where the
582/// true chord KL diverges from the linear prediction by more than
583/// [`VALIDITY_DIVERGENCE_FRACTION`] (relative to the linear prediction). This is
584/// pure surface curvature: on a flat decoder the two agree for every `τ` and the
585/// radius is the whole move. If the metric kills the tangent (no linear effect to
586/// validate), the move is trusted to its full length.
587fn validity_radius(ctx: &SteerContext<'_>, t_from: &[f64], t_to: &[f64]) -> Result<f64, String> {
588 let d = ctx.d;
589 let p = ctx.p;
590 let full_len: f64 = t_from
591 .iter()
592 .zip(t_to.iter())
593 .map(|(&a, &b)| (b - a) * (b - a))
594 .sum::<f64>()
595 .sqrt();
596 if full_len == 0.0 {
597 return Ok(0.0);
598 }
599 let mut dt = vec![0.0_f64; d];
600 for a in 0..d {
601 dt[a] = t_to[a] - t_from[a];
602 }
603 let amp = ctx.amplitude;
604
605 // Initial-tangent linear output move per unit τ: v0 = (∂g/∂t|_{t_from}) Δt.
606 let tang0 = decode_tangents_at(ctx.evaluator, ctx.decoder, t_from, p, d)?;
607 let mut v0 = Array1::<f64>::zeros(p);
608 for i in 0..p {
609 let mut acc = 0.0_f64;
610 for a in 0..d {
611 acc += tang0[[i, a]] * dt[a];
612 }
613 v0[i] = acc;
614 }
615 // ½ a² ‖v0‖²_M — the per-τ² linear KL coefficient.
616 let lin_coeff = 0.5 * amp * amp * ctx.metric.fisher_mass(ctx.row, v0.view());
617 // No linear effect to validate against ⇒ trust the full move.
618 if !(lin_coeff > 0.0) {
619 return Ok(full_len);
620 }
621
622 let g_from = decode_at(ctx.evaluator, ctx.decoder, t_from, p)?;
623 let steps = STEER_PATH_STEPS;
624 for s in 0..steps {
625 let tau = (s as f64 + 1.0) / steps as f64;
626 let mut t_mid = vec![0.0_f64; d];
627 for a in 0..d {
628 t_mid[a] = t_from[a] + tau * dt[a];
629 }
630 let g_tau = decode_at(ctx.evaluator, ctx.decoder, &t_mid, p)?;
631 let mut chord = Array1::<f64>::zeros(p);
632 for i in 0..p {
633 chord[i] = amp * (g_tau[i] - g_from[i]);
634 }
635 // True chord KL of the prefix, and the linear prediction τ²·lin_coeff.
636 let chord_kl = 0.5 * ctx.metric.fisher_mass(ctx.row, chord.view());
637 let lin_kl = tau * tau * lin_coeff;
638 let rel = (chord_kl - lin_kl).abs() / lin_kl;
639 if rel > VALIDITY_DIVERGENCE_FRACTION {
640 return Ok(tau * full_len);
641 }
642 }
643 Ok(full_len)
644}