use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use crate::solver::evidence::{HybridAtomCandidate, HybridAtomChoice, select_hybrid_atom};
use crate::terms::analytic_penalties::{
AnalyticPenalty, IBPAssignmentPenalty, IbpHessianDiagThirdChannels,
SoftmaxAssignmentSparsityPenalty, resolve_learnable_weight,
};
use crate::terms::latent::{LatentCoordValues, LatentIdMode, LatentManifold};
use crate::terms::sae::manifold::SaeManifoldRho;
pub(crate) const SAE_ASSIGNMENT_LOGIT_STEP_CAP_TAUS: f64 = 4.0;
pub(crate) const SAE_ATOM_ACTIVE_MASS_FLOOR: f64 = 1.0e-3;
pub(crate) const SAE_ATOM_COLLAPSE_RESEED_BUDGET: usize = 1;
pub(crate) const SAE_ATOM_DECODER_NORM_COLLAPSE_RATIO: f64 = 1.0e-3;
pub(crate) const SAE_DICTIONARY_COLLAPSE_EV_FLOOR: f64 = 0.28;
pub(crate) const SAE_DICTIONARY_COCOLLAPSE_RESEED_BUDGET: usize = 3;
pub(crate) const JUMPRELU_OPTIMIZATION_LOGIT_CUTOFF: f64 = -36.0;
#[inline]
pub(crate) fn jumprelu_in_optimization_band(logit: f64, threshold: f64, temperature: f64) -> bool {
(logit - threshold) / temperature > JUMPRELU_OPTIMIZATION_LOGIT_CUTOFF
}
#[derive(Debug, Clone, Copy)]
pub enum AssignmentMode {
Softmax { temperature: f64, sparsity: f64 },
IBPMap {
temperature: f64,
alpha: f64,
learnable_alpha: bool,
},
JumpReLU { temperature: f64, threshold: f64 },
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RoutingPredictor {
Snapshot,
ChartGeometry,
}
impl AssignmentMode {
#[must_use]
pub fn softmax(temperature: f64) -> Self {
Self::Softmax {
temperature,
sparsity: 1.0,
}
}
#[must_use]
pub fn ibp_map(temperature: f64, alpha: f64, learnable_alpha: bool) -> Self {
Self::IBPMap {
temperature,
alpha,
learnable_alpha,
}
}
#[must_use]
pub fn jumprelu(temperature: f64, threshold: f64) -> Self {
Self::JumpReLU {
temperature,
threshold,
}
}
pub fn temperature(&self) -> f64 {
match *self {
AssignmentMode::Softmax { temperature, .. }
| AssignmentMode::IBPMap { temperature, .. }
| AssignmentMode::JumpReLU { temperature, .. } => temperature,
}
}
pub(crate) fn set_temperature(&mut self, new_temperature: f64) -> Result<(), String> {
if !(new_temperature.is_finite() && new_temperature > 0.0) {
return Err(format!(
"AssignmentMode: temperature must be finite and positive; got {new_temperature}"
));
}
match self {
AssignmentMode::Softmax { temperature, .. }
| AssignmentMode::IBPMap { temperature, .. }
| AssignmentMode::JumpReLU { temperature, .. } => {
*temperature = new_temperature;
}
}
Ok(())
}
pub(crate) fn validate(&self) -> Result<(), String> {
let temperature = self.temperature();
if !(temperature.is_finite() && temperature > 0.0) {
return Err(format!(
"AssignmentMode: temperature must be finite and positive; got {temperature}"
));
}
match *self {
AssignmentMode::Softmax { sparsity, .. } => {
if !(sparsity.is_finite() && sparsity > 0.0) {
return Err(format!(
"AssignmentMode::Softmax: sparsity must be finite and positive; got {sparsity}"
));
}
}
AssignmentMode::IBPMap { alpha, .. } => {
if !(alpha.is_finite() && alpha > 0.0) {
return Err(format!(
"AssignmentMode::IBPMap: alpha must be finite and positive; got {alpha}"
));
}
}
AssignmentMode::JumpReLU { threshold, .. } => {
if !threshold.is_finite() {
return Err(format!(
"AssignmentMode::JumpReLU: threshold must be finite; got {threshold}"
));
}
}
}
Ok(())
}
pub(crate) fn resolved_ibp_alpha(&self, rho: &SaeManifoldRho) -> Option<f64> {
match *self {
AssignmentMode::IBPMap {
alpha,
learnable_alpha,
..
} => Some(if learnable_alpha {
resolve_learnable_weight(alpha, rho.log_lambda_sparse)
} else {
alpha
}),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct SaeAssignment {
pub logits: Array2<f64>,
pub coords: Vec<LatentCoordValues>,
pub mode: AssignmentMode,
pub ungated: Vec<bool>,
pub frozen_logits: Option<Array2<f64>>,
}
impl SaeAssignment {
#[must_use = "build error must be handled"]
pub fn new(
logits: Array2<f64>,
coords: Vec<LatentCoordValues>,
temperature: f64,
) -> Result<Self, String> {
Self::with_mode(logits, coords, AssignmentMode::softmax(temperature))
}
#[must_use = "build error must be handled"]
pub fn with_mode(
mut logits: Array2<f64>,
coords: Vec<LatentCoordValues>,
mode: AssignmentMode,
) -> Result<Self, String> {
mode.validate()?;
let n = logits.nrows();
let k = logits.ncols();
if coords.len() != k {
return Err(format!(
"SaeAssignment::new: coords length {} must equal K={k}",
coords.len()
));
}
for (atom, coord) in coords.iter().enumerate() {
if coord.n_obs() != n {
return Err(format!(
"SaeAssignment::new: coord atom {atom} has n_obs={} but logits has {n}",
coord.n_obs()
));
}
}
for row in 0..n {
validate_finite_logits(logits.row(row), row)?;
}
if matches!(mode, AssignmentMode::Softmax { .. }) {
canonicalize_softmax_logits(&mut logits);
}
Ok(Self {
logits,
coords,
mode,
ungated: vec![false; k],
frozen_logits: None,
})
}
#[must_use = "build error must be handled"]
pub fn with_frozen_routing(
mut self,
predicted: Option<Array2<f64>>,
) -> Result<Self, String> {
if let Some(ref p) = predicted {
if p.dim() != (self.n_obs(), self.k_atoms()) {
return Err(format!(
"SaeAssignment::with_frozen_routing: predicted shape {:?} must be ({}, {})",
p.dim(),
self.n_obs(),
self.k_atoms()
));
}
if matches!(self.mode, AssignmentMode::Softmax { .. }) {
return Err(
"SaeAssignment::with_frozen_routing: frozen routing under Softmax is rejected \
— the coupled simplex's entropy majorizer is assembled over the logits, which \
a frozen (non-optimized) routing would leave inconsistent; this separable-mode \
contract supports IBP-MAP and JumpReLU, whose per-atom gates have no \
simplex-coupled curvature to skip"
.to_string(),
);
}
for row in 0..p.nrows() {
validate_finite_logits(p.row(row), row)?;
}
}
self.frozen_logits = predicted;
Ok(self)
}
pub fn routing_is_frozen(&self) -> bool {
self.frozen_logits.is_some()
}
pub(crate) fn routing_logits_row(&self, row: usize) -> ArrayView1<'_, f64> {
match self.frozen_logits {
Some(ref f) => f.row(row),
None => self.logits.row(row),
}
}
pub(crate) fn logit_is_fixed(&self, k: usize) -> bool {
self.routing_is_frozen() || self.ungated.get(k).copied().unwrap_or(false)
}
pub(crate) fn fixed_logit_mask(&self) -> Vec<bool> {
if self.routing_is_frozen() {
vec![true; self.k_atoms()]
} else {
self.ungated.clone()
}
}
#[must_use = "build error must be handled"]
pub fn freeze_routing_from_current_logits(self) -> Result<Self, String> {
let snapshot = self.logits.clone();
self.with_frozen_routing(Some(snapshot))
}
pub fn freeze_routing_in_place(&mut self) -> Result<(), String> {
if matches!(self.mode, AssignmentMode::Softmax { .. }) {
return Err(
"SaeAssignment::freeze_routing_in_place: frozen routing under Softmax is rejected \
(coupled-simplex entropy-majorizer); use IBP-MAP or JumpReLU"
.to_string(),
);
}
let snapshot = self.logits.clone();
for row in 0..snapshot.nrows() {
validate_finite_logits(snapshot.row(row), row)?;
}
self.frozen_logits = Some(snapshot);
Ok(())
}
pub fn set_frozen_routing_in_place(&mut self, predicted: Array2<f64>) -> Result<(), String> {
if predicted.dim() != (self.n_obs(), self.k_atoms()) {
return Err(format!(
"SaeAssignment::set_frozen_routing_in_place: predicted shape {:?} must be ({}, {})",
predicted.dim(),
self.n_obs(),
self.k_atoms()
));
}
if matches!(self.mode, AssignmentMode::Softmax { .. }) {
return Err(
"SaeAssignment::set_frozen_routing_in_place: frozen routing under Softmax is \
rejected (coupled-simplex entropy-majorizer); use IBP-MAP or JumpReLU"
.to_string(),
);
}
for row in 0..predicted.nrows() {
validate_finite_logits(predicted.row(row), row)?;
}
self.frozen_logits = Some(predicted);
Ok(())
}
pub fn thaw_routing(&mut self) {
self.frozen_logits = None;
}
#[must_use = "build error must be handled"]
pub fn with_ungated(mut self, flags: Vec<bool>) -> Result<Self, String> {
if flags.len() != self.k_atoms() {
return Err(format!(
"SaeAssignment::with_ungated: flags length {} must equal K={}",
flags.len(),
self.k_atoms()
));
}
if matches!(self.mode, AssignmentMode::Softmax { .. }) && flags.iter().any(|&u| u) {
return Err(
"SaeAssignment::with_ungated: an ungated atom under Softmax routing is \
rejected — the coupled simplex requires a gated-subset renormalization \
reflected in the logit-JVP and entropy majorizer, which this separable-mode \
contract does not perform; route a dense background tier as IBP-MAP or JumpReLU"
.to_string(),
);
}
self.ungated = flags;
Ok(self)
}
pub fn has_ungated(&self) -> bool {
self.ungated.iter().any(|&u| u)
}
pub fn n_obs(&self) -> usize {
self.logits.nrows()
}
pub fn k_atoms(&self) -> usize {
self.logits.ncols()
}
pub fn total_coord_dim(&self) -> usize {
self.coords.iter().map(|c| c.latent_dim()).sum()
}
pub fn assignment_coord_dim(&self) -> usize {
match self.mode {
AssignmentMode::Softmax { .. } => self.k_atoms().saturating_sub(1),
AssignmentMode::IBPMap { .. } | AssignmentMode::JumpReLU { .. } => self.k_atoms(),
}
}
pub fn row_block_dim(&self) -> usize {
self.assignment_coord_dim() + self.total_coord_dim()
}
pub fn coord_offsets(&self) -> Vec<usize> {
let mut out = Vec::with_capacity(self.k_atoms());
let mut cursor = self.assignment_coord_dim();
for coord in &self.coords {
out.push(cursor);
cursor += coord.latent_dim();
}
out
}
pub fn assignments(&self) -> Array2<f64> {
let n = self.n_obs();
let k = self.k_atoms();
let mut out = Array2::<f64>::zeros((n, k));
for row in 0..n {
let a = self.assignments_row(row);
for atom in 0..k {
out[[row, atom]] = a[atom];
}
}
out
}
pub fn assignments_row(&self, row: usize) -> Array1<f64> {
self.try_assignments_row(row)
.expect("assignment logits must be finite")
}
pub fn try_assignments_row(&self, row: usize) -> Result<Array1<f64>, String> {
self.try_assignments_row_with_alpha(row, None)
}
pub(crate) fn try_assignments_row_for_rho(
&self,
row: usize,
rho: &SaeManifoldRho,
) -> Result<Array1<f64>, String> {
self.try_assignments_row_with_alpha(row, self.mode.resolved_ibp_alpha(rho))
}
fn try_assignments_row_with_alpha(
&self,
row: usize,
resolved_ibp_alpha: Option<f64>,
) -> Result<Array1<f64>, String> {
let routing = self.routing_logits_row(row);
validate_finite_logits(routing, row)?;
if self.k_atoms() == 1 && matches!(self.mode, AssignmentMode::Softmax { .. }) {
return Ok(Array1::from_vec(vec![1.0]));
}
let mut row_gates = match self.mode {
AssignmentMode::Softmax { temperature, .. } => softmax_row(routing, temperature),
AssignmentMode::IBPMap {
temperature, alpha, ..
} => ibp_map_row(routing, temperature, resolved_ibp_alpha.unwrap_or(alpha)),
AssignmentMode::JumpReLU {
temperature,
threshold,
} => jumprelu_row(routing, temperature, threshold),
};
if self.has_ungated() {
for (k, gate) in row_gates.iter_mut().enumerate() {
if self.ungated[k] {
*gate = 1.0;
}
}
}
Ok(row_gates)
}
pub(crate) fn persist_resolved_ibp_alpha(&mut self, rho: &SaeManifoldRho) -> bool {
let AssignmentMode::IBPMap {
temperature,
alpha,
learnable_alpha: true,
} = self.mode
else {
return false;
};
let resolved_alpha = resolve_learnable_weight(alpha, rho.log_lambda_sparse);
self.mode = AssignmentMode::IBPMap {
temperature,
alpha: resolved_alpha,
learnable_alpha: false,
};
true
}
pub(crate) fn assignments_for_rho(&self, rho: &SaeManifoldRho) -> Result<Array2<f64>, String> {
let n = self.n_obs();
let k = self.k_atoms();
let mut out = Array2::<f64>::zeros((n, k));
for row in 0..n {
let a = self.try_assignments_row_for_rho(row, rho)?;
for atom in 0..k {
out[[row, atom]] = a[atom];
}
}
Ok(out)
}
pub fn flatten_ext_coords(&self) -> Array1<f64> {
let n = self.n_obs();
let q = self.row_block_dim();
let k = self.k_atoms();
let assignment_dim = self.assignment_coord_dim();
let offsets = self.coord_offsets();
let mut out = Array1::<f64>::zeros(n * q);
for row in 0..n {
let base = row * q;
for atom in 0..assignment_dim {
out[base + atom] = self.logits[[row, atom]];
}
for atom in 0..k {
let d = self.coords[atom].latent_dim();
let t_row = self.coords[atom].row(row);
for axis in 0..d {
out[base + offsets[atom] + axis] = t_row[axis];
}
}
}
out
}
#[must_use = "build error must be handled"]
pub fn from_blocks_with_mode(
logits: Array2<f64>,
coord_blocks: Vec<Array2<f64>>,
mode: AssignmentMode,
) -> Result<Self, String> {
let coords = coord_blocks
.iter()
.map(|c| LatentCoordValues::from_matrix(c.view(), LatentIdMode::None))
.collect();
Self::with_mode(logits, coords, mode)
}
#[must_use = "build error must be handled"]
pub fn from_blocks_with_mode_and_manifolds(
logits: Array2<f64>,
coord_blocks: Vec<Array2<f64>>,
manifolds: Vec<LatentManifold>,
mode: AssignmentMode,
) -> Result<Self, String> {
if coord_blocks.len() != manifolds.len() {
return Err(format!(
"SaeAssignment::from_blocks_with_mode_and_manifolds: coord block length {} != manifold length {}",
coord_blocks.len(),
manifolds.len()
));
}
let coords = coord_blocks
.iter()
.zip(manifolds)
.map(|(c, manifold)| {
LatentCoordValues::from_matrix_with_manifold(c.view(), LatentIdMode::None, manifold)
})
.collect();
Self::with_mode(logits, coords, mode)
}
}
pub(crate) fn neutral_gate_weights(mode: AssignmentMode, k_atoms: usize) -> Array1<f64> {
match mode {
AssignmentMode::Softmax { .. } => Array1::from_elem(k_atoms, 1.0 / (k_atoms.max(1) as f64)),
AssignmentMode::IBPMap {
temperature, alpha, ..
} => ibp_map_row(Array1::<f64>::zeros(k_atoms).view(), temperature, alpha),
AssignmentMode::JumpReLU { .. } => Array1::from_elem(k_atoms, 0.5),
}
}
pub(crate) fn softmax_row(logits: ArrayView1<'_, f64>, temperature: f64) -> Array1<f64> {
let k = logits.len();
let inv_tau = 1.0 / temperature;
let mut max_logit = f64::NEG_INFINITY;
for &v in logits.iter() {
max_logit = max_logit.max(v);
}
let mut out = Array1::<f64>::zeros(k);
let mut sum = 0.0;
for i in 0..k {
let v = ((logits[i] - max_logit) * inv_tau).exp();
out[i] = v;
sum += v;
}
assert!(sum.is_finite() && sum > 0.0);
for v in out.iter_mut() {
*v /= sum;
}
out
}
pub(crate) fn validate_finite_logits(
logits: ArrayView1<'_, f64>,
row: usize,
) -> Result<(), String> {
for (col, &v) in logits.iter().enumerate() {
if !v.is_finite() {
return Err(format!(
"SaeAssignment: non-finite assignment logit at row {row}, atom {col}: {v}"
));
}
}
Ok(())
}
pub(crate) fn canonicalize_softmax_logits(logits: &mut Array2<f64>) {
let k = logits.ncols();
if k == 0 {
return;
}
if k == 1 {
logits.fill(0.0);
return;
}
for row in 0..logits.nrows() {
let reference = logits[[row, k - 1]];
for col in 0..k - 1 {
logits[[row, col]] -= reference;
}
logits[[row, k - 1]] = 0.0;
}
}
pub(crate) fn ordered_geometric_shrinkage_prior(k_atoms: usize, alpha: f64) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(k_atoms);
let log_ratio = (alpha / (alpha + 1.0)).ln();
for k in 0..k_atoms {
let log_pi = ((k + 1) as f64) * log_ratio;
out[k] = log_pi.exp().max(f64::MIN_POSITIVE);
}
out
}
pub fn ibp_map_row(logits: ArrayView1<'_, f64>, temperature: f64, alpha: f64) -> Array1<f64> {
let prior = ordered_geometric_shrinkage_prior(logits.len(), alpha);
let mut out = Array1::<f64>::zeros(logits.len());
for i in 0..logits.len() {
out[i] = crate::linalg::utils::stable_logistic(logits[i] / temperature) * prior[i];
}
out
}
#[must_use]
pub fn ibp_map_row_value_grad(
logits: ArrayView1<'_, f64>,
temperature: f64,
alpha: f64,
) -> (Array1<f64>, Array1<f64>) {
let prior = ordered_geometric_shrinkage_prior(logits.len(), alpha);
let inv_tau = 1.0 / temperature;
let mut value = Array1::<f64>::zeros(logits.len());
let mut grad = Array1::<f64>::zeros(logits.len());
for i in 0..logits.len() {
let sig = crate::linalg::utils::stable_logistic(logits[i] * inv_tau);
value[i] = sig * prior[i];
grad[i] = sig * (1.0 - sig) * inv_tau * prior[i];
}
(value, grad)
}
pub fn jumprelu_row(logits: ArrayView1<'_, f64>, temperature: f64, threshold: f64) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(logits.len());
for i in 0..logits.len() {
if logits[i] > threshold {
out[i] = crate::linalg::utils::stable_logistic((logits[i] - threshold) / temperature);
}
}
out
}
pub(crate) struct ActiveAtomLogitJvp<'a> {
pub(crate) mode: AssignmentMode,
pub(crate) k: usize,
pub(crate) logit_k: f64,
pub(crate) a_k: f64,
pub(crate) decoded_k: ArrayView1<'a, f64>,
pub(crate) fitted: ArrayView1<'a, f64>,
pub(crate) ibp_prior: Option<&'a [f64]>,
pub(crate) compact_index: usize,
pub(crate) ungated: bool,
}
pub(crate) fn fill_active_atom_logit_jvp(
input: ActiveAtomLogitJvp<'_>,
jac_compact: &mut Array2<f64>,
) {
let ActiveAtomLogitJvp {
mode,
k,
logit_k,
a_k,
decoded_k,
fitted,
ibp_prior,
compact_index,
ungated,
} = input;
let p = fitted.len();
if ungated {
return;
}
match mode {
AssignmentMode::Softmax { temperature, .. } => {
let inv_tau = 1.0 / temperature;
for out_col in 0..p {
jac_compact[[compact_index, out_col]] =
a_k * (decoded_k[out_col] - fitted[out_col]) * inv_tau;
}
}
AssignmentMode::IBPMap { temperature, .. } => {
let inv_tau = 1.0 / temperature;
let prior =
ibp_prior.expect("fill_active_atom_logit_jvp: IBPMap requires precomputed prior");
let pi_k = prior[k];
let sig = if pi_k > 0.0 { a_k / pi_k } else { 0.0 };
let dz = sig * (1.0 - sig) * inv_tau * pi_k;
for out_col in 0..p {
jac_compact[[compact_index, out_col]] = dz * decoded_k[out_col];
}
}
AssignmentMode::JumpReLU {
temperature,
threshold,
} => {
if logit_k <= threshold {
return;
}
let inv_tau = 1.0 / temperature;
let activation = crate::linalg::utils::stable_logistic((logit_k - threshold) * inv_tau);
let da = activation * (1.0 - activation) * inv_tau;
for out_col in 0..p {
jac_compact[[compact_index, out_col]] = da * decoded_k[out_col];
}
}
}
}
pub(crate) fn fill_assignment_logit_jvp_rows(
mode: AssignmentMode,
logits: ArrayView1<'_, f64>,
assignments: ArrayView1<'_, f64>,
decoded: ArrayView2<'_, f64>,
fitted: ArrayView1<'_, f64>,
ibp_prior: Option<&[f64]>,
ungated: &[bool],
local_jac: &mut Array2<f64>,
) {
let is_ungated = |k: usize| ungated.get(k).copied().unwrap_or(false);
match mode {
AssignmentMode::Softmax { temperature, .. } => {
if assignments.len() == 1 {
return;
}
let inv_tau = 1.0 / temperature;
for logit_col in 0..assignments.len() - 1 {
if is_ungated(logit_col) {
continue;
}
for out_col in 0..fitted.len() {
local_jac[[logit_col, out_col]] = assignments[logit_col]
* (decoded[[logit_col, out_col]] - fitted[out_col])
* inv_tau;
}
}
}
AssignmentMode::IBPMap { temperature, .. } => {
let inv_tau = 1.0 / temperature;
let prior = ibp_prior
.expect("fill_assignment_logit_jvp_rows: IBPMap requires precomputed prior");
for logit_col in 0..assignments.len() {
if is_ungated(logit_col) {
continue;
}
let pi_k = prior[logit_col];
let a_k = assignments[logit_col];
let sig = if pi_k > 0.0 { a_k / pi_k } else { 0.0 };
let dz = sig * (1.0 - sig) * inv_tau * pi_k;
for out_col in 0..fitted.len() {
local_jac[[logit_col, out_col]] = dz * decoded[[logit_col, out_col]];
}
}
}
AssignmentMode::JumpReLU {
temperature,
threshold,
} => {
let inv_tau = 1.0 / temperature;
for logit_col in 0..assignments.len() {
if is_ungated(logit_col) || logits[logit_col] <= threshold {
continue;
}
let activation = crate::linalg::utils::stable_logistic(
(logits[logit_col] - threshold) * inv_tau,
);
let da = activation * (1.0 - activation) * inv_tau;
for out_col in 0..fitted.len() {
local_jac[[logit_col, out_col]] = da * decoded[[logit_col, out_col]];
}
}
}
}
}
pub(crate) fn flat_logits(logits: ArrayView2<'_, f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(logits.len());
for row in 0..logits.nrows() {
let start = row * logits.ncols();
for col in 0..logits.ncols() {
out[start + col] = logits[[row, col]];
}
}
out
}
pub(crate) fn assignment_prior_value(assignment: &SaeAssignment, rho: &SaeManifoldRho) -> f64 {
for row in 0..assignment.n_obs() {
validate_finite_logits(assignment.logits.row(row), row)
.expect("assignment logits must be finite");
}
let target = flat_logits(assignment.logits.view());
if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
return 0.0;
}
match assignment.mode {
AssignmentMode::Softmax {
temperature,
sparsity,
} => {
let penalty = SoftmaxAssignmentSparsityPenalty::new(assignment.k_atoms(), temperature);
let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse + sparsity.ln()]);
penalty.value(target.view(), rho_view.view())
}
AssignmentMode::IBPMap {
temperature,
alpha,
learnable_alpha,
} => {
let mut penalty = IBPAssignmentPenalty::new(
assignment.k_atoms(),
alpha,
temperature,
learnable_alpha,
);
let rho_view = if learnable_alpha {
Array1::from_vec(vec![rho.log_lambda_sparse])
} else {
penalty.weight = rho.lambda_sparse();
Array1::zeros(0)
};
penalty.value(target.view(), rho_view.view())
}
AssignmentMode::JumpReLU {
temperature,
threshold,
} => {
let sparsity_strength = rho.lambda_sparse();
let mut acc = 0.0;
for &logit in target.iter() {
if jumprelu_in_optimization_band(logit, threshold, temperature) {
acc += crate::linalg::utils::stable_logistic((logit - threshold) / temperature);
}
}
sparsity_strength * acc
}
}
}
pub(crate) fn assignment_prior_log_strength_derivative(
assignment: &SaeAssignment,
rho: &SaeManifoldRho,
) -> f64 {
for row in 0..assignment.n_obs() {
validate_finite_logits(assignment.logits.row(row), row)
.expect("assignment logits must be finite");
}
let target = flat_logits(assignment.logits.view());
if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
return 0.0;
}
match assignment.mode {
AssignmentMode::Softmax { .. } | AssignmentMode::JumpReLU { .. } => {
assignment_prior_value(assignment, rho)
}
AssignmentMode::IBPMap {
temperature,
alpha,
learnable_alpha,
} => {
let mut penalty = IBPAssignmentPenalty::new(
assignment.k_atoms(),
alpha,
temperature,
learnable_alpha,
);
if learnable_alpha {
let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse]);
penalty.grad_rho(target.view(), rho_view.view())[0]
} else {
penalty.weight = rho.lambda_sparse();
penalty.value(target.view(), Array1::<f64>::zeros(0).view())
}
}
}
}
pub(crate) fn assignment_prior_log_strength_hdiag(
assignment: &SaeAssignment,
rho: &SaeManifoldRho,
) -> Result<Array1<f64>, String> {
for row in 0..assignment.n_obs() {
validate_finite_logits(assignment.logits.row(row), row)?;
}
let target = flat_logits(assignment.logits.view());
if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
return Ok(Array1::<f64>::zeros(target.len()));
}
match assignment.mode {
AssignmentMode::Softmax {
temperature,
sparsity,
} => {
let penalty = SoftmaxAssignmentSparsityPenalty::new(assignment.k_atoms(), temperature);
let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse + sparsity.ln()]);
penalty
.hessian_diag(target.view(), rho_view.view())
.ok_or_else(|| {
"softmax assignment log-strength hessian diag unavailable".to_string()
})
}
AssignmentMode::JumpReLU {
temperature,
threshold,
} => {
let sparsity_strength = rho.lambda_sparse();
let inv_tau = 1.0 / temperature;
let inv_tau2 = inv_tau * inv_tau;
let mut d = Array1::<f64>::zeros(target.len());
for idx in 0..target.len() {
let logit = target[idx];
if !jumprelu_in_optimization_band(logit, threshold, temperature) {
continue;
}
let activation =
crate::linalg::utils::stable_logistic((logit - threshold) * inv_tau);
let slope = activation * (1.0 - activation);
d[idx] = sparsity_strength * slope * (1.0 - 2.0 * activation) * inv_tau2;
}
Ok(d)
}
AssignmentMode::IBPMap {
temperature,
alpha,
learnable_alpha,
} => {
let mut penalty = IBPAssignmentPenalty::new(
assignment.k_atoms(),
alpha,
temperature,
learnable_alpha,
);
if learnable_alpha {
let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse]);
Ok(penalty.hessian_diag_log_alpha_derivative(target.view(), rho_view.view()))
} else {
penalty.weight = rho.lambda_sparse();
penalty
.hessian_diag(target.view(), Array1::<f64>::zeros(0).view())
.ok_or_else(|| {
"IBP assignment log-strength hessian diag unavailable".to_string()
})
}
}
}
}
pub(crate) fn assignment_prior_log_strength_target_mixed(
assignment: &SaeAssignment,
rho: &SaeManifoldRho,
) -> Result<Array1<f64>, String> {
for row in 0..assignment.n_obs() {
validate_finite_logits(assignment.logits.row(row), row)?;
}
let target = flat_logits(assignment.logits.view());
if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
return Ok(Array1::<f64>::zeros(target.len()));
}
match assignment.mode {
AssignmentMode::IBPMap {
temperature,
alpha,
learnable_alpha: true,
} => {
let penalty = IBPAssignmentPenalty::new(assignment.k_atoms(), alpha, temperature, true);
let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse]);
Ok(penalty.log_alpha_target_mixed_derivative(target.view(), rho_view.view()))
}
_ => Ok(assignment_prior_grad_hdiag(assignment, rho)?.0),
}
}
pub(crate) fn assignment_prior_grad_hdiag(
assignment: &SaeAssignment,
rho: &SaeManifoldRho,
) -> Result<(Array1<f64>, Array1<f64>), String> {
for row in 0..assignment.n_obs() {
validate_finite_logits(assignment.logits.row(row), row)?;
}
let target = flat_logits(assignment.logits.view());
let mut grad = Array1::<f64>::zeros(target.len());
let mut diag = Array1::<f64>::zeros(target.len());
if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
return Ok((grad, diag));
}
let (sparsity_grad, sparsity_diag) = match assignment.mode {
AssignmentMode::Softmax {
temperature,
sparsity,
} => {
let penalty = SoftmaxAssignmentSparsityPenalty::new(assignment.k_atoms(), temperature);
let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse + sparsity.ln()]);
let g = penalty.grad_target(target.view(), rho_view.view());
let d = penalty
.hessian_diag(target.view(), rho_view.view())
.ok_or_else(|| "softmax assignment hessian diag unavailable".to_string())?;
(g, d)
}
AssignmentMode::IBPMap {
temperature,
alpha,
learnable_alpha,
} => {
let mut penalty = IBPAssignmentPenalty::new(
assignment.k_atoms(),
alpha,
temperature,
learnable_alpha,
);
let rho_view = if learnable_alpha {
Array1::from_vec(vec![rho.log_lambda_sparse])
} else {
penalty.weight = rho.lambda_sparse();
Array1::zeros(0)
};
let g = penalty.grad_target(target.view(), rho_view.view());
let d = penalty
.hessian_diag(target.view(), rho_view.view())
.ok_or_else(|| "IBP assignment hessian diag unavailable".to_string())?;
(g, d)
}
AssignmentMode::JumpReLU {
temperature,
threshold,
} => {
let sparsity_strength = rho.lambda_sparse();
let inv_tau = 1.0 / temperature;
let inv_tau2 = inv_tau * inv_tau;
let mut g = Array1::<f64>::zeros(target.len());
let mut d = Array1::<f64>::zeros(target.len());
for idx in 0..target.len() {
let logit = target[idx];
if !jumprelu_in_optimization_band(logit, threshold, temperature) {
continue;
}
let activation =
crate::linalg::utils::stable_logistic((logit - threshold) * inv_tau);
let slope = activation * (1.0 - activation);
g[idx] = sparsity_strength * slope * inv_tau;
d[idx] = sparsity_strength * slope * (1.0 - 2.0 * activation) * inv_tau2;
}
(g, d)
}
};
grad += &sparsity_grad;
diag += &sparsity_diag;
if assignment.has_ungated() || assignment.routing_is_frozen() {
let k = assignment.k_atoms();
for idx in 0..grad.len() {
if assignment.logit_is_fixed(idx % k) {
grad[idx] = 0.0;
diag[idx] = 0.0;
}
}
}
Ok((grad, diag))
}
pub(crate) fn ibp_assignment_third_channels(
assignment: &SaeAssignment,
rho: &SaeManifoldRho,
) -> Result<Option<IbpHessianDiagThirdChannels>, String> {
let AssignmentMode::IBPMap {
temperature,
alpha,
learnable_alpha,
} = assignment.mode
else {
return Ok(None);
};
for row in 0..assignment.n_obs() {
validate_finite_logits(assignment.logits.row(row), row)?;
}
let target = flat_logits(assignment.logits.view());
let mut penalty =
IBPAssignmentPenalty::new(assignment.k_atoms(), alpha, temperature, learnable_alpha);
let rho_view = if learnable_alpha {
Array1::from_vec(vec![rho.log_lambda_sparse])
} else {
penalty.weight = rho.lambda_sparse();
Array1::zeros(0)
};
let mut channels =
penalty.hessian_diag_logit_third_channels(target.view(), rho_view.view());
if assignment.has_ungated() || assignment.routing_is_frozen() {
let k = channels.k_max;
for idx in 0..channels.z_jac.len() {
if assignment.logit_is_fixed(idx % k) {
channels.z_jac[idx] = 0.0;
channels.local_logit_third[idx] = 0.0;
channels.m_channel[idx] = 0.0;
channels.logit_curvature[idx] = 0.0;
}
}
for atom in 0..k {
if assignment.logit_is_fixed(atom) {
channels.cross_row_d[atom] = 0.0;
channels.cross_row_dd[atom] = 0.0;
}
}
}
Ok(Some(channels))
}
pub fn select_hybrid_atom_parameterization(
manifold: &LatentManifold,
curved: Option<HybridAtomCandidate>,
linear: HybridAtomCandidate,
) -> HybridAtomChoice {
let curved = if manifold.is_euclidean() {
None
} else {
curved
};
let candidates: Vec<HybridAtomCandidate> = match curved {
Some(c) => vec![linear, c],
None => vec![linear],
};
select_hybrid_atom(&candidates).expect("hybrid atom slot always has the linear candidate")
}
#[cfg(test)]
mod ibp_prior_614_tests {
use super::*;
fn ratio(alpha: f64) -> f64 {
alpha / (alpha + 1.0)
}
#[test]
fn first_atom_is_shrunk_not_unity() {
for &alpha in &[0.1_f64, 0.5, 1.0, 2.0, 5.0] {
let prior = ordered_geometric_shrinkage_prior(8, alpha);
let r = ratio(alpha);
assert!(
(prior[0] - r).abs() < 1e-12,
"π_0 must be the single-stick mean α/(α+1)={r} (was the unshrunk 1.0 in #614); got {}",
prior[0]
);
assert!(
prior[0] < 1.0,
"first atom must be shrunk (π_0<1) for alpha={alpha}; got {}",
prior[0]
);
}
}
#[test]
fn prior_is_consistent_geometric_product_mean() {
for &alpha in &[0.3_f64, 1.0, 4.0] {
let k = 12;
let prior = ordered_geometric_shrinkage_prior(k, alpha);
let r = ratio(alpha);
for j in 0..k {
let expected = r.powi((j + 1) as i32);
assert!(
(prior[j] - expected).abs() < 1e-12 * expected.max(1.0),
"alpha={alpha} π_{j}: expected {expected}, got {}",
prior[j]
);
}
for j in 1..k {
assert!(
prior[j] < prior[j - 1],
"alpha={alpha}: prior must strictly decrease at index {j}"
);
}
}
}
#[test]
fn alpha_behaves_as_concentration() {
let lo = ordered_geometric_shrinkage_prior(8, 0.5);
let hi = ordered_geometric_shrinkage_prior(8, 5.0);
assert!(
hi[0] > lo[0],
"larger alpha must raise π_0 (concentration): {} vs {}",
hi[0],
lo[0]
);
assert!(
hi[4] > lo[4],
"larger alpha must put more mass in the tail: {} vs {}",
hi[4],
lo[4]
);
}
}
#[cfg(test)]
mod hybrid_split_tests {
use super::*;
use crate::solver::evidence::HybridAtomParam;
#[test]
fn flat_chart_drops_curved_candidate_and_keeps_linear() {
let linear = HybridAtomCandidate::linear(100.0, 2);
let curved = HybridAtomCandidate::curved(1, 1.0, 5, Some(2.0));
let choice =
select_hybrid_atom_parameterization(&LatentManifold::Euclidean, Some(curved), linear);
assert!(choice.param.is_linear());
}
#[test]
fn curveable_chart_selects_curved_when_turning_pays() {
let linear = HybridAtomCandidate::linear(100.0, 2);
let curved = HybridAtomCandidate::curved(1, 70.0, 5, Some(2.0 * std::f64::consts::PI));
let choice = select_hybrid_atom_parameterization(
&LatentManifold::Circle {
period: 2.0 * std::f64::consts::PI,
},
Some(curved),
linear,
);
assert_eq!(choice.param, HybridAtomParam::Curved { latent_dim: 1 });
}
#[test]
fn curveable_chart_falls_back_to_linear_when_no_curved_candidate() {
let linear = HybridAtomCandidate::linear(33.0, 2);
let choice = select_hybrid_atom_parameterization(
&LatentManifold::Circle {
period: 2.0 * std::f64::consts::PI,
},
None,
linear,
);
assert!(choice.param.is_linear());
assert_eq!(choice.num_parameters, 2);
}
}
#[cfg(test)]
mod frozen_routing_1033_tests {
use super::*;
fn ibp_assignment(n: usize, k: usize) -> SaeAssignment {
let logits = Array2::from_shape_fn((n, k), |(i, kk)| 0.3 + 0.05 * (i as f64) - 0.1 * (kk as f64));
let coords: Vec<Array2<f64>> =
(0..k).map(|_| Array2::from_shape_fn((n, 1), |(i, _)| (i as f64) * 0.1)).collect();
SaeAssignment::from_blocks_with_mode(logits, coords, AssignmentMode::ibp_map(0.5, 1.0, false))
.unwrap()
}
#[test]
fn frozen_routing_decouples_gates_from_logit_updates_1033() {
let (n, k) = (6usize, 3usize);
let mut a = ibp_assignment(n, k).freeze_routing_from_current_logits().unwrap();
assert!(a.routing_is_frozen());
let rho = SaeManifoldRho::new(0.0, 0.0, vec![Array1::<f64>::zeros(1); k]);
let before: Vec<Array1<f64>> =
(0..n).map(|r| a.try_assignments_row_for_rho(r, &rho).unwrap()).collect();
a.logits.mapv_inplace(|v| v + 5.0);
let after: Vec<Array1<f64>> =
(0..n).map(|r| a.try_assignments_row_for_rho(r, &rho).unwrap()).collect();
for r in 0..n {
for kk in 0..k {
assert_eq!(
before[r][kk], after[r][kk],
"row {r} atom {kk}: frozen-routing gate must be UNCHANGED by a free-logit \
update (decoupled from inner-fit drift); {} vs {}",
before[r][kk], after[r][kk]
);
}
}
}
#[test]
fn frozen_routing_gates_are_rho_invariant_1033() {
let (n, k) = (5usize, 2usize);
let a = ibp_assignment(n, k).freeze_routing_from_current_logits().unwrap();
let rho_a = SaeManifoldRho::new((1e-3_f64).ln(), (1e-2_f64).ln(), vec![Array1::<f64>::zeros(1); k]);
let rho_b = SaeManifoldRho::new((1e3_f64).ln(), (1e1_f64).ln(), vec![Array1::<f64>::zeros(1); k]);
for r in 0..n {
let ga = a.try_assignments_row_for_rho(r, &rho_a).unwrap();
let gb = a.try_assignments_row_for_rho(r, &rho_b).unwrap();
for kk in 0..k {
assert_eq!(
ga[kk], gb[kk],
"row {r} atom {kk}: frozen-routing gate must be ρ-INVARIANT (the n-independence \
lever); {} at ρ_a vs {} at ρ_b",
ga[kk], gb[kk]
);
}
}
}
#[test]
fn frozen_routing_fixes_all_logits_and_thaw_restores_free_path_1033() {
let (n, k) = (4usize, 3usize);
let mut a = ibp_assignment(n, k).freeze_routing_from_current_logits().unwrap();
let mask = a.fixed_logit_mask();
assert_eq!(mask.len(), k);
assert!(mask.iter().all(|&f| f), "frozen routing must fix ALL logits");
for kk in 0..k {
assert!(a.logit_is_fixed(kk), "atom {kk} logit must be fixed under frozen routing");
}
a.thaw_routing();
assert!(!a.routing_is_frozen());
assert!(a.fixed_logit_mask().iter().all(|&f| !f), "thaw must restore the free-logit path");
}
#[test]
fn frozen_routing_rejects_softmax_1033() {
let (n, k) = (4usize, 3usize);
let logits = Array2::from_shape_fn((n, k), |(i, kk)| 0.1 * (i as f64) - 0.05 * (kk as f64));
let coords: Vec<Array2<f64>> =
(0..k).map(|_| Array2::from_shape_fn((n, 1), |(i, _)| (i as f64) * 0.1)).collect();
let a = SaeAssignment::from_blocks_with_mode(logits, coords, AssignmentMode::softmax(1.0))
.unwrap();
assert!(
a.freeze_routing_from_current_logits().is_err(),
"frozen routing under Softmax must be rejected (simplex entropy-majorizer coupling)"
);
}
}