Expand description
Per-feature scalar-gate decoder with masked interchange-swap variant.
This primitive is not specific to any one front-end. It is callable from
the gam Rust library directly, from the CLI (whenever a decoder
interchange-intervention probe is needed), and from PyTorch via the
gam-pyffi bindings. The intended use is Distributed Alignment Search
(DAS, Geiger et al. CLeaR 2024): given two inputs a and b, transplant
the latent atoms hypothesized to encode a causal variable from a into
b, decode with shared reconstruction weights and a shared per-feature
scalar gate, and back-propagate a swap-reconstruction error against a
target. The closed-form forward and analytic gradients live here so the
exact same arithmetic is used by every caller.
§Forward
With latent Z ∈ ℝ^{B×F}, scalar gate g ∈ ℝ^F, decoder weights
W ∈ ℝ^{D×F}, and optional bias b ∈ ℝ^D,
X̂[i, d] = Σ_f g[f] · Z[i, f] · W[d, f] + b[d]Masked interchange-swap forward composes the latent first,
Z_eff[i, f] = mask[f] ? Z_a[i, f] : Z_b[i, f],then runs the plain decode on Z_eff. The gate g and the weights W
are SHARED between the two source decodings — only the latent activations
are interchanged. The scalar gate is decoupled from the reconstruction
matrix on purpose: that decoupling is what gives DAS a parameter to
transplant.
§Backward
From upstream Ȳ = ∂L/∂X̂ ∈ ℝ^{B×D},
∂L/∂Z[i, f] = g[f] · Σ_d Ȳ[i, d] · W[d, f]
∂L/∂g[f] = Σ_i Z[i, f] · Σ_d Ȳ[i, d] · W[d, f]
∂L/∂W[d, f] = g[f] · Σ_i Ȳ[i, d] · Z[i, f]
∂L/∂b[d] = Σ_i Ȳ[i, d]For the masked-swap path, ∂L/∂Z_a keeps the columns where mask[f]
is true (the rest are zero) and ∂L/∂Z_b keeps the columns where
mask[f] is false. All other adjoints (∂L/∂g, ∂L/∂W, ∂L/∂b)
are computed from the composed Z_eff exactly as in the plain case.
Structs§
- Interchange
Decode Backward - Adjoints returned by the plain backward.
- Interchange
Decode Forward - Inputs to the plain (non-swap) gated decode forward.
- Interchange
Swap Backward - Adjoints returned by the masked-swap backward.
- Interchange
Swap Forward - Inputs to the masked-swap forward.
Functions§
- interchange_
decode_ backward - Backward for the plain decode.
grad_outis∂L/∂X̂. - interchange_
decode_ forward - Plain gated decode:
X̂[i, d] = Σ_f g[f] · Z[i, f] · W[d, f] + b[d]. - interchange_
swap_ backward - Backward for the masked-swap variant.
- interchange_
swap_ forward - Masked-swap forward.