use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use crate::solver::evidence::{HybridAtomCandidate, HybridAtomChoice, select_hybrid_atom};
use crate::terms::analytic_penalties::{
AnalyticPenalty, IBPAssignmentPenalty, IbpHessianDiagThirdChannels,
SoftmaxAssignmentSparsityPenalty, resolve_learnable_weight,
};
use crate::terms::latent_coord::{LatentCoordValues, LatentIdMode, LatentManifold};
use crate::terms::sae_manifold::SaeManifoldRho;
pub(crate) const SAE_ASSIGNMENT_LOGIT_STEP_CAP_TAUS: f64 = 4.0;
pub(crate) const SAE_ATOM_ACTIVE_MASS_FLOOR: f64 = 1.0e-3;
pub(crate) const SAE_ATOM_COLLAPSE_RESEED_BUDGET: usize = 1;
pub(crate) const JUMPRELU_OPTIMIZATION_LOGIT_CUTOFF: f64 = -36.0;
#[inline]
pub(crate) fn jumprelu_in_optimization_band(logit: f64, threshold: f64, temperature: f64) -> bool {
(logit - threshold) / temperature > JUMPRELU_OPTIMIZATION_LOGIT_CUTOFF
}
#[derive(Debug, Clone, Copy)]
pub enum AssignmentMode {
Softmax { temperature: f64, sparsity: f64 },
IBPMap {
temperature: f64,
alpha: f64,
learnable_alpha: bool,
},
JumpReLU { temperature: f64, threshold: f64 },
}
impl AssignmentMode {
#[must_use]
pub fn softmax(temperature: f64) -> Self {
Self::Softmax {
temperature,
sparsity: 1.0,
}
}
#[must_use]
pub fn ibp_map(temperature: f64, alpha: f64, learnable_alpha: bool) -> Self {
Self::IBPMap {
temperature,
alpha,
learnable_alpha,
}
}
#[must_use]
pub fn jumprelu(temperature: f64, threshold: f64) -> Self {
Self::JumpReLU {
temperature,
threshold,
}
}
pub fn temperature(&self) -> f64 {
match *self {
AssignmentMode::Softmax { temperature, .. }
| AssignmentMode::IBPMap { temperature, .. }
| AssignmentMode::JumpReLU { temperature, .. } => temperature,
}
}
pub(crate) fn set_temperature(&mut self, new_temperature: f64) -> Result<(), String> {
if !(new_temperature.is_finite() && new_temperature > 0.0) {
return Err(format!(
"AssignmentMode: temperature must be finite and positive; got {new_temperature}"
));
}
match self {
AssignmentMode::Softmax { temperature, .. }
| AssignmentMode::IBPMap { temperature, .. }
| AssignmentMode::JumpReLU { temperature, .. } => {
*temperature = new_temperature;
}
}
Ok(())
}
pub(crate) fn validate(&self) -> Result<(), String> {
let temperature = self.temperature();
if !(temperature.is_finite() && temperature > 0.0) {
return Err(format!(
"AssignmentMode: temperature must be finite and positive; got {temperature}"
));
}
match *self {
AssignmentMode::Softmax { sparsity, .. } => {
if !(sparsity.is_finite() && sparsity > 0.0) {
return Err(format!(
"AssignmentMode::Softmax: sparsity must be finite and positive; got {sparsity}"
));
}
}
AssignmentMode::IBPMap { alpha, .. } => {
if !(alpha.is_finite() && alpha > 0.0) {
return Err(format!(
"AssignmentMode::IBPMap: alpha must be finite and positive; got {alpha}"
));
}
}
AssignmentMode::JumpReLU { threshold, .. } => {
if !threshold.is_finite() {
return Err(format!(
"AssignmentMode::JumpReLU: threshold must be finite; got {threshold}"
));
}
}
}
Ok(())
}
pub(crate) fn resolved_ibp_alpha(&self, rho: &SaeManifoldRho) -> Option<f64> {
match *self {
AssignmentMode::IBPMap {
alpha,
learnable_alpha,
..
} => Some(if learnable_alpha {
resolve_learnable_weight(alpha, rho.log_lambda_sparse)
} else {
alpha
}),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct SaeAssignment {
pub logits: Array2<f64>,
pub coords: Vec<LatentCoordValues>,
pub mode: AssignmentMode,
}
impl SaeAssignment {
#[must_use = "build error must be handled"]
pub fn new(
logits: Array2<f64>,
coords: Vec<LatentCoordValues>,
temperature: f64,
) -> Result<Self, String> {
Self::with_mode(logits, coords, AssignmentMode::softmax(temperature))
}
#[must_use = "build error must be handled"]
pub fn with_mode(
mut logits: Array2<f64>,
coords: Vec<LatentCoordValues>,
mode: AssignmentMode,
) -> Result<Self, String> {
mode.validate()?;
let n = logits.nrows();
let k = logits.ncols();
if coords.len() != k {
return Err(format!(
"SaeAssignment::new: coords length {} must equal K={k}",
coords.len()
));
}
for (atom, coord) in coords.iter().enumerate() {
if coord.n_obs() != n {
return Err(format!(
"SaeAssignment::new: coord atom {atom} has n_obs={} but logits has {n}",
coord.n_obs()
));
}
}
for row in 0..n {
validate_finite_logits(logits.row(row), row)?;
}
if matches!(mode, AssignmentMode::Softmax { .. }) {
canonicalize_softmax_logits(&mut logits);
}
Ok(Self {
logits,
coords,
mode,
})
}
pub fn n_obs(&self) -> usize {
self.logits.nrows()
}
pub fn k_atoms(&self) -> usize {
self.logits.ncols()
}
pub fn total_coord_dim(&self) -> usize {
self.coords.iter().map(|c| c.latent_dim()).sum()
}
pub fn assignment_coord_dim(&self) -> usize {
match self.mode {
AssignmentMode::Softmax { .. } => self.k_atoms().saturating_sub(1),
AssignmentMode::IBPMap { .. } | AssignmentMode::JumpReLU { .. } => self.k_atoms(),
}
}
pub fn row_block_dim(&self) -> usize {
self.assignment_coord_dim() + self.total_coord_dim()
}
pub fn coord_offsets(&self) -> Vec<usize> {
let mut out = Vec::with_capacity(self.k_atoms());
let mut cursor = self.assignment_coord_dim();
for coord in &self.coords {
out.push(cursor);
cursor += coord.latent_dim();
}
out
}
pub fn assignments(&self) -> Array2<f64> {
let n = self.n_obs();
let k = self.k_atoms();
let mut out = Array2::<f64>::zeros((n, k));
for row in 0..n {
let a = self.assignments_row(row);
for atom in 0..k {
out[[row, atom]] = a[atom];
}
}
out
}
pub fn assignments_row(&self, row: usize) -> Array1<f64> {
self.try_assignments_row(row)
.expect("assignment logits must be finite")
}
pub fn try_assignments_row(&self, row: usize) -> Result<Array1<f64>, String> {
self.try_assignments_row_with_alpha(row, None)
}
pub(crate) fn try_assignments_row_for_rho(
&self,
row: usize,
rho: &SaeManifoldRho,
) -> Result<Array1<f64>, String> {
self.try_assignments_row_with_alpha(row, self.mode.resolved_ibp_alpha(rho))
}
fn try_assignments_row_with_alpha(
&self,
row: usize,
resolved_ibp_alpha: Option<f64>,
) -> Result<Array1<f64>, String> {
validate_finite_logits(self.logits.row(row), row)?;
if self.k_atoms() == 1 && matches!(self.mode, AssignmentMode::Softmax { .. }) {
return Ok(Array1::from_vec(vec![1.0]));
}
match self.mode {
AssignmentMode::Softmax { temperature, .. } => {
Ok(softmax_row(self.logits.row(row), temperature))
}
AssignmentMode::IBPMap {
temperature, alpha, ..
} => Ok(ibp_map_row(
self.logits.row(row),
temperature,
resolved_ibp_alpha.unwrap_or(alpha),
)),
AssignmentMode::JumpReLU {
temperature,
threshold,
} => Ok(jumprelu_row(self.logits.row(row), temperature, threshold)),
}
}
pub(crate) fn persist_resolved_ibp_alpha(&mut self, rho: &SaeManifoldRho) -> bool {
let AssignmentMode::IBPMap {
temperature,
alpha,
learnable_alpha: true,
} = self.mode
else {
return false;
};
let resolved_alpha = resolve_learnable_weight(alpha, rho.log_lambda_sparse);
self.mode = AssignmentMode::IBPMap {
temperature,
alpha: resolved_alpha,
learnable_alpha: false,
};
true
}
pub(crate) fn assignments_for_rho(&self, rho: &SaeManifoldRho) -> Result<Array2<f64>, String> {
let n = self.n_obs();
let k = self.k_atoms();
let mut out = Array2::<f64>::zeros((n, k));
for row in 0..n {
let a = self.try_assignments_row_for_rho(row, rho)?;
for atom in 0..k {
out[[row, atom]] = a[atom];
}
}
Ok(out)
}
pub fn flatten_ext_coords(&self) -> Array1<f64> {
let n = self.n_obs();
let q = self.row_block_dim();
let k = self.k_atoms();
let assignment_dim = self.assignment_coord_dim();
let offsets = self.coord_offsets();
let mut out = Array1::<f64>::zeros(n * q);
for row in 0..n {
let base = row * q;
for atom in 0..assignment_dim {
out[base + atom] = self.logits[[row, atom]];
}
for atom in 0..k {
let d = self.coords[atom].latent_dim();
let t_row = self.coords[atom].row(row);
for axis in 0..d {
out[base + offsets[atom] + axis] = t_row[axis];
}
}
}
out
}
#[must_use = "build error must be handled"]
pub fn from_blocks_with_mode(
logits: Array2<f64>,
coord_blocks: Vec<Array2<f64>>,
mode: AssignmentMode,
) -> Result<Self, String> {
let coords = coord_blocks
.iter()
.map(|c| LatentCoordValues::from_matrix(c.view(), LatentIdMode::None))
.collect();
Self::with_mode(logits, coords, mode)
}
#[must_use = "build error must be handled"]
pub fn from_blocks_with_mode_and_manifolds(
logits: Array2<f64>,
coord_blocks: Vec<Array2<f64>>,
manifolds: Vec<LatentManifold>,
mode: AssignmentMode,
) -> Result<Self, String> {
if coord_blocks.len() != manifolds.len() {
return Err(format!(
"SaeAssignment::from_blocks_with_mode_and_manifolds: coord block length {} != manifold length {}",
coord_blocks.len(),
manifolds.len()
));
}
let coords = coord_blocks
.iter()
.zip(manifolds)
.map(|(c, manifold)| {
LatentCoordValues::from_matrix_with_manifold(c.view(), LatentIdMode::None, manifold)
})
.collect();
Self::with_mode(logits, coords, mode)
}
}
pub(crate) fn sae_sigmoid_derivatives_from_value(
value: f64,
inv_tau: f64,
scale: f64,
) -> (f64, f64, f64) {
let sig = if scale > 0.0 { value / scale } else { 0.0 };
let dz = scale * sig * (1.0 - sig) * inv_tau;
let d2z = scale * sig * (1.0 - sig) * (1.0 - 2.0 * sig) * inv_tau * inv_tau;
(value, dz, d2z)
}
pub(crate) fn neutral_gate_weights(mode: AssignmentMode, k_atoms: usize) -> Array1<f64> {
match mode {
AssignmentMode::Softmax { .. } => Array1::from_elem(k_atoms, 1.0 / (k_atoms.max(1) as f64)),
AssignmentMode::IBPMap {
temperature, alpha, ..
} => ibp_map_row(Array1::<f64>::zeros(k_atoms).view(), temperature, alpha),
AssignmentMode::JumpReLU { .. } => Array1::from_elem(k_atoms, 0.5),
}
}
pub(crate) fn softmax_row(logits: ArrayView1<'_, f64>, temperature: f64) -> Array1<f64> {
let k = logits.len();
let inv_tau = 1.0 / temperature;
let mut max_logit = f64::NEG_INFINITY;
for &v in logits.iter() {
max_logit = max_logit.max(v);
}
let mut out = Array1::<f64>::zeros(k);
let mut sum = 0.0;
for i in 0..k {
let v = ((logits[i] - max_logit) * inv_tau).exp();
out[i] = v;
sum += v;
}
assert!(sum.is_finite() && sum > 0.0);
for v in out.iter_mut() {
*v /= sum;
}
out
}
pub(crate) fn validate_finite_logits(
logits: ArrayView1<'_, f64>,
row: usize,
) -> Result<(), String> {
for (col, &v) in logits.iter().enumerate() {
if !v.is_finite() {
return Err(format!(
"SaeAssignment: non-finite assignment logit at row {row}, atom {col}: {v}"
));
}
}
Ok(())
}
pub(crate) fn canonicalize_softmax_logits(logits: &mut Array2<f64>) {
let k = logits.ncols();
if k == 0 {
return;
}
if k == 1 {
logits.fill(0.0);
return;
}
for row in 0..logits.nrows() {
let reference = logits[[row, k - 1]];
for col in 0..k - 1 {
logits[[row, col]] -= reference;
}
logits[[row, k - 1]] = 0.0;
}
}
pub(crate) fn ibp_stick_breaking_prior(k_atoms: usize, alpha: f64) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(k_atoms);
let log_ratio = (alpha / (alpha + 1.0)).ln();
for k in 0..k_atoms {
let log_pi = (k as f64) * log_ratio;
out[k] = log_pi.exp().max(f64::MIN_POSITIVE);
}
out
}
pub fn ibp_map_row(logits: ArrayView1<'_, f64>, temperature: f64, alpha: f64) -> Array1<f64> {
let prior = ibp_stick_breaking_prior(logits.len(), alpha);
let mut out = Array1::<f64>::zeros(logits.len());
for i in 0..logits.len() {
out[i] = crate::linalg::utils::stable_logistic(logits[i] / temperature) * prior[i];
}
out
}
#[must_use]
pub fn ibp_map_row_value_grad(
logits: ArrayView1<'_, f64>,
temperature: f64,
alpha: f64,
) -> (Array1<f64>, Array1<f64>) {
let prior = ibp_stick_breaking_prior(logits.len(), alpha);
let inv_tau = 1.0 / temperature;
let mut value = Array1::<f64>::zeros(logits.len());
let mut grad = Array1::<f64>::zeros(logits.len());
for i in 0..logits.len() {
let sig = crate::linalg::utils::stable_logistic(logits[i] * inv_tau);
value[i] = sig * prior[i];
grad[i] = sig * (1.0 - sig) * inv_tau * prior[i];
}
(value, grad)
}
pub fn jumprelu_row(logits: ArrayView1<'_, f64>, temperature: f64, threshold: f64) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(logits.len());
for i in 0..logits.len() {
if logits[i] > threshold {
out[i] = crate::linalg::utils::stable_logistic((logits[i] - threshold) / temperature);
}
}
out
}
pub(crate) struct ActiveAtomLogitJvp<'a> {
pub(crate) mode: AssignmentMode,
pub(crate) k: usize,
pub(crate) logit_k: f64,
pub(crate) a_k: f64,
pub(crate) decoded_k: ArrayView1<'a, f64>,
pub(crate) fitted: ArrayView1<'a, f64>,
pub(crate) ibp_prior: Option<&'a [f64]>,
pub(crate) compact_index: usize,
}
pub(crate) fn fill_active_atom_logit_jvp(
input: ActiveAtomLogitJvp<'_>,
jac_compact: &mut Array2<f64>,
) {
let ActiveAtomLogitJvp {
mode,
k,
logit_k,
a_k,
decoded_k,
fitted,
ibp_prior,
compact_index,
} = input;
let p = fitted.len();
match mode {
AssignmentMode::Softmax { temperature, .. } => {
let inv_tau = 1.0 / temperature;
for out_col in 0..p {
jac_compact[[compact_index, out_col]] =
a_k * (decoded_k[out_col] - fitted[out_col]) * inv_tau;
}
}
AssignmentMode::IBPMap { temperature, .. } => {
let inv_tau = 1.0 / temperature;
let prior =
ibp_prior.expect("fill_active_atom_logit_jvp: IBPMap requires precomputed prior");
let pi_k = prior[k];
let sig = if pi_k > 0.0 { a_k / pi_k } else { 0.0 };
let dz = sig * (1.0 - sig) * inv_tau * pi_k;
for out_col in 0..p {
jac_compact[[compact_index, out_col]] = dz * decoded_k[out_col];
}
}
AssignmentMode::JumpReLU {
temperature,
threshold,
} => {
if logit_k <= threshold {
return;
}
let inv_tau = 1.0 / temperature;
let activation = crate::linalg::utils::stable_logistic((logit_k - threshold) * inv_tau);
let da = activation * (1.0 - activation) * inv_tau;
for out_col in 0..p {
jac_compact[[compact_index, out_col]] = da * decoded_k[out_col];
}
}
}
}
pub(crate) fn fill_assignment_logit_jvp_rows(
mode: AssignmentMode,
logits: ArrayView1<'_, f64>,
assignments: ArrayView1<'_, f64>,
decoded: ArrayView2<'_, f64>,
fitted: ArrayView1<'_, f64>,
ibp_prior: Option<&[f64]>,
local_jac: &mut Array2<f64>,
) {
match mode {
AssignmentMode::Softmax { temperature, .. } => {
if assignments.len() == 1 {
return;
}
let inv_tau = 1.0 / temperature;
for logit_col in 0..assignments.len() - 1 {
for out_col in 0..fitted.len() {
local_jac[[logit_col, out_col]] = assignments[logit_col]
* (decoded[[logit_col, out_col]] - fitted[out_col])
* inv_tau;
}
}
}
AssignmentMode::IBPMap { temperature, .. } => {
let inv_tau = 1.0 / temperature;
let prior = ibp_prior
.expect("fill_assignment_logit_jvp_rows: IBPMap requires precomputed prior");
for logit_col in 0..assignments.len() {
let pi_k = prior[logit_col];
let a_k = assignments[logit_col];
let sig = if pi_k > 0.0 { a_k / pi_k } else { 0.0 };
let dz = sig * (1.0 - sig) * inv_tau * pi_k;
for out_col in 0..fitted.len() {
local_jac[[logit_col, out_col]] = dz * decoded[[logit_col, out_col]];
}
}
}
AssignmentMode::JumpReLU {
temperature,
threshold,
} => {
let inv_tau = 1.0 / temperature;
for logit_col in 0..assignments.len() {
if logits[logit_col] <= threshold {
continue;
}
let activation = crate::linalg::utils::stable_logistic(
(logits[logit_col] - threshold) * inv_tau,
);
let da = activation * (1.0 - activation) * inv_tau;
for out_col in 0..fitted.len() {
local_jac[[logit_col, out_col]] = da * decoded[[logit_col, out_col]];
}
}
}
}
}
pub(crate) fn flat_logits(logits: ArrayView2<'_, f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(logits.len());
for row in 0..logits.nrows() {
let start = row * logits.ncols();
for col in 0..logits.ncols() {
out[start + col] = logits[[row, col]];
}
}
out
}
pub(crate) fn assignment_prior_value(assignment: &SaeAssignment, rho: &SaeManifoldRho) -> f64 {
for row in 0..assignment.n_obs() {
validate_finite_logits(assignment.logits.row(row), row)
.expect("assignment logits must be finite");
}
let target = flat_logits(assignment.logits.view());
if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
return 0.0;
}
match assignment.mode {
AssignmentMode::Softmax {
temperature,
sparsity,
} => {
let penalty = SoftmaxAssignmentSparsityPenalty::new(assignment.k_atoms(), temperature);
let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse + sparsity.ln()]);
penalty.value(target.view(), rho_view.view())
}
AssignmentMode::IBPMap {
temperature,
alpha,
learnable_alpha,
} => {
let mut penalty = IBPAssignmentPenalty::new(
assignment.k_atoms(),
alpha,
temperature,
learnable_alpha,
);
let rho_view = if learnable_alpha {
Array1::from_vec(vec![rho.log_lambda_sparse])
} else {
penalty.weight = rho.lambda_sparse();
Array1::zeros(0)
};
penalty.value(target.view(), rho_view.view())
}
AssignmentMode::JumpReLU {
temperature,
threshold,
} => {
let sparsity_strength = rho.lambda_sparse();
let mut acc = 0.0;
for &logit in target.iter() {
if jumprelu_in_optimization_band(logit, threshold, temperature) {
acc += crate::linalg::utils::stable_logistic((logit - threshold) / temperature);
}
}
sparsity_strength * acc
}
}
}
pub(crate) fn assignment_prior_log_strength_derivative(
assignment: &SaeAssignment,
rho: &SaeManifoldRho,
) -> f64 {
for row in 0..assignment.n_obs() {
validate_finite_logits(assignment.logits.row(row), row)
.expect("assignment logits must be finite");
}
let target = flat_logits(assignment.logits.view());
if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
return 0.0;
}
match assignment.mode {
AssignmentMode::Softmax { .. } | AssignmentMode::JumpReLU { .. } => {
assignment_prior_value(assignment, rho)
}
AssignmentMode::IBPMap {
temperature,
alpha,
learnable_alpha,
} => {
let mut penalty = IBPAssignmentPenalty::new(
assignment.k_atoms(),
alpha,
temperature,
learnable_alpha,
);
if learnable_alpha {
let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse]);
penalty.grad_rho(target.view(), rho_view.view())[0]
} else {
penalty.weight = rho.lambda_sparse();
penalty.value(target.view(), Array1::<f64>::zeros(0).view())
}
}
}
}
pub(crate) fn assignment_prior_log_strength_hdiag(
assignment: &SaeAssignment,
rho: &SaeManifoldRho,
) -> Result<Array1<f64>, String> {
for row in 0..assignment.n_obs() {
validate_finite_logits(assignment.logits.row(row), row)?;
}
let target = flat_logits(assignment.logits.view());
if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
return Ok(Array1::<f64>::zeros(target.len()));
}
match assignment.mode {
AssignmentMode::Softmax {
temperature,
sparsity,
} => {
let penalty = SoftmaxAssignmentSparsityPenalty::new(assignment.k_atoms(), temperature);
let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse + sparsity.ln()]);
penalty
.hessian_diag(target.view(), rho_view.view())
.ok_or_else(|| {
"softmax assignment log-strength hessian diag unavailable".to_string()
})
}
AssignmentMode::JumpReLU {
temperature,
threshold,
} => {
let sparsity_strength = rho.lambda_sparse();
let inv_tau = 1.0 / temperature;
let inv_tau2 = inv_tau * inv_tau;
let mut d = Array1::<f64>::zeros(target.len());
for idx in 0..target.len() {
let logit = target[idx];
if !jumprelu_in_optimization_band(logit, threshold, temperature) {
continue;
}
let activation =
crate::linalg::utils::stable_logistic((logit - threshold) * inv_tau);
let slope = activation * (1.0 - activation);
d[idx] = sparsity_strength * slope * (1.0 - 2.0 * activation) * inv_tau2;
}
Ok(d)
}
AssignmentMode::IBPMap {
temperature,
alpha,
learnable_alpha,
} => {
let mut penalty = IBPAssignmentPenalty::new(
assignment.k_atoms(),
alpha,
temperature,
learnable_alpha,
);
if learnable_alpha {
let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse]);
Ok(penalty.hessian_diag_log_alpha_derivative(target.view(), rho_view.view()))
} else {
penalty.weight = rho.lambda_sparse();
penalty
.hessian_diag(target.view(), Array1::<f64>::zeros(0).view())
.ok_or_else(|| {
"IBP assignment log-strength hessian diag unavailable".to_string()
})
}
}
}
}
pub(crate) fn assignment_prior_log_strength_target_mixed(
assignment: &SaeAssignment,
rho: &SaeManifoldRho,
) -> Result<Array1<f64>, String> {
for row in 0..assignment.n_obs() {
validate_finite_logits(assignment.logits.row(row), row)?;
}
let target = flat_logits(assignment.logits.view());
if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
return Ok(Array1::<f64>::zeros(target.len()));
}
match assignment.mode {
AssignmentMode::IBPMap {
temperature,
alpha,
learnable_alpha: true,
} => {
let penalty = IBPAssignmentPenalty::new(assignment.k_atoms(), alpha, temperature, true);
let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse]);
Ok(penalty.log_alpha_target_mixed_derivative(target.view(), rho_view.view()))
}
_ => Ok(assignment_prior_grad_hdiag(assignment, rho)?.0),
}
}
pub(crate) fn assignment_prior_grad_hdiag(
assignment: &SaeAssignment,
rho: &SaeManifoldRho,
) -> Result<(Array1<f64>, Array1<f64>), String> {
for row in 0..assignment.n_obs() {
validate_finite_logits(assignment.logits.row(row), row)?;
}
let target = flat_logits(assignment.logits.view());
let mut grad = Array1::<f64>::zeros(target.len());
let mut diag = Array1::<f64>::zeros(target.len());
if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
return Ok((grad, diag));
}
let (sparsity_grad, sparsity_diag) = match assignment.mode {
AssignmentMode::Softmax {
temperature,
sparsity,
} => {
let penalty = SoftmaxAssignmentSparsityPenalty::new(assignment.k_atoms(), temperature);
let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse + sparsity.ln()]);
let g = penalty.grad_target(target.view(), rho_view.view());
let d = penalty
.hessian_diag(target.view(), rho_view.view())
.ok_or_else(|| "softmax assignment hessian diag unavailable".to_string())?;
(g, d)
}
AssignmentMode::IBPMap {
temperature,
alpha,
learnable_alpha,
} => {
let mut penalty = IBPAssignmentPenalty::new(
assignment.k_atoms(),
alpha,
temperature,
learnable_alpha,
);
let rho_view = if learnable_alpha {
Array1::from_vec(vec![rho.log_lambda_sparse])
} else {
penalty.weight = rho.lambda_sparse();
Array1::zeros(0)
};
let g = penalty.grad_target(target.view(), rho_view.view());
let d = penalty
.hessian_diag(target.view(), rho_view.view())
.ok_or_else(|| "IBP assignment hessian diag unavailable".to_string())?;
(g, d)
}
AssignmentMode::JumpReLU {
temperature,
threshold,
} => {
let sparsity_strength = rho.lambda_sparse();
let inv_tau = 1.0 / temperature;
let inv_tau2 = inv_tau * inv_tau;
let mut g = Array1::<f64>::zeros(target.len());
let mut d = Array1::<f64>::zeros(target.len());
for idx in 0..target.len() {
let logit = target[idx];
if !jumprelu_in_optimization_band(logit, threshold, temperature) {
continue;
}
let activation =
crate::linalg::utils::stable_logistic((logit - threshold) * inv_tau);
let slope = activation * (1.0 - activation);
g[idx] = sparsity_strength * slope * inv_tau;
d[idx] = sparsity_strength * slope * (1.0 - 2.0 * activation) * inv_tau2;
}
(g, d)
}
};
grad += &sparsity_grad;
diag += &sparsity_diag;
Ok((grad, diag))
}
pub(crate) fn ibp_assignment_third_channels(
assignment: &SaeAssignment,
rho: &SaeManifoldRho,
) -> Result<Option<IbpHessianDiagThirdChannels>, String> {
let AssignmentMode::IBPMap {
temperature,
alpha,
learnable_alpha,
} = assignment.mode
else {
return Ok(None);
};
for row in 0..assignment.n_obs() {
validate_finite_logits(assignment.logits.row(row), row)?;
}
let target = flat_logits(assignment.logits.view());
let mut penalty =
IBPAssignmentPenalty::new(assignment.k_atoms(), alpha, temperature, learnable_alpha);
let rho_view = if learnable_alpha {
Array1::from_vec(vec![rho.log_lambda_sparse])
} else {
penalty.weight = rho.lambda_sparse();
Array1::zeros(0)
};
Ok(Some(penalty.hessian_diag_logit_third_channels(
target.view(),
rho_view.view(),
)))
}
pub fn select_hybrid_atom_parameterization(
manifold: &LatentManifold,
curved: Option<HybridAtomCandidate>,
linear: HybridAtomCandidate,
) -> HybridAtomChoice {
let curved = if manifold.is_euclidean() {
None
} else {
curved
};
let candidates: Vec<HybridAtomCandidate> = match curved {
Some(c) => vec![linear, c],
None => vec![linear],
};
select_hybrid_atom(&candidates).expect("hybrid atom slot always has the linear candidate")
}
#[cfg(test)]
mod hybrid_split_tests {
use super::*;
use crate::solver::evidence::HybridAtomParam;
#[test]
fn flat_chart_drops_curved_candidate_and_keeps_linear() {
let linear = HybridAtomCandidate::linear(100.0, 2);
let curved = HybridAtomCandidate::curved(1, 1.0, 5, Some(2.0));
let choice =
select_hybrid_atom_parameterization(&LatentManifold::Euclidean, Some(curved), linear);
assert!(choice.param.is_linear());
}
#[test]
fn curveable_chart_selects_curved_when_turning_pays() {
let linear = HybridAtomCandidate::linear(100.0, 2);
let curved = HybridAtomCandidate::curved(1, 70.0, 5, Some(2.0 * std::f64::consts::PI));
let choice = select_hybrid_atom_parameterization(
&LatentManifold::Circle {
period: 2.0 * std::f64::consts::PI,
},
Some(curved),
linear,
);
assert_eq!(choice.param, HybridAtomParam::Curved { latent_dim: 1 });
}
#[test]
fn curveable_chart_falls_back_to_linear_when_no_curved_candidate() {
let linear = HybridAtomCandidate::linear(33.0, 2);
let choice = select_hybrid_atom_parameterization(
&LatentManifold::Circle {
period: 2.0 * std::f64::consts::PI,
},
None,
linear,
);
assert!(choice.param.is_linear());
assert_eq!(choice.num_parameters, 2);
}
}