gam_terms/decoders/
gated_decoder.rs1use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
2
3#[derive(Debug, Clone)]
4pub struct GatedSAEDecoder {
5 pub w_gate: Array2<f64>,
6 pub w_amp: Array2<f64>,
7}
8
9impl GatedSAEDecoder {
10 #[must_use = "build error must be handled"]
11 pub fn new(w_gate: Array2<f64>, w_amp: Array2<f64>) -> Result<Self, String> {
12 if w_gate.nrows() != w_gate.ncols() {
13 return Err(format!(
14 "GatedSAEDecoder::new requires square W_gate; got {:?}",
15 w_gate.dim()
16 ));
17 }
18 if w_amp.ncols() != w_gate.ncols() {
19 return Err(format!(
20 "GatedSAEDecoder::new requires W_amp columns {} to match W_gate input {}",
21 w_amp.ncols(),
22 w_gate.ncols()
23 ));
24 }
25 if !w_gate.iter().all(|v| v.is_finite()) || !w_amp.iter().all(|v| v.is_finite()) {
26 return Err("GatedSAEDecoder::new requires finite weights".to_string());
27 }
28 Ok(Self { w_gate, w_amp })
29 }
30
31 pub fn input_dim(&self) -> usize {
32 self.w_gate.ncols()
33 }
34
35 pub fn output_dim(&self) -> usize {
36 self.w_amp.nrows()
37 }
38
39 pub fn decode_row(&self, x: ArrayView1<'_, f64>) -> Result<Array1<f64>, String> {
40 if x.len() != self.input_dim() {
41 return Err(format!(
42 "GatedSAEDecoder::decode_row expected x len {}, got {}",
43 self.input_dim(),
44 x.len()
45 ));
46 }
47 let mut gated = Array1::<f64>::zeros(x.len());
54 for gate_row in 0..self.w_gate.nrows() {
55 let mut logit = 0.0;
56 for col in 0..self.w_gate.ncols() {
57 logit += self.w_gate[[gate_row, col]] * x[col];
58 }
59 if logit > 0.0 {
60 gated[gate_row] = x[gate_row];
61 }
62 }
63 Ok(self.w_amp.dot(&gated))
64 }
65
66 pub fn decode_batch(&self, x: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
67 if x.ncols() != self.input_dim() {
68 return Err(format!(
69 "GatedSAEDecoder::decode_batch expected {} columns, got {}",
70 self.input_dim(),
71 x.ncols()
72 ));
73 }
74 let mut out = Array2::<f64>::zeros((x.nrows(), self.output_dim()));
75 for row in 0..x.nrows() {
76 let decoded = self.decode_row(x.row(row))?;
77 out.row_mut(row).assign(&decoded);
78 }
79 Ok(out)
80 }
81}