gam_sae/inference/steering.rs
1//! `steer_delta` — the **steering primitive with output dosimetry**: the
2//! actionable LLM payload of the SAE-manifold machine.
3//!
4//! # What this computes
5//!
6//! Given a fitted [`SaeManifoldTerm`] and the per-row output-Fisher
7//! [`RowMetric`], a *steering move* is "drive atom `k`'s latent coordinate from
8//! `t_from` to `t_to`". The atom's decoder curve `g_k(t) = Φ_k(t) B_k` maps that
9//! latent move to an **activation-space delta** — the actual vector you add to
10//! the residual stream / reconstruction to realize the move *on the manifold*:
11//!
12//! ```text
13//! δ = a · ( g_k(t_to) − g_k(t_from) ) (the on-manifold move)
14//! ```
15//!
16//! where `a` is the atom's amplitude (how loudly the atom is expressed). This is
17//! the thing a downstream consumer adds to a hidden state.
18//!
19//! # Dosimetry — how big is this push, in nats?
20//!
21//! The headline number is the **predicted output effect**: how much behavioral
22//! change (in nats of KL on the model's output distribution) the move induces.
23//! For a locally-quadratic output readout the KL of a parameter move `Δ` is
24//! `½ Δᵀ F Δ` with `F` the output-Fisher information — exactly the inner product
25//! [`RowMetric`] carries. The dose is the Fisher quadratic form of the move,
26//! **integrated along the decoder curve** rather than read only at the endpoints:
27//!
28//! ```text
29//! predicted_nats = ½ ∫_{t_from}^{t_to} a² · g_k'(t)ᵀ M_n g_k'(t) dt
30//! ```
31//!
32//! evaluated in small steps via the per-row pullback / fisher-mass methods. The
33//! path integral is the honest dose: it follows the curved surface, so a long arc
34//! that doubles back is not under-counted the way a straight endpoint chord would
35//! be.
36//!
37//! # Validity radius — where local linearization stops being trusted
38//!
39//! A consumer must know *how far* the move can be trusted as a linear push. The
40//! **validity radius** is the latent step size at which the path-integrated dose
41//! diverges from the straight endpoint quadratic form
42//! `½ a² δ̂ᵀ M δ̂` (the local-linear prediction) by more than
43//! [`VALIDITY_DIVERGENCE_FRACTION`]. Beyond it the surface has curved enough that
44//! the endpoint chord no longer represents the move. We **report** it; we do not
45//! silently clip to it.
46//!
47//! # Off-manifold guard
48//!
49//! `δ` is, by construction, a chord of the decoder curve, so it should lie in the
50//! atom's local tangent/frame at `t_from` (up to second-order curvature). The
51//! **off-manifold norm** projects `δ` onto the span of the local decoder tangents
52//! `∂g_k/∂t` at `t_from` and reports the residual norm — a self-check that the
53//! steering move stays on the learned surface. It is `≈ 0` for small steps and
54//! grows with arc curvature; a large value means the requested move left the
55//! manifold and the dose number is not to be trusted.
56//!
57//! # Read-only / no loss contact
58//!
59//! This module is a **pure read** over the fitted term and the metric. It calls
60//! only `g_k(t)` evaluation ([`SaeManifoldAtom`]'s decoder + installed
61//! [`SaeBasisEvaluator`]) and the criterion-facing
62//! [`RowMetric::fisher_mass`] / [`RowMetric::pullback`]. It never mutates the
63//! model, never touches a likelihood / criterion / penalty, and the solver floor
64//! `δ` of [`RowMetric`] never enters any number it reports (the fisher-mass /
65//! pullback face is `δ`-free, #747).
66
67use ndarray::{Array1, Array2, ArrayView1};
68
69use crate::encode::EncodeAtlas;
70use crate::manifold::SaeManifoldTerm;
71use gam_problem::{MetricProvenance, RowMetric};
72use gam_terms::inference::structure_evidence::log_e_from_p_calibrator;
73
74/// Number of sub-steps the latent path `[t_from, t_to]` is integrated over for
75/// the dosimetry path integral. The decoder curve is smooth, so a modest
76/// midpoint-rule grid resolves the arc; fixed (no clock / no adaptivity) so the
77/// reported dose is deterministic.
78const STEER_PATH_STEPS: usize = 64;
79
80/// The fraction by which the path-integrated dose may diverge from the straight
81/// endpoint quadratic form before the move is declared past its validity radius.
82/// At `0.1` we trust the linearization while the curved-path dose stays within
83/// 10% of the chord dose.
84const VALIDITY_DIVERGENCE_FRACTION: f64 = 0.1;
85
86/// Active-mass floor below which a row is "off" for an atom and excluded from the
87/// representative-row / amplitude selection. Mirrors `atom_lens::ACTIVE_MASS_FLOOR`.
88const ACTIVE_MASS_FLOOR: f64 = 1e-6;
89
90/// The actionable output of a steering query over one atom.
91#[derive(Clone, Debug, PartialEq)]
92pub struct SteerPlan {
93 /// Which atom was steered (index into [`SaeManifoldTerm::atoms`]).
94 pub atom: usize,
95 /// The atom's name (mirrors [`crate::manifold::SaeManifoldAtom::name`]).
96 pub atom_name: String,
97 /// The source latent coordinate `t_from` (length = atom's `latent_dim`).
98 pub t_from: Vec<f64>,
99 /// The target latent coordinate `t_to` (length = atom's `latent_dim`).
100 pub t_to: Vec<f64>,
101 /// The amplitude `a` the on-manifold move was scaled by (the atom's mean
102 /// active assignment mass; `1.0` if the atom is active on no row).
103 pub amplitude: f64,
104 /// The row whose per-row output-Fisher metric the dose was measured through
105 /// (the atom's most-active row; `0` if active nowhere).
106 pub measured_row: usize,
107 /// **The activation-space delta**: `δ = a · (g_k(t_to) − g_k(t_from))`, a
108 /// length-`p` vector in the reconstruction/output space — the actual move to
109 /// add to a hidden state.
110 pub delta: Array1<f64>,
111 /// **DOSIMETRY**: predicted output effect of the move in **nats** of KL,
112 /// integrated along the decoder curve through the output-Fisher metric.
113 /// `None` when the metric carries no behavioral information (Euclidean
114 /// provenance) — the dose is *not available*, not zero.
115 pub predicted_nats: Option<f64>,
116 /// **VALIDITY RADIUS**: the latent step size (Euclidean norm of the move from
117 /// `t_from`) at which the path-integrated dose first diverges from the
118 /// straight endpoint quadratic form by more than
119 /// [`VALIDITY_DIVERGENCE_FRACTION`]. Equals the full move length when the
120 /// linearization is trusted all the way to `t_to`. `None` under a no-behavior
121 /// metric (there is no dose to validate).
122 pub validity_radius: Option<f64>,
123 /// **OFF-MANIFOLD GUARD**: the norm of `δ`'s component outside the span of
124 /// the atom's local decoder tangents `∂g_k/∂t` at `t_from`. `≈ 0` by
125 /// construction (the move is a chord of the curve); a large value flags a
126 /// move that left the learned surface.
127 pub off_manifold_norm: f64,
128 /// The provenance of the metric the dose was read through, echoed so a
129 /// consumer can certify *why* `predicted_nats` is `None` when it is.
130 pub metric_provenance: MetricProvenance,
131}
132
133/// Result of writing one certified chart coordinate into an activation row.
134///
135/// The edited row is always `x + δ`, where `δ` is the delta returned by
136/// [`steer_delta`] for the atom's current encoded coordinate and the requested
137/// target coordinate. Because only the on-manifold atom chord is added, every
138/// component of `x` outside this atom's chart residual is preserved exactly; this
139/// is the locality guarantee missing from whole-residual linear-steering
140/// baselines.
141#[derive(Clone, Debug)]
142pub struct CoordinateSetResult {
143 /// The edited activation/reconstruction row.
144 pub edited: Array1<f64>,
145 /// Certified coordinate read from the input row before the write.
146 pub t_from_certified: Array1<f64>,
147 /// Certificate attached to `t_from_certified`.
148 pub encode_certificate: crate::encode::RowCertificate,
149 /// Steering plan whose `delta` was added to the row.
150 pub steer: SteerPlan,
151}
152
153/// Write atom `atom_k`'s chart coordinate in row `x` to `t_to` by delta
154/// steering, preserving the row's off-atom/off-subspace residual exactly.
155///
156/// `amplitude` is the assignment/intensity with which the row expresses this
157/// atom; callers that have already separated existence/intensity/position should
158/// pass the intensity and only swap the position coordinate. The certified read
159/// uses [`EncodeAtlas::certified_encode_row`]; the write uses [`steer_delta`].
160pub fn set_coordinate(
161 model: &SaeManifoldTerm,
162 metric: &RowMetric,
163 atlas: &EncodeAtlas,
164 x: ArrayView1<'_, f64>,
165 atom_k: usize,
166 amplitude: f64,
167 t_to: &[f64],
168) -> Result<CoordinateSetResult, String> {
169 let atom = model.atoms.get(atom_k).ok_or_else(|| {
170 format!(
171 "set_coordinate: atom index {atom_k} out of range (term has {} atoms)",
172 model.k_atoms()
173 )
174 })?;
175 if x.len() != atom.output_dim() {
176 return Err(format!(
177 "set_coordinate: input row has length {} but atom {atom_k} output_dim is {}",
178 x.len(),
179 atom.output_dim()
180 ));
181 }
182 let (t_from, cert) = atlas.certified_encode_row(atom, atom_k, x, amplitude)?;
183 let steer = steer_delta_with_amplitude(
184 model,
185 metric,
186 atom_k,
187 t_from.as_slice().unwrap_or(&[]),
188 t_to,
189 amplitude,
190 )?;
191 let mut edited = x.to_owned();
192 if edited.len() != steer.delta.len() {
193 return Err(format!(
194 "set_coordinate: steering delta length {} does not match row length {}",
195 steer.delta.len(),
196 edited.len()
197 ));
198 }
199 for i in 0..edited.len() {
200 edited[i] += steer.delta[i];
201 }
202 Ok(CoordinateSetResult {
203 edited,
204 t_from_certified: t_from,
205 encode_certificate: cert,
206 steer,
207 })
208}
209
210/// Result of a coordinate interchange: donor position read from `x_source`, then
211/// written into `x_target` while preserving the target residual and intensity.
212#[derive(Clone, Debug)]
213pub struct InterchangeResult {
214 /// Target row after the donor coordinate has been delta-written into it.
215 pub edited_target: Array1<f64>,
216 /// Donor/source coordinate that was transplanted.
217 pub donor_t: Array1<f64>,
218 /// Target coordinate before the transplant.
219 pub target_t_before: Array1<f64>,
220 /// Target behavior coordinate after the transplant, re-read from the edit.
221 pub target_t_after: Array1<f64>,
222 /// Steering dose in nats, when a behavioral metric is available.
223 pub predicted_nats: Option<f64>,
224 /// Norm of the steering delta outside the local atom tangent frame.
225 pub off_manifold_norm: f64,
226 /// Reported steering validity radius.
227 pub validity_radius: Option<f64>,
228 /// Calibrated log e-value for counterfactual consistency: larger means the
229 /// post-edit target coordinate landed closer to the donor coordinate.
230 pub counterfactual_consistency_log_e: f64,
231 /// Underlying coordinate-write plan.
232 pub set_result: CoordinateSetResult,
233}
234
235/// Interchange atom `atom_k`'s chart coordinate from `x_source` into `x_target`.
236///
237/// The source coordinate is certified with `source_amplitude`; the target write
238/// is performed with `target_amplitude`, so swapping a position coordinate cannot
239/// silently smuggle donor intensity into the target. The returned consistency
240/// e-value is computed by re-encoding the edited target and calibrating the
241/// coordinate landing error into the existing structure-evidence e-currency.
242pub fn interchange(
243 model: &SaeManifoldTerm,
244 metric: &RowMetric,
245 atlas: &EncodeAtlas,
246 x_target: ArrayView1<'_, f64>,
247 target_amplitude: f64,
248 x_source: ArrayView1<'_, f64>,
249 source_amplitude: f64,
250 atom_k: usize,
251) -> Result<InterchangeResult, String> {
252 let atom = model.atoms.get(atom_k).ok_or_else(|| {
253 format!(
254 "interchange: atom index {atom_k} out of range (term has {} atoms)",
255 model.k_atoms()
256 )
257 })?;
258 let (donor_t, _donor_cert) =
259 atlas.certified_encode_row(atom, atom_k, x_source, source_amplitude)?;
260 let set = set_coordinate(
261 model,
262 metric,
263 atlas,
264 x_target,
265 atom_k,
266 target_amplitude,
267 donor_t.as_slice().unwrap_or(&[]),
268 )?;
269 let (target_t_after, _after_cert) =
270 atlas.certified_encode_row(atom, atom_k, set.edited.view(), target_amplitude)?;
271 let landing_error = l2_distance(donor_t.view(), target_t_after.view())?;
272 let scale = set
273 .steer
274 .validity_radius
275 .unwrap_or_else(|| {
276 l2_distance(set.t_from_certified.view(), donor_t.view())
277 .unwrap_or(1.0)
278 .max(1e-12)
279 })
280 .max(1e-12);
281 // Convert closeness into a superuniform-shaped p-value and then into the
282 // repository's standard e-value currency. Exact hits approach machine-small
283 // p-values; errors at/above the validity radius produce e-values near or
284 // below one, so shuffled-chart negative controls do not accumulate evidence.
285 let z = (scale / landing_error.max(1e-12)).min(1.0e6);
286 let p_value = (-0.5 * z * z).exp().clamp(f64::MIN_POSITIVE, 1.0);
287 let log_e = log_e_from_p_calibrator(p_value)?;
288 Ok(InterchangeResult {
289 edited_target: set.edited.clone(),
290 donor_t,
291 target_t_before: set.t_from_certified.clone(),
292 target_t_after,
293 predicted_nats: set.steer.predicted_nats,
294 off_manifold_norm: set.steer.off_manifold_norm,
295 validity_radius: set.steer.validity_radius,
296 counterfactual_consistency_log_e: log_e,
297 set_result: set,
298 })
299}
300
301fn l2_distance(a: ArrayView1<'_, f64>, b: ArrayView1<'_, f64>) -> Result<f64, String> {
302 if a.len() != b.len() {
303 return Err(format!(
304 "coordinate distance length mismatch: {} vs {}",
305 a.len(),
306 b.len()
307 ));
308 }
309 let mut ss = 0.0;
310 for i in 0..a.len() {
311 let r = a[i] - b[i];
312 ss += r * r;
313 }
314 Ok(ss.sqrt())
315}
316
317/// Build a [`SteerPlan`] for driving atom `atom_k` from `t_from` to `t_to`.
318///
319/// `model` is the fitted term (read only); `metric` is the per-row output-Fisher
320/// inner product the dose is measured through (typically `model.row_metric()`'s
321/// own metric, or any metric whose row/output dims match the term). `t_from` and
322/// `t_to` are latent coordinates of length `atom.latent_dim`.
323///
324/// Errors when the atom index is out of range, the coordinate lengths do not
325/// match the atom's latent dimension, the atom has no installed
326/// [`crate::manifold::SaeBasisEvaluator`] (arbitrary-`t` evaluation
327/// requires one), or the metric dimensions do not match the term. Under a
328/// Euclidean (no-behavior) metric the geometry is still produced but
329/// `predicted_nats` / `validity_radius` degrade to `None`.
330pub fn steer_delta(
331 model: &SaeManifoldTerm,
332 metric: &RowMetric,
333 atom_k: usize,
334 t_from: &[f64],
335 t_to: &[f64],
336) -> Result<SteerPlan, String> {
337 steer_delta_impl(model, metric, atom_k, t_from, t_to, None)
338}
339
340fn steer_delta_with_amplitude(
341 model: &SaeManifoldTerm,
342 metric: &RowMetric,
343 atom_k: usize,
344 t_from: &[f64],
345 t_to: &[f64],
346 amplitude: f64,
347) -> Result<SteerPlan, String> {
348 if !(amplitude.is_finite() && amplitude > 0.0) {
349 return Err(format!(
350 "steer_delta_with_amplitude: amplitude must be finite and positive, got {amplitude}"
351 ));
352 }
353 steer_delta_impl(model, metric, atom_k, t_from, t_to, Some(amplitude))
354}
355
356fn steer_delta_impl(
357 model: &SaeManifoldTerm,
358 metric: &RowMetric,
359 atom_k: usize,
360 t_from: &[f64],
361 t_to: &[f64],
362 amplitude_override: Option<f64>,
363) -> Result<SteerPlan, String> {
364 let k = model.k_atoms();
365 if atom_k >= k {
366 return Err(format!(
367 "steer_delta: atom index {atom_k} out of range (term has {k} atoms)"
368 ));
369 }
370 let atom = &model.atoms[atom_k];
371 let d = atom.latent_dim;
372 let p = atom.output_dim();
373 if t_from.len() != d || t_to.len() != d {
374 return Err(format!(
375 "steer_delta: t_from/t_to must have length latent_dim={d}; got {} and {}",
376 t_from.len(),
377 t_to.len()
378 ));
379 }
380 let evaluator = atom.basis_evaluator.as_ref().ok_or_else(|| {
381 format!(
382 "steer_delta: atom {atom_k} ('{}') has no installed basis evaluator; \
383 arbitrary-t decoder evaluation requires one",
384 atom.name
385 )
386 })?;
387
388 // --- amplitude & the row the dose is measured through -------------------
389 // The amplitude is the atom's mean active assignment mass (how loudly the
390 // atom is expressed), mirroring the presence weighting in `atom_lens`. The
391 // measured row is the atom's single most-active row: the per-row metric
392 // there is the most representative behavioral inner product for "this atom
393 // is on". Both fall back gracefully when the atom is active nowhere.
394 let assignments = model.assignment.assignments();
395 let n = model.n_obs();
396 let mut mass_sum = 0.0_f64;
397 let mut active_count = 0.0_f64;
398 let mut best_row = 0usize;
399 let mut best_mass = f64::NEG_INFINITY;
400 for row in 0..n {
401 let mass = assignments[[row, atom_k]];
402 if mass > best_mass {
403 best_mass = mass;
404 best_row = row;
405 }
406 if mass > ACTIVE_MASS_FLOOR {
407 mass_sum += mass;
408 active_count += 1.0;
409 }
410 }
411 let amplitude = amplitude_override.unwrap_or_else(|| {
412 if active_count > 0.0 {
413 mass_sum / active_count
414 } else {
415 1.0
416 }
417 });
418
419 // --- the on-manifold activation-space delta -----------------------------
420 let g_from = decode_at(evaluator.as_ref(), &atom.decoder_coefficients, t_from, p)?;
421 let g_to = decode_at(evaluator.as_ref(), &atom.decoder_coefficients, t_to, p)?;
422 let mut delta = Array1::<f64>::zeros(p);
423 for i in 0..p {
424 delta[i] = amplitude * (g_to[i] - g_from[i]);
425 }
426
427 // Whether the metric can/does match this term and carries behavior.
428 let provenance = metric.provenance();
429 let behavior_available =
430 metric_carries_behavior(provenance) && metric.n_rows() == n && metric.p_out() == p;
431
432 // --- off-manifold guard -------------------------------------------------
433 // Project δ onto the span of the local decoder tangents ∂g_k/∂t and report
434 // the residual norm. The tangents are evaluated at the move's MIDPOINT, not
435 // at t_from: the chord of a curve is symmetric about its midpoint, so its
436 // component transverse to the midpoint tangent is the true second-order
437 // sagitta (`O(‖Δt‖²)`), whereas the endpoint tangent differs from the chord
438 // direction already at first order. Measuring against the midpoint frame is
439 // therefore the honest "did the move stay on the surface" self-check: it is
440 // `≈ 0` for an on-manifold move and grows only with genuine arc curvature.
441 let mut t_mid = vec![0.0_f64; d];
442 for a in 0..d {
443 t_mid[a] = 0.5 * (t_from[a] + t_to[a]);
444 }
445 let tangents =
446 decode_tangents_at(evaluator.as_ref(), &atom.decoder_coefficients, &t_mid, p, d)?;
447 let off_manifold_norm = off_manifold_residual_norm(&tangents, delta.view());
448
449 // --- dosimetry: path-integrated Fisher dose -----------------------------
450 let (predicted_nats, validity_radius) = if !behavior_available {
451 (None, None)
452 } else {
453 let ctx = SteerContext {
454 evaluator: evaluator.as_ref(),
455 decoder: &atom.decoder_coefficients,
456 metric,
457 row: best_row,
458 p,
459 d,
460 amplitude,
461 };
462 let dose = path_integrated_dose(&ctx, t_from, t_to)?;
463 let radius = validity_radius(&ctx, t_from, t_to)?;
464 (Some(dose), Some(radius))
465 };
466
467 Ok(SteerPlan {
468 atom: atom_k,
469 atom_name: atom.name.clone(),
470 t_from: t_from.to_vec(),
471 t_to: t_to.to_vec(),
472 amplitude,
473 measured_row: best_row,
474 delta,
475 predicted_nats,
476 validity_radius,
477 off_manifold_norm,
478 metric_provenance: provenance,
479 })
480}
481
482/// The model's predicted output-mean response to an applied activation push
483/// `δ`, under the LOCAL-LINEAR reading of its fitted surface: the projection
484/// of `δ` onto the span of atom `atom_k`'s decoder tangents `∂g_k/∂t` at the
485/// operating point `t_at`. A dictionary "predicts" exactly the component of a
486/// push it can carry along its learned surface; the transverse component is
487/// off-manifold and predicted to die (this is the same local model the
488/// off-manifold guard and the dosimetry chord trust, used in the same radius).
489///
490/// This is `μ(δ)` for the design loop of
491/// [`gam_terms::inference::structure_evidence`]: two structural hypotheses about
492/// the same activations (e.g. "one curved atom" vs "two flat atoms") are two
493/// fitted terms whose tangent spans differ, so they predict DIFFERENT
494/// responses to the same probe — and that disagreement, in the output-Fisher
495/// metric, is what `select_probe_by_expected_evidence` maximizes.
496pub fn predicted_response(
497 model: &SaeManifoldTerm,
498 atom_k: usize,
499 t_at: &[f64],
500 delta: ArrayView1<'_, f64>,
501) -> Result<Array1<f64>, String> {
502 let k = model.k_atoms();
503 if atom_k >= k {
504 return Err(format!(
505 "predicted_response: atom index {atom_k} out of range (term has {k} atoms)"
506 ));
507 }
508 let atom = &model.atoms[atom_k];
509 let d = atom.latent_dim;
510 let p = atom.output_dim();
511 if t_at.len() != d {
512 return Err(format!(
513 "predicted_response: t_at must have length latent_dim={d}; got {}",
514 t_at.len()
515 ));
516 }
517 if delta.len() != p {
518 return Err(format!(
519 "predicted_response: delta must have length output_dim={p}; got {}",
520 delta.len()
521 ));
522 }
523 let evaluator = atom.basis_evaluator.as_ref().ok_or_else(|| {
524 format!(
525 "predicted_response: atom {atom_k} ('{}') has no installed basis evaluator",
526 atom.name
527 )
528 })?;
529 let tangents = decode_tangents_at(evaluator.as_ref(), &atom.decoder_coefficients, t_at, p, d)?;
530 Ok(project_onto_tangent_span(&tangents, delta))
531}
532
533/// Does this provenance carry behavioral (output-Fisher) information? Euclidean
534/// is the isotropic activation-only path and carries none; the factored
535/// provenances do. (Mirrors `atom_lens::metric_carries_behavior`.)
536fn metric_carries_behavior(p: MetricProvenance) -> bool {
537 match p {
538 MetricProvenance::Euclidean => false,
539 MetricProvenance::OutputFisher { .. }
540 | MetricProvenance::OutputFisherDownstream { .. }
541 | MetricProvenance::BehavioralFisher { .. }
542 | MetricProvenance::WhitenedStructured { .. } => true,
543 }
544}
545
546/// Evaluate the decoder output `g_k(t) = Φ_k(t) B_k ∈ ℝ^p` at an arbitrary
547/// latent coordinate `t` (length `d`) via the atom's installed evaluator.
548fn decode_at(
549 evaluator: &dyn crate::manifold::SaeBasisEvaluator,
550 decoder: &Array2<f64>,
551 t: &[f64],
552 p: usize,
553) -> Result<Array1<f64>, String> {
554 let d = t.len();
555 let coords = Array2::from_shape_vec((1, d), t.to_vec())
556 .map_err(|e| format!("steer_delta::decode_at: coord shape: {e}"))?;
557 let (phi, _jet) = evaluator.evaluate(coords.view())?;
558 let m = decoder.nrows();
559 if phi.ncols() != m {
560 return Err(format!(
561 "steer_delta::decode_at: evaluator returned {} basis cols but decoder has {m} rows",
562 phi.ncols()
563 ));
564 }
565 let mut g = Array1::<f64>::zeros(p);
566 for basis_col in 0..m {
567 let phi_v = phi[[0, basis_col]];
568 if phi_v == 0.0 {
569 continue;
570 }
571 for out_col in 0..p {
572 g[out_col] += phi_v * decoder[[basis_col, out_col]];
573 }
574 }
575 Ok(g)
576}
577
578/// Evaluate the decoder tangents `∂g_k/∂t_a = Φ_k'(t) B_k ∈ ℝ^p`, one per latent
579/// axis `a ∈ 0..d`, at an arbitrary latent coordinate `t`. Returned as a
580/// `(p × d)` matrix whose column `a` is the tangent along axis `a`.
581fn decode_tangents_at(
582 evaluator: &dyn crate::manifold::SaeBasisEvaluator,
583 decoder: &Array2<f64>,
584 t: &[f64],
585 p: usize,
586 d: usize,
587) -> Result<Array2<f64>, String> {
588 let coords = Array2::from_shape_vec((1, d), t.to_vec())
589 .map_err(|e| format!("steer_delta::decode_tangents_at: coord shape: {e}"))?;
590 let (_phi, jet) = evaluator.evaluate(coords.view())?;
591 let m = decoder.nrows();
592 if jet.dim() != (1, m, d) {
593 return Err(format!(
594 "steer_delta::decode_tangents_at: evaluator jet {:?} != (1, {m}, {d})",
595 jet.dim()
596 ));
597 }
598 let mut tang = Array2::<f64>::zeros((p, d));
599 for axis in 0..d {
600 for basis_col in 0..m {
601 let dphi = jet[[0, basis_col, axis]];
602 if dphi == 0.0 {
603 continue;
604 }
605 for out_col in 0..p {
606 tang[[out_col, axis]] += dphi * decoder[[basis_col, out_col]];
607 }
608 }
609 }
610 Ok(tang)
611}
612
613/// Least-squares projection of `δ` onto the span of the local tangents
614/// (columns of `tangents`, shape `p × d`): `δ̂ = T (TᵀT)⁻¹ Tᵀ δ` via a small
615/// `d × d` Gram solve (with a tiny diagonal jitter to absorb a rank-deficient
616/// tangent frame; the jitter only shrinks the projection, never inflates it).
617fn project_onto_tangent_span(tangents: &Array2<f64>, delta: ArrayView1<'_, f64>) -> Array1<f64> {
618 let p = tangents.nrows();
619 let d = tangents.ncols();
620 if d == 0 {
621 return Array1::<f64>::zeros(p);
622 }
623 // Gram = TᵀT (d × d) and rhs = Tᵀδ (d).
624 let mut gram = Array2::<f64>::zeros((d, d));
625 let mut rhs = Array1::<f64>::zeros(d);
626 for a in 0..d {
627 let mut r = 0.0_f64;
628 for i in 0..p {
629 r += tangents[[i, a]] * delta[i];
630 }
631 rhs[a] = r;
632 for b in a..d {
633 let mut acc = 0.0_f64;
634 for i in 0..p {
635 acc += tangents[[i, a]] * tangents[[i, b]];
636 }
637 gram[[a, b]] = acc;
638 gram[[b, a]] = acc;
639 }
640 }
641 let trace: f64 = (0..d).map(|a| gram[[a, a]]).sum();
642 let jitter = if trace > 0.0 { 1e-12 * trace } else { 1e-12 };
643 for a in 0..d {
644 gram[[a, a]] += jitter;
645 }
646 let coeffs = solve_spd_small(&gram, &rhs);
647 let mut proj = Array1::<f64>::zeros(p);
648 for i in 0..p {
649 for a in 0..d {
650 proj[i] += tangents[[i, a]] * coeffs[a];
651 }
652 }
653 proj
654}
655
656/// Norm of `δ`'s component orthogonal to the span of the local tangents:
657/// `‖δ − δ̂‖` with `δ̂` the [`project_onto_tangent_span`] projection.
658fn off_manifold_residual_norm(tangents: &Array2<f64>, delta: ArrayView1<'_, f64>) -> f64 {
659 let proj = project_onto_tangent_span(tangents, delta);
660 let mut res_sq = 0.0_f64;
661 for i in 0..delta.len() {
662 let r = delta[i] - proj[i];
663 res_sq += r * r;
664 }
665 res_sq.max(0.0).sqrt()
666}
667
668/// Tiny symmetric-positive-definite solve via Cholesky for the `d × d` tangent
669/// Gram (`d` is the atom's latent dim, typically 1–3). Falls back to the bare rhs
670/// if the factorization fails (a fully degenerate frame), which only inflates the
671/// reported off-manifold residual — never deflates it.
672fn solve_spd_small(gram: &Array2<f64>, rhs: &Array1<f64>) -> Array1<f64> {
673 let d = gram.nrows();
674 // Cholesky L LᵀT = gram.
675 let mut l = Array2::<f64>::zeros((d, d));
676 for i in 0..d {
677 for j in 0..=i {
678 let mut sum = gram[[i, j]];
679 for k in 0..j {
680 sum -= l[[i, k]] * l[[j, k]];
681 }
682 if i == j {
683 if sum <= 0.0 {
684 return Array1::<f64>::zeros(d);
685 }
686 l[[i, j]] = sum.sqrt();
687 } else {
688 l[[i, j]] = sum / l[[j, j]];
689 }
690 }
691 }
692 // Forward solve L y = rhs.
693 let mut y = Array1::<f64>::zeros(d);
694 for i in 0..d {
695 let mut sum = rhs[i];
696 for k in 0..i {
697 sum -= l[[i, k]] * y[k];
698 }
699 y[i] = sum / l[[i, i]];
700 }
701 // Back solve Lᵀ x = y.
702 let mut x = Array1::<f64>::zeros(d);
703 for i in (0..d).rev() {
704 let mut sum = y[i];
705 for k in (i + 1)..d {
706 sum -= l[[k, i]] * x[k];
707 }
708 x[i] = sum / l[[i, i]];
709 }
710 x
711}
712
713/// The fixed geometry of one steering query, bundled so the dose integrator and
714/// its helpers take a single context rather than a long argument list.
715struct SteerContext<'a> {
716 evaluator: &'a dyn crate::manifold::SaeBasisEvaluator,
717 decoder: &'a Array2<f64>,
718 metric: &'a RowMetric,
719 /// The row whose per-row metric the dose is measured through.
720 row: usize,
721 /// Output dimension `p`.
722 p: usize,
723 /// Latent dimension `d`.
724 d: usize,
725 /// Amplitude `a` the move is scaled by.
726 amplitude: f64,
727}
728
729/// Path-integrated Fisher dose
730/// `½ a² ∫ g_k'(t)ᵀ M g_k'(t) dt` along the straight latent segment
731/// `t(τ) = t_from + τ (t_to − t_from)`, `τ ∈ [0, 1]`, by the midpoint rule over
732/// [`STEER_PATH_STEPS`] sub-steps.
733///
734/// The local quadratic `g'(t)ᵀ M g'(t)` is the [`RowMetric::pullback`] of the
735/// per-axis decoder tangents contracted with the latent velocity `Δt`, so this
736/// uses only the criterion-facing pullback (no loss / no solver floor).
737fn path_integrated_dose(
738 ctx: &SteerContext<'_>,
739 t_from: &[f64],
740 t_to: &[f64],
741) -> Result<f64, String> {
742 let d = ctx.d;
743 let p = ctx.p;
744 let steps = STEER_PATH_STEPS;
745 let dtau = 1.0 / steps as f64;
746 // Latent velocity Δt (constant along the straight segment).
747 let mut dt = vec![0.0_f64; d];
748 for a in 0..d {
749 dt[a] = t_to[a] - t_from[a];
750 }
751 let mut acc = 0.0_f64;
752 let amp2 = ctx.amplitude * ctx.amplitude;
753 for s in 0..steps {
754 // Midpoint of sub-step s in τ, mapped to a latent coordinate.
755 let tau_mid = (s as f64 + 0.5) * dtau;
756 let mut t_mid = vec![0.0_f64; d];
757 for a in 0..d {
758 t_mid[a] = t_from[a] + tau_mid * dt[a];
759 }
760 // Decoder tangents at the midpoint: ∂g/∂t_a, columns of a (p × d) matrix.
761 let tang = decode_tangents_at(ctx.evaluator, ctx.decoder, &t_mid, p, d)?;
762 // The pulled-back metric at this point is g_{ab} = (∂g/∂t)ᵀ M (∂g/∂t),
763 // the d × d local inner product of latent motion *in output-Fisher
764 // units*. We form it through the criterion-facing `RowMetric::pullback`
765 // (which never materializes the p × p M and never sees the solver δ),
766 // then contract the latent velocity Δt twice: the squared output-Fisher
767 // speed along the path is Δtᵀ g Δt. The decoder Jacobian is passed flat
768 // row-major (J[i, a] = j_row[i * d + a]) as `pullback` expects.
769 let mut j_row = vec![0.0_f64; p * d];
770 for i in 0..p {
771 for a in 0..d {
772 j_row[i * d + a] = tang[[i, a]];
773 }
774 }
775 let g_ab = ctx.metric.pullback(ctx.row, &j_row, d);
776 let mut speed_sq = 0.0_f64;
777 for a in 0..d {
778 for b in 0..d {
779 speed_sq += dt[a] * g_ab[[a, b]] * dt[b];
780 }
781 }
782 acc += 0.5 * amp2 * speed_sq * dtau;
783 }
784 Ok(acc)
785}
786
787/// The validity radius: the latent step length (Euclidean distance from
788/// `t_from`) at which **local linearization stops being trusted**.
789///
790/// Linearizing the steering move means predicting the output effect of a prefix
791/// step `τ·Δt` from the initial tangent alone: the first-order output move is
792/// `δ_lin(τ) = a · (∂g/∂t|_{t_from}) · (τ Δt)`, whose output-Fisher KL is the
793/// quadratic form `½ ‖δ_lin(τ)‖²_M = τ² · ½ a² ‖∂g/∂t·Δt‖²_M`. The **true**
794/// effect of that prefix is the chord quadratic form of the *actual* curved
795/// output move `½ a² ‖g(t_from + τΔt) − g(t_from)‖²_M`.
796///
797/// The radius is the chord length `τ* · ‖Δt‖` at the first prefix `τ*` where the
798/// true chord KL diverges from the linear prediction by more than
799/// [`VALIDITY_DIVERGENCE_FRACTION`] (relative to the linear prediction). This is
800/// pure surface curvature: on a flat decoder the two agree for every `τ` and the
801/// radius is the whole move. If the metric kills the tangent (no linear effect to
802/// validate), the move is trusted to its full length.
803fn validity_radius(ctx: &SteerContext<'_>, t_from: &[f64], t_to: &[f64]) -> Result<f64, String> {
804 let d = ctx.d;
805 let p = ctx.p;
806 let full_len: f64 = t_from
807 .iter()
808 .zip(t_to.iter())
809 .map(|(&a, &b)| (b - a) * (b - a))
810 .sum::<f64>()
811 .sqrt();
812 if full_len == 0.0 {
813 return Ok(0.0);
814 }
815 let mut dt = vec![0.0_f64; d];
816 for a in 0..d {
817 dt[a] = t_to[a] - t_from[a];
818 }
819 let amp = ctx.amplitude;
820
821 // Initial-tangent linear output move per unit τ: v0 = (∂g/∂t|_{t_from}) Δt.
822 let tang0 = decode_tangents_at(ctx.evaluator, ctx.decoder, t_from, p, d)?;
823 let mut v0 = Array1::<f64>::zeros(p);
824 for i in 0..p {
825 let mut acc = 0.0_f64;
826 for a in 0..d {
827 acc += tang0[[i, a]] * dt[a];
828 }
829 v0[i] = acc;
830 }
831 // ½ a² ‖v0‖²_M — the per-τ² linear KL coefficient.
832 let lin_coeff = 0.5 * amp * amp * ctx.metric.fisher_mass(ctx.row, v0.view());
833 // No linear effect to validate against ⇒ trust the full move.
834 if !(lin_coeff > 0.0) {
835 return Ok(full_len);
836 }
837
838 let g_from = decode_at(ctx.evaluator, ctx.decoder, t_from, p)?;
839 let steps = STEER_PATH_STEPS;
840 for s in 0..steps {
841 let tau = (s as f64 + 1.0) / steps as f64;
842 let mut t_mid = vec![0.0_f64; d];
843 for a in 0..d {
844 t_mid[a] = t_from[a] + tau * dt[a];
845 }
846 let g_tau = decode_at(ctx.evaluator, ctx.decoder, &t_mid, p)?;
847 let mut chord = Array1::<f64>::zeros(p);
848 for i in 0..p {
849 chord[i] = amp * (g_tau[i] - g_from[i]);
850 }
851 // True chord KL of the prefix, and the linear prediction τ²·lin_coeff.
852 let chord_kl = 0.5 * ctx.metric.fisher_mass(ctx.row, chord.view());
853 let lin_kl = tau * tau * lin_coeff;
854 let rel = (chord_kl - lin_kl).abs() / lin_kl;
855 if rel > VALIDITY_DIVERGENCE_FRACTION {
856 return Ok(tau * full_len);
857 }
858 }
859 Ok(full_len)
860}