Skip to main content

Module interchange_decoder

Module interchange_decoder 

Source
Expand description

Per-feature scalar-gate decoder with masked interchange-swap variant.

This primitive is not specific to any one front-end. It is callable from the gam Rust library directly, from the CLI (whenever a decoder interchange-intervention probe is needed), and from PyTorch via the gam-pyffi bindings. The intended use is Distributed Alignment Search (DAS, Geiger et al. CLeaR 2024): given two inputs a and b, transplant the latent atoms hypothesized to encode a causal variable from a into b, decode with shared reconstruction weights and a shared per-feature scalar gate, and back-propagate a swap-reconstruction error against a target. The closed-form forward and analytic gradients live here so the exact same arithmetic is used by every caller.

§Forward

With latent Z ∈ ℝ^{B×F}, scalar gate g ∈ ℝ^F, decoder weights W ∈ ℝ^{D×F}, and optional bias b ∈ ℝ^D,

X̂[i, d] = Σ_f g[f] · Z[i, f] · W[d, f] + b[d]

Masked interchange-swap forward composes the latent first,

Z_eff[i, f] = mask[f] ? Z_a[i, f] : Z_b[i, f],

then runs the plain decode on Z_eff. The gate g and the weights W are SHARED between the two source decodings — only the latent activations are interchanged. The scalar gate is decoupled from the reconstruction matrix on purpose: that decoupling is what gives DAS a parameter to transplant.

§Backward

From upstream Ȳ = ∂L/∂X̂ ∈ ℝ^{B×D},

∂L/∂Z[i, f] = g[f] · Σ_d Ȳ[i, d] · W[d, f]
∂L/∂g[f]   = Σ_i Z[i, f] · Σ_d Ȳ[i, d] · W[d, f]
∂L/∂W[d, f] = g[f] · Σ_i Ȳ[i, d] · Z[i, f]
∂L/∂b[d]   = Σ_i Ȳ[i, d]

For the masked-swap path, ∂L/∂Z_a keeps the columns where mask[f] is true (the rest are zero) and ∂L/∂Z_b keeps the columns where mask[f] is false. All other adjoints (∂L/∂g, ∂L/∂W, ∂L/∂b) are computed from the composed Z_eff exactly as in the plain case.

Structs§

InterchangeDecodeBackward
Adjoints returned by the plain backward.
InterchangeDecodeForward
Inputs to the plain (non-swap) gated decode forward.
InterchangeSwapBackward
Adjoints returned by the masked-swap backward.
InterchangeSwapForward
Inputs to the masked-swap forward.

Functions§

interchange_decode_backward
Backward for the plain decode. grad_out is ∂L/∂X̂.
interchange_decode_forward
Plain gated decode: X̂[i, d] = Σ_f g[f] · Z[i, f] · W[d, f] + b[d].
interchange_swap_backward
Backward for the masked-swap variant.
interchange_swap_forward
Masked-swap forward.