Skip to main content

gam_sae/inference/
atom_lens.rs

1//! Two-score per-atom **lens** (#980, amended): an *additive* per-atom report on
2//! a fitted [`SaeManifoldTerm`](crate::manifold::SaeManifoldTerm).
3//!
4//! # The amendment this file encodes
5//!
6//! The original #980 framing folded the output-Fisher metric into the SAE
7//! *loss* — "replace the Euclidean reconstruction loss by a Fisher-pulled-back
8//! loss". That is wrong: it makes the gauge drive the fit, which silently
9//! suppresses any structure that is *represented but not currently used*, and it
10//! couples the criterion to a quantity (the output-Fisher factors) that is
11//! optional and may be absent. The corrected paradigm:
12//!
13//! * **The SAE fit stays on activations.** The reconstruction likelihood
14//!   whitens through the [`RowMetric`](gam_problem::RowMetric)
15//!   exactly as before; with the default Euclidean provenance that is the
16//!   bit-for-bit isotropic path. The Fisher metric **never** replaces the loss.
17//! * **The lens is an additive report.** It reads the *already-fitted* model and
18//!   the (optional) `RowMetric`, and emits, per atom, two orthogonal scores plus
19//!   their discrepancy. Nothing it computes feeds back into any loss, criterion,
20//!   penalty, or optimizer state.
21//!
22//! # The two scores
23//!
24//! For each atom `k`:
25//!
26//! * **presence** (representational, activation-side, *Fisher-free*): how
27//!   strongly the atom is encoded *in the activations*. Mean active mass on the
28//!   rows where the atom is truly active, times an amplitude-weighted decoder
29//!   norm. This is a pure reconstruction-side quantity: it does not touch the
30//!   `RowMetric` at all, so it is identical whether or not output-Fisher factors
31//!   were supplied. *Everything represented survives* — a loud-but-inert atom is
32//!   just as present as a quiet load-bearing one.
33//! * **coupling** (behavioral, *the only place Fisher enters*): the output-Fisher
34//!   mass along the atom's decoder tangent `dg_k/dt`, averaged over the atom's
35//!   active rows. This is computed through
36//!   [`RowMetric::fisher_mass`](gam_problem::RowMetric::fisher_mass)
37//!   — a *reported* score, never folded into a loss or criterion. Under a
38//!   Euclidean / no-Fisher provenance the coupling is **not available** (`None`),
39//!   degrading gracefully exactly as the harvest of the Fisher factors is
40//!   optional. It is never an error.
41//!
42//! # The headline: discrepancy
43//!
44//! `discrepancy = normalized_presence − normalized_coupling`. A high value means
45//! **high presence + low coupling**: the atom is strongly *represented* in the
46//! activations yet carries almost no behavioral mass — "represented but not
47//! currently used", i.e. *thinking it, not saying it*. That is the headline
48//! safety number this lens exists to surface. The lens *reports* it; it does not
49//! suppress the atom, because suppression would be the loss-replacement mistake
50//! the amendment removes.
51
52use ndarray::{ArrayView1, ArrayView2};
53
54use gam_problem::{MetricProvenance, RowMetric};
55use crate::manifold::SaeManifoldTerm;
56
57/// Below this active mass a row is not "truly active" for an atom, so it
58/// contributes to neither the presence average nor the coupling average. The
59/// assignment masses are convex weights in `[0, 1]`; this floor excludes rows
60/// where the atom is essentially off (numerical dust) from the per-atom
61/// averages, so a globally-near-zero atom does not get a spuriously large
62/// amplitude-per-active-row.
63pub const SAE_TRUST_ACTIVE_MASS_FLOOR: f64 = 1e-6;
64
65/// One atom's lens entry.
66#[derive(Clone, Debug, PartialEq)]
67pub struct AtomLensEntry {
68    /// The atom's name (mirrors [`crate::manifold::SaeManifoldAtom::name`]).
69    pub name: String,
70    /// **presence** (representational, activation-side, Fisher-free): mean active
71    /// mass on truly-active rows × amplitude-weighted decoder norm. Always
72    /// available — it reads only the activation-side fit.
73    pub presence: f64,
74    /// **coupling** (behavioral): mean output-Fisher mass of the decoder tangent
75    /// `dg_k/dt` over the atom's active rows. `None` under a Euclidean /
76    /// no-Fisher provenance (the metric carries no behavioral information, so the
77    /// score is *not available* — not zero, not an error).
78    pub coupling: Option<f64>,
79    /// **presence** normalized to `[0, 1]` across the report's atoms (divided by
80    /// the max presence; `0` if every atom has zero presence).
81    pub presence_normalized: f64,
82    /// **coupling** normalized to `[0, 1]` across the report's atoms (divided by
83    /// the max coupling). `None` whenever coupling itself is unavailable.
84    pub coupling_normalized: Option<f64>,
85    /// The headline: `presence_normalized − coupling_normalized`, the
86    /// "represented but not currently used" discrepancy. High ⇒ thinking it, not
87    /// saying it. `None` when coupling is unavailable (no behavioral axis to
88    /// compare presence against).
89    pub discrepancy: Option<f64>,
90}
91
92impl AtomLensEntry {
93    /// Whether this atom reads as **represented but not currently used** —
94    /// strong activation presence, weak behavioral coupling. Pure classification
95    /// of the already-computed scores; it suppresses nothing.
96    ///
97    /// Returns `false` when coupling is unavailable (no behavioral axis exists to
98    /// declare a discrepancy against).
99    pub fn is_represented_not_used(&self) -> bool {
100        match self.discrepancy {
101            Some(d) => d >= REPRESENTED_NOT_USED_THRESHOLD,
102            None => false,
103        }
104    }
105
106    /// Whether this atom reads as **used** — its behavioral coupling is at least
107    /// as strong as its representational presence (non-positive discrepancy).
108    /// Returns `false` when coupling is unavailable.
109    pub fn is_used(&self) -> bool {
110        match self.discrepancy {
111            Some(d) => d <= USED_THRESHOLD,
112            None => false,
113        }
114    }
115}
116
117/// Discrepancy at or above this flags "represented but not currently used".
118/// Presence and coupling are each normalized to `[0, 1]`, so the discrepancy
119/// lives in `[-1, 1]`; a value this large means presence outruns coupling by a
120/// wide, normalized margin.
121const REPRESENTED_NOT_USED_THRESHOLD: f64 = 0.5;
122
123/// Discrepancy at or below this flags "used" (coupling matches or exceeds
124/// presence).
125const USED_THRESHOLD: f64 = 0.0;
126
127/// The full two-score lens over every atom of a fitted SAE-manifold term.
128#[derive(Clone, Debug, PartialEq)]
129pub struct AtomTwoLensReport {
130    /// One entry per atom, in atom order.
131    pub atoms: Vec<AtomLensEntry>,
132    /// The provenance of the metric the coupling was read through (or would have
133    /// been): `OutputFisher` / `WhitenedStructured` ⇒ coupling available;
134    /// `Euclidean` (or no metric installed) ⇒ coupling unavailable. Echoed so a
135    /// consumer can certify *why* a coupling is `None`.
136    pub coupling_provenance: Option<MetricProvenance>,
137}
138
139impl AtomTwoLensReport {
140    /// Whether the behavioral coupling axis is available at all (i.e. an
141    /// output-Fisher / structured metric was installed). When `false`, every
142    /// entry's `coupling`, `coupling_normalized`, and `discrepancy` are `None`.
143    pub fn coupling_available(&self) -> bool {
144        self.coupling_provenance
145            .is_some_and(metric_carries_behavior)
146    }
147}
148
149/// Does this provenance carry behavioral (output-Fisher) information? Euclidean
150/// does not (it is the isotropic activation-only path); the factored
151/// provenances do.
152fn metric_carries_behavior(p: MetricProvenance) -> bool {
153    match p {
154        MetricProvenance::Euclidean => false,
155        MetricProvenance::OutputFisher { .. }
156        | MetricProvenance::OutputFisherDownstream { .. }
157        | MetricProvenance::WhitenedStructured { .. } => true,
158    }
159}
160
161/// Build the two-score per-atom lens over a fitted [`SaeManifoldTerm`].
162///
163/// `model` is the fitted term (read only). `metric` is the per-row inner product
164/// the coupling is measured through; pass the model's own installed metric
165/// ([`SaeManifoldTerm::row_metric`]) or any metric whose row/output dimensions
166/// match the term. When the metric's provenance is Euclidean (no behavioral
167/// information), the coupling degrades to `None` for every atom — the lens stays
168/// available, only its behavioral axis is absent.
169///
170/// This function is a *pure read*: it never mutates the model, never touches a
171/// loss / criterion / penalty, and the only place the Fisher metric enters is the
172/// [`RowMetric::fisher_mass`] call that produces the (reported) coupling score.
173pub fn atom_two_lens(
174    model: &SaeManifoldTerm,
175    metric: &RowMetric,
176    assignments_override: Option<ArrayView2<'_, f64>>,
177) -> AtomTwoLensReport {
178    let n = model.n_obs();
179    let k = model.k_atoms();
180    let provenance = metric.provenance();
181    // Coupling is only meaningful when the metric carries behavioral
182    // information *and* its dimensions match the term. A mismatched metric (or a
183    // Euclidean one) degrades the behavioral axis to "not available" rather than
184    // erroring — the lens is optional, mirroring the harvest being optional.
185    let coupling_axis_available = metric_carries_behavior(provenance)
186        && metric.n_rows() == n
187        && metric.p_out() == model.output_dim();
188
189    // Per-row assignment masses, computed once. When a hard top-k projection has
190    // been applied (#1232), the caller supplies the projected matrix so the lens
191    // matches the returned payload rather than the smooth optimization assignments.
192    let assignments_owned;
193    let assignments = match assignments_override {
194        Some(view) => view,
195        None => {
196            assignments_owned = model.assignment.assignments();
197            assignments_owned.view()
198        }
199    };
200
201    let mut presence = vec![0.0_f64; k];
202    let mut coupling_raw = vec![0.0_f64; k];
203    let mut any_coupling = vec![false; k];
204
205    for (atom_idx, atom) in model.atoms.iter().enumerate() {
206        // Amplitude-weighted decoder norm: ‖B_k‖_F. The decoder coefficients
207        // B_k ∈ ℝ^{M_k × p} are the linear map from basis activations to the
208        // reconstruction output, so their Frobenius norm is the per-atom output
209        // amplitude per unit of basis activation — the "how loud is this atom in
210        // the reconstruction" factor of presence. Pure activation-side: no
211        // metric is consulted.
212        let decoder_norm = atom
213            .decoder_coefficients
214            .iter()
215            .map(|&b| b * b)
216            .sum::<f64>()
217            .sqrt();
218
219        let latent_dim = atom.latent_dim;
220
221        let mut active_mass_sum = 0.0_f64;
222        let mut active_row_count = 0.0_f64;
223        let mut coupling_sum = 0.0_f64;
224
225        for row in 0..n {
226            let mass = assignments[[row, atom_idx]];
227            if !(mass > SAE_TRUST_ACTIVE_MASS_FLOOR) {
228                continue;
229            }
230            active_mass_sum += mass;
231            active_row_count += 1.0;
232
233            if coupling_axis_available {
234                // Behavioral coupling on this active row: the output-Fisher mass
235                // of the decoder tangent dg_k/dt summed over the atom's latent
236                // axes, weighted by the active mass (so a barely-active row
237                // contributes proportionally less behavioral evidence, matching
238                // the presence weighting). This is the ONLY place the Fisher
239                // metric enters; `fisher_mass` reads no loss / criterion.
240                let mut row_tangent_mass = 0.0_f64;
241                for axis in 0..latent_dim {
242                    let dg = atom.decoded_derivative_row(row, axis);
243                    let dg_view: ArrayView1<'_, f64> = dg.view();
244                    row_tangent_mass += metric.fisher_mass(row, dg_view);
245                }
246                coupling_sum += mass * row_tangent_mass;
247                any_coupling[atom_idx] = true;
248            }
249        }
250
251        // Mean active mass on truly-active rows (0 if the atom is active nowhere).
252        let mean_active_mass = if active_row_count > 0.0 {
253            active_mass_sum / active_row_count
254        } else {
255            0.0
256        };
257        presence[atom_idx] = mean_active_mass * decoder_norm;
258
259        // Mean behavioral coupling over the atom's active rows.
260        if coupling_axis_available && active_row_count > 0.0 {
261            coupling_raw[atom_idx] = coupling_sum / active_row_count;
262        }
263    }
264
265    // Normalize presence across atoms (divide by the max; 0 when all zero).
266    let presence_max = presence.iter().copied().fold(0.0_f64, f64::max);
267    // Normalize coupling across atoms, only over atoms with an available score.
268    let coupling_max = coupling_raw
269        .iter()
270        .zip(any_coupling.iter())
271        .filter(|&(_, &has)| has)
272        .map(|(&c, _)| c)
273        .fold(0.0_f64, f64::max);
274
275    let mut entries = Vec::with_capacity(k);
276    for (atom_idx, atom) in model.atoms.iter().enumerate() {
277        let p = presence[atom_idx];
278        let presence_normalized = if presence_max > 0.0 {
279            p / presence_max
280        } else {
281            0.0
282        };
283
284        let (coupling, coupling_normalized, discrepancy) =
285            if coupling_axis_available && any_coupling[atom_idx] {
286                let c = coupling_raw[atom_idx];
287                let c_norm = if coupling_max > 0.0 {
288                    c / coupling_max
289                } else {
290                    0.0
291                };
292                (Some(c), Some(c_norm), Some(presence_normalized - c_norm))
293            } else {
294                (None, None, None)
295            };
296
297        entries.push(AtomLensEntry {
298            name: atom.name.clone(),
299            presence: p,
300            coupling,
301            presence_normalized,
302            coupling_normalized,
303            discrepancy,
304        });
305    }
306
307    AtomTwoLensReport {
308        atoms: entries,
309        coupling_provenance: Some(provenance),
310    }
311}