use super::*;
#[derive(Debug, Clone)]
pub struct SaeRowLayout {
pub active_atoms: Vec<Vec<usize>>,
pub coord_starts: Vec<Vec<usize>>,
pub coord_offsets_full: Vec<usize>,
pub coord_dims: Vec<usize>,
}
impl SaeRowLayout {
pub(crate) fn from_jumprelu(
n: usize,
k_atoms: usize,
threshold: f64,
temperature: f64,
logits: &Array2<f64>,
coord_dims: Vec<usize>,
coord_offsets_full: Vec<usize>,
) -> Self {
let mut per_row = Vec::with_capacity(n);
for row in 0..n {
let row_logits = logits.row(row);
let active: Vec<usize> = (0..k_atoms)
.filter(|&k| {
crate::assignment::jumprelu_in_optimization_band(
row_logits[k],
threshold,
temperature,
)
})
.collect();
per_row.push(active);
}
Self::from_active_atoms(per_row, coord_dims, coord_offsets_full)
}
pub(crate) fn from_dense_weights(
assignments: &[Array1<f64>],
k_active_cap: usize,
relative_cutoff: f64,
coord_dims: Vec<usize>,
coord_offsets_full: Vec<usize>,
) -> Self {
let cap = k_active_cap.max(1);
let mut per_row = Vec::with_capacity(assignments.len());
for a in assignments {
let k = a.len();
let row_peak = a.iter().fold(0.0_f64, |m, &v| m.max(v.abs()));
let cutoff = relative_cutoff * row_peak;
let mut idx: Vec<usize> = (0..k).collect();
if cap < k {
idx.select_nth_unstable_by(cap - 1, |&i, &j| {
a[j].abs()
.partial_cmp(&a[i].abs())
.unwrap_or(std::cmp::Ordering::Equal)
});
idx.truncate(cap);
}
let mut active: Vec<usize> = idx
.into_iter()
.filter(|&k_idx| a[k_idx].abs() > cutoff)
.collect();
if active.is_empty() {
let top = (0..k).fold(None::<usize>, |best, i| match best {
Some(b) if a[b].abs() >= a[i].abs() => Some(b),
_ => Some(i),
});
if let Some(top) = top {
active.push(top);
}
}
active.sort_unstable();
per_row.push(active);
}
Self::from_active_atoms(per_row, coord_dims, coord_offsets_full)
}
pub(crate) fn from_active_atoms(
active_atoms: Vec<Vec<usize>>,
coord_dims: Vec<usize>,
coord_offsets_full: Vec<usize>,
) -> Self {
let mut coord_starts_all = Vec::with_capacity(active_atoms.len());
for active in &active_atoms {
let mut starts = Vec::with_capacity(active.len());
let mut cursor = active.len();
for &k in active {
starts.push(cursor);
cursor += coord_dims[k];
}
coord_starts_all.push(starts);
}
Self {
active_atoms,
coord_starts: coord_starts_all,
coord_offsets_full,
coord_dims,
}
}
pub fn row_q_active(&self, row: usize) -> usize {
let active = &self.active_atoms[row];
let coord_sum: usize = active.iter().map(|&k| self.coord_dims[k]).sum();
active.len() + coord_sum
}
pub fn expand_row(&self, row: usize, delta_t_row: &[f64], out: &mut [f64]) {
for v in out.iter_mut() {
*v = 0.0;
}
let active = &self.active_atoms[row];
let starts = &self.coord_starts[row];
for (j, &k) in active.iter().enumerate() {
out[k] = delta_t_row[j];
let d = self.coord_dims[k];
let full_off = self.coord_offsets_full[k];
for axis in 0..d {
out[full_off + axis] = delta_t_row[starts[j] + axis];
}
}
}
}