use gam_math::jet_scalar::{JetScalar, Order1, Order2};
use gam_math::jet_tower::Tower4;
#[inline]
fn recip<const K: usize, S: JetScalar<K>>(s: &S) -> S {
let u = s.value();
let u2 = u * u;
let u3 = u2 * u;
let u4 = u3 * u;
let u5 = u4 * u;
s.compose_unary([1.0 / u, -1.0 / u2, 2.0 / u3, -6.0 / u4, 24.0 / u5])
}
pub const SAE_FIXED_COORD_SLOT: usize = usize::MAX;
#[derive(Debug, Clone, Copy)]
pub enum RowGate {
Softmax { inv_tau: f64 },
PerAtomLogistic { inv_tau: f64 },
}
#[derive(Debug, Clone)]
pub struct AtomRowBasisJet {
pub phi: Vec<f64>,
pub d_phi: Vec<Vec<f64>>,
pub d2_phi: Vec<Vec<Vec<f64>>>,
pub decoder: Vec<Vec<f64>>,
pub latent_dim: usize,
}
impl AtomRowBasisJet {
fn n_basis(&self) -> usize {
self.phi.len()
}
fn out_dim(&self) -> usize {
self.decoder.first().map_or(0, Vec::len)
}
fn basis_tower<const K: usize, S: JetScalar<K>>(
&self,
basis_col: usize,
coord_slots: &[usize],
) -> S {
let mut acc = S::constant(self.phi[basis_col]);
for axis in 0..self.latent_dim {
let slot = coord_slots[axis];
let d1 = self.d_phi[basis_col][axis];
if d1 != 0.0 {
if slot != SAE_FIXED_COORD_SLOT {
acc = acc.add(&S::variable(0.0, slot).scale(d1));
}
}
}
for axis_a in 0..self.latent_dim {
for axis_b in 0..self.latent_dim {
let d2 = self.d2_phi[basis_col][axis_a][axis_b];
if d2 == 0.0 {
continue;
}
if coord_slots[axis_a] == SAE_FIXED_COORD_SLOT
|| coord_slots[axis_b] == SAE_FIXED_COORD_SLOT
{
continue;
}
let va = S::variable(0.0, coord_slots[axis_a]);
let vb = S::variable(0.0, coord_slots[axis_b]);
acc = acc.add(&va.mul(&vb).scale(0.5 * d2));
}
}
acc
}
fn decoded_tower<const K: usize, S: JetScalar<K>>(
&self,
out_col: usize,
coord_slots: &[usize],
) -> S {
let mut acc = S::constant(0.0);
for basis_col in 0..self.n_basis() {
let b = self.decoder[basis_col][out_col];
if b == 0.0 {
continue;
}
acc = acc.add(&self.basis_tower::<K, S>(basis_col, coord_slots).scale(b));
}
acc
}
}
#[derive(Debug, Clone)]
pub struct SaeReconstructionRowProgram {
pub atoms: Vec<AtomRowBasisJet>,
pub gate_value: Vec<f64>,
pub logits: Vec<f64>,
pub gate_scale: Vec<f64>,
pub gate_shift: Vec<f64>,
pub gate: RowGate,
pub logit_slot: Vec<Option<usize>>,
pub coord_slot: Vec<Vec<usize>>,
pub n_primaries: usize,
}
impl SaeReconstructionRowProgram {
fn gate_tower<const K: usize, S: JetScalar<K>>(&self, atom: usize) -> S {
match self.gate {
RowGate::Softmax { inv_tau } => {
let shift = self
.logits
.iter()
.copied()
.fold(f64::NEG_INFINITY, f64::max)
* inv_tau;
let mut denom = S::constant(0.0);
let mut numer = S::constant(0.0);
for j in 0..self.gate_value.len() {
let lj = match self.logit_slot[j] {
Some(slot) => S::variable(self.logits[j], slot),
None => S::constant(self.logits[j]),
};
let ej = lj.scale(inv_tau).sub(&S::constant(shift)).exp();
if j == atom {
numer = ej;
}
denom = denom.add(&ej);
}
numer.mul(&recip(&denom))
}
RowGate::PerAtomLogistic { inv_tau } => {
let l = match self.logit_slot[atom] {
Some(slot) => S::variable(self.logits[atom], slot),
None => S::constant(self.logits[atom]),
};
let x = l.sub(&S::constant(self.gate_shift[atom])).scale(inv_tau);
let one = S::constant(1.0);
let sigma = if x.value() >= 0.0 {
one.mul(&recip(&one.add(&x.scale(-1.0).exp())))
} else {
let ex = x.exp();
ex.mul(&recip(&one.add(&ex)))
};
sigma.scale(self.gate_scale[atom])
}
}
}
fn all_gates<const K: usize, S: JetScalar<K>>(&self) -> Vec<S> {
let n = self.gate_value.len();
match self.gate {
RowGate::Softmax { inv_tau } => {
let shift = self
.logits
.iter()
.copied()
.fold(f64::NEG_INFINITY, f64::max)
* inv_tau;
let mut exps: Vec<S> = Vec::with_capacity(n);
let mut denom = S::constant(0.0);
for j in 0..n {
let lj = match self.logit_slot[j] {
Some(slot) => S::variable(self.logits[j], slot),
None => S::constant(self.logits[j]),
};
let ej = lj.scale(inv_tau).sub(&S::constant(shift)).exp();
denom = denom.add(&ej);
exps.push(ej);
}
let inv = recip(&denom);
exps.iter().map(|e| e.mul(&inv)).collect()
}
RowGate::PerAtomLogistic { .. } => {
(0..n).map(|atom| self.gate_tower::<K, S>(atom)).collect()
}
}
}
fn reconstruction_column_generic<const K: usize, S: JetScalar<K>>(&self, out_col: usize) -> S {
assert_eq!(
self.n_primaries, K,
"SaeReconstructionRowProgram: tower arity K={K} must equal n_primaries={}",
self.n_primaries
);
let mut acc = S::constant(0.0);
for (atom, atom_jet) in self.atoms.iter().enumerate() {
let gate = self.gate_tower::<K, S>(atom);
let decoded = atom_jet.decoded_tower::<K, S>(out_col, &self.coord_slot[atom]);
acc = acc.add(&gate.mul(&decoded));
}
acc
}
#[must_use]
pub fn reconstruction_column_packed<const K: usize>(&self, out_col: usize) -> Order2<K> {
self.reconstruction_column_generic::<K, Order2<K>>(out_col)
}
#[must_use]
pub fn reconstruction_all_columns_packed<const K: usize>(&self) -> Vec<Order2<K>> {
assert_eq!(
self.n_primaries, K,
"SaeReconstructionRowProgram: tower arity K={K} must equal n_primaries={}",
self.n_primaries
);
let p = self.out_dim();
let gates: Vec<Order2<K>> = self.all_gates::<K, Order2<K>>();
let bases: Vec<Vec<Order2<K>>> = self
.atoms
.iter()
.enumerate()
.map(|(atom, atom_jet)| {
(0..atom_jet.n_basis())
.map(|b| atom_jet.basis_tower::<K, Order2<K>>(b, &self.coord_slot[atom]))
.collect()
})
.collect();
(0..p)
.map(|c| {
let mut acc = Order2::<K>::constant(0.0);
for (atom, atom_jet) in self.atoms.iter().enumerate() {
let mut decoded = Order2::<K>::constant(0.0);
for basis_col in 0..atom_jet.n_basis() {
let coeff = atom_jet.decoder[basis_col][c];
if coeff == 0.0 {
continue;
}
decoded = decoded.add(&bases[atom][basis_col].scale(coeff));
}
acc = acc.add(&gates[atom].mul(&decoded));
}
acc
})
.collect()
}
#[must_use]
pub fn reconstruction_column<const K: usize>(&self, out_col: usize) -> Tower4<K> {
self.reconstruction_column_generic::<K, Tower4<K>>(out_col)
}
fn beta_border_generic<const K: usize, S: JetScalar<K>>(
&self,
atom: usize,
basis_col: usize,
) -> S {
assert_eq!(
self.n_primaries, K,
"SaeReconstructionRowProgram: tower arity K={K} must equal n_primaries={}",
self.n_primaries
);
let gate = self.gate_tower::<K, S>(atom);
let phi = self.atoms[atom].basis_tower::<K, S>(basis_col, &self.coord_slot[atom]);
gate.mul(&phi)
}
#[must_use]
pub fn beta_border_tower_packed<const K: usize>(
&self,
atom: usize,
basis_col: usize,
) -> Order2<K> {
self.beta_border_generic::<K, Order2<K>>(atom, basis_col)
}
#[must_use]
pub fn beta_border_tower<const K: usize>(&self, atom: usize, basis_col: usize) -> Tower4<K> {
self.beta_border_generic::<K, Tower4<K>>(atom, basis_col)
}
#[must_use]
pub fn beta_border_towers_packed<const K: usize>(
&self,
channels: &[(usize, usize)],
) -> Vec<Order2<K>> {
assert_eq!(
self.n_primaries, K,
"SaeReconstructionRowProgram: tower arity K={K} must equal n_primaries={}",
self.n_primaries
);
let gates: Vec<Order2<K>> = self.all_gates::<K, Order2<K>>();
channels
.iter()
.map(|&(atom, basis_col)| {
let phi =
self.atoms[atom].basis_tower::<K, Order2<K>>(basis_col, &self.coord_slot[atom]);
gates[atom].mul(&phi)
})
.collect()
}
#[must_use]
pub fn beta_border_order1_packed<const K: usize>(
&self,
channels: &[(usize, usize)],
) -> Vec<Order1<K>> {
assert_eq!(
self.n_primaries, K,
"SaeReconstructionRowProgram: tower arity K={K} must equal n_primaries={}",
self.n_primaries
);
let gates: Vec<Order1<K>> = self.all_gates::<K, Order1<K>>();
channels
.iter()
.map(|&(atom, basis_col)| {
let phi =
self.atoms[atom].basis_tower::<K, Order1<K>>(basis_col, &self.coord_slot[atom]);
gates[atom].mul(&phi)
})
.collect()
}
#[must_use]
pub fn out_dim(&self) -> usize {
self.atoms.first().map_or(0, AtomRowBasisJet::out_dim)
}
}
const LANES: usize = 4;
#[inline]
fn l_splat(x: f64) -> [f64; LANES] {
[x; LANES]
}
#[inline]
fn l_add(a: [f64; LANES], b: [f64; LANES]) -> [f64; LANES] {
let mut o = [0.0; LANES];
for i in 0..LANES {
o[i] = a[i] + b[i];
}
o
}
#[inline]
fn l_mul(a: [f64; LANES], b: [f64; LANES]) -> [f64; LANES] {
let mut o = [0.0; LANES];
for i in 0..LANES {
o[i] = a[i] * b[i];
}
o
}
#[derive(Clone, Copy)]
struct O2x4<const K: usize> {
v: [f64; LANES],
g: [[f64; LANES]; K],
h: [[[f64; LANES]; K]; K],
}
impl<const K: usize> O2x4<K> {
#[inline]
fn constant(c: [f64; LANES]) -> Self {
O2x4 {
v: c,
g: [[0.0; LANES]; K],
h: [[[0.0; LANES]; K]; K],
}
}
#[inline]
fn variable(value: [f64; LANES], axis: usize) -> Self {
let mut out = Self::constant(value);
out.g[axis] = l_splat(1.0);
out
}
#[inline]
fn add(&self, o: &Self) -> Self {
let mut out = *self;
out.v = l_add(self.v, o.v);
for i in 0..K {
out.g[i] = l_add(self.g[i], o.g[i]);
for j in 0..K {
out.h[i][j] = l_add(self.h[i][j], o.h[i][j]);
}
}
out
}
#[inline]
fn scale(&self, s: [f64; LANES]) -> Self {
let mut out = *self;
out.v = l_mul(self.v, s);
for i in 0..K {
out.g[i] = l_mul(self.g[i], s);
for j in 0..K {
out.h[i][j] = l_mul(self.h[i][j], s);
}
}
out
}
#[inline]
fn sub(&self, o: &Self) -> Self {
self.add(&o.scale(l_splat(-1.0)))
}
#[inline]
fn mul(&self, o: &Self) -> Self {
let a = self;
let b = o;
let mut out = Self::constant(l_mul(a.v, b.v));
for i in 0..K {
out.g[i] = l_add(l_mul(a.v, b.g[i]), l_mul(a.g[i], b.v));
}
for i in 0..K {
for j in 0..K {
let t0 = l_mul(a.v, b.h[i][j]);
let t1 = l_add(t0, l_mul(a.g[i], b.g[j]));
let t2 = l_add(t1, l_mul(a.g[j], b.g[i]));
out.h[i][j] = l_add(t2, l_mul(a.h[i][j], b.v));
}
}
out
}
#[inline]
fn compose(&self, d: [[f64; LANES]; 3]) -> Self {
let mut out = Self::constant(d[0]);
for i in 0..K {
let mut acc = l_splat(0.0);
acc = l_add(acc, l_mul(d[1], self.g[i]));
out.g[i] = acc;
}
for i in 0..K {
for j in 0..K {
let mut acc = l_splat(0.0);
acc = l_add(acc, l_mul(d[1], self.h[i][j]));
acc = l_add(acc, l_mul(l_mul(d[2], self.g[i]), self.g[j]));
out.h[i][j] = acc;
}
}
out
}
#[inline]
fn exp(&self) -> Self {
let mut e = [0.0; LANES];
for i in 0..LANES {
e[i] = self.v[i].exp();
}
self.compose([e, e, e])
}
#[inline]
fn recip(&self) -> Self {
let mut d0 = [0.0; LANES];
let mut d1 = [0.0; LANES];
let mut d2 = [0.0; LANES];
for i in 0..LANES {
let u = self.v[i];
let u2 = u * u;
let u3 = u2 * u;
d0[i] = 1.0 / u;
d1[i] = -1.0 / u2;
d2[i] = 2.0 / u3;
}
self.compose([d0, d1, d2])
}
#[inline]
fn lane(&self, i: usize) -> Order2<K> {
let mut t = gam_math::jet_tower::Tower2::<K>::constant(self.v[i]);
for a in 0..K {
t.g[a] = self.g[a][i];
for b in 0..K {
t.h[a][b] = self.h[a][b][i];
}
}
Order2(t)
}
}
#[derive(Clone, Copy)]
struct O1x4<const K: usize> {
v: [f64; LANES],
g: [[f64; LANES]; K],
}
impl<const K: usize> O1x4<K> {
#[inline]
fn constant(c: [f64; LANES]) -> Self {
O1x4 {
v: c,
g: [[0.0; LANES]; K],
}
}
#[inline]
fn variable(value: [f64; LANES], axis: usize) -> Self {
let mut out = Self::constant(value);
out.g[axis] = l_splat(1.0);
out
}
#[inline]
fn add(&self, o: &Self) -> Self {
let mut out = *self;
out.v = l_add(self.v, o.v);
for i in 0..K {
out.g[i] = l_add(self.g[i], o.g[i]);
}
out
}
#[inline]
fn scale(&self, s: [f64; LANES]) -> Self {
let mut out = *self;
out.v = l_mul(self.v, s);
for i in 0..K {
out.g[i] = l_mul(self.g[i], s);
}
out
}
#[inline]
fn sub(&self, o: &Self) -> Self {
self.add(&o.scale(l_splat(-1.0)))
}
#[inline]
fn mul(&self, o: &Self) -> Self {
let a = self;
let b = o;
let mut out = Self::constant(l_mul(a.v, b.v));
for i in 0..K {
out.g[i] = l_add(l_mul(a.v, b.g[i]), l_mul(a.g[i], b.v));
}
out
}
#[inline]
fn compose(&self, d: [[f64; LANES]; 2]) -> Self {
let mut out = Self::constant(d[0]);
for i in 0..K {
let mut acc = l_splat(0.0);
acc = l_add(acc, l_mul(d[1], self.g[i]));
out.g[i] = acc;
}
out
}
#[inline]
fn exp(&self) -> Self {
let mut e = [0.0; LANES];
for i in 0..LANES {
e[i] = self.v[i].exp();
}
self.compose([e, e])
}
#[inline]
fn recip(&self) -> Self {
let mut d0 = [0.0; LANES];
let mut d1 = [0.0; LANES];
for i in 0..LANES {
let u = self.v[i];
let u2 = u * u;
d0[i] = 1.0 / u;
d1[i] = -1.0 / u2;
}
self.compose([d0, d1])
}
#[inline]
fn lane(&self, i: usize) -> Order1<K> {
let mut g = [0.0; K];
for a in 0..K {
g[a] = self.g[a][i];
}
Order1 { v: self.v[i], g }
}
}
impl SaeReconstructionRowProgram {
fn batch_aligned_softmax_with(&self, other: &Self) -> bool {
match (self.gate, other.gate) {
(RowGate::Softmax { inv_tau: a }, RowGate::Softmax { inv_tau: b }) => {
if a.to_bits() != b.to_bits() {
return false;
}
}
_ => return false,
}
if self.n_primaries != other.n_primaries
|| self.atoms.len() != other.atoms.len()
|| self.logit_slot != other.logit_slot
|| self.coord_slot != other.coord_slot
|| self.logits.len() != other.logits.len()
{
return false;
}
for (a, b) in self.atoms.iter().zip(other.atoms.iter()) {
if a.latent_dim != b.latent_dim
|| a.n_basis() != b.n_basis()
|| a.out_dim() != b.out_dim()
{
return false;
}
}
true
}
fn all_gates_o2x4<const K: usize>(&self, rows: &[&Self; LANES], inv_tau: f64) -> Vec<O2x4<K>> {
let n = self.gate_value.len();
let inv_tau_l = l_splat(inv_tau);
let mut shift = [0.0; LANES];
for (lane, r) in rows.iter().enumerate() {
shift[lane] = r.logits.iter().copied().fold(f64::NEG_INFINITY, f64::max) * inv_tau;
}
let mut exps: Vec<O2x4<K>> = Vec::with_capacity(n);
let mut denom = O2x4::<K>::constant(l_splat(0.0));
for j in 0..n {
let mut lj_val = [0.0; LANES];
for (lane, r) in rows.iter().enumerate() {
lj_val[lane] = r.logits[j];
}
let lj = match self.logit_slot[j] {
Some(slot) => O2x4::<K>::variable(lj_val, slot),
None => O2x4::<K>::constant(lj_val),
};
let ej = lj.scale(inv_tau_l).sub(&O2x4::<K>::constant(shift)).exp();
denom = denom.add(&ej);
exps.push(ej);
}
let inv = denom.recip();
exps.iter().map(|e| e.mul(&inv)).collect()
}
fn all_gates_o1x4<const K: usize>(&self, rows: &[&Self; LANES], inv_tau: f64) -> Vec<O1x4<K>> {
let n = self.gate_value.len();
let inv_tau_l = l_splat(inv_tau);
let mut shift = [0.0; LANES];
for (lane, r) in rows.iter().enumerate() {
shift[lane] = r.logits.iter().copied().fold(f64::NEG_INFINITY, f64::max) * inv_tau;
}
let mut exps: Vec<O1x4<K>> = Vec::with_capacity(n);
let mut denom = O1x4::<K>::constant(l_splat(0.0));
for j in 0..n {
let mut lj_val = [0.0; LANES];
for (lane, r) in rows.iter().enumerate() {
lj_val[lane] = r.logits[j];
}
let lj = match self.logit_slot[j] {
Some(slot) => O1x4::<K>::variable(lj_val, slot),
None => O1x4::<K>::constant(lj_val),
};
let ej = lj.scale(inv_tau_l).sub(&O1x4::<K>::constant(shift)).exp();
denom = denom.add(&ej);
exps.push(ej);
}
let inv = denom.recip();
exps.iter().map(|e| e.mul(&inv)).collect()
}
fn basis_tower_o2x4<const K: usize>(
rows: &[&Self; LANES],
atom: usize,
basis_col: usize,
coord_slots: &[usize],
) -> O2x4<K> {
let latent = rows[0].atoms[atom].latent_dim;
let mut phi0 = [0.0; LANES];
for (lane, r) in rows.iter().enumerate() {
phi0[lane] = r.atoms[atom].phi[basis_col];
}
let mut acc = O2x4::<K>::constant(phi0);
for axis in 0..latent {
let slot = coord_slots[axis];
if slot == SAE_FIXED_COORD_SLOT {
continue;
}
let mut d1 = [0.0; LANES];
let mut any = false;
for (lane, r) in rows.iter().enumerate() {
let v = r.atoms[atom].d_phi[basis_col][axis];
d1[lane] = v;
any |= v != 0.0;
}
if any {
acc = acc.add(&O2x4::<K>::variable(l_splat(0.0), slot).scale(d1));
}
}
for axis_a in 0..latent {
for axis_b in 0..latent {
if coord_slots[axis_a] == SAE_FIXED_COORD_SLOT
|| coord_slots[axis_b] == SAE_FIXED_COORD_SLOT
{
continue;
}
let mut d2 = [0.0; LANES];
let mut any = false;
for (lane, r) in rows.iter().enumerate() {
let v = r.atoms[atom].d2_phi[basis_col][axis_a][axis_b];
d2[lane] = v;
any |= v != 0.0;
}
if !any {
continue;
}
let mut half_d2 = [0.0; LANES];
for lane in 0..LANES {
half_d2[lane] = 0.5 * d2[lane];
}
let va = O2x4::<K>::variable(l_splat(0.0), coord_slots[axis_a]);
let vb = O2x4::<K>::variable(l_splat(0.0), coord_slots[axis_b]);
acc = acc.add(&va.mul(&vb).scale(half_d2));
}
}
acc
}
fn basis_tower_o1x4<const K: usize>(
rows: &[&Self; LANES],
atom: usize,
basis_col: usize,
coord_slots: &[usize],
) -> O1x4<K> {
let latent = rows[0].atoms[atom].latent_dim;
let mut phi0 = [0.0; LANES];
for (lane, r) in rows.iter().enumerate() {
phi0[lane] = r.atoms[atom].phi[basis_col];
}
let mut acc = O1x4::<K>::constant(phi0);
for axis in 0..latent {
let slot = coord_slots[axis];
if slot == SAE_FIXED_COORD_SLOT {
continue;
}
let mut d1 = [0.0; LANES];
let mut any = false;
for (lane, r) in rows.iter().enumerate() {
let v = r.atoms[atom].d_phi[basis_col][axis];
d1[lane] = v;
any |= v != 0.0;
}
if any {
acc = acc.add(&O1x4::<K>::variable(l_splat(0.0), slot).scale(d1));
}
}
for axis_a in 0..latent {
for axis_b in 0..latent {
if coord_slots[axis_a] == SAE_FIXED_COORD_SLOT
|| coord_slots[axis_b] == SAE_FIXED_COORD_SLOT
{
continue;
}
let mut d2 = [0.0; LANES];
let mut any = false;
for (lane, r) in rows.iter().enumerate() {
let v = r.atoms[atom].d2_phi[basis_col][axis_a][axis_b];
d2[lane] = v;
any |= v != 0.0;
}
if !any {
continue;
}
let mut half_d2 = [0.0; LANES];
for lane in 0..LANES {
half_d2[lane] = 0.5 * d2[lane];
}
let va = O1x4::<K>::variable(l_splat(0.0), coord_slots[axis_a]);
let vb = O1x4::<K>::variable(l_splat(0.0), coord_slots[axis_b]);
acc = acc.add(&va.mul(&vb).scale(half_d2));
}
}
acc
}
#[must_use]
pub fn reconstruction_all_columns_batch4<const K: usize>(
rows: [&Self; 4],
) -> Option<[Vec<Order2<K>>; 4]> {
let head = rows[0];
if head.n_primaries != K {
return None;
}
let inv_tau = match head.gate {
RowGate::Softmax { inv_tau } => inv_tau,
RowGate::PerAtomLogistic { .. } => return None,
};
for r in &rows[1..] {
if !head.batch_aligned_softmax_with(r) {
return None;
}
}
let p = head.out_dim();
let gates: Vec<O2x4<K>> = head.all_gates_o2x4::<K>(&rows, inv_tau);
let bases: Vec<Vec<O2x4<K>>> = head
.atoms
.iter()
.enumerate()
.map(|(atom, atom_jet)| {
(0..atom_jet.n_basis())
.map(|b| Self::basis_tower_o2x4::<K>(&rows, atom, b, &head.coord_slot[atom]))
.collect()
})
.collect();
let mut cols: [Vec<Order2<K>>; LANES] =
[Vec::new(), Vec::new(), Vec::new(), Vec::new()];
for c in 0..p {
let mut acc = O2x4::<K>::constant(l_splat(0.0));
for (atom, atom_jet) in head.atoms.iter().enumerate() {
let mut decoded = O2x4::<K>::constant(l_splat(0.0));
for basis_col in 0..atom_jet.n_basis() {
let mut coeff = [0.0; LANES];
let mut any = false;
for (lane, r) in rows.iter().enumerate() {
let v = r.atoms[atom].decoder[basis_col][c];
coeff[lane] = v;
any |= v != 0.0;
}
if any {
decoded = decoded.add(&bases[atom][basis_col].scale(coeff));
}
}
acc = acc.add(&gates[atom].mul(&decoded));
}
for lane in 0..LANES {
cols[lane].push(acc.lane(lane));
}
}
Some(cols)
}
#[must_use]
pub fn beta_border_order1_batch4<const K: usize>(
rows: [&Self; 4],
channels: &[(usize, usize)],
) -> Option<[Vec<Order1<K>>; 4]> {
let head = rows[0];
if head.n_primaries != K {
return None;
}
let inv_tau = match head.gate {
RowGate::Softmax { inv_tau } => inv_tau,
RowGate::PerAtomLogistic { .. } => return None,
};
for r in &rows[1..] {
if !head.batch_aligned_softmax_with(r) {
return None;
}
}
let gates: Vec<O1x4<K>> = head.all_gates_o1x4::<K>(&rows, inv_tau);
let mut out: [Vec<Order1<K>>; LANES] =
[Vec::new(), Vec::new(), Vec::new(), Vec::new()];
for &(atom, basis_col) in channels {
let phi = Self::basis_tower_o1x4::<K>(&rows, atom, basis_col, &head.coord_slot[atom]);
let s = gates[atom].mul(&phi);
for lane in 0..LANES {
out[lane].push(s.lane(lane));
}
}
Some(out)
}
}
#[cfg(test)]
mod tests {
use super::*;
struct HandChannels {
first: Vec<f64>, second: Vec<Vec<f64>>, value: f64,
}
fn softmax_gate_derivs(gate: &[f64], inv_tau: f64) -> (Vec<Vec<f64>>, Vec<Vec<Vec<f64>>>) {
let k = gate.len();
let mut dz = vec![vec![0.0_f64; k]; k];
let mut d2z = vec![vec![vec![0.0_f64; k]; k]; k];
for j in 0..k {
for kk in 0..k {
let ind = if kk == j { 1.0 } else { 0.0 };
dz[j][kk] = gate[kk] * (ind - gate[j]) * inv_tau;
}
}
for j in 0..k {
for l in 0..k {
for kk in 0..k {
let ikl = if kk == l { 1.0 } else { 0.0 };
let ikj = if kk == j { 1.0 } else { 0.0 };
let ijl = if j == l { 1.0 } else { 0.0 };
d2z[j][l][kk] = gate[kk]
* ((ikl - gate[l]) * (ikj - gate[j]) - gate[j] * (ijl - gate[l]))
* inv_tau
* inv_tau;
}
}
}
(dz, d2z)
}
fn hand_softmax_column(
prog: &SaeReconstructionRowProgram,
out_col: usize,
inv_tau: f64,
) -> HandChannels {
let k = prog.atoms.len();
let n = prog.n_primaries;
let decoded: Vec<f64> = (0..k)
.map(|kk| {
(0..prog.atoms[kk].n_basis())
.map(|b| prog.atoms[kk].phi[b] * prog.atoms[kk].decoder[b][out_col])
.sum()
})
.collect();
let d1: Vec<Vec<f64>> = (0..k)
.map(|kk| {
(0..prog.atoms[kk].latent_dim)
.map(|axis| {
(0..prog.atoms[kk].n_basis())
.map(|b| {
prog.atoms[kk].d_phi[b][axis] * prog.atoms[kk].decoder[b][out_col]
})
.sum()
})
.collect()
})
.collect();
let d2: Vec<Vec<Vec<f64>>> = (0..k)
.map(|kk| {
(0..prog.atoms[kk].latent_dim)
.map(|a| {
(0..prog.atoms[kk].latent_dim)
.map(|b| {
(0..prog.atoms[kk].n_basis())
.map(|col| {
prog.atoms[kk].d2_phi[col][a][b]
* prog.atoms[kk].decoder[col][out_col]
})
.sum()
})
.collect()
})
.collect()
})
.collect();
let (dz, d2z) = softmax_gate_derivs(&prog.gate_value, inv_tau);
let logit_idx = |kk: usize| prog.logit_slot[kk];
let coord_idx = |kk: usize, axis: usize| prog.coord_slot[kk][axis];
let value: f64 = (0..k).map(|kk| prog.gate_value[kk] * decoded[kk]).sum();
let mut first = vec![0.0_f64; n];
for j in 0..k {
if let Some(p) = logit_idx(j) {
first[p] = (0..k).map(|kk| dz[j][kk] * decoded[kk]).sum();
}
}
for kk in 0..k {
for axis in 0..prog.atoms[kk].latent_dim {
first[coord_idx(kk, axis)] = prog.gate_value[kk] * d1[kk][axis];
}
}
let mut second = vec![vec![0.0_f64; n]; n];
for j in 0..k {
for l in 0..k {
if let (Some(pj), Some(pl)) = (logit_idx(j), logit_idx(l)) {
second[pj][pl] = (0..k).map(|kk| d2z[j][l][kk] * decoded[kk]).sum();
}
}
}
for j in 0..k {
for kk in 0..k {
for axis in 0..prog.atoms[kk].latent_dim {
if let Some(pj) = logit_idx(j) {
let pc = coord_idx(kk, axis);
let val = dz[j][kk] * d1[kk][axis];
second[pj][pc] = val;
second[pc][pj] = val;
}
}
}
}
for kk in 0..k {
for a in 0..prog.atoms[kk].latent_dim {
for b in 0..prog.atoms[kk].latent_dim {
let pa = coord_idx(kk, a);
let pb = coord_idx(kk, b);
second[pa][pb] = prog.gate_value[kk] * d2[kk][a][b];
}
}
}
HandChannels {
first,
second,
value,
}
}
fn softmax_fixture(inv_tau: f64) -> (SaeReconstructionRowProgram, f64) {
let n_basis = 3;
let out_dim = 4;
let mk_atom = |seed: f64| {
let phi: Vec<f64> = (0..n_basis)
.map(|b| 0.3 + 0.2 * (b as f64 + seed))
.collect();
let d_phi: Vec<Vec<f64>> = (0..n_basis)
.map(|b| {
(0..2)
.map(|axis| 0.1 * (b as f64 + 1.0) - 0.05 * axis as f64 + 0.03 * seed)
.collect()
})
.collect();
let d2_phi: Vec<Vec<Vec<f64>>> = (0..n_basis)
.map(|b| {
(0..2)
.map(|a| {
(0..2)
.map(|bb| {
0.02 * (b as f64 + 1.0)
+ 0.01 * (a as f64)
+ 0.01 * (bb as f64)
+ 0.004 * seed
})
.collect()
})
.collect()
})
.collect();
let decoder: Vec<Vec<f64>> = (0..n_basis)
.map(|b| {
(0..out_dim)
.map(|c| 0.5 - 0.1 * (b as f64) + 0.07 * (c as f64) + 0.02 * seed)
.collect()
})
.collect();
AtomRowBasisJet {
phi,
d_phi,
d2_phi,
decoder,
latent_dim: 2,
}
};
let logits = vec![0.4_f64, -0.7];
let e: Vec<f64> = logits.iter().map(|&l| (l * inv_tau).exp()).collect();
let s: f64 = e.iter().sum();
let gate_value: Vec<f64> = e.iter().map(|&v| v / s).collect();
let prog = SaeReconstructionRowProgram {
atoms: vec![mk_atom(0.0), mk_atom(1.0)],
gate_value,
logits,
gate_scale: vec![1.0, 1.0],
gate_shift: vec![0.0, 0.0],
gate: RowGate::Softmax { inv_tau },
logit_slot: vec![Some(0), Some(1)],
coord_slot: vec![vec![2, 3], vec![4, 5]],
n_primaries: 6,
};
(prog, inv_tau)
}
fn recon_scalar_softmax(
prog: &SaeReconstructionRowProgram,
out_col: usize,
inv_tau: f64,
delta: &[f64],
) -> f64 {
let k = prog.atoms.len();
let exps: Vec<f64> = (0..k)
.map(|kk| {
let dl = match prog.logit_slot[kk] {
Some(slot) => delta[slot],
None => 0.0,
};
((prog.logits[kk] + dl) * inv_tau).exp()
})
.collect();
let denom: f64 = exps.iter().sum();
let mut acc = 0.0;
for kk in 0..k {
let gate = exps[kk] / denom;
let atom = &prog.atoms[kk];
let mut decoded = 0.0;
for b in 0..atom.n_basis() {
let mut phi = atom.phi[b];
for a in 0..atom.latent_dim {
let ua = delta[prog.coord_slot[kk][a]];
phi += atom.d_phi[b][a] * ua;
}
for a in 0..atom.latent_dim {
let ua = delta[prog.coord_slot[kk][a]];
for a2 in 0..atom.latent_dim {
let ub = delta[prog.coord_slot[kk][a2]];
phi += 0.5 * atom.d2_phi[b][a][a2] * ua * ub;
}
}
decoded += phi * atom.decoder[b][out_col];
}
acc += gate * decoded;
}
acc
}
fn fd_fourth(
prog: &SaeReconstructionRowProgram,
out_col: usize,
inv_tau: f64,
axes: [usize; 4],
h: f64,
) -> f64 {
let n = prog.n_primaries;
let mut acc = 0.0;
for mask in 0..16u32 {
let mut delta = vec![0.0_f64; n];
let mut sign = 1.0;
for (slot, &ax) in axes.iter().enumerate() {
if (mask >> slot) & 1 == 1 {
delta[ax] += h;
} else {
delta[ax] -= h;
sign = -sign;
}
}
acc += sign * recon_scalar_softmax(prog, out_col, inv_tau, &delta);
}
acc / (16.0 * h * h * h * h)
}
fn fd_third(
prog: &SaeReconstructionRowProgram,
out_col: usize,
inv_tau: f64,
axes: [usize; 3],
h: f64,
) -> f64 {
let n = prog.n_primaries;
let mut acc = 0.0;
for mask in 0..8u32 {
let mut delta = vec![0.0_f64; n];
let mut sign = 1.0;
for (slot, &ax) in axes.iter().enumerate() {
if (mask >> slot) & 1 == 1 {
delta[ax] += h;
} else {
delta[ax] -= h;
sign = -sign;
}
}
acc += sign * recon_scalar_softmax(prog, out_col, inv_tau, &delta);
}
acc / (8.0 * h * h * h)
}
#[test]
fn softmax_reconstruction_t3_t4_match_independent_fd_witness() {
let (prog, inv_tau) = softmax_fixture(1.1);
let h3 = 2e-3;
let h4 = 1e-2;
for out_col in 0..prog.out_dim() {
let tower = prog.reconstruction_column::<6>(out_col);
let t3_floor = tower
.t3
.iter()
.flatten()
.flatten()
.fold(0.0_f64, |m, x| m.max(x.abs()))
.max(1e-9);
let t4_floor = tower
.t4
.iter()
.flatten()
.flatten()
.flatten()
.fold(0.0_f64, |m, x| m.max(x.abs()))
.max(1e-9);
for a in 0..6 {
for b in 0..6 {
for c in 0..6 {
let fd = fd_third(&prog, out_col, inv_tau, [a, b, c], h3);
assert!(
(tower.t3[a][b][c] - fd).abs() <= 5e-5 * t3_floor,
"col {out_col} t3[{a}][{b}][{c}]: tower {:+.10e} vs fd {:+.10e}",
tower.t3[a][b][c],
fd
);
for d in 0..6 {
let fd4 = fd_fourth(&prog, out_col, inv_tau, [a, b, c, d], h4);
assert!(
(tower.t4[a][b][c][d] - fd4).abs() <= 5e-4 * t4_floor,
"col {out_col} t4[{a}][{b}][{c}][{d}]: tower {:+.10e} vs fd {:+.10e}",
tower.t4[a][b][c][d],
fd4
);
}
}
}
}
}
}
#[test]
fn planted_t3_t4_corruption_is_caught_by_fd_witness() {
let (prog, inv_tau) = softmax_fixture(1.1);
let out_col = 2;
let tower = prog.reconstruction_column::<6>(out_col);
let axes3 = [0usize, 2, 3];
let fd3 = fd_third(&prog, out_col, inv_tau, axes3, 2e-3);
let t3_floor = tower
.t3
.iter()
.flatten()
.flatten()
.fold(0.0_f64, |m, x| m.max(x.abs()))
.max(1e-9);
assert!(
(tower.t3[0][2][3] - fd3).abs() <= 5e-5 * t3_floor,
"honest t3 must match witness"
);
let corrupt = -tower.t3[0][2][3];
assert!(
(corrupt - fd3).abs() > 5e-5 * t3_floor,
"a sign-flipped t3 block must disagree with the FD witness"
);
let axes4 = [0usize, 0, 2, 3];
let fd4 = fd_fourth(&prog, out_col, inv_tau, axes4, 1e-2);
let t4_floor = tower
.t4
.iter()
.flatten()
.flatten()
.flatten()
.fold(0.0_f64, |m, x| m.max(x.abs()))
.max(1e-9);
let corrupt4 = tower.t4[0][0][2][3] + 10.0 * t4_floor;
assert!(
(corrupt4 - fd4).abs() > 5e-4 * t4_floor,
"a corrupted t4 block must disagree with the FD witness"
);
}
#[test]
fn softmax_reconstruction_tower_matches_hand_channels_all_columns() {
let (prog, inv_tau) = softmax_fixture(1.3);
for out_col in 0..prog.out_dim() {
let tower = prog.reconstruction_column::<6>(out_col);
let hand = hand_softmax_column(&prog, out_col, inv_tau);
let g_floor = tower.g.iter().fold(0.0_f64, |m, x| m.max(x.abs()));
let h_floor = tower
.h
.iter()
.flatten()
.fold(0.0_f64, |m, x| m.max(x.abs()));
assert!(
(tower.v - hand.value).abs() <= 1e-9 * hand.value.abs().max(1.0),
"col {out_col} value: tower {} vs hand {}",
tower.v,
hand.value
);
for a in 0..6 {
assert!(
(tower.g[a] - hand.first[a]).abs() <= 1e-9 * g_floor.max(1e-12),
"col {out_col} first[{a}]: tower {} vs hand {}",
tower.g[a],
hand.first[a]
);
for b in 0..6 {
assert!(
(tower.h[a][b] - hand.second[a][b]).abs() <= 1e-8 * h_floor.max(1e-12),
"col {out_col} second[{a}][{b}]: tower {} vs hand {}",
tower.h[a][b],
hand.second[a][b]
);
}
}
}
}
#[test]
fn planted_cross_block_sign_flip_is_caught() {
let (prog, inv_tau) = softmax_fixture(1.3);
let out_col = 1;
let tower = prog.reconstruction_column::<6>(out_col);
let mut hand = hand_softmax_column(&prog, out_col, inv_tau);
hand.second[0][4] = -hand.second[0][4];
hand.second[4][0] = -hand.second[4][0];
let h_floor = tower
.h
.iter()
.flatten()
.fold(0.0_f64, |m, x| m.max(x.abs()));
let disagrees = (tower.h[0][4] - hand.second[0][4]).abs() > 1e-8 * h_floor.max(1e-12);
assert!(
disagrees,
"a flipped cross block must disagree with the tower truth"
);
}
#[test]
fn softmax_gate_tower_matches_hand_gate_derivatives() {
let (prog, inv_tau) = softmax_fixture(0.9);
let (dz, d2z) = softmax_gate_derivs(&prog.gate_value, inv_tau);
for atom in 0..prog.atoms.len() {
let gate = prog.gate_tower::<6, Tower4<6>>(atom);
assert!((gate.v - prog.gate_value[atom]).abs() < 1e-12);
for j in 0..prog.atoms.len() {
let slot = prog.logit_slot[j].unwrap();
assert!(
(gate.g[slot] - dz[j][atom]).abs() < 1e-9,
"gate {atom} d/dlogit {j}: tower {} vs hand {}",
gate.g[slot],
dz[j][atom]
);
}
for j in 0..prog.atoms.len() {
for l in 0..prog.atoms.len() {
let sj = prog.logit_slot[j].unwrap();
let sl = prog.logit_slot[l].unwrap();
assert!(
(gate.h[sj][sl] - d2z[j][l][atom]).abs() < 1e-8,
"gate {atom} d2/dlogit {j}{l}: tower {} vs hand {}",
gate.h[sj][sl],
d2z[j][l][atom]
);
}
}
}
}
#[test]
fn per_atom_logistic_gate_matches_closed_form() {
let inv_tau = 1.4;
let logit = 0.6;
let shift = 0.2;
let x: f64 = (logit - shift) * inv_tau;
let sigma = 1.0 / (1.0 + (-x).exp());
let prog = SaeReconstructionRowProgram {
atoms: vec![AtomRowBasisJet {
phi: vec![1.0],
d_phi: vec![vec![0.0]],
d2_phi: vec![vec![vec![0.0]]],
decoder: vec![vec![1.0]],
latent_dim: 1,
}],
gate_value: vec![sigma],
logits: vec![logit],
gate_scale: vec![1.0],
gate_shift: vec![shift],
gate: RowGate::PerAtomLogistic { inv_tau },
logit_slot: vec![Some(0)],
coord_slot: vec![vec![1]],
n_primaries: 2,
};
let gate = prog.gate_tower::<2, Tower4<2>>(0);
assert!((gate.v - sigma).abs() < 1e-12);
let d1 = sigma * (1.0 - sigma) * inv_tau;
let d2 = sigma * (1.0 - sigma) * (1.0 - 2.0 * sigma) * inv_tau * inv_tau;
assert!((gate.g[0] - d1).abs() < 1e-9, "σ': {} vs {}", gate.g[0], d1);
assert!(
(gate.h[0][0] - d2).abs() < 1e-9,
"σ'': {} vs {}",
gate.h[0][0],
d2
);
}
#[test]
fn order2_reconstruction_matches_tower_value_grad_hessian() {
for tau in [0.9_f64, 1.3, 2.1] {
let (prog, _inv_tau) = softmax_fixture(tau);
for out_col in 0..prog.out_dim() {
let packed = prog.reconstruction_column_packed::<6>(out_col);
let tower = prog.reconstruction_column::<6>(out_col);
let g = packed.g();
let h = packed.h();
let band = |x: f64| 1e-12 + 1e-12 * x.abs();
assert!(
(packed.value() - tower.v).abs() <= band(tower.v),
"col {out_col} value: order2 {} vs tower {}",
packed.value(),
tower.v
);
for a in 0..6 {
assert!(
(g[a] - tower.g[a]).abs() <= band(tower.g[a]),
"col {out_col} g[{a}]: order2 {} vs tower {}",
g[a],
tower.g[a]
);
for b in 0..6 {
assert!(
(h[a][b] - tower.h[a][b]).abs() <= band(tower.h[a][b]),
"col {out_col} h[{a}][{b}]: order2 {} vs tower {}",
h[a][b],
tower.h[a][b]
);
}
}
}
}
}
#[test]
fn order2_beta_border_matches_tower_value_grad() {
let (prog, _inv_tau) = softmax_fixture(1.1);
for atom in 0..prog.atoms.len() {
for basis_col in 0..prog.atoms[atom].n_basis() {
let packed = prog.beta_border_tower_packed::<6>(atom, basis_col);
let tower = prog.beta_border_tower::<6>(atom, basis_col);
let g = packed.g();
let band = |x: f64| 1e-12 + 1e-12 * x.abs();
assert!(
(packed.value() - tower.v).abs() <= band(tower.v),
"atom {atom} b {basis_col} value: order2 {} vs tower {}",
packed.value(),
tower.v
);
for a in 0..6 {
assert!(
(g[a] - tower.g[a]).abs() <= band(tower.g[a]),
"atom {atom} b {basis_col} g[{a}]: order2 {} vs tower {}",
g[a],
tower.g[a]
);
}
}
}
}
#[test]
fn shared_all_gates_bit_identical_to_per_atom_gate_tower() {
for tau in [0.9_f64, 1.3, 2.1] {
let (prog, _inv_tau) = softmax_fixture(tau);
let all = prog.all_gates::<6, Order2<6>>();
assert_eq!(all.len(), prog.gate_value.len());
for atom in 0..prog.gate_value.len() {
let per = prog.gate_tower::<6, Order2<6>>(atom);
assert_eq!(all[atom].value(), per.value(), "atom {atom} value");
for a in 0..6 {
assert_eq!(all[atom].g()[a], per.g()[a], "atom {atom} g[{a}]");
for b in 0..6 {
assert_eq!(
all[atom].h()[a][b],
per.h()[a][b],
"atom {atom} h[{a}][{b}]"
);
}
}
}
}
}
#[test]
fn hoisted_all_columns_bit_identical_to_per_column() {
for tau in [0.9_f64, 1.3, 2.1] {
let (prog, _inv_tau) = softmax_fixture(tau);
let all = prog.reconstruction_all_columns_packed::<6>();
assert_eq!(all.len(), prog.out_dim());
for out_col in 0..prog.out_dim() {
let per = prog.reconstruction_column_packed::<6>(out_col);
let ah = all[out_col];
assert_eq!(ah.value(), per.value(), "col {out_col} value");
for a in 0..6 {
assert_eq!(ah.g()[a], per.g()[a], "col {out_col} g[{a}]");
for b in 0..6 {
assert_eq!(ah.h()[a][b], per.h()[a][b], "col {out_col} h[{a}][{b}]");
}
}
}
}
}
fn softmax_batch_fixture(inv_tau: f64) -> [SaeReconstructionRowProgram; LANES] {
let n_basis = 3;
let out_dim = 4;
let mk = |row_seed: f64| {
let mk_atom = |seed: f64| {
let phi: Vec<f64> = (0..n_basis)
.map(|b| 0.3 + 0.2 * (b as f64 + seed) + 0.11 * row_seed)
.collect();
let d_phi: Vec<Vec<f64>> = (0..n_basis)
.map(|b| {
(0..2)
.map(|axis| {
0.1 * (b as f64 + 1.0) - 0.05 * axis as f64 + 0.03 * seed
+ 0.017 * row_seed
})
.collect()
})
.collect();
let d2_phi: Vec<Vec<Vec<f64>>> = (0..n_basis)
.map(|b| {
(0..2)
.map(|a| {
(0..2)
.map(|bb| {
0.02 * (b as f64 + 1.0)
+ 0.01 * (a as f64)
+ 0.01 * (bb as f64)
+ 0.004 * seed
+ 0.003 * row_seed
})
.collect()
})
.collect()
})
.collect();
let decoder: Vec<Vec<f64>> = (0..n_basis)
.map(|b| {
(0..out_dim)
.map(|c| {
0.5 - 0.1 * (b as f64) + 0.07 * (c as f64) + 0.02 * seed
+ 0.009 * row_seed
})
.collect()
})
.collect();
AtomRowBasisJet {
phi,
d_phi,
d2_phi,
decoder,
latent_dim: 2,
}
};
let logits = vec![0.4 + 0.21 * row_seed, -0.7 + 0.13 * row_seed];
let e: Vec<f64> = logits.iter().map(|&l| (l * inv_tau).exp()).collect();
let s: f64 = e.iter().sum();
let gate_value: Vec<f64> = e.iter().map(|&v| v / s).collect();
SaeReconstructionRowProgram {
atoms: vec![mk_atom(0.0), mk_atom(1.0)],
gate_value,
logits,
gate_scale: vec![1.0, 1.0],
gate_shift: vec![0.0, 0.0],
gate: RowGate::Softmax { inv_tau },
logit_slot: vec![Some(0), Some(1)],
coord_slot: vec![vec![2, 3], vec![4, 5]],
n_primaries: 6,
}
};
[mk(0.0), mk(1.0), mk(2.0), mk(3.0)]
}
#[test]
fn batch4_reconstruction_bit_identical_to_per_row() {
let mut comparisons = 0usize;
for tau in [0.7_f64, 0.9, 1.1, 1.3, 1.7, 2.1, 2.9] {
let rows = softmax_batch_fixture(tau);
let refs = [&rows[0], &rows[1], &rows[2], &rows[3]];
let batch = SaeReconstructionRowProgram::reconstruction_all_columns_batch4::<6>(refs)
.expect("softmax-aligned rows must batch");
for lane in 0..LANES {
let per = rows[lane].reconstruction_all_columns_packed::<6>();
assert_eq!(per.len(), batch[lane].len());
for (c, (b, p)) in batch[lane].iter().zip(per.iter()).enumerate() {
assert_eq!(
b.value().to_bits(),
p.value().to_bits(),
"tau {tau} lane {lane} col {c} value"
);
let (bg, pg) = (b.g(), p.g());
let (bh, ph) = (b.h(), p.h());
for a in 0..6 {
assert_eq!(bg[a].to_bits(), pg[a].to_bits(), "lane {lane} col {c} g[{a}]");
for d in 0..6 {
assert_eq!(
bh[a][d].to_bits(),
ph[a][d].to_bits(),
"lane {lane} col {c} h[{a}][{d}]"
);
comparisons += 1;
}
}
}
}
}
assert!(comparisons >= 2000, "oracle ran {comparisons} comparisons");
}
#[test]
fn batch4_beta_border_bit_identical_to_per_row() {
let mut comparisons = 0usize;
for tau in [0.7_f64, 0.9, 1.1, 1.3, 1.7, 2.1, 2.9] {
let rows = softmax_batch_fixture(tau);
let refs = [&rows[0], &rows[1], &rows[2], &rows[3]];
let mut chans: Vec<(usize, usize)> = Vec::new();
for atom in 0..rows[0].atoms.len() {
for b in 0..rows[0].atoms[atom].n_basis() {
chans.push((atom, b));
}
}
chans.push(chans[0]); let batch =
SaeReconstructionRowProgram::beta_border_order1_batch4::<6>(refs, &chans)
.expect("softmax-aligned rows must batch");
for lane in 0..LANES {
let per = rows[lane].beta_border_order1_packed::<6>(&chans);
assert_eq!(per.len(), batch[lane].len());
for (i, (b, p)) in batch[lane].iter().zip(per.iter()).enumerate() {
assert_eq!(b.value().to_bits(), p.value().to_bits(), "lane {lane} chan {i} v");
let (bg, pg) = (b.g(), p.g());
for a in 0..6 {
assert_eq!(
bg[a].to_bits(),
pg[a].to_bits(),
"lane {lane} chan {i} g[{a}]"
);
comparisons += 1;
}
}
}
}
assert!(comparisons >= 1000, "oracle ran {comparisons} comparisons");
}
#[test]
fn batch4_declines_non_softmax() {
let inv_tau = 1.1;
let mk = || SaeReconstructionRowProgram {
atoms: vec![AtomRowBasisJet {
phi: vec![1.0],
d_phi: vec![vec![0.0]],
d2_phi: vec![vec![vec![0.0]]],
decoder: vec![vec![1.0]],
latent_dim: 1,
}],
gate_value: vec![0.6],
logits: vec![0.6],
gate_scale: vec![1.0],
gate_shift: vec![0.2],
gate: RowGate::PerAtomLogistic { inv_tau },
logit_slot: vec![Some(0)],
coord_slot: vec![vec![1]],
n_primaries: 2,
};
let rows = [mk(), mk(), mk(), mk()];
let refs = [&rows[0], &rows[1], &rows[2], &rows[3]];
assert!(
SaeReconstructionRowProgram::reconstruction_all_columns_batch4::<2>(refs).is_none()
);
}
#[test]
fn hoisted_beta_border_bit_identical_to_per_channel() {
let (prog, _inv_tau) = softmax_fixture(1.1);
let mut chans: Vec<(usize, usize)> = Vec::new();
for atom in 0..prog.atoms.len() {
for basis_col in 0..prog.atoms[atom].n_basis() {
chans.push((atom, basis_col));
}
}
if let Some(&first) = chans.first() {
chans.push(first);
}
let batched = prog.beta_border_towers_packed::<6>(&chans);
assert_eq!(batched.len(), chans.len());
for (i, &(atom, basis_col)) in chans.iter().enumerate() {
let per = prog.beta_border_tower_packed::<6>(atom, basis_col);
let b = batched[i];
assert_eq!(b.value(), per.value(), "chan {i} value");
for a in 0..6 {
assert_eq!(b.g()[a], per.g()[a], "chan {i} g[{a}]");
}
}
}
}