gam_problem/row_metric.rs
1//! `RowMetric` — the single provenance-carrying per-row inner product shared by
2//! the SAE-manifold **likelihood** (residual whitening) and the **gauge**
3//! (isometry pullback weight).
4//!
5//! # Why this exists
6//!
7//! The SAE-manifold machine historically carried *two* independent inner
8//! products:
9//!
10//! * the **likelihood** measured reconstruction residuals isotropically — a
11//! single scalar dispersion `φ̂ = RSS / residual-dof`, the data-fit loop
12//! summing the bare `½ rᵀr`; there was no per-row metric at all; and
13//! * the **gauge** carried its own per-row metric in
14//! `IsometryPenalty.weight: WeightField` — a low-rank `W_n = U_n U_nᵀ`
15//! pullback `g_n = J_nᵀ W_n J_n`, settable independently of anything the
16//! likelihood saw.
17//!
18//! Nothing structurally forced "the metric the likelihood whitens by" to equal
19//! "the metric the gauge pulls back through". That is exactly the
20//! objective↔gradient-desync bug class wearing geometry clothing: a
21//! likelihood-metric ≠ gauge-metric state was *representable*.
22//!
23//! `RowMetric` collapses the two into one object. The likelihood whitens
24//! through it; the gauge `WeightField` is *constructed from* it. A
25//! divergent-metric state is therefore unrepresentable — there is only one
26//! per-row factor stack `U_n`, with one [`MetricProvenance`] tag.
27//!
28//! # Magic-by-default selector
29//!
30//! There is no flag. The provenance is chosen by whether per-row Fisher factors
31//! exist:
32//!
33//! * no factors supplied ⇒ [`MetricProvenance::Euclidean`]; `W_n = I_p`;
34//! whitening is the identity, so `φ̂` and the data-fit loop are
35//! **bit-for-bit** the prior isotropic path; and
36//! * per-row Fisher factors supplied ⇒ [`MetricProvenance::OutputFisher`]; the
37//! residual is whitened by `U_nᵀ` and the gauge pulls back through the same
38//! `U_n`.
39//!
40//! # Validation
41//!
42//! Every metric block is constructed **through**
43//! [`crate::normalize_fisher_rao_blocks`], which
44//! broadcasts and eigenvalue-validates PSD-ness. `RowMetric` does not
45//! reimplement that validation; it materializes `W_n = U_n U_nᵀ` (which is PSD
46//! by construction) and runs it through the shared normalizer as the
47//! single point of truth for "is this a valid precision metric".
48//!
49//! Any rank floor used to make a block invertible for an internal solve is
50//! **solver-only** (mirroring `RidgePolicy::solver_only`, #747): it never enters
51//! the residual the objective sums, so `δ` cannot bias the criterion.
52//!
53//! # Rung 1 — the behavioral metric *in the reconstruction loss* (nats currency)
54//!
55//! [`MetricProvenance::OutputFisher`] installs the output-Fisher inner product
56//! as a **gauge** metric only: it whitens *nothing* (`whitens_likelihood()` is
57//! `false`), by deliberate #980 contract, so reconstruction stays the isotropic
58//! `½‖r‖²`. That answers "what coordinate is canonical", not "what does a
59//! reconstruction error *cost*".
60//!
61//! [`MetricProvenance::BehavioralFisher`] is the opposite deliberate choice:
62//! the **same** low-rank output-Fisher factors, but installed as the
63//! reconstruction *likelihood weight*. Plain MSE prices a reconstruction error
64//! `e = x − x̂` by its Euclidean size; the model, however, reads the activation
65//! only through the rest of the network, so the behavioral cost of `e` is the
66//! KL between the clean and corrupted next-token distributions,
67//! `KL ≈ ½ eᵀ G(x) e` with `G = JᵀFJ` the network-Jacobian pullback of the
68//! output Fisher `F` (units: **nats**). Minimizing `(x−x̂)ᵀ G (x−x̂)` instead of
69//! `‖x−x̂‖²` is **generalized least squares**: for a *fixed* per-row `G` it is
70//! still a linear Gaussian model in the coefficients, so the entire
71//! REML/evidence/EDF/certificate stack survives verbatim — this is why the
72//! metric rides the identical `whitens_likelihood()` plumbing the
73//! [`MetricProvenance::WhitenedStructured`] noise model uses, and why the G=I
74//! limit reproduces the plain-MSE fit bit-for-bit (see the module tests).
75//!
76//! This is the principled form of Braun's end-to-end **KL + MSE** objective.
77//! Anchoring to the activation keeps it *reconstruction* (it does not collapse
78//! to "match the logits by any means" — the decoder still has to reproduce `x`),
79//! while pricing the residual in nats through `G`. The payoff is automatic
80//! selection for *mattering*: `G`'s null directions — activation structure the
81//! rest of the network cannot read — are penalized nothing, because
82//! `eᵀ G e = 0` there. MSE in a behaviorally-inert direction goes free, which is
83//! the correct behavior, not a bug: nothing downstream changes, so nothing
84//! should be paid.
85//!
86//! **The d×d `G` is never materialized.** `G` is sketched by `s` random probes,
87//! `vᵢ = Jᵀ F^{1/2} uᵢ` (`uᵢ` iid, `s ≈ 4…16`), computed by `s` backward passes
88//! per token at *harvest* time (the model-interaction boundary) and stored as
89//! the columns of the per-row factor `U_n = [v₁ … v_s] ∈ ℝ^{p×s}`. Then
90//! `G ≈ Σᵢ vᵢ vᵢᵀ = U_n U_nᵀ` and the criterion-facing
91//! `eᵀ G e ≈ Σᵢ (vᵢᵀ e)² = ‖U_nᵀ e‖²` is exactly what
92//! [`RowMetric::quad_form`] / [`RowMetric::whiten_residual_row`] already
93//! compute — zero train-time model cost, `O(p·s)` per row. See
94//! [`RowMetric::behavioral_fisher`] and the probe-packing helper
95//! [`pack_probe_factors`].
96
97use ndarray::{Array2, Array3, ArrayView1};
98use std::sync::Arc;
99
100use crate::normalize_fisher_rao_blocks;
101
102/// Per-observation behavioral-metric field `W_n ∈ ℝ^{p × p}`, stored in
103/// **low-rank factored form** `W_n = U_n U_n^T` with `U_n ∈ ℝ^{p × r_n}`.
104///
105/// The canonical coordinate is the one where one unit of motion in `t` is one
106/// unit of behavioral change in the output space, so the `W_n` weighting is
107/// load-bearing: the pullback metric is `g_n = J_n^T W_n J_n`. Storing as
108/// `U_n` lets every contraction in this module run in
109/// `(J^T U_n)(U_n^T J)` order, which is `O(p · r · d + r · d²)` per row — we
110/// **never** materialize the `p × p` `W_n`, which is essential when `p`
111/// (number of observation channels) is large but rank is small (e.g. one or
112/// two behavioral dimensions per latent observation).
113///
114/// `Identity` is the gauge-fix default and corresponds to `U_n = I_p` so the
115/// pullback reduces to the standard `J_n^T J_n`. `Factored` stores the
116/// per-row `U_n` blocks contiguously: every row's factor is `p × rank`, and
117/// rows may share the same rank (uniform-rank case) or vary if the field is
118/// data-driven. For the uniform-rank case the storage is
119/// `(n_obs, p * rank)` row-major.
120#[derive(Clone)]
121pub enum WeightField {
122 /// `W_n = I_p` for every `n`. Reduces to the bare pullback `J^T J`.
123 Identity,
124 /// Per-row low-rank factor `U_n ∈ ℝ^{p × rank}`. Storage layout: a
125 /// `(n_obs, p * rank)` row-major matrix where row `n` packs `U_n` in
126 /// column-major-within-row order `U_n[i, k] = u[n, i * rank + k]`.
127 Factored {
128 u: Arc<Array2<f64>>,
129 rank: usize,
130 p_out: usize,
131 },
132}
133
134impl std::fmt::Debug for WeightField {
135 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136 match self {
137 WeightField::Identity => f.write_str("Identity"),
138 WeightField::Factored { u, rank, p_out } => f
139 .debug_struct("Factored")
140 .field("shape", &format_args!("{}×{}", u.nrows(), u.ncols()))
141 .field("rank", rank)
142 .field("p_out", p_out)
143 .finish(),
144 }
145 }
146}
147
148impl WeightField {
149 /// Apply `U_n^T J_n` for a specific row, given both the row's `J_n` flat
150 /// `(p * d)` slice and the row's `U_n` flat `(p * rank)` slice. Returns
151 /// the `(rank × d)` matrix and its row count.
152 pub fn project_jac_row_with_u(
153 u_row: &[f64],
154 jac_row: &[f64],
155 p: usize,
156 rank: usize,
157 d: usize,
158 ) -> Array2<f64> {
159 // M[k, a] = Σ_i U[i, k] · J[i, a].
160 let mut m = Array2::<f64>::zeros((rank, d));
161 for k in 0..rank {
162 for a in 0..d {
163 let mut s = 0.0;
164 for i in 0..p {
165 s += u_row[i * rank + k] * jac_row[i * d + a];
166 }
167 m[[k, a]] = s;
168 }
169 }
170 m
171 }
172}
173
174/// Where the per-row metric came from — the provenance that makes
175/// "likelihood-metric ≠ gauge-metric" diagnosable instead of silent.
176///
177/// Object 4 (the gauge object) reads this to certify which inner product the
178/// fit actually used; #974 fills [`MetricProvenance::WhitenedStructured`] with a
179/// factor-analytic residual-covariance whitening.
180#[derive(Clone, Copy, PartialEq, Eq, Debug)]
181pub enum MetricProvenance {
182 /// `M_n = I_p` for every row. The likelihood is isotropic and the gauge
183 /// pullback reduces to the bare `J_nᵀ J_n`. This is the default and is
184 /// bit-for-bit the historical isotropic-`φ̂` path.
185 Euclidean,
186 /// `M_n = U_n U_nᵀ (+ solver-only δI)` from supplied per-row output-Fisher
187 /// factors `U_n ∈ ℝ^{p × rank}`. The canonical "one unit of latent motion ↦
188 /// one unit of behavioral change" metric: residuals are whitened in the
189 /// output-Fisher inner product and the gauge pulls back through the same
190 /// factors. The `rank` is carried in the provenance so a consumer (Object 4)
191 /// can certify the factor rank that produced the inner product.
192 OutputFisher { rank: usize },
193 /// `M_n = U_n U_nᵀ` from per-row output-Fisher factors that aggregate the
194 /// **downstream** influence of position `n` over future positions through
195 /// the KV path, rather than the same-position logits of
196 /// [`MetricProvenance::OutputFisher`] (#980, mechanism 2).
197 ///
198 /// The same-position pullback `∂logits_t/∂x_t` can be ≈ 0 for a feature
199 /// whose entire causal effect lands many tokens later (information carried
200 /// forward through attention); a gauge built on it is blind to exactly that
201 /// content. This provenance is the forward-looking alternative: each row's
202 /// factor `U_n` is the top-`rank` factorization of the aggregated output
203 /// Fisher `Σ_{t ≥ n} (∂logits_t/∂x_n)ᵀ F_t (∂logits_t/∂x_n)` over future
204 /// positions the residual stream at `n` reaches. It is provenance-generic:
205 /// it whitens nothing ([`Self::whitens_likelihood`] is `false`, like
206 /// [`MetricProvenance::OutputFisher`]) and drives the gauge / lens /
207 /// enrichment unchanged ([`Self::is_output_fisher_like`]). The lens/gauge
208 /// machinery consumes it identically; only the *scientific* reading
209 /// changes — dormant-feature detection becomes forward-looking (a feature
210 /// driving far-future tokens now registers behavioral coupling that the
211 /// same-position metric reported as ≈ 0).
212 OutputFisherDownstream { rank: usize },
213 /// **Rung 1** — the output-Fisher metric installed as the reconstruction
214 /// **likelihood weight** (generalized least squares in nats), not merely as
215 /// a gauge. `M_n = U_n U_nᵀ ≈ G_n = J_nᵀ F_n J_n` is the `s`-probe sketch of
216 /// the pulled-back output Fisher, with `U_n = [v₁ … v_s]`,
217 /// `vᵢ = J_nᵀ F_n^{1/2} uᵢ`, and `probes = s` the number of random probes
218 /// (the factor rank).
219 ///
220 /// This is the *only* [`RowMetric::is_output_fisher_like`]-adjacent
221 /// provenance for which [`RowMetric::whitens_likelihood`] is `true`: the
222 /// data-fit sums `½ eᵀ G_n e = ½ ‖U_nᵀ e‖²` (nats) instead of `½‖e‖²`. It is
223 /// distinct from [`Self::OutputFisher`] precisely because the choice to let
224 /// the metric enter the *loss* (rather than only the gauge) is deliberate and
225 /// must not be silently inherited by the #980 gauge / two-tier-harvest
226 /// contract — that contract relies on [`Self::OutputFisher`] whitening
227 /// nothing. Because `G_n` is a *fixed* per-row metric, the whitened problem
228 /// is again linear-Gaussian in the coefficients, so REML/evidence/EDF are
229 /// unchanged (the GLS-preserves-REML property, verified in the module tests
230 /// against the `G=I` plain-MSE limit).
231 BehavioralFisher { probes: usize },
232 /// Structured-residual whitening: `M_n = Σ_n^{-1}` from the **estimated**
233 /// factor-analytic residual covariance `Σ_n = Λ c(z_n) Λᵀ + D` (#974), with
234 /// `factor_rank` the selected factor count. Produced by
235 /// Structured-residual producers materialize this provenance when they fit
236 /// a residual-covariance whitening model;
237 /// the only provenance for which
238 /// [`whitens_likelihood`](RowMetric::whitens_likelihood) is `true`. It
239 /// carries the same low-rank factor layout as
240 /// [`MetricProvenance::OutputFisher`].
241 WhitenedStructured { factor_rank: usize },
242}
243
244/// The single per-row metric object. Holds one low-rank factor stack `U_n` (or
245/// none, for Euclidean) plus the validated PSD blocks, tagged with its
246/// [`MetricProvenance`].
247///
248/// `p` is the output dimensionality (residual / Jacobian-column dimension); the
249/// per-row factor `U_n ∈ ℝ^{p × rank}` so `W_n = U_n U_nᵀ ∈ ℝ^{p × p}` without
250/// ever being materialized as `p × p` in any hot path.
251#[derive(Clone, Debug)]
252pub struct RowMetric {
253 provenance: MetricProvenance,
254 n_rows: usize,
255 p: usize,
256 rank: usize,
257 /// `(n_rows, p * rank)` row-major: `U_n[i, k] = u[n, i * rank + k]`. `None`
258 /// for [`MetricProvenance::Euclidean`] (the identity factor is implicit).
259 factors: Option<Arc<Array2<f64>>>,
260 /// **Solver-only** Tikhonov floor `δ` added as `δ I_p` to make a
261 /// rank-deficient `U_n U_nᵀ` invertible for an *internal solve only*.
262 ///
263 /// Invariant (mirrors `RidgePolicy::solver_only`, #747): `δ` **never** enters
264 /// any quantity that feeds the evidence criterion. The criterion-facing
265 /// quad-form / whitening / fisher-mass methods all use the *un-floored*
266 /// `U_n U_nᵀ`; only [`Self::solve_floor`]-tagged solver helpers see `δ`. A
267 /// nonzero floor therefore cannot bias the objective the optimizer reports.
268 solver_delta: f64,
269 /// Per-row traces `tr(M_n)` of the criterion-facing (un-floored) metric.
270 ///
271 /// This is the only dense-block reduction any consumer reads (the #980
272 /// Fisher-mass row measure); the `(n_rows, p, p)` block stack itself is
273 /// validated **streamingly** at construction through
274 /// [`normalize_fisher_rao_blocks`] one row at a time and then dropped.
275 /// Retaining it was `n·p²·8` bytes — 13 GiB at `(n=2000, p=896)` and an
276 /// OOM at LLM-scale `p` — for a record nothing ever re-read. The solver
277 /// `δ` is deliberately *not* baked in here, so this is the
278 /// criterion-facing trace.
279 traces: ndarray::Array1<f64>,
280}
281
282impl RowMetric {
283 /// Euclidean metric: `W_n = I_p` for all `n`. Whitening is the identity, so
284 /// the likelihood residual path is bit-for-bit the prior isotropic `φ̂`.
285 ///
286 /// Constructed directly: the identity stack is PSD axiomatically, so
287 /// routing it through the dense normalizer would materialize and
288 /// spectrum-check `n` identity blocks (`n·p²` memory, `n·p³` flops) to
289 /// validate a tautology. `tr(I_p) = p` per row.
290 pub fn euclidean(n_rows: usize, p: usize) -> Result<Self, String> {
291 Ok(Self {
292 provenance: MetricProvenance::Euclidean,
293 n_rows,
294 p,
295 rank: p,
296 factors: None,
297 solver_delta: 0.0,
298 traces: ndarray::Array1::<f64>::from_elem(n_rows, p as f64),
299 })
300 }
301
302 /// Output-Fisher metric: per-row low-rank factors `U_n ∈ ℝ^{p × rank}`
303 /// supplied as a `(n_rows, p * rank)` row-major matrix (`U_n[i, k] =
304 /// u[n, i * rank + k]`). The induced `M_n = U_n U_nᵀ` is PSD by
305 /// construction; it is validated through [`normalize_fisher_rao_blocks`] so
306 /// the validation path is shared. No solver floor (`δ = 0`).
307 pub fn output_fisher(u: Arc<Array2<f64>>, p: usize, rank: usize) -> Result<Self, String> {
308 Self::from_factors(MetricProvenance::OutputFisher { rank }, u, p, rank, 0.0)
309 }
310
311 /// Downstream-influence output-Fisher metric: per-row factors `U_n ∈
312 /// ℝ^{p × rank}` whose `M_n = U_n U_nᵀ` is the aggregated output Fisher of
313 /// position `n` over the **future** positions it reaches through the KV path
314 /// ([`MetricProvenance::OutputFisherDownstream`], #980 mechanism 2). The
315 /// factor layout is identical to [`Self::output_fisher`]; only the
316 /// provenance tag (and hence the scientific reading) differs. Whitens
317 /// nothing, drives the gauge / lens / enrichment exactly as the
318 /// same-position metric does — the consuming machinery is provenance-generic
319 /// (see [`Self::is_output_fisher_like`]).
320 pub fn output_fisher_downstream(
321 u: Arc<Array2<f64>>,
322 p: usize,
323 rank: usize,
324 ) -> Result<Self, String> {
325 Self::from_factors(
326 MetricProvenance::OutputFisherDownstream { rank },
327 u,
328 p,
329 rank,
330 0.0,
331 )
332 }
333
334 /// **Rung 1** — the output-Fisher metric as a reconstruction *likelihood
335 /// weight* (GLS in nats): per-row `s`-probe factors `U_n ∈ ℝ^{p × probes}`
336 /// supplied as a `(n_rows, p * probes)` row-major matrix
337 /// (`U_n[i, k] = u[n, i * probes + k]`), so that column `k` is the probe
338 /// vector `v_k = J_nᵀ F_n^{1/2} u_k` and `M_n = U_n U_nᵀ ≈ G_n`. Unlike
339 /// [`Self::output_fisher`], the resulting metric returns
340 /// `whitens_likelihood() == true`: the data-fit prices reconstruction error
341 /// as `½ eᵀ G_n e`. Validated through [`normalize_fisher_rao_blocks`] like
342 /// every factored metric; no solver floor (`δ = 0`).
343 ///
344 /// See [`pack_probe_factors`] to build `u` from a natural `(n, p, s)` probe
345 /// stack emitted at harvest time.
346 pub fn behavioral_fisher(u: Arc<Array2<f64>>, p: usize, probes: usize) -> Result<Self, String> {
347 Self::from_factors(
348 MetricProvenance::BehavioralFisher { probes },
349 u,
350 p,
351 probes,
352 0.0,
353 )
354 }
355
356 /// Like [`Self::output_fisher`] but with a **solver-only** Tikhonov floor
357 /// `δ ≥ 0`. The floor is recorded for solver helpers only; every
358 /// criterion-facing method (`quad_form`, `whiten_residual`, `fisher_mass`)
359 /// ignores it (#747 discipline), so the evidence criterion is `δ`-free.
360 pub fn output_fisher_with_solver_floor(
361 u: Arc<Array2<f64>>,
362 p: usize,
363 rank: usize,
364 solver_delta: f64,
365 ) -> Result<Self, String> {
366 if !(solver_delta.is_finite() && solver_delta >= 0.0) {
367 return Err(format!(
368 "RowMetric::output_fisher_with_solver_floor: solver_delta must be finite and \
369 non-negative; got {solver_delta}"
370 ));
371 }
372 Self::from_factors(
373 MetricProvenance::OutputFisher { rank },
374 u,
375 p,
376 rank,
377 solver_delta,
378 )
379 }
380
381 /// Structured-residual whitening from supplied per-row precision factors.
382 ///
383 /// `u` carries the per-row factor stack `U_n ∈ ℝ^{p × rank}` (row-major flat)
384 /// with `U_n U_nᵀ = M_n = Σ_n^{-1}` — the precision of the **estimated**
385 /// residual-covariance noise model. This is the low-level constructor; #974
386 /// producers that *fit* `Σ_n` (a low-rank factor + diagonal + smooth
387 /// activity-scale) assemble these factors and call through here. Because the
388 /// provenance is
389 /// [`MetricProvenance::WhitenedStructured`], [`Self::whitens_likelihood`] is
390 /// `true`: a metric built this way is the first that whitens the likelihood.
391 pub fn whitened_structured(u: Arc<Array2<f64>>, p: usize, rank: usize) -> Result<Self, String> {
392 Self::from_factors(
393 MetricProvenance::WhitenedStructured { factor_rank: rank },
394 u,
395 p,
396 rank,
397 0.0,
398 )
399 }
400
401 fn from_factors(
402 provenance: MetricProvenance,
403 u: Arc<Array2<f64>>,
404 p: usize,
405 rank: usize,
406 solver_delta: f64,
407 ) -> Result<Self, String> {
408 let n_rows = u.nrows();
409 if u.ncols() != p * rank {
410 return Err(format!(
411 "RowMetric::from_factors: factor matrix has {} cols; expected p*rank = {}*{} = {}",
412 u.ncols(),
413 p,
414 rank,
415 p * rank
416 ));
417 }
418 if !u.iter().all(|v| v.is_finite()) {
419 return Err("RowMetric::from_factors: factors must be finite".to_string());
420 }
421 // Materialize W_n = U_n U_nᵀ one row at a time (PSD by construction),
422 // validate each through the single shared normalizer rather than
423 // reimplementing the PSD check, record its trace, and drop the block.
424 // Streaming keeps construction O(p²) memory; the former whole-stack
425 // materialization retained `n·p²` doubles nothing ever re-read.
426 let mut traces = ndarray::Array1::<f64>::zeros(n_rows);
427 let mut full = Array3::<f64>::zeros((1, p, p));
428 for row in 0..n_rows {
429 for i in 0..p {
430 for j in 0..p {
431 let mut acc = 0.0;
432 for k in 0..rank {
433 acc += u[[row, i * rank + k]] * u[[row, j * rank + k]];
434 }
435 full[[0, i, j]] = acc;
436 }
437 }
438 normalize_fisher_rao_blocks(full.view().into_dyn(), 1, p)
439 .map_err(|e| format!("RowMetric::from_factors: row {row}: {e}"))?;
440 let mut tr = 0.0_f64;
441 for i in 0..p {
442 tr += full[[0, i, i]];
443 }
444 traces[row] = tr;
445 }
446 Ok(Self {
447 provenance,
448 n_rows,
449 p,
450 rank,
451 factors: Some(u),
452 solver_delta,
453 traces,
454 })
455 }
456
457 /// The provenance tag (consumed by Object 4 to certify the inner product).
458 pub fn provenance(&self) -> MetricProvenance {
459 self.provenance
460 }
461
462 /// Whether this metric is allowed to **whiten the likelihood** (i.e. replace
463 /// the isotropic reconstruction data-fit `½ rᵀr` with the whitened
464 /// `½ rᵀ M_n r`).
465 ///
466 /// This is TRUE for two provenances, for two distinct reasons:
467 ///
468 /// * [`MetricProvenance::WhitenedStructured`] — a genuinely *estimated noise
469 /// model* (a factor-analytic residual covariance, #974), for which
470 /// whitening the likelihood is the statistically correct thing to do; and
471 /// * [`MetricProvenance::BehavioralFisher`] — the **Rung 1** deliberate
472 /// choice to price reconstruction error in nats: the output-Fisher metric
473 /// `G_n` installed *as the loss weight* (`½ eᵀ G_n e`), a generalized
474 /// least-squares reconstruction. Because `G_n` is a fixed per-row metric
475 /// the problem stays linear-Gaussian, so REML/evidence/EDF are preserved.
476 ///
477 /// It is FALSE for [`MetricProvenance::Euclidean`] (nothing to whiten by) and
478 /// for the *gauge-only* [`MetricProvenance::OutputFisher`] /
479 /// [`MetricProvenance::OutputFisherDownstream`]: there the output-Fisher
480 /// inner product is an **output-geometry gauge**, and whitening the
481 /// likelihood by it *implicitly* (without the caller electing GLS) would
482 /// silently replace the reconstruction loss with a Fisher pullback — the #980
483 /// failure mode, and the reason the two-tier harvest can withhold factors
484 /// from a row without changing its loss. `BehavioralFisher` is the *explicit*
485 /// election of that same arithmetic as the intended objective.
486 pub fn whitens_likelihood(&self) -> bool {
487 matches!(
488 self.provenance,
489 MetricProvenance::WhitenedStructured { .. } | MetricProvenance::BehavioralFisher { .. }
490 )
491 }
492
493 /// Whether this metric **drives the gauge** — i.e. the isometry-penalty
494 /// pullback weight is taken from it rather than the identity.
495 ///
496 /// TRUE for any non-[`MetricProvenance::Euclidean`] provenance: both
497 /// [`MetricProvenance::OutputFisher`] and
498 /// [`MetricProvenance::WhitenedStructured`] supply a non-identity per-row
499 /// inner product the gauge pulls back through. Euclidean reduces the gauge
500 /// pullback to the bare `J_nᵀ J_n`, so it does not drive the gauge.
501 pub fn drives_gauge(&self) -> bool {
502 !matches!(self.provenance, MetricProvenance::Euclidean)
503 }
504
505 /// Whether this metric is an **output-Fisher gauge** — either the
506 /// same-position [`MetricProvenance::OutputFisher`] or the downstream
507 /// [`MetricProvenance::OutputFisherDownstream`] (#980). The two share every
508 /// consumer behavior (Sym(F) separation under the gauge, two-lens coupling,
509 /// steering geometry, enrichment); they differ only in the *scientific*
510 /// reading of what behavioral coupling means (same-position vs
511 /// forward-looking). Consumers that gate on "is this an output-Fisher
512 /// pullback" should use this predicate rather than matching one variant, so
513 /// the downstream metric rides the identical path.
514 pub fn is_output_fisher_like(&self) -> bool {
515 matches!(
516 self.provenance,
517 MetricProvenance::OutputFisher { .. } | MetricProvenance::OutputFisherDownstream { .. }
518 )
519 }
520
521 /// Number of rows the metric is defined over.
522 pub fn n_rows(&self) -> usize {
523 self.n_rows
524 }
525
526 /// Output dimensionality `p` (residual / Jacobian-column dimension).
527 pub fn p_out(&self) -> usize {
528 self.p
529 }
530
531 /// The factor rank: the dimension of the whitened residual
532 /// [`Self::whiten_residual_row`] returns (and the column count of the per-row
533 /// factor `U_n ∈ ℝ^{p × rank}`). For [`MetricProvenance::Euclidean`] this is
534 /// `p` (the implicit identity factor), so a consumer that sizes a whitened
535 /// buffer by `metric_rank()` gets the right length in every provenance.
536 pub fn metric_rank(&self) -> usize {
537 self.rank
538 }
539
540 /// Per-row traces `tr(M_n)` of the criterion-facing (un-floored) metric —
541 /// the Fisher-mass reduction the #980 row measure consumes. The dense
542 /// `(n_rows, p, p)` stack is validated streamingly at construction and
543 /// never retained; consumers wanting an explicit `W_n` rebuild it from
544 /// [`Self::metric_rank`]-sized factors.
545 pub fn row_traces(&self) -> ndarray::ArrayView1<'_, f64> {
546 self.traces.view()
547 }
548
549 /// Whiten a single `p`-dimensional residual row `r` into the coordinates
550 /// whose squared Euclidean norm equals `rᵀ W_n r`.
551 ///
552 /// * Euclidean: returns `r` unchanged (`‖r‖² = rᵀ I r`), so the likelihood
553 /// reproduces the isotropic `½ rᵀr` data-fit bit-for-bit.
554 /// * Factored: returns `U_nᵀ r ∈ ℝ^{rank}`, with
555 /// `‖U_nᵀ r‖² = rᵀ U_n U_nᵀ r = rᵀ W_n r`.
556 ///
557 /// This is the load-bearing identity that lets the data-fit loop sum
558 /// `0.5 * Σ whitened²` and recover exactly `rᵀ W_n r` whatever the
559 /// provenance.
560 pub fn whiten_residual_row(&self, row: usize, r: ArrayView1<'_, f64>) -> Vec<f64> {
561 match &self.factors {
562 None => r.iter().copied().collect(),
563 Some(u) => {
564 let mut out = vec![0.0_f64; self.rank];
565 for k in 0..self.rank {
566 let mut acc = 0.0;
567 for i in 0..self.p {
568 acc += u[[row, i * self.rank + k]] * r[i];
569 }
570 out[k] = acc;
571 }
572 out
573 }
574 }
575 }
576
577 /// The factor entry `U_n[i, k]` for one row (`i ∈ [0, p)`, `k ∈ [0, rank)`).
578 /// For [`MetricProvenance::Euclidean`] the implicit factor is `I_p`, so this
579 /// returns `1.0` when `i == k` and `0.0` otherwise — letting a consumer that
580 /// whitens a Jacobian via `factor_entry` produce the identity whitening
581 /// without a provenance branch. Reads the **un-floored** factors (criterion
582 /// face, #747).
583 #[inline]
584 pub fn factor_entry(&self, row: usize, i: usize, k: usize) -> f64 {
585 match &self.factors {
586 None => {
587 if i == k {
588 1.0
589 } else {
590 0.0
591 }
592 }
593 Some(u) => u[[row, i * self.rank + k]],
594 }
595 }
596
597 /// Apply the full per-row metric `M_n x = U_n (U_nᵀ x) ∈ ℝ^p` for one
598 /// `p`-vector `x`, formed factored (`rank` flops in, `p` flops out) — never
599 /// materializing `M_n` as `p × p`. Euclidean returns `x` unchanged
600 /// (`M_n = I_p`). This is the p-space metric-applied vector the SAE β-tier
601 /// data-fit gradient contracts (β lives in p-output space, so its gradient
602 /// needs `M_n r_n`, not the rank-space whitened residual `U_nᵀ r_n`). Uses the
603 /// **un-floored** factors (criterion face, `δ`-free, #747 invariant).
604 pub fn apply_metric_row(&self, row: usize, x: ArrayView1<'_, f64>) -> Vec<f64> {
605 match &self.factors {
606 None => x.iter().copied().collect(),
607 Some(u) => {
608 // w = U_nᵀ x ∈ ℝ^{rank}.
609 let mut w = vec![0.0_f64; self.rank];
610 for k in 0..self.rank {
611 let mut acc = 0.0;
612 for i in 0..self.p {
613 acc += u[[row, i * self.rank + k]] * x[i];
614 }
615 w[k] = acc;
616 }
617 // out = U_n w ∈ ℝ^p.
618 let mut out = vec![0.0_f64; self.p];
619 for i in 0..self.p {
620 let mut acc = 0.0;
621 for k in 0..self.rank {
622 acc += u[[row, i * self.rank + k]] * w[k];
623 }
624 out[i] = acc;
625 }
626 out
627 }
628 }
629 }
630
631 /// Pullback metric `g_n = J_nᵀ W_n J_n` for one row, formed as
632 /// `(J_nᵀ U_n)(U_nᵀ J_n)` — never materializing the `p × p` `W_n`.
633 ///
634 /// `j_row` is the row's Jacobian `J_n ∈ ℝ^{p × d}` flattened row-major
635 /// (`J_n[i, a] = j_row[i * d + a]`). Returns the `d × d` `g_n`.
636 pub fn pullback(&self, row: usize, j_row: &[f64], d: usize) -> Array2<f64> {
637 match &self.factors {
638 None => {
639 // W_n = I_p ⇒ g_n = J_nᵀ J_n.
640 let mut g = Array2::<f64>::zeros((d, d));
641 for a in 0..d {
642 for b in a..d {
643 let mut acc = 0.0;
644 for i in 0..self.p {
645 acc += j_row[i * d + a] * j_row[i * d + b];
646 }
647 g[[a, b]] = acc;
648 g[[b, a]] = acc;
649 }
650 }
651 g
652 }
653 Some(u) => {
654 // M_n = U_nᵀ J_n ∈ ℝ^{rank × d}; g_n = M_nᵀ M_n.
655 let mut m = Array2::<f64>::zeros((self.rank, d));
656 for k in 0..self.rank {
657 for a in 0..d {
658 let mut acc = 0.0;
659 for i in 0..self.p {
660 acc += u[[row, i * self.rank + k]] * j_row[i * d + a];
661 }
662 m[[k, a]] = acc;
663 }
664 }
665 let mut g = Array2::<f64>::zeros((d, d));
666 for a in 0..d {
667 for b in a..d {
668 let mut acc = 0.0;
669 for k in 0..self.rank {
670 acc += m[[k, a]] * m[[k, b]];
671 }
672 g[[a, b]] = acc;
673 g[[b, a]] = acc;
674 }
675 }
676 g
677 }
678 }
679 }
680
681 /// Quadratic form `r_nᵀ M_n r_n` for one row's residual `r_n ∈ ℝ^p`, formed
682 /// **factored** as `‖U_nᵀ r_n‖²` — never materializing the `p × p` `M_n`.
683 ///
684 /// This is the criterion-facing squared residual the likelihood sums; it uses
685 /// the **un-floored** `U_n U_nᵀ`, so the solver `δ` does not enter it
686 /// (#747 invariant). Euclidean provenance returns the bit-identical `‖r_n‖²`.
687 #[inline]
688 pub fn quad_form(&self, row: usize, r: ArrayView1<'_, f64>) -> f64 {
689 match &self.factors {
690 None => r.iter().map(|&v| v * v).sum(),
691 Some(_) => self
692 .whiten_residual_row(row, r)
693 .iter()
694 .map(|&w| w * w)
695 .sum(),
696 }
697 }
698
699 /// Whiten a per-row Jacobian `J_n ∈ ℝ^{p × d}` (row-major flat,
700 /// `J_n[i, a] = j_row[i * d + a]`) into `M_n = U_nᵀ J_n ∈ ℝ^{rank × d}` so
701 /// that `M_nᵀ M_n = J_nᵀ (U_n U_nᵀ) J_n = J_nᵀ W_n J_n` is the pullback
702 /// **without** any `p × p` intermediate. Euclidean returns `J_n` reshaped to
703 /// `(p, d)` (the identity whitening). Solver `δ` is not applied (criterion
704 /// face).
705 pub fn whiten_jacobian(&self, row: usize, j_row: &[f64], d: usize) -> Array2<f64> {
706 match &self.factors {
707 None => {
708 let mut out = Array2::<f64>::zeros((self.p, d));
709 for i in 0..self.p {
710 for a in 0..d {
711 out[[i, a]] = j_row[i * d + a];
712 }
713 }
714 out
715 }
716 Some(u) => {
717 let mut m = Array2::<f64>::zeros((self.rank, d));
718 for k in 0..self.rank {
719 for a in 0..d {
720 let mut acc = 0.0;
721 for i in 0..self.p {
722 acc += u[[row, i * self.rank + k]] * j_row[i * d + a];
723 }
724 m[[k, a]] = acc;
725 }
726 }
727 m
728 }
729 }
730 }
731
732 /// Fisher mass of a per-row output vector `x_n ∈ ℝ^p`: the scalar
733 /// `x_nᵀ M_n x_n` (alias of [`Self::quad_form`] read as an information mass
734 /// rather than a residual square). Factored, never `p × p`, `δ`-free.
735 #[inline]
736 pub fn fisher_mass(&self, row: usize, x: ArrayView1<'_, f64>) -> f64 {
737 self.quad_form(row, x)
738 }
739
740 /// The **solver-only** Tikhonov floor `δ` (#747). Returned for internal
741 /// solver helpers that need `U_n U_nᵀ + δ I` to be invertible; by contract
742 /// no caller may fold this into a criterion-facing quantity. Always `0` for
743 /// Euclidean and for factored metrics built without an explicit floor.
744 pub fn solver_floor(&self) -> f64 {
745 self.solver_delta
746 }
747
748 /// The gauge view of this metric: the
749 /// [`crate::WeightField`] the isometry penalty pulls back through.
750 ///
751 /// This is the **single** way an `IsometryPenalty` acquires a non-identity
752 /// gauge metric — the independent `WeightField` setter has been removed — so
753 /// the gauge metric is, by construction, the same object the likelihood
754 /// whitens with.
755 pub fn to_weight_field(&self) -> crate::WeightField {
756 use crate::WeightField;
757 match &self.factors {
758 None => WeightField::Identity,
759 Some(u) => WeightField::Factored {
760 u: Arc::clone(u),
761 rank: self.rank,
762 p_out: self.p,
763 },
764 }
765 }
766}
767
768/// Pack a harvest-emitted probe stack into the row-major factor layout
769/// [`RowMetric::behavioral_fisher`] expects.
770///
771/// The harvest boundary (the model-interaction side) emits, per token, `s`
772/// probe vectors `vₖ = J_nᵀ F_n^{1/2} uₖ ∈ ℝ^p` — the natural shape is
773/// `probes[n, i, k] = (vₖ)ᵢ`, an `(n_rows, p, probes)` stack. This assembles the
774/// `(n_rows, p · probes)` row-major matrix `u[n, i·probes + k] = probes[n, i, k]`
775/// that the constructor consumes so that column `k` of the per-row factor `U_n`
776/// is exactly probe `vₖ` and `M_n = U_n U_nᵀ = Σₖ vₖ vₖᵀ ≈ G_n`.
777///
778/// This is a pure repack of the standard C-order flattening; it exists so the
779/// harvest → metric seam is a single named, validated Rust surface rather than
780/// an ad-hoc reshape at each call site. Errors on non-finite entries so the
781/// failure is caught here rather than deep in [`normalize_fisher_rao_blocks`].
782pub fn pack_probe_factors(probes: ndarray::ArrayView3<'_, f64>) -> Result<Array2<f64>, String> {
783 let (n_rows, p, s) = probes.dim();
784 if s == 0 {
785 return Err("pack_probe_factors: need at least one probe (s == 0)".to_string());
786 }
787 if !probes.iter().all(|v| v.is_finite()) {
788 return Err("pack_probe_factors: probe entries must be finite".to_string());
789 }
790 let mut u = Array2::<f64>::zeros((n_rows, p * s));
791 for n in 0..n_rows {
792 for i in 0..p {
793 for k in 0..s {
794 u[[n, i * s + k]] = probes[[n, i, k]];
795 }
796 }
797 }
798 Ok(u)
799}
800
801#[cfg(test)]
802mod tests {
803 use super::*;
804 use ndarray::array;
805
806 // ── RowMetric::euclidean ──────────────────────────────────────────────────
807
808 #[test]
809 fn euclidean_metric_has_correct_dimensions() {
810 let m = RowMetric::euclidean(5, 3).unwrap();
811 assert_eq!(m.n_rows(), 5);
812 assert_eq!(m.p_out(), 3);
813 assert_eq!(m.metric_rank(), 3);
814 }
815
816 #[test]
817 fn euclidean_metric_traces_equal_p() {
818 let p = 4_usize;
819 let m = RowMetric::euclidean(3, p).unwrap();
820 for tr in m.row_traces().iter() {
821 assert!((*tr - p as f64).abs() < 1e-14, "trace {tr} != p={p}");
822 }
823 }
824
825 #[test]
826 fn euclidean_provenance_is_euclidean() {
827 let m = RowMetric::euclidean(1, 2).unwrap();
828 assert_eq!(m.provenance(), MetricProvenance::Euclidean);
829 }
830
831 #[test]
832 fn euclidean_does_not_whiten_likelihood() {
833 let m = RowMetric::euclidean(1, 2).unwrap();
834 assert!(!m.whitens_likelihood());
835 }
836
837 #[test]
838 fn euclidean_does_not_drive_gauge() {
839 let m = RowMetric::euclidean(1, 2).unwrap();
840 assert!(!m.drives_gauge());
841 }
842
843 #[test]
844 fn euclidean_is_not_output_fisher_like() {
845 let m = RowMetric::euclidean(1, 2).unwrap();
846 assert!(!m.is_output_fisher_like());
847 }
848
849 #[test]
850 fn euclidean_solver_floor_is_zero() {
851 let m = RowMetric::euclidean(1, 2).unwrap();
852 assert_eq!(m.solver_floor(), 0.0);
853 }
854
855 #[test]
856 fn euclidean_to_weight_field_is_identity() {
857 let m = RowMetric::euclidean(1, 2).unwrap();
858 assert!(matches!(m.to_weight_field(), WeightField::Identity));
859 }
860
861 #[test]
862 fn euclidean_whiten_residual_is_passthrough() {
863 let m = RowMetric::euclidean(1, 3).unwrap();
864 let r = array![1.0_f64, 2.0, 3.0];
865 let w = m.whiten_residual_row(0, r.view());
866 assert_eq!(w, vec![1.0, 2.0, 3.0]);
867 }
868
869 #[test]
870 fn euclidean_factor_entry_is_identity() {
871 let m = RowMetric::euclidean(1, 3).unwrap();
872 assert_eq!(m.factor_entry(0, 0, 0), 1.0);
873 assert_eq!(m.factor_entry(0, 1, 1), 1.0);
874 assert_eq!(m.factor_entry(0, 2, 2), 1.0);
875 assert_eq!(m.factor_entry(0, 0, 1), 0.0);
876 assert_eq!(m.factor_entry(0, 1, 0), 0.0);
877 }
878
879 #[test]
880 fn euclidean_quad_form_is_squared_norm() {
881 let m = RowMetric::euclidean(1, 3).unwrap();
882 let r = array![1.0_f64, 2.0, 2.0];
883 assert!((m.quad_form(0, r.view()) - 9.0).abs() < 1e-14);
884 }
885
886 // ── MetricProvenance predicates ───────────────────────────────────────────
887
888 #[test]
889 fn output_fisher_drives_gauge_but_not_likelihood() {
890 let u = Arc::new(array![[1.0_f64]]);
891 let m = RowMetric::output_fisher(u, 1, 1).unwrap();
892 assert!(m.drives_gauge());
893 assert!(!m.whitens_likelihood());
894 assert!(m.is_output_fisher_like());
895 }
896
897 #[test]
898 fn whitened_structured_whitens_likelihood_and_drives_gauge() {
899 let u = Arc::new(array![[1.0_f64]]);
900 let m = RowMetric::whitened_structured(u, 1, 1).unwrap();
901 assert!(m.whitens_likelihood());
902 assert!(m.drives_gauge());
903 assert!(!m.is_output_fisher_like());
904 }
905
906 #[test]
907 fn behavioral_fisher_whitens_likelihood_and_drives_gauge() {
908 // The Rung-1 deliberate GLS metric: unlike the gauge-only OutputFisher,
909 // it whitens the reconstruction likelihood.
910 let u = Arc::new(array![[1.0_f64, 0.5]]); // p=1, probes=2
911 let m = RowMetric::behavioral_fisher(u, 1, 2).unwrap();
912 assert!(m.whitens_likelihood());
913 assert!(m.drives_gauge());
914 assert_eq!(
915 m.provenance(),
916 MetricProvenance::BehavioralFisher { probes: 2 }
917 );
918 assert_eq!(m.metric_rank(), 2);
919 }
920
921 #[test]
922 fn behavioral_fisher_quad_form_is_probe_sum() {
923 // p=2, s=2 probes v1=(1,0), v2=(0,2) → G = diag(1,4);
924 // e=(3,1) → eᵀGe = 9·1 + 1·4 = 13 = Σ (vᵢᵀe)² = 3² + 2² = 13.
925 // Column-major-within-row layout U[i,k]=u[i*probes+k]:
926 // U[0,0]=1 U[0,1]=0 U[1,0]=0 U[1,1]=2
927 let u = Arc::new(array![[1.0_f64, 0.0, 0.0, 2.0]]);
928 let m = RowMetric::behavioral_fisher(u, 2, 2).unwrap();
929 let e = array![3.0_f64, 1.0];
930 assert!((m.quad_form(0, e.view()) - 13.0).abs() < 1e-12);
931 }
932
933 #[test]
934 fn behavioral_fisher_g_identity_reproduces_euclidean_quad_form() {
935 // GLS with G=I must reduce to plain MSE. Identity probes (s=p, U=I_p)
936 // ⇒ M_n = I ⇒ quad_form == ‖e‖², matching Euclidean bit-for-bit, and
937 // metric_rank == p so the whitened residual-dof accounting is unchanged.
938 let p = 3;
939 let mut u = Array2::<f64>::zeros((1, p * p));
940 for i in 0..p {
941 u[[0, i * p + i]] = 1.0;
942 }
943 let bf = RowMetric::behavioral_fisher(Arc::new(u), p, p).unwrap();
944 let euc = RowMetric::euclidean(1, p).unwrap();
945 let e = array![1.5_f64, -2.0, 0.25];
946 assert_eq!(bf.metric_rank(), euc.metric_rank());
947 assert!((bf.quad_form(0, e.view()) - euc.quad_form(0, e.view())).abs() < 1e-14);
948 // and whitened residual is the residual itself (identity whitening)
949 assert_eq!(bf.whiten_residual_row(0, e.view()), vec![1.5, -2.0, 0.25]);
950 }
951
952 #[test]
953 fn pack_probe_factors_matches_manual_layout() {
954 use ndarray::Array3;
955 // n=1, p=2, s=2: probes[0,i,k] = v_k[i]; v0=(1,3), v1=(2,4)
956 let mut probes = Array3::<f64>::zeros((1, 2, 2));
957 probes[[0, 0, 0]] = 1.0; // v0[0]
958 probes[[0, 1, 0]] = 3.0; // v0[1]
959 probes[[0, 0, 1]] = 2.0; // v1[0]
960 probes[[0, 1, 1]] = 4.0; // v1[1]
961 let u = pack_probe_factors(probes.view()).unwrap();
962 // Layout U[i,k] = u[i*s + k]: [v0[0],v1[0], v0[1],v1[1]] = [1,2,3,4]
963 assert_eq!(u.as_slice().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
964 // Round-trips into a valid metric whose G = v0 v0ᵀ + v1 v1ᵀ.
965 let m = RowMetric::behavioral_fisher(Arc::new(u), 2, 2).unwrap();
966 // e=(1,0): eᵀGe = v0[0]²+v1[0]² = 1+4 = 5.
967 let e = array![1.0_f64, 0.0];
968 assert!((m.quad_form(0, e.view()) - 5.0).abs() < 1e-12);
969 }
970
971 #[test]
972 fn pack_probe_factors_rejects_zero_probes() {
973 use ndarray::Array3;
974 let probes = Array3::<f64>::zeros((2, 3, 0));
975 assert!(pack_probe_factors(probes.view()).is_err());
976 }
977
978 #[test]
979 fn output_fisher_downstream_is_output_fisher_like() {
980 let u = Arc::new(array![[1.0_f64]]);
981 let m = RowMetric::output_fisher_downstream(u, 1, 1).unwrap();
982 assert!(m.is_output_fisher_like());
983 assert!(m.drives_gauge());
984 }
985
986 // ── WeightField::project_jac_row_with_u ──────────────────────────────────
987
988 #[test]
989 fn project_jac_with_identity_returns_jac() {
990 // p=2, rank=2, d=2; U=I_2, J=[[1,2],[3,4]] → M = U^T J = J
991 let u_row = [1.0_f64, 0.0, 0.0, 1.0]; // U[i,k]=u[i*rank+k], I_2
992 let j_row = [1.0_f64, 2.0, 3.0, 4.0]; // J[i,a]=j[i*d+a]
993 let m = WeightField::project_jac_row_with_u(&u_row, &j_row, 2, 2, 2);
994 assert!((m[[0, 0]] - 1.0).abs() < 1e-14);
995 assert!((m[[0, 1]] - 2.0).abs() < 1e-14);
996 assert!((m[[1, 0]] - 3.0).abs() < 1e-14);
997 assert!((m[[1, 1]] - 4.0).abs() < 1e-14);
998 }
999
1000 #[test]
1001 fn project_jac_with_zeros_returns_zero_matrix() {
1002 let u_row = [0.0_f64, 0.0];
1003 let j_row = [1.0_f64, 2.0];
1004 let m = WeightField::project_jac_row_with_u(&u_row, &j_row, 2, 1, 1);
1005 assert_eq!(m[[0, 0]], 0.0);
1006 }
1007}