Skip to main content

gam_terms/decoders/
interchange_decoder.rs

1//! Per-feature scalar-gate decoder with masked interchange-swap variant.
2//!
3//! This primitive is not specific to any one front-end. It is callable from
4//! the `gam` Rust library directly, from the CLI (whenever a decoder
5//! interchange-intervention probe is needed), and from PyTorch via the
6//! `gam-pyffi` bindings. The intended use is *Distributed Alignment Search*
7//! (DAS, Geiger et al. CLeaR 2024): given two inputs `a` and `b`, transplant
8//! the latent atoms hypothesized to encode a causal variable from `a` into
9//! `b`, decode with shared reconstruction weights and a shared per-feature
10//! scalar gate, and back-propagate a swap-reconstruction error against a
11//! target. The closed-form forward and analytic gradients live here so the
12//! exact same arithmetic is used by every caller.
13//!
14//! Forward
15//! -------
16//! With latent `Z ∈ ℝ^{B×F}`, scalar gate `g ∈ ℝ^F`, decoder weights
17//! `W ∈ ℝ^{D×F}`, and optional bias `b ∈ ℝ^D`,
18//!
19//!     X̂[i, d] = Σ_f g[f] · Z[i, f] · W[d, f] + b[d]
20//!
21//! Masked interchange-swap forward composes the latent first,
22//!
23//!     Z_eff[i, f] = mask[f] ? Z_a[i, f] : Z_b[i, f],
24//!
25//! then runs the plain decode on `Z_eff`. The gate `g` and the weights `W`
26//! are SHARED between the two source decodings — only the latent activations
27//! are interchanged. The scalar gate is decoupled from the reconstruction
28//! matrix on purpose: that decoupling is what gives DAS a parameter to
29//! transplant.
30//!
31//! Backward
32//! --------
33//! From upstream `Ȳ = ∂L/∂X̂ ∈ ℝ^{B×D}`,
34//!
35//!     ∂L/∂Z[i, f] = g[f] · Σ_d Ȳ[i, d] · W[d, f]
36//!     ∂L/∂g[f]   = Σ_i Z[i, f] · Σ_d Ȳ[i, d] · W[d, f]
37//!     ∂L/∂W[d, f] = g[f] · Σ_i Ȳ[i, d] · Z[i, f]
38//!     ∂L/∂b[d]   = Σ_i Ȳ[i, d]
39//!
40//! For the masked-swap path, `∂L/∂Z_a` keeps the columns where `mask[f]`
41//! is true (the rest are zero) and `∂L/∂Z_b` keeps the columns where
42//! `mask[f]` is false. All other adjoints (`∂L/∂g`, `∂L/∂W`, `∂L/∂b`)
43//! are computed from the composed `Z_eff` exactly as in the plain case.
44
45use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
46
47/// Inputs to the plain (non-swap) gated decode forward.
48#[derive(Debug, Clone, Copy)]
49pub struct InterchangeDecodeForward<'a> {
50    pub z: ArrayView2<'a, f64>,
51    pub weights: ArrayView2<'a, f64>,
52    pub gate: ArrayView1<'a, f64>,
53    pub bias: Option<ArrayView1<'a, f64>>,
54}
55
56/// Inputs to the masked-swap forward.
57#[derive(Debug, Clone, Copy)]
58pub struct InterchangeSwapForward<'a> {
59    pub z_a: ArrayView2<'a, f64>,
60    pub z_b: ArrayView2<'a, f64>,
61    pub mask: ArrayView1<'a, bool>,
62    pub weights: ArrayView2<'a, f64>,
63    pub gate: ArrayView1<'a, f64>,
64    pub bias: Option<ArrayView1<'a, f64>>,
65}
66
67/// Adjoints returned by the plain backward.
68#[derive(Debug, Clone)]
69pub struct InterchangeDecodeBackward {
70    pub grad_z: Array2<f64>,
71    pub grad_weights: Array2<f64>,
72    pub grad_gate: Array1<f64>,
73    pub grad_bias: Option<Array1<f64>>,
74}
75
76/// Adjoints returned by the masked-swap backward.
77#[derive(Debug, Clone)]
78pub struct InterchangeSwapBackward {
79    pub grad_z_a: Array2<f64>,
80    pub grad_z_b: Array2<f64>,
81    pub grad_weights: Array2<f64>,
82    pub grad_gate: Array1<f64>,
83    pub grad_bias: Option<Array1<f64>>,
84}
85
86fn check_shapes_forward(
87    z_rows: usize,
88    z_cols: usize,
89    weights: ArrayView2<'_, f64>,
90    gate: ArrayView1<'_, f64>,
91    bias: Option<ArrayView1<'_, f64>>,
92) -> Result<(), String> {
93    let (d, f_weights) = weights.dim();
94    if f_weights != z_cols {
95        return Err(format!(
96            "interchange_decode: weights has F={f_weights}, expected {z_cols}"
97        ));
98    }
99    if gate.len() != z_cols {
100        return Err(format!(
101            "interchange_decode: gate has length {}, expected {z_cols}",
102            gate.len()
103        ));
104    }
105    if let Some(b) = bias
106        && b.len() != d
107    {
108        return Err(format!(
109            "interchange_decode: bias has length {}, expected D={d}",
110            b.len()
111        ));
112    }
113    if z_rows == 0 || z_cols == 0 {
114        return Err("interchange_decode: latent must be non-empty".to_string());
115    }
116    if !weights.iter().all(|v| v.is_finite()) {
117        return Err("interchange_decode: weights must be finite".to_string());
118    }
119    if !gate.iter().all(|v| v.is_finite()) {
120        return Err("interchange_decode: gate must be finite".to_string());
121    }
122    if let Some(b) = bias
123        && !b.iter().all(|v| v.is_finite())
124    {
125        return Err("interchange_decode: bias must be finite".to_string());
126    }
127    Ok(())
128}
129
130/// Plain gated decode: `X̂[i, d] = Σ_f g[f] · Z[i, f] · W[d, f] + b[d]`.
131pub fn interchange_decode_forward(
132    inputs: InterchangeDecodeForward<'_>,
133) -> Result<Array2<f64>, String> {
134    let (b_rows, f) = inputs.z.dim();
135    check_shapes_forward(b_rows, f, inputs.weights, inputs.gate, inputs.bias)?;
136    if !inputs.z.iter().all(|v| v.is_finite()) {
137        return Err("interchange_decode: latent must be finite".to_string());
138    }
139
140    let d = inputs.weights.nrows();
141    let mut z_gated = Array2::<f64>::zeros((b_rows, f));
142    for i in 0..b_rows {
143        for j in 0..f {
144            z_gated[[i, j]] = inputs.z[[i, j]] * inputs.gate[j];
145        }
146    }
147    // out = z_gated · Wᵀ
148    let mut out = z_gated.dot(&inputs.weights.t());
149    if let Some(bias) = inputs.bias {
150        for i in 0..b_rows {
151            for k in 0..d {
152                out[[i, k]] += bias[k];
153            }
154        }
155    }
156    Ok(out)
157}
158
159/// Masked-swap forward.
160pub fn interchange_swap_forward(inputs: InterchangeSwapForward<'_>) -> Result<Array2<f64>, String> {
161    if inputs.z_a.dim() != inputs.z_b.dim() {
162        return Err(format!(
163            "interchange_swap: z_a {:?} and z_b {:?} must have the same shape",
164            inputs.z_a.dim(),
165            inputs.z_b.dim()
166        ));
167    }
168    let (b_rows, f) = inputs.z_a.dim();
169    if inputs.mask.len() != f {
170        return Err(format!(
171            "interchange_swap: mask length {} must equal F={f}",
172            inputs.mask.len()
173        ));
174    }
175    if !inputs.z_a.iter().all(|v| v.is_finite()) || !inputs.z_b.iter().all(|v| v.is_finite()) {
176        return Err("interchange_swap: latents must be finite".to_string());
177    }
178    let mut z_eff = Array2::<f64>::zeros((b_rows, f));
179    for j in 0..f {
180        let take_a = inputs.mask[j];
181        if take_a {
182            for i in 0..b_rows {
183                z_eff[[i, j]] = inputs.z_a[[i, j]];
184            }
185        } else {
186            for i in 0..b_rows {
187                z_eff[[i, j]] = inputs.z_b[[i, j]];
188            }
189        }
190    }
191    interchange_decode_forward(InterchangeDecodeForward {
192        z: z_eff.view(),
193        weights: inputs.weights,
194        gate: inputs.gate,
195        bias: inputs.bias,
196    })
197}
198
199/// Backward for the plain decode. `grad_out` is `∂L/∂X̂`.
200pub fn interchange_decode_backward(
201    z: ArrayView2<'_, f64>,
202    weights: ArrayView2<'_, f64>,
203    gate: ArrayView1<'_, f64>,
204    grad_out: ArrayView2<'_, f64>,
205    with_bias: bool,
206) -> Result<InterchangeDecodeBackward, String> {
207    let (b_rows, f) = z.dim();
208    let (d, f_w) = weights.dim();
209    if f_w != f {
210        return Err(format!(
211            "interchange_decode_backward: weights has F={f_w}, expected {f}"
212        ));
213    }
214    if gate.len() != f {
215        return Err(format!(
216            "interchange_decode_backward: gate length {} != F={f}",
217            gate.len()
218        ));
219    }
220    if grad_out.dim() != (b_rows, d) {
221        return Err(format!(
222            "interchange_decode_backward: grad_out shape {:?} != ({b_rows}, {d})",
223            grad_out.dim()
224        ));
225    }
226
227    // Working term: G[i, f] = Σ_d grad_out[i, d] · W[d, f]   ( = grad_out · W )
228    let g_mat = grad_out.dot(&weights); // (B, F)
229
230    // ∂L/∂Z[i, f] = g[f] · G[i, f]
231    let mut grad_z = Array2::<f64>::zeros((b_rows, f));
232    for i in 0..b_rows {
233        for j in 0..f {
234            grad_z[[i, j]] = gate[j] * g_mat[[i, j]];
235        }
236    }
237
238    // ∂L/∂g[f] = Σ_i Z[i, f] · G[i, f]
239    let mut grad_gate = Array1::<f64>::zeros(f);
240    for j in 0..f {
241        let mut acc = 0.0;
242        for i in 0..b_rows {
243            acc += z[[i, j]] * g_mat[[i, j]];
244        }
245        grad_gate[j] = acc;
246    }
247
248    // ∂L/∂W[d, f] = g[f] · Σ_i grad_out[i, d] · Z[i, f]
249    //             = g[f] · (grad_outᵀ · Z)[d, f]
250    let mut grad_weights = grad_out.t().dot(&z); // (D, F)
251    for j in 0..f {
252        let scale = gate[j];
253        for k in 0..d {
254            grad_weights[[k, j]] *= scale;
255        }
256    }
257
258    let grad_bias = if with_bias {
259        let mut gb = Array1::<f64>::zeros(d);
260        for i in 0..b_rows {
261            for k in 0..d {
262                gb[k] += grad_out[[i, k]];
263            }
264        }
265        Some(gb)
266    } else {
267        None
268    };
269
270    Ok(InterchangeDecodeBackward {
271        grad_z,
272        grad_weights,
273        grad_gate,
274        grad_bias,
275    })
276}
277
278/// Backward for the masked-swap variant.
279pub fn interchange_swap_backward(
280    z_a: ArrayView2<'_, f64>,
281    z_b: ArrayView2<'_, f64>,
282    mask: ArrayView1<'_, bool>,
283    weights: ArrayView2<'_, f64>,
284    gate: ArrayView1<'_, f64>,
285    grad_out: ArrayView2<'_, f64>,
286    with_bias: bool,
287) -> Result<InterchangeSwapBackward, String> {
288    if z_a.dim() != z_b.dim() {
289        return Err(format!(
290            "interchange_swap_backward: z_a {:?} and z_b {:?} must have the same shape",
291            z_a.dim(),
292            z_b.dim()
293        ));
294    }
295    let (b_rows, f) = z_a.dim();
296    if mask.len() != f {
297        return Err(format!(
298            "interchange_swap_backward: mask length {} != F={f}",
299            mask.len()
300        ));
301    }
302
303    // Build z_eff and reuse the plain backward.
304    let mut z_eff = Array2::<f64>::zeros((b_rows, f));
305    for j in 0..f {
306        let take_a = mask[j];
307        if take_a {
308            for i in 0..b_rows {
309                z_eff[[i, j]] = z_a[[i, j]];
310            }
311        } else {
312            for i in 0..b_rows {
313                z_eff[[i, j]] = z_b[[i, j]];
314            }
315        }
316    }
317    let inner = interchange_decode_backward(z_eff.view(), weights, gate, grad_out, with_bias)?;
318
319    // Distribute ∂L/∂Z_eff to ∂L/∂Z_a / ∂L/∂Z_b along the mask.
320    let mut grad_z_a = Array2::<f64>::zeros((b_rows, f));
321    let mut grad_z_b = Array2::<f64>::zeros((b_rows, f));
322    for j in 0..f {
323        let take_a = mask[j];
324        if take_a {
325            for i in 0..b_rows {
326                grad_z_a[[i, j]] = inner.grad_z[[i, j]];
327            }
328        } else {
329            for i in 0..b_rows {
330                grad_z_b[[i, j]] = inner.grad_z[[i, j]];
331            }
332        }
333    }
334
335    Ok(InterchangeSwapBackward {
336        grad_z_a,
337        grad_z_b,
338        grad_weights: inner.grad_weights,
339        grad_gate: inner.grad_gate,
340        grad_bias: inner.grad_bias,
341    })
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347    use ndarray::{Array1, Array2, array};
348
349    fn approx_eq(a: &Array2<f64>, b: &Array2<f64>, tol: f64) -> bool {
350        if a.dim() != b.dim() {
351            return false;
352        }
353        a.iter().zip(b.iter()).all(|(x, y)| (x - y).abs() < tol)
354    }
355
356    #[test]
357    fn forward_matches_hand_recomputation() {
358        let z = array![[1.0, -2.0, 0.5], [0.0, 3.0, -1.0]];
359        let w = array![[0.1, 0.2, 0.3], [-0.4, 0.5, 0.6]];
360        let g = array![1.0, 0.5, -1.0];
361        let bias = array![0.01, -0.02];
362        let out = interchange_decode_forward(InterchangeDecodeForward {
363            z: z.view(),
364            weights: w.view(),
365            gate: g.view(),
366            bias: Some(bias.view()),
367        })
368        .unwrap();
369        // expected row i, col k: Σ_f g[f] z[i,f] w[k,f] + bias[k]
370        let mut expected = Array2::<f64>::zeros((2, 2));
371        for i in 0..2 {
372            for k in 0..2 {
373                let mut acc = bias[k];
374                for j in 0..3 {
375                    acc += g[j] * z[[i, j]] * w[[k, j]];
376                }
377                expected[[i, k]] = acc;
378            }
379        }
380        assert!(approx_eq(&out, &expected, 1e-12));
381    }
382
383    #[test]
384    fn swap_all_true_matches_z_a_forward() {
385        let z_a = array![[1.0, -2.0], [3.0, 0.5]];
386        let z_b = array![[10.0, 20.0], [-30.0, 40.0]];
387        let w = array![[0.1, 0.2], [0.3, -0.4], [0.5, 0.6]];
388        let g = array![0.7, -0.3];
389        let mask = Array1::from(vec![true, true]);
390        let swapped = interchange_swap_forward(InterchangeSwapForward {
391            z_a: z_a.view(),
392            z_b: z_b.view(),
393            mask: mask.view(),
394            weights: w.view(),
395            gate: g.view(),
396            bias: None,
397        })
398        .unwrap();
399        let plain = interchange_decode_forward(InterchangeDecodeForward {
400            z: z_a.view(),
401            weights: w.view(),
402            gate: g.view(),
403            bias: None,
404        })
405        .unwrap();
406        assert!(approx_eq(&swapped, &plain, 1e-12));
407    }
408
409    #[test]
410    fn swap_all_false_matches_z_b_forward() {
411        let z_a = array![[1.0, -2.0], [3.0, 0.5]];
412        let z_b = array![[10.0, 20.0], [-30.0, 40.0]];
413        let w = array![[0.1, 0.2], [0.3, -0.4]];
414        let g = array![0.7, -0.3];
415        let mask = Array1::from(vec![false, false]);
416        let swapped = interchange_swap_forward(InterchangeSwapForward {
417            z_a: z_a.view(),
418            z_b: z_b.view(),
419            mask: mask.view(),
420            weights: w.view(),
421            gate: g.view(),
422            bias: None,
423        })
424        .unwrap();
425        let plain = interchange_decode_forward(InterchangeDecodeForward {
426            z: z_b.view(),
427            weights: w.view(),
428            gate: g.view(),
429            bias: None,
430        })
431        .unwrap();
432        assert!(approx_eq(&swapped, &plain, 1e-12));
433    }
434
435    #[test]
436    fn backward_matches_finite_differences() {
437        let z = array![[0.4, -0.7, 1.1], [0.2, 0.8, -0.3]];
438        let w = array![[0.1, 0.2, 0.3], [-0.4, 0.5, 0.6]];
439        let g = array![0.6, -0.2, 1.3];
440        let bias = array![0.05, -0.01];
441        let grad_out = array![[1.0, -0.5], [0.3, 0.8]];
442
443        let an = interchange_decode_backward(z.view(), w.view(), g.view(), grad_out.view(), true)
444            .unwrap();
445
446        // L = sum(grad_out * forward(z, w, g, bias))
447        // ∂L/∂z[i,j] via finite differences
448        let eps = 1e-6;
449        for i in 0..z.nrows() {
450            for j in 0..z.ncols() {
451                let mut zp = z.clone();
452                let mut zm = z.clone();
453                zp[[i, j]] += eps;
454                zm[[i, j]] -= eps;
455                let fp = interchange_decode_forward(InterchangeDecodeForward {
456                    z: zp.view(),
457                    weights: w.view(),
458                    gate: g.view(),
459                    bias: Some(bias.view()),
460                })
461                .unwrap();
462                let fm = interchange_decode_forward(InterchangeDecodeForward {
463                    z: zm.view(),
464                    weights: w.view(),
465                    gate: g.view(),
466                    bias: Some(bias.view()),
467                })
468                .unwrap();
469                let lp: f64 = fp.iter().zip(grad_out.iter()).map(|(a, b)| a * b).sum();
470                let lm: f64 = fm.iter().zip(grad_out.iter()).map(|(a, b)| a * b).sum();
471                let fd = (lp - lm) / (2.0 * eps);
472                assert!(
473                    (an.grad_z[[i, j]] - fd).abs() < 1e-7,
474                    "grad_z mismatch at ({i},{j}): analytic {} vs fd {}",
475                    an.grad_z[[i, j]],
476                    fd
477                );
478            }
479        }
480        // ∂L/∂g[j]
481        for j in 0..g.len() {
482            let mut gp = g.clone();
483            let mut gm = g.clone();
484            gp[j] += eps;
485            gm[j] -= eps;
486            let fp = interchange_decode_forward(InterchangeDecodeForward {
487                z: z.view(),
488                weights: w.view(),
489                gate: gp.view(),
490                bias: Some(bias.view()),
491            })
492            .unwrap();
493            let fm = interchange_decode_forward(InterchangeDecodeForward {
494                z: z.view(),
495                weights: w.view(),
496                gate: gm.view(),
497                bias: Some(bias.view()),
498            })
499            .unwrap();
500            let lp: f64 = fp.iter().zip(grad_out.iter()).map(|(a, b)| a * b).sum();
501            let lm: f64 = fm.iter().zip(grad_out.iter()).map(|(a, b)| a * b).sum();
502            let fd = (lp - lm) / (2.0 * eps);
503            assert!(
504                (an.grad_gate[j] - fd).abs() < 1e-7,
505                "grad_gate mismatch at {j}: analytic {} vs fd {}",
506                an.grad_gate[j],
507                fd
508            );
509        }
510        // ∂L/∂W[d, j]
511        for d in 0..w.nrows() {
512            for j in 0..w.ncols() {
513                let mut wp = w.clone();
514                let mut wm = w.clone();
515                wp[[d, j]] += eps;
516                wm[[d, j]] -= eps;
517                let fp = interchange_decode_forward(InterchangeDecodeForward {
518                    z: z.view(),
519                    weights: wp.view(),
520                    gate: g.view(),
521                    bias: Some(bias.view()),
522                })
523                .unwrap();
524                let fm = interchange_decode_forward(InterchangeDecodeForward {
525                    z: z.view(),
526                    weights: wm.view(),
527                    gate: g.view(),
528                    bias: Some(bias.view()),
529                })
530                .unwrap();
531                let lp: f64 = fp.iter().zip(grad_out.iter()).map(|(a, b)| a * b).sum();
532                let lm: f64 = fm.iter().zip(grad_out.iter()).map(|(a, b)| a * b).sum();
533                let fd = (lp - lm) / (2.0 * eps);
534                assert!(
535                    (an.grad_weights[[d, j]] - fd).abs() < 1e-7,
536                    "grad_W mismatch at ({d},{j}): analytic {} vs fd {}",
537                    an.grad_weights[[d, j]],
538                    fd
539                );
540            }
541        }
542        // ∂L/∂bias[d]
543        let bias_grad = an.grad_bias.as_ref().unwrap();
544        for d in 0..bias.len() {
545            let mut bp = bias.clone();
546            let mut bm = bias.clone();
547            bp[d] += eps;
548            bm[d] -= eps;
549            let fp = interchange_decode_forward(InterchangeDecodeForward {
550                z: z.view(),
551                weights: w.view(),
552                gate: g.view(),
553                bias: Some(bp.view()),
554            })
555            .unwrap();
556            let fm = interchange_decode_forward(InterchangeDecodeForward {
557                z: z.view(),
558                weights: w.view(),
559                gate: g.view(),
560                bias: Some(bm.view()),
561            })
562            .unwrap();
563            let lp: f64 = fp.iter().zip(grad_out.iter()).map(|(a, b)| a * b).sum();
564            let lm: f64 = fm.iter().zip(grad_out.iter()).map(|(a, b)| a * b).sum();
565            let fd = (lp - lm) / (2.0 * eps);
566            assert!(
567                (bias_grad[d] - fd).abs() < 1e-7,
568                "grad_bias mismatch at {d}: analytic {} vs fd {}",
569                bias_grad[d],
570                fd
571            );
572        }
573    }
574
575    #[test]
576    fn swap_backward_routes_grad_through_mask() {
577        let z_a = array![[1.0, 2.0, 3.0]];
578        let z_b = array![[-1.0, -2.0, -3.0]];
579        let w = array![[0.5, 0.25, -0.1]];
580        let g = array![1.0, 0.5, -1.0];
581        let mask = Array1::from(vec![true, false, true]);
582        let grad_out = array![[1.0]];
583        let bk = interchange_swap_backward(
584            z_a.view(),
585            z_b.view(),
586            mask.view(),
587            w.view(),
588            g.view(),
589            grad_out.view(),
590            false,
591        )
592        .unwrap();
593        // For j in {0, 2} (mask true): grad_z_a[0, j] = g[j] * w[0, j]; grad_z_b[0, j] = 0
594        // For j=1 (mask false): grad_z_b[0, 1] = g[1] * w[0, 1]; grad_z_a[0, 1] = 0
595        assert!((bk.grad_z_a[[0, 0]] - 1.0 * 0.5).abs() < 1e-12);
596        assert!((bk.grad_z_a[[0, 1]] - 0.0).abs() < 1e-12);
597        assert!((bk.grad_z_a[[0, 2]] - (-1.0) * (-0.1)).abs() < 1e-12);
598        assert!((bk.grad_z_b[[0, 0]] - 0.0).abs() < 1e-12);
599        assert!((bk.grad_z_b[[0, 1]] - 0.5 * 0.25).abs() < 1e-12);
600        assert!((bk.grad_z_b[[0, 2]] - 0.0).abs() < 1e-12);
601    }
602}