use super::*;
#[derive(Debug, Clone)]
pub struct SaeRowLayout {
pub active_atoms: Vec<Vec<usize>>,
pub coord_starts: Vec<Vec<usize>>,
pub coord_offsets_full: Vec<usize>,
pub coord_dims: Vec<usize>,
}
pub(crate) struct JumpReluLayoutParams<'a> {
pub n: usize,
pub k_atoms: usize,
pub threshold: f64,
pub temperature: f64,
pub logits: &'a Array2<f64>,
pub contribution: &'a Array2<f64>,
pub k_active_cap: usize,
pub relative_cutoff: f64,
pub coord_dims: Vec<usize>,
pub coord_offsets_full: Vec<usize>,
}
impl SaeRowLayout {
pub(crate) fn from_jumprelu(params: JumpReluLayoutParams<'_>) -> Self {
let JumpReluLayoutParams {
n,
k_atoms,
threshold,
temperature,
logits,
contribution,
k_active_cap,
relative_cutoff,
coord_dims,
coord_offsets_full,
} = params;
use std::cmp::Ordering::Equal;
let cap = k_active_cap.max(1);
let mut per_row = Vec::with_capacity(n);
for row in 0..n {
let row_logits = logits.row(row);
let row_contrib = contribution.row(row);
let in_band = |k: usize| {
crate::assignment::jumprelu_in_optimization_band(
row_logits[k],
threshold,
temperature,
)
};
let hard: Vec<usize> = (0..k_atoms).filter(|&k| row_logits[k] > threshold).collect();
let peak = (0..k_atoms)
.filter(|&k| in_band(k))
.fold(0.0_f64, |m, k| m.max(row_contrib[k].abs()));
let cutoff = relative_cutoff * peak;
let mut extra: Vec<usize> = (0..k_atoms)
.filter(|&k| {
row_logits[k] <= threshold && in_band(k) && row_contrib[k].abs() > cutoff
})
.collect();
let budget = cap.saturating_sub(hard.len());
if extra.len() > budget {
if budget == 0 {
extra.clear();
} else {
extra.select_nth_unstable_by(budget - 1, |&i, &j| {
row_contrib[j]
.abs()
.partial_cmp(&row_contrib[i].abs())
.unwrap_or(Equal)
});
extra.truncate(budget);
}
}
let mut active: Vec<usize> = hard;
active.extend(extra);
if active.is_empty() {
let best = (0..k_atoms)
.filter(|&k| in_band(k))
.max_by(|&i, &j| {
row_contrib[i]
.abs()
.partial_cmp(&row_contrib[j].abs())
.unwrap_or(Equal)
})
.or_else(|| {
(0..k_atoms).max_by(|&i, &j| {
row_logits[i].partial_cmp(&row_logits[j]).unwrap_or(Equal)
})
});
if let Some(b) = best {
active.push(b);
}
}
active.sort_unstable();
per_row.push(active);
}
Self::from_active_atoms(per_row, coord_dims, coord_offsets_full)
}
pub(crate) fn from_dense_weights(
assignments: &[Array1<f64>],
k_active_cap: usize,
relative_cutoff: f64,
coord_dims: Vec<usize>,
coord_offsets_full: Vec<usize>,
) -> Self {
let cap = k_active_cap.max(1);
let mut per_row = Vec::with_capacity(assignments.len());
for a in assignments {
let k = a.len();
let row_peak = a.iter().fold(0.0_f64, |m, &v| m.max(v.abs()));
let cutoff = relative_cutoff * row_peak;
let mut idx: Vec<usize> = (0..k).collect();
if cap < k {
idx.select_nth_unstable_by(cap - 1, |&i, &j| {
a[j].abs()
.partial_cmp(&a[i].abs())
.unwrap_or(std::cmp::Ordering::Equal)
});
idx.truncate(cap);
}
let mut active: Vec<usize> = idx
.into_iter()
.filter(|&k_idx| a[k_idx].abs() > cutoff)
.collect();
if active.is_empty() {
let top = (0..k).fold(None::<usize>, |best, i| match best {
Some(b) if a[b].abs() >= a[i].abs() => Some(b),
_ => Some(i),
});
if let Some(top) = top {
active.push(top);
}
}
active.sort_unstable();
per_row.push(active);
}
Self::from_active_atoms(per_row, coord_dims, coord_offsets_full)
}
pub(crate) fn from_active_atoms(
active_atoms: Vec<Vec<usize>>,
coord_dims: Vec<usize>,
coord_offsets_full: Vec<usize>,
) -> Self {
let mut coord_starts_all = Vec::with_capacity(active_atoms.len());
for active in &active_atoms {
let mut starts = Vec::with_capacity(active.len());
let mut cursor = active.len();
for &k in active {
starts.push(cursor);
cursor += coord_dims[k];
}
coord_starts_all.push(starts);
}
Self {
active_atoms,
coord_starts: coord_starts_all,
coord_offsets_full,
coord_dims,
}
}
pub fn row_q_active(&self, row: usize) -> usize {
let active = &self.active_atoms[row];
let coord_sum: usize = active.iter().map(|&k| self.coord_dims[k]).sum();
active.len() + coord_sum
}
pub fn expand_row(&self, row: usize, delta_t_row: &[f64], out: &mut [f64]) {
for v in out.iter_mut() {
*v = 0.0;
}
let active = &self.active_atoms[row];
let starts = &self.coord_starts[row];
for (j, &k) in active.iter().enumerate() {
out[k] = delta_t_row[j];
let d = self.coord_dims[k];
let full_off = self.coord_offsets_full[k];
for axis in 0..d {
out[full_off + axis] = delta_t_row[starts[j] + axis];
}
}
}
}
#[cfg(test)]
mod jumprelu_hard_gate_tests {
use super::{JumpReluLayoutParams, SaeRowLayout};
use crate::assignment::{assignment_prior_grad_hdiag, AssignmentMode, SaeAssignment};
use crate::manifold::{
SaeAtomBasisKind, SaeManifoldAtom, SaeManifoldRho, SaeManifoldTerm,
SAE_DENSE_BETA_PENALTY_PROBE_MAX_DIM,
};
use gam_terms::latent::LatentManifold;
use ndarray::{Array1, Array2, Array3};
fn logit_slope(logit: f64) -> f64 {
let a = 1.0 / (1.0 + (-logit).exp());
a * (1.0 - a)
}
#[test]
fn from_jumprelu_block_size_tracks_hard_gate_not_band() {
let n = 2usize;
let k = 6usize;
let (threshold, temperature) = (0.0_f64, 1.0_f64);
let logits = Array2::from_shape_vec(
(n, k),
vec![
1.0, 0.5, -0.3, -20.0, -25.0, -30.0, 0.8, 0.2, -0.4, -22.0, -26.0, -31.0, ],
)
.unwrap();
for &l in logits.iter() {
assert!(crate::assignment::jumprelu_in_optimization_band(l, threshold, temperature));
}
let contribution = logits.mapv(logit_slope);
let coord_dims = vec![1usize; k];
let coord_offsets_full: Vec<usize> = (0..k).map(|i| k + i).collect();
let layout = SaeRowLayout::from_jumprelu(JumpReluLayoutParams {
n,
k_atoms: k,
threshold,
temperature,
logits: &logits,
contribution: &contribution,
k_active_cap: k,
relative_cutoff: 1.0e-3,
coord_dims: coord_dims.clone(),
coord_offsets_full: coord_offsets_full.clone(),
});
for row in 0..n {
assert_eq!(layout.active_atoms[row], vec![0, 1, 2], "row {row}");
assert_eq!(layout.row_q_active(row), 3 * (1 + 1));
assert!(layout.row_q_active(row) < k * (1 + 1));
}
let capped = SaeRowLayout::from_jumprelu(JumpReluLayoutParams {
n,
k_atoms: k,
threshold,
temperature,
logits: &logits,
contribution: &contribution,
k_active_cap: 2,
relative_cutoff: 1.0e-3,
coord_dims,
coord_offsets_full,
});
for row in 0..n {
assert_eq!(capped.active_atoms[row], vec![0, 1], "capped row {row}");
}
}
#[test]
fn jumprelu_compact_gradient_matches_dense_full_band() {
let n = 3usize;
let k = 5usize;
let p = 2usize;
let (threshold, temperature) = (0.0_f64, 1.0_f64);
let logits = Array2::from_shape_vec(
(n, k),
vec![
1.0, 0.3, -0.4, -30.0, -34.0, 0.8, 0.1, -0.6, -28.0, -33.0, 1.2, 0.5, -0.2, -31.0, -35.0, ],
)
.unwrap();
let atoms: Vec<SaeManifoldAtom> = (0..k)
.map(|i| {
let f = (i as f64) + 1.0;
SaeManifoldAtom::new(
format!("atom{i}"),
SaeAtomBasisKind::EuclideanPatch,
1,
Array2::<f64>::from_elem((n, 2), 1.0),
Array3::<f64>::zeros((n, 2, 1)),
Array2::<f64>::from_shape_vec((2, p), vec![0.1 * f, -0.2 * f, 0.15 * f, 0.3 * f])
.unwrap(),
Array2::<f64>::eye(2),
)
.unwrap()
})
.collect();
let coords: Vec<Array2<f64>> = (0..k).map(|_| Array2::<f64>::zeros((n, 1))).collect();
let manifolds = vec![LatentManifold::Euclidean; k];
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits.clone(),
coords,
manifolds,
AssignmentMode::jumprelu(temperature, threshold),
)
.unwrap();
let mut term = SaeManifoldTerm::new(atoms, assignment).unwrap();
let target =
Array2::<f64>::from_shape_fn((n, p), |(r, c)| 0.1 * (r as f64) - 0.05 * (c as f64));
let rho = SaeManifoldRho::new(0.0, -6.0, vec![Array1::<f64>::from_elem(1, 0.0); k]);
let (assignment_grad, _) =
assignment_prior_grad_hdiag(&term.assignment, &rho).unwrap();
let contribution =
Array2::from_shape_fn((n, k), |(r, c)| assignment_grad[r * k + c].abs());
let coord_dims = vec![1usize; k];
let coord_offsets = term.assignment.coord_offsets();
let layout = SaeRowLayout::from_jumprelu(JumpReluLayoutParams {
n,
k_atoms: k,
threshold,
temperature,
logits: &logits,
contribution: &contribution,
k_active_cap: k,
relative_cutoff: 1.0e-3,
coord_dims,
coord_offsets_full: coord_offsets,
});
for row in 0..n {
assert_eq!(layout.active_atoms[row], vec![0, 1, 2], "layout row {row}");
}
let probe = SAE_DENSE_BETA_PENALTY_PROBE_MAX_DIM;
let dense = term
.assemble_arrow_schur_inner(target.view(), &rho, None, 1.0, probe, Some(None))
.unwrap();
let compact = term
.assemble_arrow_schur_inner(
target.view(),
&rho,
None,
1.0,
probe,
Some(Some(layout.clone())),
)
.unwrap();
let default = term
.assemble_arrow_schur_inner(target.view(), &rho, None, 1.0, probe, None)
.unwrap();
let q = term.assignment.row_block_dim();
assert_eq!(dense.rows.len(), n);
assert_eq!(compact.rows.len(), n);
assert_eq!(default.rows.len(), n);
let mut max_diff = 0.0_f64;
let mut saw_drop = false;
for row in 0..n {
let dgt = &dense.rows[row].gt;
assert_eq!(dgt.len(), q, "dense row {row} must be full-q");
assert!(
compact.rows[row].gt.len() < dgt.len(),
"row {row}: compact block ({}) must be smaller than dense ({})",
compact.rows[row].gt.len(),
dgt.len()
);
saw_drop = true;
assert_eq!(
default.rows[row].gt.len(),
compact.rows[row].gt.len(),
"row {row}: default (production) path must match the compact layout"
);
let compact_gt: Vec<f64> = compact.rows[row].gt.iter().copied().collect();
let mut expanded = vec![0.0_f64; q];
layout.expand_row(row, &compact_gt, &mut expanded);
for i in 0..q {
let diff = (expanded[i] - dgt[i]).abs();
max_diff = max_diff.max(diff);
assert!(
diff < 1.0e-8,
"row {row} coord {i}: compact gt {} vs dense {} (diff {diff:e})",
expanded[i],
dgt[i]
);
}
assert!(
dgt[2].abs() > 1.0e-2,
"row {row}: near-threshold band-only prior gradient must be O(0.1), got {}",
dgt[2]
);
}
assert!(saw_drop, "the compact layout must actually drop deep-band atoms");
assert!(max_diff < 1.0e-8, "max full-q gradient diff {max_diff:e}");
}
}