use super::*;
#[derive(Debug, Clone, Copy)]
pub enum SparsityKind {
SmoothedL1 { eps: f64 },
Hoyer,
Log { delta: f64 },
}
#[derive(Debug, Clone)]
pub struct SparsityPenalty {
pub target_tier: PenaltyTier,
pub kind: SparsityKind,
pub weight: f64,
pub weight_schedule: Option<ScalarWeightSchedule>,
pub strength_rho_index: usize,
pub eps_rho_index: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct SoftmaxAssignmentSparsityPenalty {
pub k_atoms: usize,
pub temperature: f64,
pub weight: f64,
pub weight_schedule: Option<ScalarWeightSchedule>,
}
impl SoftmaxAssignmentSparsityPenalty {
#[must_use]
pub fn new(k_atoms: usize, temperature: f64) -> Self {
assert!(k_atoms > 0);
assert!(temperature > 0.0);
Self {
k_atoms,
temperature,
weight: 1.0,
weight_schedule: None,
}
}
impl_with_weight_schedule!(weight);
fn softmax_row(&self, row: &[f64]) -> Vec<f64> {
let inv_tau = 1.0 / self.temperature;
let mut max_logit = f64::NEG_INFINITY;
for (idx, &v) in row.iter().enumerate() {
assert!(
v.is_finite(),
"SoftmaxAssignmentSparsityPenalty: non-finite logit at atom {idx}: {v}"
);
max_logit = max_logit.max(v);
}
let mut out = vec![0.0; self.k_atoms];
let mut sum = 0.0;
for i in 0..self.k_atoms {
let v = ((row[i] - max_logit) * inv_tau).exp();
out[i] = v;
sum += v;
}
assert!(
sum.is_finite() && sum > 0.0,
"SoftmaxAssignmentSparsityPenalty: non-finite softmax normalizer"
);
for v in out.iter_mut() {
*v /= sum;
}
out
}
fn psd_majorizer_abs_row_sums(&self, row: &[f64], scale: f64) -> Vec<f64> {
let a = self.softmax_row(row);
let k = self.k_atoms;
let l: Vec<f64> = (0..k)
.map(|i| a[i].max(ENTROPY_LOG_PROBABILITY_FLOOR).ln() + 1.0)
.collect();
let m: f64 = (0..k).map(|i| a[i] * l[i]).sum();
let mut d = vec![0.0_f64; k];
for kk in 0..k {
let h_kk = scale * a[kk] * ((m - l[kk] - 1.0) + a[kk] * (2.0 * l[kk] + 1.0 - 2.0 * m));
let mut acc = h_kk.abs();
for jj in 0..k {
if jj == kk {
continue;
}
let h_kj = scale * a[kk] * a[jj] * (l[kk] + l[jj] + 1.0 - 2.0 * m);
acc += h_kj.abs();
}
d[kk] = acc;
}
d
}
#[must_use]
pub fn row_dense_hessian(&self, row_logits: &[f64], scale: f64) -> Array2<f64> {
let k = self.k_atoms;
let a = self.softmax_row(row_logits);
let l: Vec<f64> = (0..k)
.map(|i| a[i].max(ENTROPY_LOG_PROBABILITY_FLOOR).ln() + 1.0)
.collect();
let m: f64 = (0..k).map(|i| a[i] * l[i]).sum();
let mut h = Array2::<f64>::zeros((k, k));
for kk in 0..k {
for jj in 0..k {
let indicator = if kk == jj { 1.0 } else { 0.0 };
h[[kk, jj]] = scale
* a[kk]
* (indicator * (m - l[kk] - 1.0) + a[jj] * (l[kk] + l[jj] + 1.0 - 2.0 * m));
}
}
h
}
#[must_use]
pub fn row_dense_hessian_logit_derivative(
&self,
row_logits: &[f64],
scale: f64,
w: usize,
) -> Array2<f64> {
let k = self.k_atoms;
let inv_tau = 1.0 / self.temperature;
let a = self.softmax_row(row_logits);
let l: Vec<f64> = (0..k)
.map(|i| a[i].max(ENTROPY_LOG_PROBABILITY_FLOOR).ln() + 1.0)
.collect();
let m: f64 = (0..k).map(|i| a[i] * l[i]).sum();
let da: Vec<f64> = (0..k)
.map(|r| a[r] * (if r == w { 1.0 } else { 0.0 } - a[w]) * inv_tau)
.collect();
let dl: Vec<f64> = (0..k)
.map(|r| da[r] / a[r].max(ENTROPY_LOG_PROBABILITY_FLOOR))
.collect();
let dm: f64 = (0..k).map(|r| da[r] * l[r] + a[r] * dl[r]).sum();
let mut dh = Array2::<f64>::zeros((k, k));
for kk in 0..k {
for jj in 0..k {
let indicator = if kk == jj { 1.0 } else { 0.0 };
let bracket =
indicator * (m - l[kk] - 1.0) + a[jj] * (l[kk] + l[jj] + 1.0 - 2.0 * m);
let dbracket = indicator * (dm - dl[kk])
+ da[jj] * (l[kk] + l[jj] + 1.0 - 2.0 * m)
+ a[jj] * (dl[kk] + dl[jj] - 2.0 * dm);
dh[[kk, jj]] = scale * (da[kk] * bracket + a[kk] * dbracket);
}
}
dh
}
#[must_use]
pub fn row_fisher_metric(&self, row_logits: &[f64], scale: f64) -> Array2<f64> {
let k = self.k_atoms;
let a = self.softmax_row(row_logits);
let mut g = Array2::<f64>::zeros((k, k));
for kk in 0..k {
for jj in 0..k {
let indicator = if kk == jj { 1.0 } else { 0.0 };
g[[kk, jj]] = scale * a[kk] * (indicator - a[jj]);
}
}
g
}
#[must_use]
pub fn row_fisher_metric_logit_derivative(
&self,
row_logits: &[f64],
scale: f64,
w: usize,
) -> Array2<f64> {
let k = self.k_atoms;
let inv_tau = 1.0 / self.temperature;
let a = self.softmax_row(row_logits);
let da: Vec<f64> = (0..k)
.map(|r| a[r] * (if r == w { 1.0 } else { 0.0 } - a[w]) * inv_tau)
.collect();
let mut dg = Array2::<f64>::zeros((k, k));
for kk in 0..k {
for jj in 0..k {
let indicator = if kk == jj { 1.0 } else { 0.0 };
dg[[kk, jj]] = scale * (da[kk] * (indicator - a[jj]) - a[kk] * da[jj]);
}
}
dg
}
}
impl AnalyticPenalty for SoftmaxAssignmentSparsityPenalty {
fn tier(&self) -> PenaltyTier {
PenaltyTier::Psi
}
fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
let lambda = resolve_learnable_weight(self.weight, rho[0]);
let n = target.len() / self.k_atoms;
let values: Vec<f64> = target.iter().copied().collect();
let mut acc = 0.0;
for row in 0..n {
let start = row * self.k_atoms;
let a = self.softmax_row(&values[start..start + self.k_atoms]);
for v in a {
if v > 0.0 {
acc += -v * v.ln();
}
}
}
lambda * acc
}
fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
let lambda = resolve_learnable_weight(self.weight, rho[0]);
let n = target.len() / self.k_atoms;
let values: Vec<f64> = target.iter().copied().collect();
let mut out = Array1::<f64>::zeros(target.len());
let inv_tau = 1.0 / self.temperature;
for row in 0..n {
let start = row * self.k_atoms;
let a = self.softmax_row(&values[start..start + self.k_atoms]);
let mut d_h_da = vec![0.0; self.k_atoms];
let mut mean = 0.0;
for k in 0..self.k_atoms {
let ak = a[k].max(ENTROPY_LOG_PROBABILITY_FLOOR);
d_h_da[k] = -lambda * (ak.ln() + 1.0);
mean += a[k] * d_h_da[k];
}
for k in 0..self.k_atoms {
out[start + k] = a[k] * (d_h_da[k] - mean) * inv_tau;
}
}
out
}
fn hessian_diag(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Option<Array1<f64>> {
assert_eq!(rho.len(), 1, "softmax entropy expects one rho parameter");
assert!(
rho.iter().all(|value| value.is_finite()),
"softmax entropy rho must be finite"
);
assert_eq!(
target.len() % self.k_atoms,
0,
"softmax entropy target length must be divisible by k_atoms"
);
let lambda = resolve_learnable_weight(self.weight, rho[0]);
let inv_tau = 1.0 / self.temperature;
let scale = lambda * inv_tau * inv_tau;
let n = target.len() / self.k_atoms;
let values: Vec<f64> = target.iter().copied().collect();
let mut out = Array1::<f64>::zeros(target.len());
for row in 0..n {
let start = row * self.k_atoms;
let a = self.softmax_row(&values[start..start + self.k_atoms]);
let mut mean_log_plus_one = 0.0;
for k in 0..self.k_atoms {
mean_log_plus_one += a[k] * (a[k].max(ENTROPY_LOG_PROBABILITY_FLOOR).ln() + 1.0);
}
for k in 0..self.k_atoms {
let log_plus_one = a[k].max(ENTROPY_LOG_PROBABILITY_FLOOR).ln() + 1.0;
let term = (1.0 - 2.0 * a[k]) * (mean_log_plus_one - log_plus_one) + a[k] - 1.0;
out[start + k] = scale * a[k] * term;
}
}
Some(out)
}
fn hvp(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
v: ArrayView1<'_, f64>,
) -> Array1<f64> {
let lambda = resolve_learnable_weight(self.weight, rho[0]);
assert_eq!(target.len(), v.len(), "hvp dimension mismatch");
let n = target.len() / self.k_atoms;
let values: Vec<f64> = target.iter().copied().collect();
let mut out = Array1::<f64>::zeros(target.len());
let inv_tau = 1.0 / self.temperature;
let scale = lambda * inv_tau * inv_tau;
for row in 0..n {
let start = row * self.k_atoms;
let a = self.softmax_row(&values[start..start + self.k_atoms]);
let mut mean_log_plus_one = 0.0;
let mut mean_v = 0.0;
for k in 0..self.k_atoms {
mean_log_plus_one += a[k] * (a[k].max(ENTROPY_LOG_PROBABILITY_FLOOR).ln() + 1.0);
mean_v += a[k] * v[start + k];
}
let mut mean_centered_v_log_plus_one = 0.0;
for k in 0..self.k_atoms {
let centered_v = v[start + k] - mean_v;
mean_centered_v_log_plus_one +=
a[k] * centered_v * (a[k].max(ENTROPY_LOG_PROBABILITY_FLOOR).ln() + 1.0);
}
for k in 0..self.k_atoms {
let log_plus_one = a[k].max(ENTROPY_LOG_PROBABILITY_FLOOR).ln() + 1.0;
let centered_v = v[start + k] - mean_v;
out[start + k] = scale
* a[k]
* (centered_v * (mean_log_plus_one - log_plus_one - 1.0)
+ mean_centered_v_log_plus_one);
}
}
out
}
fn psd_majorizer_diag(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Option<Array1<f64>> {
assert_eq!(rho.len(), 1, "softmax entropy expects one rho parameter");
assert_eq!(
target.len() % self.k_atoms,
0,
"softmax entropy target length must be divisible by k_atoms"
);
let lambda = resolve_learnable_weight(self.weight, rho[0]);
let inv_tau = 1.0 / self.temperature;
let scale = lambda * inv_tau * inv_tau;
let n = target.len() / self.k_atoms;
let values: Vec<f64> = target.iter().copied().collect();
let mut out = Array1::<f64>::zeros(target.len());
for row in 0..n {
let start = row * self.k_atoms;
let d = self.psd_majorizer_abs_row_sums(&values[start..start + self.k_atoms], scale);
for k in 0..self.k_atoms {
out[start + k] = d[k];
}
}
Some(out)
}
fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
Array1::from_vec(vec![self.value(target, rho)])
}
fn rho_count(&self) -> usize {
1
}
fn name(&self) -> &str {
"softmax_assignment_sparsity"
}
impl_scalar_apply_schedule!(weight);
}
impl SparsityPenalty {
#[must_use = "build error must be handled"]
pub fn smoothed_l1(target_tier: PenaltyTier, eps: f64) -> Result<Self, String> {
if !(eps.is_finite() && eps > 0.0) {
return Err(format!(
"SparsityPenalty::smoothed_l1 requires eps > 0 \
(Hessian / gradient have a `1/sqrt(x² + eps²)` factor that needs eps > 0 \
for differentiability at x = 0); got eps = {eps}"
));
}
Ok(Self {
target_tier,
kind: SparsityKind::SmoothedL1 { eps },
weight: 1.0,
weight_schedule: None,
strength_rho_index: 0,
eps_rho_index: None,
})
}
#[must_use = "build error must be handled"]
pub fn log(target_tier: PenaltyTier, delta: f64) -> Result<Self, String> {
if !(delta.is_finite() && delta > 0.0) {
return Err(format!(
"SparsityPenalty::log requires delta > 0 \
(the log-sparsifier is log(1 + x²/δ²), undefined at δ = 0); \
got delta = {delta}"
));
}
Ok(Self {
target_tier,
kind: SparsityKind::Log { delta },
weight: 1.0,
weight_schedule: None,
strength_rho_index: 0,
eps_rho_index: None,
})
}
#[must_use]
pub fn hoyer(target_tier: PenaltyTier) -> Self {
Self {
target_tier,
kind: SparsityKind::Hoyer,
weight: 1.0,
weight_schedule: None,
strength_rho_index: 0,
eps_rho_index: None,
}
}
impl_with_weight_schedule!(weight);
#[must_use]
pub fn with_eps_reml(mut self, eps_rho_index: usize) -> Self {
self.eps_rho_index = Some(eps_rho_index);
self
}
fn resolved(&self, rho: ArrayView1<'_, f64>) -> (f64, f64) {
let strength = resolve_learnable_weight(self.weight, rho[self.strength_rho_index]);
let smoothing = match (self.eps_rho_index, self.kind) {
(Some(idx), _) => rho[idx].exp().max(f64::MIN_POSITIVE),
(None, SparsityKind::SmoothedL1 { eps }) => eps,
(None, SparsityKind::Log { delta }) => delta,
(None, SparsityKind::Hoyer) => 0.0,
};
(strength, smoothing)
}
}
impl AnalyticPenalty for SparsityPenalty {
fn tier(&self) -> PenaltyTier {
self.target_tier
}
fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
let (lam, smooth) = self.resolved(rho);
match self.kind {
SparsityKind::SmoothedL1 { .. } => {
let mut acc = 0.0;
for &x in target.iter() {
acc += (x * x + smooth * smooth).sqrt();
}
lam * acc
}
SparsityKind::Hoyer => {
let n = target.len() as f64;
assert!(n > 1.0, "Hoyer requires n > 1");
let l1: f64 = target.iter().map(|x| x.abs()).sum();
let l2: f64 = target.iter().map(|x| x * x).sum::<f64>().sqrt();
if l2 == 0.0 {
return 0.0;
}
let h = (l1 / l2 - 1.0) / (n.sqrt() - 1.0);
lam * h
}
SparsityKind::Log { .. } => {
let mut acc = 0.0;
let d2 = smooth * smooth;
for &x in target.iter() {
acc += (1.0 + x * x / d2).ln();
}
lam * acc
}
}
}
fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
let (lam, smooth) = self.resolved(rho);
let mut g = Array1::<f64>::zeros(target.len());
match self.kind {
SparsityKind::SmoothedL1 { .. } => {
let eps2 = smooth * smooth;
for (i, &x) in target.iter().enumerate() {
g[i] = lam * x / (x * x + eps2).sqrt();
}
}
SparsityKind::Hoyer => {
let n = target.len() as f64;
assert!(n > 1.0, "Hoyer requires n > 1");
let l1: f64 = target.iter().map(|x| x.abs()).sum();
let l2: f64 = target.iter().map(|x| x * x).sum::<f64>().sqrt();
if l2 == 0.0 {
return g;
}
let denom = n.sqrt() - 1.0;
let a = lam / denom;
let inv_l2 = 1.0 / l2;
let inv_l2_cubed = inv_l2 * inv_l2 * inv_l2;
for (i, &x) in target.iter().enumerate() {
let sgn = if x > 0.0 {
1.0
} else if x < 0.0 {
-1.0
} else {
0.0
};
g[i] = a * (sgn * inv_l2 - l1 * x * inv_l2_cubed);
}
}
SparsityKind::Log { .. } => {
let d2 = smooth * smooth;
for (i, &x) in target.iter().enumerate() {
g[i] = lam * 2.0 * x / (d2 + x * x);
}
}
}
g
}
fn hessian_diag(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Option<Array1<f64>> {
let (lam, smooth) = self.resolved(rho);
match self.kind {
SparsityKind::SmoothedL1 { .. } => {
let mut d = Array1::<f64>::zeros(target.len());
let eps2 = smooth * smooth;
for (i, &x) in target.iter().enumerate() {
let r = (x * x + eps2).sqrt();
d[i] = lam * eps2 / (r * r * r);
}
Some(d)
}
SparsityKind::Log { .. } => {
let mut d = Array1::<f64>::zeros(target.len());
let d2 = smooth * smooth;
for (i, &x) in target.iter().enumerate() {
let denom = d2 + x * x;
d[i] = lam * 2.0 * (d2 - x * x) / (denom * denom);
}
Some(d)
}
SparsityKind::Hoyer => None,
}
}
fn hvp(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
v: ArrayView1<'_, f64>,
) -> Array1<f64> {
let (lam, smooth) = self.resolved(rho);
let n_target = target.len();
assert_eq!(v.len(), n_target, "hvp dimension mismatch");
match self.kind {
SparsityKind::SmoothedL1 { .. } => {
let mut out = Array1::<f64>::zeros(n_target);
let eps2 = smooth * smooth;
for (i, &x) in target.iter().enumerate() {
let r = (x * x + eps2).sqrt();
out[i] = lam * eps2 / (r * r * r) * v[i];
}
out
}
SparsityKind::Log { .. } => {
let mut out = Array1::<f64>::zeros(n_target);
let d2 = smooth * smooth;
for (i, &x) in target.iter().enumerate() {
let denom = d2 + x * x;
out[i] = lam * 2.0 * (d2 - x * x) / (denom * denom) * v[i];
}
out
}
SparsityKind::Hoyer => {
let n = n_target as f64;
assert!(n > 1.0, "Hoyer requires n > 1");
let l1: f64 = target.iter().map(|x| x.abs()).sum();
let l2: f64 = target.iter().map(|x| x * x).sum::<f64>().sqrt();
let mut out = Array1::<f64>::zeros(n_target);
if l2 == 0.0 {
return out;
}
let a = lam / (n.sqrt() - 1.0);
let inv_l2_cubed = 1.0 / (l2 * l2 * l2);
let inv_l2_5 = inv_l2_cubed / (l2 * l2);
let mut x_dot_v = 0.0;
let mut s_dot_v = 0.0;
for i in 0..n_target {
let xi = target[i];
let si = if xi > 0.0 {
1.0
} else if xi < 0.0 {
-1.0
} else {
0.0
};
x_dot_v += xi * v[i];
s_dot_v += si * v[i];
}
for i in 0..n_target {
let xi = target[i];
let si = if xi > 0.0 {
1.0
} else if xi < 0.0 {
-1.0
} else {
0.0
};
out[i] = a
* (-si * x_dot_v * inv_l2_cubed
- xi * s_dot_v * inv_l2_cubed
- l1 * v[i] * inv_l2_cubed
+ 3.0 * l1 * xi * x_dot_v * inv_l2_5);
}
out
}
}
}
fn psd_majorizer_diag(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Option<Array1<f64>> {
let (lam, smooth) = self.resolved(rho);
match self.kind {
SparsityKind::SmoothedL1 { .. } => self.hessian_diag(target, rho),
SparsityKind::Log { .. } => {
let mut d = Array1::<f64>::zeros(target.len());
let d2 = smooth * smooth;
for (i, &x) in target.iter().enumerate() {
d[i] = lam * 2.0 / (d2 + x * x);
}
Some(d)
}
SparsityKind::Hoyer => None,
}
}
fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
let n_rho = self.rho_count();
let mut out = Array1::<f64>::zeros(n_rho);
let p_val = self.value(target, rho);
out[self.strength_rho_index] = p_val;
if let Some(eps_idx) = self.eps_rho_index {
let (lam, smooth) = self.resolved(rho);
let mut dp_deps = 0.0;
match self.kind {
SparsityKind::SmoothedL1 { .. } => {
for &x in target.iter() {
dp_deps += smooth / (x * x + smooth * smooth).sqrt();
}
dp_deps *= lam;
}
SparsityKind::Log { .. } => {
let d2 = smooth * smooth;
for &x in target.iter() {
dp_deps += -2.0 * x * x / (smooth * (d2 + x * x));
}
dp_deps *= lam;
}
SparsityKind::Hoyer => {}
}
out[eps_idx] = smooth * dp_deps;
}
out
}
fn rho_count(&self) -> usize {
1 + if self.eps_rho_index.is_some() { 1 } else { 0 }
}
fn name(&self) -> &str {
"sparsity"
}
impl_scalar_apply_schedule!(weight);
}
#[derive(Debug, Clone)]
pub struct TopKActivationPenalty {
pub target: PsiSlice,
pub k: usize,
pub latent_dim: usize,
pub weight: f64,
pub weight_schedule: Option<ScalarWeightSchedule>,
}
impl TopKActivationPenalty {
#[must_use = "build error must be handled"]
pub fn new(target: PsiSlice, k: usize, weight: f64) -> Result<Self, String> {
let latent_dim = target
.latent_dim
.ok_or_else(|| "TopKActivationPenalty::new requires target.latent_dim".to_string())?;
if latent_dim == 0 {
return Err("TopKActivationPenalty::new requires latent_dim > 0".to_string());
}
if k == 0 || k > latent_dim {
return Err(format!(
"TopKActivationPenalty::new requires 0 < k <= latent_dim; got k={k}, latent_dim={latent_dim}"
));
}
if !(weight.is_finite() && weight > 0.0) {
return Err(format!(
"TopKActivationPenalty::new requires finite weight > 0, got {weight}"
));
}
Ok(Self {
target,
k,
latent_dim,
weight,
weight_schedule: None,
})
}
impl_with_weight_schedule!(weight);
fn topk_mask_row(&self, target: ArrayView1<'_, f64>, row: usize, mask: &mut [bool]) {
mask.fill(false);
let d = self.latent_dim;
let base = row * d;
let mut order = (0..d).collect::<Vec<_>>();
order.sort_by(|&a, &b| {
target[base + b]
.abs()
.total_cmp(&target[base + a].abs())
.then_with(|| a.cmp(&b))
});
for &axis in order.iter().take(self.k) {
mask[axis] = true;
}
}
}
impl AnalyticPenalty for TopKActivationPenalty {
fn tier(&self) -> PenaltyTier {
PenaltyTier::Psi
}
fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
assert_eq!(rho.len(), 0, "TopKActivationPenalty has no rho parameters");
let d = self.latent_dim;
let n_obs = target.len() / d;
let mut mask = vec![false; d];
let mut acc = 0.0;
for row in 0..n_obs {
self.topk_mask_row(target, row, &mut mask);
let base = row * d;
for axis in 0..d {
if mask[axis] {
let v = target[base + axis];
acc += 0.5 * self.weight * v * v;
}
}
}
acc
}
fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
assert_eq!(rho.len(), 0, "TopKActivationPenalty has no rho parameters");
let d = self.latent_dim;
let n_obs = target.len() / d;
let mut mask = vec![false; d];
let mut grad = Array1::<f64>::zeros(target.len());
for row in 0..n_obs {
self.topk_mask_row(target, row, &mut mask);
let base = row * d;
for axis in 0..d {
if mask[axis] {
grad[base + axis] = self.weight * target[base + axis];
}
}
}
grad
}
fn hessian_diag(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Option<Array1<f64>> {
assert_eq!(rho.len(), 0, "TopKActivationPenalty has no rho parameters");
let d = self.latent_dim;
let n_obs = target.len() / d;
let mut mask = vec![false; d];
let mut diag = Array1::<f64>::zeros(target.len());
for row in 0..n_obs {
self.topk_mask_row(target, row, &mut mask);
let base = row * d;
for axis in 0..d {
if mask[axis] {
diag[base + axis] = self.weight;
}
}
}
Some(diag)
}
fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
assert_eq!(rho.len(), 0, "TopKActivationPenalty has no rho parameters");
assert_eq!(
target.len() % self.latent_dim,
0,
"TopKActivationPenalty target length must be a multiple of latent_dim"
);
Array1::<f64>::zeros(0)
}
fn rho_count(&self) -> usize {
0
}
fn name(&self) -> &str {
"topk_activation"
}
impl_scalar_apply_schedule!(weight);
}
#[derive(Debug, Clone)]
pub struct JumpReLUPenalty {
pub target: PsiSlice,
pub latent_dim: usize,
pub thresholds: Array1<f64>,
pub weight: f64,
pub smoothing_eps: f64,
pub weight_schedule: Option<ScalarWeightSchedule>,
}
impl JumpReLUPenalty {
#[must_use = "build error must be handled"]
pub fn new(
target: PsiSlice,
thresholds: Array1<f64>,
weight: f64,
smoothing_eps: f64,
) -> Result<Self, String> {
let latent_dim = target
.latent_dim
.ok_or_else(|| "JumpReLUPenalty::new requires target.latent_dim".to_string())?;
if latent_dim == 0 {
return Err("JumpReLUPenalty::new requires latent_dim > 0".to_string());
}
if thresholds.len() != latent_dim {
return Err(format!(
"JumpReLUPenalty::new thresholds length {} does not match latent_dim {latent_dim}",
thresholds.len()
));
}
for (idx, &tau) in thresholds.iter().enumerate() {
if !(tau.is_finite() && tau > 0.0) {
return Err(format!(
"JumpReLUPenalty::new thresholds[{idx}] must be finite and > 0, got {tau}"
));
}
}
if !(weight.is_finite() && weight > 0.0) {
return Err(format!(
"JumpReLUPenalty::new requires finite weight > 0, got {weight}"
));
}
if !(smoothing_eps.is_finite() && smoothing_eps > 0.0) {
return Err(format!(
"JumpReLUPenalty::new requires finite smoothing_eps > 0, got {smoothing_eps}"
));
}
Ok(Self {
target,
latent_dim,
thresholds,
weight,
smoothing_eps,
weight_schedule: None,
})
}
impl_with_weight_schedule!(weight);
fn threshold(&self, axis: usize, rho: ArrayView1<'_, f64>) -> f64 {
resolve_learnable_weight(self.thresholds[axis], rho[axis])
}
pub(crate) fn sigmoid_gate(&self, x: f64) -> f64 {
if x >= 0.0 {
1.0 / (1.0 + (-x).exp())
} else {
let ex = x.exp();
ex / (1.0 + ex)
}
}
fn true_hessian_diag_entry(&self, tau: f64, gate: f64) -> f64 {
self.weight * tau * gate * (1.0 - gate) * (1.0 - 2.0 * gate)
/ (self.smoothing_eps * self.smoothing_eps)
}
fn psd_hessian_diag_entry(&self, tau: f64, gate: f64) -> f64 {
let slope = gate * (1.0 - gate);
let reweighted_l2 = slope * slope;
let abs_exact = slope * (1.0 - 2.0 * gate).abs();
self.weight * tau * reweighted_l2.max(abs_exact) / (self.smoothing_eps * self.smoothing_eps)
}
}
#[must_use]
pub fn jumprelu_gate_value_grad(z: f64, tau: f64, smoothing_eps: f64) -> (f64, f64, f64) {
let g = crate::linalg::utils::stable_logistic((z - tau) / smoothing_eps);
let value = if z > tau { z } else { 0.0 };
let slope = z * g * (1.0 - g) / smoothing_eps;
let dphi_dz = g + slope;
let dphi_dtau = -slope;
(value, dphi_dz, dphi_dtau)
}
impl AnalyticPenalty for JumpReLUPenalty {
fn tier(&self) -> PenaltyTier {
PenaltyTier::Psi
}
fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
let d = self.latent_dim;
let n_obs = target.len() / d;
let mut acc = 0.0;
for row in 0..n_obs {
let base = row * d;
for axis in 0..d {
let tau = self.threshold(axis, rho);
let gate = self.sigmoid_gate((target[base + axis] - tau) / self.smoothing_eps);
acc += self.weight * tau * gate;
}
}
acc
}
fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
let d = self.latent_dim;
let n_obs = target.len() / d;
let mut grad = Array1::<f64>::zeros(target.len());
for row in 0..n_obs {
let base = row * d;
for axis in 0..d {
let tau = self.threshold(axis, rho);
let gate = self.sigmoid_gate((target[base + axis] - tau) / self.smoothing_eps);
grad[base + axis] = self.weight * tau * gate * (1.0 - gate) / self.smoothing_eps;
}
}
grad
}
fn hessian_diag(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Option<Array1<f64>> {
let d = self.latent_dim;
let n_obs = target.len() / d;
let mut diag = Array1::<f64>::zeros(target.len());
for row in 0..n_obs {
let base = row * d;
for axis in 0..d {
let tau = self.threshold(axis, rho);
let gate = self.sigmoid_gate((target[base + axis] - tau) / self.smoothing_eps);
diag[base + axis] = self.true_hessian_diag_entry(tau, gate);
}
}
Some(diag)
}
fn hvp(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
v: ArrayView1<'_, f64>,
) -> Array1<f64> {
assert_eq!(target.len(), v.len(), "hvp dimension mismatch");
let d = self.latent_dim;
let n_obs = target.len() / d;
let mut out = Array1::<f64>::zeros(target.len());
for row in 0..n_obs {
let base = row * d;
for axis in 0..d {
let tau = self.threshold(axis, rho);
let gate = self.sigmoid_gate((target[base + axis] - tau) / self.smoothing_eps);
out[base + axis] = self.true_hessian_diag_entry(tau, gate) * v[base + axis];
}
}
out
}
fn psd_majorizer_diag(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Option<Array1<f64>> {
let d = self.latent_dim;
let n_obs = target.len() / d;
let mut diag = Array1::<f64>::zeros(target.len());
for row in 0..n_obs {
let base = row * d;
for axis in 0..d {
let tau = self.threshold(axis, rho);
let gate = self.sigmoid_gate((target[base + axis] - tau) / self.smoothing_eps);
diag[base + axis] = self.psd_hessian_diag_entry(tau, gate);
}
}
Some(diag)
}
fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
let d = self.latent_dim;
let n_obs = target.len() / d;
let mut out = Array1::<f64>::zeros(d);
for axis in 0..d {
let tau = self.threshold(axis, rho);
let mut g_tau = 0.0;
for row in 0..n_obs {
let x = target[row * d + axis];
let gate = self.sigmoid_gate((x - tau) / self.smoothing_eps);
g_tau += gate - tau * gate * (1.0 - gate) / self.smoothing_eps;
}
out[axis] = self.weight * tau * g_tau;
}
out
}
fn rho_count(&self) -> usize {
self.latent_dim
}
fn name(&self) -> &str {
"jumprelu"
}
impl_scalar_apply_schedule!(weight);
}