gam_sae/inference/atom_lens.rs
1//! Two-score per-atom **lens** (#980, amended): an *additive* per-atom report on
2//! a fitted [`SaeManifoldTerm`](crate::manifold::SaeManifoldTerm).
3//!
4//! # The amendment this file encodes
5//!
6//! The original #980 framing folded the output-Fisher metric into the SAE
7//! *loss* — "replace the Euclidean reconstruction loss by a Fisher-pulled-back
8//! loss". That is wrong: it makes the gauge drive the fit, which silently
9//! suppresses any structure that is *represented but not currently used*, and it
10//! couples the criterion to a quantity (the output-Fisher factors) that is
11//! optional and may be absent. The corrected paradigm:
12//!
13//! * **The SAE fit stays on activations.** The reconstruction likelihood
14//! whitens through the [`RowMetric`](gam_problem::RowMetric)
15//! exactly as before; with the default Euclidean provenance that is the
16//! bit-for-bit isotropic path. The Fisher metric **never** replaces the loss.
17//! * **The lens is an additive report.** It reads the *already-fitted* model and
18//! the (optional) `RowMetric`, and emits, per atom, two orthogonal scores plus
19//! their discrepancy. Nothing it computes feeds back into any loss, criterion,
20//! penalty, or optimizer state.
21//!
22//! # The two scores
23//!
24//! For each atom `k`:
25//!
26//! * **presence** (representational, activation-side, *Fisher-free*): how
27//! strongly the atom is encoded *in the activations*. Mean active mass on the
28//! rows where the atom is truly active, times an amplitude-weighted decoder
29//! norm. This is a pure reconstruction-side quantity: it does not touch the
30//! `RowMetric` at all, so it is identical whether or not output-Fisher factors
31//! were supplied. *Everything represented survives* — a loud-but-inert atom is
32//! just as present as a quiet load-bearing one.
33//! * **coupling** (behavioral, *the only place Fisher enters*): the output-Fisher
34//! mass along the atom's decoder tangent `dg_k/dt`, averaged over the atom's
35//! active rows. This is computed through
36//! [`RowMetric::fisher_mass`](gam_problem::RowMetric::fisher_mass)
37//! — a *reported* score, never folded into a loss or criterion. Under a
38//! Euclidean / no-Fisher provenance the coupling is **not available** (`None`),
39//! degrading gracefully exactly as the harvest of the Fisher factors is
40//! optional. It is never an error.
41//!
42//! # The headline: discrepancy
43//!
44//! `discrepancy = normalized_presence − normalized_coupling`. A high value means
45//! **high presence + low coupling**: the atom is strongly *represented* in the
46//! activations yet carries almost no behavioral mass — "represented but not
47//! currently used", i.e. *thinking it, not saying it*. That is the headline
48//! safety number this lens exists to surface. The lens *reports* it; it does not
49//! suppress the atom, because suppression would be the loss-replacement mistake
50//! the amendment removes.
51
52use ndarray::{ArrayView1, ArrayView2};
53
54use gam_problem::{MetricProvenance, RowMetric};
55use crate::manifold::SaeManifoldTerm;
56
57/// Below this active mass a row is not "truly active" for an atom, so it
58/// contributes to neither the presence average nor the coupling average. The
59/// assignment masses are convex weights in `[0, 1]`; this floor excludes rows
60/// where the atom is essentially off (numerical dust) from the per-atom
61/// averages, so a globally-near-zero atom does not get a spuriously large
62/// amplitude-per-active-row.
63pub const SAE_TRUST_ACTIVE_MASS_FLOOR: f64 = 1e-6;
64
65/// One atom's lens entry.
66#[derive(Clone, Debug, PartialEq)]
67pub struct AtomLensEntry {
68 /// The atom's name (mirrors [`crate::manifold::SaeManifoldAtom::name`]).
69 pub name: String,
70 /// **presence** (representational, activation-side, Fisher-free): mean active
71 /// mass on truly-active rows × amplitude-weighted decoder norm. Always
72 /// available — it reads only the activation-side fit.
73 pub presence: f64,
74 /// **coupling** (behavioral): mean output-Fisher mass of the decoder tangent
75 /// `dg_k/dt` over the atom's active rows. `None` under a Euclidean /
76 /// no-Fisher provenance (the metric carries no behavioral information, so the
77 /// score is *not available* — not zero, not an error).
78 pub coupling: Option<f64>,
79 /// **presence** normalized to `[0, 1]` across the report's atoms (divided by
80 /// the max presence; `0` if every atom has zero presence).
81 pub presence_normalized: f64,
82 /// **coupling** normalized to `[0, 1]` across the report's atoms (divided by
83 /// the max coupling). `None` whenever coupling itself is unavailable.
84 pub coupling_normalized: Option<f64>,
85 /// The headline: `presence_normalized − coupling_normalized`, the
86 /// "represented but not currently used" discrepancy. High ⇒ thinking it, not
87 /// saying it. `None` when coupling is unavailable (no behavioral axis to
88 /// compare presence against).
89 pub discrepancy: Option<f64>,
90}
91
92impl AtomLensEntry {
93 /// Whether this atom reads as **represented but not currently used** —
94 /// strong activation presence, weak behavioral coupling. Pure classification
95 /// of the already-computed scores; it suppresses nothing.
96 ///
97 /// Returns `false` when coupling is unavailable (no behavioral axis exists to
98 /// declare a discrepancy against).
99 pub fn is_represented_not_used(&self) -> bool {
100 match self.discrepancy {
101 Some(d) => d >= REPRESENTED_NOT_USED_THRESHOLD,
102 None => false,
103 }
104 }
105
106 /// Whether this atom reads as **used** — its behavioral coupling is at least
107 /// as strong as its representational presence (non-positive discrepancy).
108 /// Returns `false` when coupling is unavailable.
109 pub fn is_used(&self) -> bool {
110 match self.discrepancy {
111 Some(d) => d <= USED_THRESHOLD,
112 None => false,
113 }
114 }
115}
116
117/// Discrepancy at or above this flags "represented but not currently used".
118/// Presence and coupling are each normalized to `[0, 1]`, so the discrepancy
119/// lives in `[-1, 1]`; a value this large means presence outruns coupling by a
120/// wide, normalized margin.
121const REPRESENTED_NOT_USED_THRESHOLD: f64 = 0.5;
122
123/// Discrepancy at or below this flags "used" (coupling matches or exceeds
124/// presence).
125const USED_THRESHOLD: f64 = 0.0;
126
127/// The full two-score lens over every atom of a fitted SAE-manifold term.
128#[derive(Clone, Debug, PartialEq)]
129pub struct AtomTwoLensReport {
130 /// One entry per atom, in atom order.
131 pub atoms: Vec<AtomLensEntry>,
132 /// The provenance of the metric the coupling was read through (or would have
133 /// been): `OutputFisher` / `WhitenedStructured` ⇒ coupling available;
134 /// `Euclidean` (or no metric installed) ⇒ coupling unavailable. Echoed so a
135 /// consumer can certify *why* a coupling is `None`.
136 pub coupling_provenance: Option<MetricProvenance>,
137}
138
139impl AtomTwoLensReport {
140 /// Whether the behavioral coupling axis is available at all (i.e. an
141 /// output-Fisher / structured metric was installed). When `false`, every
142 /// entry's `coupling`, `coupling_normalized`, and `discrepancy` are `None`.
143 pub fn coupling_available(&self) -> bool {
144 self.coupling_provenance
145 .is_some_and(metric_carries_behavior)
146 }
147}
148
149/// Does this provenance carry behavioral (output-Fisher) information? Euclidean
150/// does not (it is the isotropic activation-only path); the factored
151/// provenances do.
152fn metric_carries_behavior(p: MetricProvenance) -> bool {
153 match p {
154 MetricProvenance::Euclidean => false,
155 MetricProvenance::OutputFisher { .. }
156 | MetricProvenance::OutputFisherDownstream { .. }
157 | MetricProvenance::WhitenedStructured { .. } => true,
158 }
159}
160
161/// Build the two-score per-atom lens over a fitted [`SaeManifoldTerm`].
162///
163/// `model` is the fitted term (read only). `metric` is the per-row inner product
164/// the coupling is measured through; pass the model's own installed metric
165/// ([`SaeManifoldTerm::row_metric`]) or any metric whose row/output dimensions
166/// match the term. When the metric's provenance is Euclidean (no behavioral
167/// information), the coupling degrades to `None` for every atom — the lens stays
168/// available, only its behavioral axis is absent.
169///
170/// This function is a *pure read*: it never mutates the model, never touches a
171/// loss / criterion / penalty, and the only place the Fisher metric enters is the
172/// [`RowMetric::fisher_mass`] call that produces the (reported) coupling score.
173pub fn atom_two_lens(
174 model: &SaeManifoldTerm,
175 metric: &RowMetric,
176 assignments_override: Option<ArrayView2<'_, f64>>,
177) -> AtomTwoLensReport {
178 let n = model.n_obs();
179 let k = model.k_atoms();
180 let provenance = metric.provenance();
181 // Coupling is only meaningful when the metric carries behavioral
182 // information *and* its dimensions match the term. A mismatched metric (or a
183 // Euclidean one) degrades the behavioral axis to "not available" rather than
184 // erroring — the lens is optional, mirroring the harvest being optional.
185 let coupling_axis_available = metric_carries_behavior(provenance)
186 && metric.n_rows() == n
187 && metric.p_out() == model.output_dim();
188
189 // Per-row assignment masses, computed once. When a hard top-k projection has
190 // been applied (#1232), the caller supplies the projected matrix so the lens
191 // matches the returned payload rather than the smooth optimization assignments.
192 let assignments_owned;
193 let assignments = match assignments_override {
194 Some(view) => view,
195 None => {
196 assignments_owned = model.assignment.assignments();
197 assignments_owned.view()
198 }
199 };
200
201 let mut presence = vec![0.0_f64; k];
202 let mut coupling_raw = vec![0.0_f64; k];
203 let mut any_coupling = vec![false; k];
204
205 for (atom_idx, atom) in model.atoms.iter().enumerate() {
206 // Amplitude-weighted decoder norm: ‖B_k‖_F. The decoder coefficients
207 // B_k ∈ ℝ^{M_k × p} are the linear map from basis activations to the
208 // reconstruction output, so their Frobenius norm is the per-atom output
209 // amplitude per unit of basis activation — the "how loud is this atom in
210 // the reconstruction" factor of presence. Pure activation-side: no
211 // metric is consulted.
212 let decoder_norm = atom
213 .decoder_coefficients
214 .iter()
215 .map(|&b| b * b)
216 .sum::<f64>()
217 .sqrt();
218
219 let latent_dim = atom.latent_dim;
220
221 let mut active_mass_sum = 0.0_f64;
222 let mut active_row_count = 0.0_f64;
223 let mut coupling_sum = 0.0_f64;
224
225 for row in 0..n {
226 let mass = assignments[[row, atom_idx]];
227 if !(mass > SAE_TRUST_ACTIVE_MASS_FLOOR) {
228 continue;
229 }
230 active_mass_sum += mass;
231 active_row_count += 1.0;
232
233 if coupling_axis_available {
234 // Behavioral coupling on this active row: the output-Fisher mass
235 // of the decoder tangent dg_k/dt summed over the atom's latent
236 // axes, weighted by the active mass (so a barely-active row
237 // contributes proportionally less behavioral evidence, matching
238 // the presence weighting). This is the ONLY place the Fisher
239 // metric enters; `fisher_mass` reads no loss / criterion.
240 let mut row_tangent_mass = 0.0_f64;
241 for axis in 0..latent_dim {
242 let dg = atom.decoded_derivative_row(row, axis);
243 let dg_view: ArrayView1<'_, f64> = dg.view();
244 row_tangent_mass += metric.fisher_mass(row, dg_view);
245 }
246 coupling_sum += mass * row_tangent_mass;
247 any_coupling[atom_idx] = true;
248 }
249 }
250
251 // Mean active mass on truly-active rows (0 if the atom is active nowhere).
252 let mean_active_mass = if active_row_count > 0.0 {
253 active_mass_sum / active_row_count
254 } else {
255 0.0
256 };
257 presence[atom_idx] = mean_active_mass * decoder_norm;
258
259 // Mean behavioral coupling over the atom's active rows.
260 if coupling_axis_available && active_row_count > 0.0 {
261 coupling_raw[atom_idx] = coupling_sum / active_row_count;
262 }
263 }
264
265 // Normalize presence across atoms (divide by the max; 0 when all zero).
266 let presence_max = presence.iter().copied().fold(0.0_f64, f64::max);
267 // Normalize coupling across atoms, only over atoms with an available score.
268 let coupling_max = coupling_raw
269 .iter()
270 .zip(any_coupling.iter())
271 .filter(|&(_, &has)| has)
272 .map(|(&c, _)| c)
273 .fold(0.0_f64, f64::max);
274
275 let mut entries = Vec::with_capacity(k);
276 for (atom_idx, atom) in model.atoms.iter().enumerate() {
277 let p = presence[atom_idx];
278 let presence_normalized = if presence_max > 0.0 {
279 p / presence_max
280 } else {
281 0.0
282 };
283
284 let (coupling, coupling_normalized, discrepancy) =
285 if coupling_axis_available && any_coupling[atom_idx] {
286 let c = coupling_raw[atom_idx];
287 let c_norm = if coupling_max > 0.0 {
288 c / coupling_max
289 } else {
290 0.0
291 };
292 (Some(c), Some(c_norm), Some(presence_normalized - c_norm))
293 } else {
294 (None, None, None)
295 };
296
297 entries.push(AtomLensEntry {
298 name: atom.name.clone(),
299 presence: p,
300 coupling,
301 presence_normalized,
302 coupling_normalized,
303 discrepancy,
304 });
305 }
306
307 AtomTwoLensReport {
308 atoms: entries,
309 coupling_provenance: Some(provenance),
310 }
311}