Skip to main content

gam_terms/decoders/
gated_decoder.rs

1use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
2
3#[derive(Debug, Clone)]
4pub struct GatedSAEDecoder {
5    pub w_gate: Array2<f64>,
6    pub w_amp: Array2<f64>,
7}
8
9impl GatedSAEDecoder {
10    #[must_use = "build error must be handled"]
11    pub fn new(w_gate: Array2<f64>, w_amp: Array2<f64>) -> Result<Self, String> {
12        if w_gate.nrows() != w_gate.ncols() {
13            return Err(format!(
14                "GatedSAEDecoder::new requires square W_gate; got {:?}",
15                w_gate.dim()
16            ));
17        }
18        if w_amp.ncols() != w_gate.ncols() {
19            return Err(format!(
20                "GatedSAEDecoder::new requires W_amp columns {} to match W_gate input {}",
21                w_amp.ncols(),
22                w_gate.ncols()
23            ));
24        }
25        if !w_gate.iter().all(|v| v.is_finite()) || !w_amp.iter().all(|v| v.is_finite()) {
26            return Err("GatedSAEDecoder::new requires finite weights".to_string());
27        }
28        Ok(Self { w_gate, w_amp })
29    }
30
31    pub fn input_dim(&self) -> usize {
32        self.w_gate.ncols()
33    }
34
35    pub fn output_dim(&self) -> usize {
36        self.w_amp.nrows()
37    }
38
39    pub fn decode_row(&self, x: ArrayView1<'_, f64>) -> Result<Array1<f64>, String> {
40        if x.len() != self.input_dim() {
41            return Err(format!(
42                "GatedSAEDecoder::decode_row expected x len {}, got {}",
43                self.input_dim(),
44                x.len()
45            ));
46        }
47        // Canonical Gated-SAE (Rajamanoharan et al., 2024): the gate is the
48        // Heaviside of the gating pre-activation — `gate_i = 1` iff
49        // `(W_gate x)_i > 0`, and 0 otherwise (including exact zero and any
50        // negative logit). Magnitudes flow through `W_amp` only for active
51        // coordinates; sign of the active coordinate is preserved because we
52        // multiply by `x[i]` itself, not by its magnitude.
53        let mut gated = Array1::<f64>::zeros(x.len());
54        for gate_row in 0..self.w_gate.nrows() {
55            let mut logit = 0.0;
56            for col in 0..self.w_gate.ncols() {
57                logit += self.w_gate[[gate_row, col]] * x[col];
58            }
59            if logit > 0.0 {
60                gated[gate_row] = x[gate_row];
61            }
62        }
63        Ok(self.w_amp.dot(&gated))
64    }
65
66    pub fn decode_batch(&self, x: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
67        if x.ncols() != self.input_dim() {
68            return Err(format!(
69                "GatedSAEDecoder::decode_batch expected {} columns, got {}",
70                self.input_dim(),
71                x.ncols()
72            ));
73        }
74        let mut out = Array2::<f64>::zeros((x.nrows(), self.output_dim()));
75        for row in 0..x.nrows() {
76            let decoded = self.decode_row(x.row(row))?;
77            out.row_mut(row).assign(&decoded);
78        }
79        Ok(out)
80    }
81}