use crate::families::jet_tower::Tower4;
#[derive(Debug, Clone, Copy)]
pub enum RowGate {
Softmax { inv_tau: f64 },
PerAtomLogistic { inv_tau: f64 },
}
#[derive(Debug, Clone)]
pub struct AtomRowBasisJet {
pub phi: Vec<f64>,
pub d_phi: Vec<Vec<f64>>,
pub d2_phi: Vec<Vec<Vec<f64>>>,
pub decoder: Vec<Vec<f64>>,
pub latent_dim: usize,
}
impl AtomRowBasisJet {
fn n_basis(&self) -> usize {
self.phi.len()
}
fn out_dim(&self) -> usize {
self.decoder.first().map_or(0, Vec::len)
}
fn basis_tower<const K: usize>(&self, basis_col: usize, coord_slots: &[usize]) -> Tower4<K> {
let mut acc = Tower4::<K>::constant(self.phi[basis_col]);
for axis in 0..self.latent_dim {
let slot = coord_slots[axis];
let d1 = self.d_phi[basis_col][axis];
if d1 != 0.0 {
acc = acc + Tower4::<K>::variable(0.0, slot).scale(d1);
}
}
for axis_a in 0..self.latent_dim {
for axis_b in 0..self.latent_dim {
let d2 = self.d2_phi[basis_col][axis_a][axis_b];
if d2 == 0.0 {
continue;
}
let va = Tower4::<K>::variable(0.0, coord_slots[axis_a]);
let vb = Tower4::<K>::variable(0.0, coord_slots[axis_b]);
acc = acc + va.mul(&vb).scale(0.5 * d2);
}
}
acc
}
fn decoded_tower<const K: usize>(&self, out_col: usize, coord_slots: &[usize]) -> Tower4<K> {
let mut acc = Tower4::<K>::zero();
for basis_col in 0..self.n_basis() {
let b = self.decoder[basis_col][out_col];
if b == 0.0 {
continue;
}
acc = acc + self.basis_tower::<K>(basis_col, coord_slots).scale(b);
}
acc
}
}
#[derive(Debug, Clone)]
pub struct SaeReconstructionRowProgram {
pub atoms: Vec<AtomRowBasisJet>,
pub gate_value: Vec<f64>,
pub logits: Vec<f64>,
pub gate_shift: Vec<f64>,
pub gate: RowGate,
pub logit_slot: Vec<Option<usize>>,
pub coord_slot: Vec<Vec<usize>>,
pub n_primaries: usize,
}
impl SaeReconstructionRowProgram {
fn gate_tower<const K: usize>(&self, atom: usize) -> Tower4<K> {
match self.gate {
RowGate::Softmax { inv_tau } => {
let mut denom = Tower4::<K>::zero();
let mut numer = Tower4::<K>::zero();
for j in 0..self.gate_value.len() {
let lj = match self.logit_slot[j] {
Some(slot) => Tower4::<K>::variable(self.logits[j], slot),
None => Tower4::<K>::constant(self.logits[j]),
};
let ej = lj.scale(inv_tau).exp();
if j == atom {
numer = ej;
}
denom = denom + ej;
}
numer / denom
}
RowGate::PerAtomLogistic { inv_tau } => {
let l = match self.logit_slot[atom] {
Some(slot) => Tower4::<K>::variable(self.logits[atom], slot),
None => Tower4::<K>::constant(self.logits[atom]),
};
let x = (l - self.gate_shift[atom]).scale(inv_tau);
let one = Tower4::<K>::constant(1.0);
one / (one + x.scale(-1.0).exp())
}
}
}
#[must_use]
pub fn reconstruction_column<const K: usize>(&self, out_col: usize) -> Tower4<K> {
assert_eq!(
self.n_primaries, K,
"SaeReconstructionRowProgram: tower arity K={K} must equal n_primaries={}",
self.n_primaries
);
let mut acc = Tower4::<K>::zero();
for (atom, atom_jet) in self.atoms.iter().enumerate() {
let gate = self.gate_tower::<K>(atom);
let decoded = atom_jet.decoded_tower::<K>(out_col, &self.coord_slot[atom]);
acc = acc + gate.mul(&decoded);
}
acc
}
#[must_use]
pub fn out_dim(&self) -> usize {
self.atoms.first().map_or(0, AtomRowBasisJet::out_dim)
}
}
#[cfg(test)]
mod tests {
use super::*;
struct HandChannels {
first: Vec<f64>, second: Vec<Vec<f64>>, value: f64,
}
fn softmax_gate_derivs(gate: &[f64], inv_tau: f64) -> (Vec<Vec<f64>>, Vec<Vec<Vec<f64>>>) {
let k = gate.len();
let mut dz = vec![vec![0.0_f64; k]; k];
let mut d2z = vec![vec![vec![0.0_f64; k]; k]; k];
for j in 0..k {
for kk in 0..k {
let ind = if kk == j { 1.0 } else { 0.0 };
dz[j][kk] = gate[kk] * (ind - gate[j]) * inv_tau;
}
}
for j in 0..k {
for l in 0..k {
for kk in 0..k {
let ikl = if kk == l { 1.0 } else { 0.0 };
let ikj = if kk == j { 1.0 } else { 0.0 };
let ijl = if j == l { 1.0 } else { 0.0 };
d2z[j][l][kk] = gate[kk]
* ((ikl - gate[l]) * (ikj - gate[j]) - gate[j] * (ijl - gate[l]))
* inv_tau
* inv_tau;
}
}
}
(dz, d2z)
}
fn hand_softmax_column(
prog: &SaeReconstructionRowProgram,
out_col: usize,
inv_tau: f64,
) -> HandChannels {
let k = prog.atoms.len();
let n = prog.n_primaries;
let decoded: Vec<f64> = (0..k)
.map(|kk| {
(0..prog.atoms[kk].n_basis())
.map(|b| prog.atoms[kk].phi[b] * prog.atoms[kk].decoder[b][out_col])
.sum()
})
.collect();
let d1: Vec<Vec<f64>> = (0..k)
.map(|kk| {
(0..prog.atoms[kk].latent_dim)
.map(|axis| {
(0..prog.atoms[kk].n_basis())
.map(|b| {
prog.atoms[kk].d_phi[b][axis] * prog.atoms[kk].decoder[b][out_col]
})
.sum()
})
.collect()
})
.collect();
let d2: Vec<Vec<Vec<f64>>> = (0..k)
.map(|kk| {
(0..prog.atoms[kk].latent_dim)
.map(|a| {
(0..prog.atoms[kk].latent_dim)
.map(|b| {
(0..prog.atoms[kk].n_basis())
.map(|col| {
prog.atoms[kk].d2_phi[col][a][b]
* prog.atoms[kk].decoder[col][out_col]
})
.sum()
})
.collect()
})
.collect()
})
.collect();
let (dz, d2z) = softmax_gate_derivs(&prog.gate_value, inv_tau);
let logit_idx = |kk: usize| prog.logit_slot[kk];
let coord_idx = |kk: usize, axis: usize| prog.coord_slot[kk][axis];
let value: f64 = (0..k).map(|kk| prog.gate_value[kk] * decoded[kk]).sum();
let mut first = vec![0.0_f64; n];
for j in 0..k {
if let Some(p) = logit_idx(j) {
first[p] = (0..k).map(|kk| dz[j][kk] * decoded[kk]).sum();
}
}
for kk in 0..k {
for axis in 0..prog.atoms[kk].latent_dim {
first[coord_idx(kk, axis)] = prog.gate_value[kk] * d1[kk][axis];
}
}
let mut second = vec![vec![0.0_f64; n]; n];
for j in 0..k {
for l in 0..k {
if let (Some(pj), Some(pl)) = (logit_idx(j), logit_idx(l)) {
second[pj][pl] = (0..k).map(|kk| d2z[j][l][kk] * decoded[kk]).sum();
}
}
}
for j in 0..k {
for kk in 0..k {
for axis in 0..prog.atoms[kk].latent_dim {
if let Some(pj) = logit_idx(j) {
let pc = coord_idx(kk, axis);
let val = dz[j][kk] * d1[kk][axis];
second[pj][pc] = val;
second[pc][pj] = val;
}
}
}
}
for kk in 0..k {
for a in 0..prog.atoms[kk].latent_dim {
for b in 0..prog.atoms[kk].latent_dim {
let pa = coord_idx(kk, a);
let pb = coord_idx(kk, b);
second[pa][pb] = prog.gate_value[kk] * d2[kk][a][b];
}
}
}
HandChannels {
first,
second,
value,
}
}
fn softmax_fixture(inv_tau: f64) -> (SaeReconstructionRowProgram, f64) {
let n_basis = 3;
let out_dim = 4;
let mk_atom = |seed: f64| {
let phi: Vec<f64> = (0..n_basis)
.map(|b| 0.3 + 0.2 * (b as f64 + seed))
.collect();
let d_phi: Vec<Vec<f64>> = (0..n_basis)
.map(|b| {
(0..2)
.map(|axis| 0.1 * (b as f64 + 1.0) - 0.05 * axis as f64 + 0.03 * seed)
.collect()
})
.collect();
let d2_phi: Vec<Vec<Vec<f64>>> = (0..n_basis)
.map(|b| {
(0..2)
.map(|a| {
(0..2)
.map(|bb| {
0.02 * (b as f64 + 1.0)
+ 0.01 * (a as f64)
+ 0.01 * (bb as f64)
+ 0.004 * seed
})
.collect()
})
.collect()
})
.collect();
let decoder: Vec<Vec<f64>> = (0..n_basis)
.map(|b| {
(0..out_dim)
.map(|c| 0.5 - 0.1 * (b as f64) + 0.07 * (c as f64) + 0.02 * seed)
.collect()
})
.collect();
AtomRowBasisJet {
phi,
d_phi,
d2_phi,
decoder,
latent_dim: 2,
}
};
let logits = vec![0.4_f64, -0.7];
let e: Vec<f64> = logits.iter().map(|&l| (l * inv_tau).exp()).collect();
let s: f64 = e.iter().sum();
let gate_value: Vec<f64> = e.iter().map(|&v| v / s).collect();
let prog = SaeReconstructionRowProgram {
atoms: vec![mk_atom(0.0), mk_atom(1.0)],
gate_value,
logits,
gate_shift: vec![0.0, 0.0],
gate: RowGate::Softmax { inv_tau },
logit_slot: vec![Some(0), Some(1)],
coord_slot: vec![vec![2, 3], vec![4, 5]],
n_primaries: 6,
};
(prog, inv_tau)
}
#[test]
fn softmax_reconstruction_tower_matches_hand_channels_all_columns() {
let (prog, inv_tau) = softmax_fixture(1.3);
for out_col in 0..prog.out_dim() {
let tower = prog.reconstruction_column::<6>(out_col);
let hand = hand_softmax_column(&prog, out_col, inv_tau);
let g_floor = tower.g.iter().fold(0.0_f64, |m, x| m.max(x.abs()));
let h_floor = tower
.h
.iter()
.flatten()
.fold(0.0_f64, |m, x| m.max(x.abs()));
assert!(
(tower.v - hand.value).abs() <= 1e-9 * hand.value.abs().max(1.0),
"col {out_col} value: tower {} vs hand {}",
tower.v,
hand.value
);
for a in 0..6 {
assert!(
(tower.g[a] - hand.first[a]).abs() <= 1e-9 * g_floor.max(1e-12),
"col {out_col} first[{a}]: tower {} vs hand {}",
tower.g[a],
hand.first[a]
);
for b in 0..6 {
assert!(
(tower.h[a][b] - hand.second[a][b]).abs() <= 1e-8 * h_floor.max(1e-12),
"col {out_col} second[{a}][{b}]: tower {} vs hand {}",
tower.h[a][b],
hand.second[a][b]
);
}
}
}
}
#[test]
fn planted_cross_block_sign_flip_is_caught() {
let (prog, inv_tau) = softmax_fixture(1.3);
let out_col = 1;
let tower = prog.reconstruction_column::<6>(out_col);
let mut hand = hand_softmax_column(&prog, out_col, inv_tau);
hand.second[0][4] = -hand.second[0][4];
hand.second[4][0] = -hand.second[4][0];
let h_floor = tower
.h
.iter()
.flatten()
.fold(0.0_f64, |m, x| m.max(x.abs()));
let disagrees = (tower.h[0][4] - hand.second[0][4]).abs() > 1e-8 * h_floor.max(1e-12);
assert!(
disagrees,
"a flipped cross block must disagree with the tower truth"
);
}
#[test]
fn softmax_gate_tower_matches_hand_gate_derivatives() {
let (prog, inv_tau) = softmax_fixture(0.9);
let (dz, d2z) = softmax_gate_derivs(&prog.gate_value, inv_tau);
for atom in 0..prog.atoms.len() {
let gate = prog.gate_tower::<6>(atom);
assert!((gate.v - prog.gate_value[atom]).abs() < 1e-12);
for j in 0..prog.atoms.len() {
let slot = prog.logit_slot[j].unwrap();
assert!(
(gate.g[slot] - dz[j][atom]).abs() < 1e-9,
"gate {atom} d/dlogit {j}: tower {} vs hand {}",
gate.g[slot],
dz[j][atom]
);
}
for j in 0..prog.atoms.len() {
for l in 0..prog.atoms.len() {
let sj = prog.logit_slot[j].unwrap();
let sl = prog.logit_slot[l].unwrap();
assert!(
(gate.h[sj][sl] - d2z[j][l][atom]).abs() < 1e-8,
"gate {atom} d2/dlogit {j}{l}: tower {} vs hand {}",
gate.h[sj][sl],
d2z[j][l][atom]
);
}
}
}
}
#[test]
fn per_atom_logistic_gate_matches_closed_form() {
let inv_tau = 1.4;
let logit = 0.6;
let shift = 0.2;
let x: f64 = (logit - shift) * inv_tau;
let sigma = 1.0 / (1.0 + (-x).exp());
let prog = SaeReconstructionRowProgram {
atoms: vec![AtomRowBasisJet {
phi: vec![1.0],
d_phi: vec![vec![0.0]],
d2_phi: vec![vec![vec![0.0]]],
decoder: vec![vec![1.0]],
latent_dim: 1,
}],
gate_value: vec![sigma],
logits: vec![logit],
gate_shift: vec![shift],
gate: RowGate::PerAtomLogistic { inv_tau },
logit_slot: vec![Some(0)],
coord_slot: vec![vec![1]],
n_primaries: 2,
};
let gate = prog.gate_tower::<2>(0);
assert!((gate.v - sigma).abs() < 1e-12);
let d1 = sigma * (1.0 - sigma) * inv_tau;
let d2 = sigma * (1.0 - sigma) * (1.0 - 2.0 * sigma) * inv_tau * inv_tau;
assert!((gate.g[0] - d1).abs() < 1e-9, "σ': {} vs {}", gate.g[0], d1);
assert!(
(gate.h[0][0] - d2).abs() < 1e-9,
"σ'': {} vs {}",
gate.h[0][0],
d2
);
}
}