use super::*;
#[derive(Debug, Clone)]
pub struct IBPAssignmentPenalty {
pub k_max: usize,
pub alpha: f64,
pub tau: f64,
pub temperature_schedule: Option<GumbelTemperatureSchedule>,
pub learnable_alpha: bool,
pub weight: f64,
pub weight_schedule: Option<ScalarWeightSchedule>,
}
impl IBPAssignmentPenalty {
#[must_use]
pub fn new(k_max: usize, alpha: f64, tau: f64, learnable_alpha: bool) -> Self {
assert!(k_max > 0);
assert!(alpha.is_finite() && alpha > 0.0);
assert!(tau.is_finite() && tau > 0.0);
Self {
k_max,
alpha,
tau,
temperature_schedule: None,
learnable_alpha,
weight: 1.0,
weight_schedule: None,
}
}
#[must_use]
pub fn with_temperature_schedule(mut self, schedule: GumbelTemperatureSchedule) -> Self {
self.tau = schedule.current_tau(schedule.iter_count);
self.temperature_schedule = Some(schedule);
self
}
impl_with_weight_schedule!(weight);
fn resolved_alpha(&self, rho: ArrayView1<'_, f64>) -> f64 {
if self.learnable_alpha {
resolve_learnable_weight(self.alpha, rho[0])
} else {
self.alpha
}
}
fn concrete_temperature(&self) -> f64 {
self.tau
}
fn concrete_logits(&self, target: ArrayView1<'_, f64>) -> Array1<f64> {
let tau = self.concrete_temperature();
let mut out = Array1::<f64>::zeros(target.len());
for i in 0..target.len() {
let x = target[i] / tau;
out[i] = if x >= 0.0 {
1.0 / (1.0 + (-x).exp())
} else {
let ex = x.exp();
ex / (1.0 + ex)
};
}
out
}
fn pi_map(&self, z: ArrayView1<'_, f64>, alpha: f64) -> Array1<f64> {
let n = z.len() / self.k_max;
let a = alpha / self.k_max as f64;
let mut pi = Array1::<f64>::zeros(self.k_max);
for k in 0..self.k_max {
let mut active_mass = 0.0;
for row in 0..n {
active_mass += z[row * self.k_max + k];
}
let denom = (n as f64 + a - 1.0).max(IBP_COUNT_DENOM_FLOOR);
let raw = (active_mass + a - 1.0) / denom;
pi[k] = raw.clamp(IBP_INTERIOR_TOL, 1.0 - IBP_INTERIOR_TOL);
}
pi
}
#[must_use]
pub fn hessian_diag_logit_third_channels(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> IbpHessianDiagThirdChannels {
let alpha = self.resolved_alpha(rho);
let a = alpha / self.k_max as f64;
let tau = self.concrete_temperature();
let inv_tau = 1.0 / tau;
let inv_tau2 = inv_tau * inv_tau;
let z = self.concrete_logits(target);
let pi = self.pi_map(z.view(), alpha);
let n = z.len() / self.k_max;
let denom = (n as f64 + a - 1.0).max(IBP_COUNT_DENOM_FLOOR);
let mut active_mass = Array1::<f64>::zeros(self.k_max);
for row in 0..n {
let start = row * self.k_max;
for k in 0..self.k_max {
active_mass[k] += z[start + k];
}
}
let mut score = Array1::<f64>::zeros(self.k_max);
let mut score_derivative = Array1::<f64>::zeros(self.k_max);
let mut score_second_derivative = Array1::<f64>::zeros(self.k_max);
for k in 0..self.k_max {
let pk = pi[k].clamp(IBP_PROBABILITY_CLAMP, 1.0 - IBP_PROBABILITY_CLAMP);
let mass = active_mass[k];
let raw = (mass + a - 1.0) / denom;
let pi_jac = if raw > IBP_INTERIOR_TOL && raw < 1.0 - IBP_INTERIOR_TOL {
1.0 / denom
} else {
0.0
};
let bce_pi_score = -mass / pk + (n as f64 - mass) / (1.0 - pk);
let beta_pi_score = -(a - 1.0) / pk;
let pi_score = bce_pi_score + beta_pi_score;
let pi_score_derivative = -1.0 / pk + (mass + a - 1.0) * pi_jac / (pk * pk)
- 1.0 / (1.0 - pk)
+ (n as f64 - mass) * pi_jac / ((1.0 - pk) * (1.0 - pk));
let direct_z_score = ((1.0 - pk) / pk).ln();
let implicit_pi_score = pi_score * pi_jac;
score[k] = direct_z_score + implicit_pi_score;
let direct_z_score_derivative = pi_jac * (-1.0 / pk - 1.0 / (1.0 - pk));
score_derivative[k] = direct_z_score_derivative + pi_score_derivative * pi_jac;
let one_minus = 1.0 - pk;
let ddzd = pi_jac * pi_jac * (1.0 / (pk * pk) - 1.0 / (one_minus * one_minus));
let dpisd = 2.0 / (pk * pk)
- 2.0 * (mass + a - 1.0) * pi_jac / (pk * pk * pk)
- 2.0 / (one_minus * one_minus)
+ 2.0 * (n as f64 - mass) * pi_jac / (one_minus * one_minus * one_minus);
score_second_derivative[k] = ddzd + dpisd * pi_jac;
}
let len = target.len();
let mut z_jac = Array1::<f64>::zeros(len);
let mut local_logit_third = Array1::<f64>::zeros(len);
let mut m_channel = Array1::<f64>::zeros(len);
for row in 0..n {
let start = row * self.k_max;
for k in 0..self.k_max {
let zk = z[start + k];
let jac = zk * (1.0 - zk) * inv_tau;
let c_ik = zk * (1.0 - zk) * (1.0 - 2.0 * zk) * inv_tau2;
let dz_j = (1.0 - 2.0 * zk) * inv_tau;
let dz_c = (1.0 - 6.0 * zk + 6.0 * zk * zk) * inv_tau2;
let dz_h = score_derivative[k] * 2.0 * jac * dz_j + score[k] * dz_c;
z_jac[start + k] = jac;
local_logit_third[start + k] = self.weight * jac * dz_h;
m_channel[start + k] = self.weight
* (score_second_derivative[k] * jac * jac + score_derivative[k] * c_ik);
}
}
let mut cross_row_d = Array1::<f64>::zeros(self.k_max);
for k in 0..self.k_max {
cross_row_d[k] = self.weight * score_derivative[k];
}
IbpHessianDiagThirdChannels {
k_max: self.k_max,
z_jac,
local_logit_third,
m_channel,
cross_row_d,
}
}
#[must_use]
pub fn log_alpha_target_mixed_derivative(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(target.len());
if !self.learnable_alpha {
return out;
}
let alpha = self.resolved_alpha(rho);
let a = alpha / self.k_max as f64;
let tau = self.concrete_temperature();
let z = self.concrete_logits(target);
let pi = self.pi_map(z.view(), alpha);
let n = z.len() / self.k_max;
let denom = (n as f64 + a - 1.0).max(IBP_COUNT_DENOM_FLOOR);
let mut active_mass = Array1::<f64>::zeros(self.k_max);
for row in 0..n {
let start = row * self.k_max;
for k in 0..self.k_max {
active_mass[k] += z[start + k];
}
}
let mut pi_jac = Array1::<f64>::zeros(self.k_max);
for k in 0..self.k_max {
let raw = (active_mass[k] + a - 1.0) / denom;
if raw > IBP_INTERIOR_TOL && raw < 1.0 - IBP_INTERIOR_TOL {
pi_jac[k] = 1.0 / denom;
}
}
for row in 0..n {
let start = row * self.k_max;
for k in 0..self.k_max {
let zk = z[start + k];
let z_jac = zk * (1.0 - zk) / tau;
let pk = pi[k].clamp(IBP_PROBABILITY_CLAMP, 1.0 - IBP_PROBABILITY_CLAMP);
out[start + k] = -self.weight * a * pi_jac[k] * z_jac / pk;
}
}
out
}
#[must_use]
pub fn hessian_diag_log_alpha_derivative(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(target.len());
if !self.learnable_alpha {
return out;
}
let alpha = self.resolved_alpha(rho);
let a = alpha / self.k_max as f64;
let tau = self.concrete_temperature();
let inv_tau = 1.0 / tau;
let inv_tau2 = inv_tau * inv_tau;
let z = self.concrete_logits(target);
let pi = self.pi_map(z.view(), alpha);
let n = z.len() / self.k_max;
let denom = (n as f64 + a - 1.0).max(IBP_COUNT_DENOM_FLOOR);
let mut active_mass = Array1::<f64>::zeros(self.k_max);
for row in 0..n {
let start = row * self.k_max;
for k in 0..self.k_max {
active_mass[k] += z[start + k];
}
}
let mut d_score = Array1::<f64>::zeros(self.k_max);
let mut d_score_derivative = Array1::<f64>::zeros(self.k_max);
for k in 0..self.k_max {
let pk = pi[k].clamp(IBP_PROBABILITY_CLAMP, 1.0 - IBP_PROBABILITY_CLAMP);
let mass = active_mass[k];
let raw = (mass + a - 1.0) / denom;
if raw <= IBP_INTERIOR_TOL || raw >= 1.0 - IBP_INTERIOR_TOL {
continue;
}
let one_minus = 1.0 - pk;
let dpi_da = (n as f64 - mass) / (denom * denom);
let dpi_drho = a * dpi_da;
let d_score_dpi = -1.0 / pk - 1.0 / one_minus;
d_score[k] = d_score_dpi * dpi_drho;
let inv_p = 1.0 / pk;
let inv_q = 1.0 / one_minus;
let a_channel = inv_p + inv_q;
let d_a_channel_da = dpi_da * (-inv_p * inv_p + inv_q * inv_q);
let d_score_derivative_da = a_channel / (denom * denom) - d_a_channel_da / denom;
d_score_derivative[k] = a * d_score_derivative_da;
}
for row in 0..n {
let start = row * self.k_max;
for k in 0..self.k_max {
let zk = z[start + k];
let z_jac = zk * (1.0 - zk) * inv_tau;
let z_second = zk * (1.0 - zk) * (1.0 - 2.0 * zk) * inv_tau2;
out[start + k] =
self.weight * (d_score_derivative[k] * z_jac * z_jac + d_score[k] * z_second);
}
}
out
}
}
#[derive(Debug, Clone)]
pub struct IbpHessianDiagThirdChannels {
pub k_max: usize,
pub z_jac: Array1<f64>,
pub local_logit_third: Array1<f64>,
pub m_channel: Array1<f64>,
pub cross_row_d: Array1<f64>,
}
impl AnalyticPenalty for IBPAssignmentPenalty {
fn tier(&self) -> PenaltyTier {
PenaltyTier::Psi
}
fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
let alpha = self.resolved_alpha(rho);
let a = alpha / self.k_max as f64;
let z = self.concrete_logits(target);
let pi = self.pi_map(z.view(), alpha);
let n = z.len() / self.k_max;
let mut acc = 0.0;
for row in 0..n {
let start = row * self.k_max;
for k in 0..self.k_max {
let zk = z[start + k].clamp(IBP_PROBABILITY_CLAMP, 1.0 - IBP_PROBABILITY_CLAMP);
let pk = pi[k].clamp(IBP_PROBABILITY_CLAMP, 1.0 - IBP_PROBABILITY_CLAMP);
acc -= zk * pk.ln() + (1.0 - zk) * (1.0 - pk).ln();
}
}
for k in 0..self.k_max {
acc -= a.ln();
acc -= (a - 1.0) * pi[k].ln();
}
self.weight * acc
}
fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
let alpha = self.resolved_alpha(rho);
let a = alpha / self.k_max as f64;
let tau = self.concrete_temperature();
let z = self.concrete_logits(target);
let pi = self.pi_map(z.view(), alpha);
let n = z.len() / self.k_max;
let denom = (n as f64 + a - 1.0).max(IBP_COUNT_DENOM_FLOOR);
let mut out = Array1::<f64>::zeros(target.len());
let mut active_mass = Array1::<f64>::zeros(self.k_max);
for row in 0..n {
let start = row * self.k_max;
for k in 0..self.k_max {
active_mass[k] += z[start + k];
}
}
let mut pi_score = Array1::<f64>::zeros(self.k_max);
let mut pi_jac = Array1::<f64>::zeros(self.k_max);
for k in 0..self.k_max {
let pk = pi[k].clamp(IBP_PROBABILITY_CLAMP, 1.0 - IBP_PROBABILITY_CLAMP);
let mass = active_mass[k];
let raw = (mass + a - 1.0) / denom;
if raw > IBP_INTERIOR_TOL && raw < 1.0 - IBP_INTERIOR_TOL {
pi_jac[k] = 1.0 / denom;
}
let bce_pi_score = -mass / pk + (n as f64 - mass) / (1.0 - pk);
let beta_pi_score = -(a - 1.0) / pk;
pi_score[k] = bce_pi_score + beta_pi_score;
}
for row in 0..n {
let start = row * self.k_max;
for k in 0..self.k_max {
let zk = z[start + k];
let pk = pi[k].clamp(IBP_PROBABILITY_CLAMP, 1.0 - IBP_PROBABILITY_CLAMP);
let direct_z_score = ((1.0 - pk) / pk).ln();
let implicit_pi_score = pi_score[k] * pi_jac[k];
out[start + k] =
self.weight * (direct_z_score + implicit_pi_score) * zk * (1.0 - zk) / tau;
}
}
out
}
fn hessian_diag(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Option<Array1<f64>> {
let alpha = self.resolved_alpha(rho);
let a = alpha / self.k_max as f64;
let tau = self.concrete_temperature();
let z = self.concrete_logits(target);
let pi = self.pi_map(z.view(), alpha);
let n = z.len() / self.k_max;
let mut out = Array1::<f64>::zeros(target.len());
let inv_tau2 = 1.0 / (tau * tau);
let denom = (n as f64 + a - 1.0).max(IBP_COUNT_DENOM_FLOOR);
let mut active_mass = Array1::<f64>::zeros(self.k_max);
for row in 0..n {
let start = row * self.k_max;
for k in 0..self.k_max {
active_mass[k] += z[start + k];
}
}
let mut pi_score = Array1::<f64>::zeros(self.k_max);
let mut pi_score_derivative = Array1::<f64>::zeros(self.k_max);
let mut pi_jac = Array1::<f64>::zeros(self.k_max);
for k in 0..self.k_max {
let pk = pi[k].clamp(IBP_PROBABILITY_CLAMP, 1.0 - IBP_PROBABILITY_CLAMP);
let mass = active_mass[k];
let raw = (mass + a - 1.0) / denom;
if raw > IBP_INTERIOR_TOL && raw < 1.0 - IBP_INTERIOR_TOL {
pi_jac[k] = 1.0 / denom;
}
let bce_pi_score = -mass / pk + (n as f64 - mass) / (1.0 - pk);
let beta_pi_score = -(a - 1.0) / pk;
pi_score[k] = bce_pi_score + beta_pi_score;
pi_score_derivative[k] = -1.0 / pk + (mass + a - 1.0) * pi_jac[k] / (pk * pk)
- 1.0 / (1.0 - pk)
+ (n as f64 - mass) * pi_jac[k] / ((1.0 - pk) * (1.0 - pk));
}
for row in 0..n {
let start = row * self.k_max;
for k in 0..self.k_max {
let zk = z[start + k];
let pk = pi[k].clamp(IBP_PROBABILITY_CLAMP, 1.0 - IBP_PROBABILITY_CLAMP);
let direct_z_score = ((1.0 - pk) / pk).ln();
let implicit_pi_score = pi_score[k] * pi_jac[k];
let score = direct_z_score + implicit_pi_score;
let direct_z_score_derivative = pi_jac[k] * (-1.0 / pk - 1.0 / (1.0 - pk));
let score_derivative =
direct_z_score_derivative + pi_score_derivative[k] * pi_jac[k];
let z_jac = zk * (1.0 - zk) / tau;
out[start + k] = self.weight
* (score_derivative * z_jac * z_jac
+ score * zk * (1.0 - zk) * (1.0 - 2.0 * zk) * inv_tau2);
}
}
Some(out)
}
fn hvp(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
v: ArrayView1<'_, f64>,
) -> Array1<f64> {
assert_eq!(
v.len(),
target.len(),
"IBPAssignmentPenalty::hvp dimension mismatch"
);
let alpha = self.resolved_alpha(rho);
let a = alpha / self.k_max as f64;
let tau = self.concrete_temperature();
let z = self.concrete_logits(target);
let pi = self.pi_map(z.view(), alpha);
let n = z.len() / self.k_max;
let inv_tau = 1.0 / tau;
let inv_tau2 = inv_tau * inv_tau;
let denom = (n as f64 + a - 1.0).max(IBP_COUNT_DENOM_FLOOR);
let mut active_mass = Array1::<f64>::zeros(self.k_max);
for row in 0..n {
let start = row * self.k_max;
for k in 0..self.k_max {
active_mass[k] += z[start + k];
}
}
let mut score = Array1::<f64>::zeros(self.k_max);
let mut score_derivative = Array1::<f64>::zeros(self.k_max);
for k in 0..self.k_max {
let pk = pi[k].clamp(IBP_PROBABILITY_CLAMP, 1.0 - IBP_PROBABILITY_CLAMP);
let mass = active_mass[k];
let raw = (mass + a - 1.0) / denom;
let pi_jac = if raw > IBP_INTERIOR_TOL && raw < 1.0 - IBP_INTERIOR_TOL {
1.0 / denom
} else {
0.0
};
let bce_pi_score = -mass / pk + (n as f64 - mass) / (1.0 - pk);
let beta_pi_score = -(a - 1.0) / pk;
let pi_score = bce_pi_score + beta_pi_score;
let pi_score_derivative = -1.0 / pk + (mass + a - 1.0) * pi_jac / (pk * pk)
- 1.0 / (1.0 - pk)
+ (n as f64 - mass) * pi_jac / ((1.0 - pk) * (1.0 - pk));
let direct_z_score = ((1.0 - pk) / pk).ln();
let implicit_pi_score = pi_score * pi_jac;
score[k] = direct_z_score + implicit_pi_score;
let direct_z_score_derivative = pi_jac * (-1.0 / pk - 1.0 / (1.0 - pk));
score_derivative[k] = direct_z_score_derivative + pi_score_derivative * pi_jac;
}
let mut s_per_col = Array1::<f64>::zeros(self.k_max);
for row in 0..n {
let start = row * self.k_max;
for k in 0..self.k_max {
let zk = z[start + k];
let zjac = zk * (1.0 - zk) * inv_tau;
s_per_col[k] += zjac * v[start + k];
}
}
let mut out = Array1::<f64>::zeros(target.len());
for row in 0..n {
let start = row * self.k_max;
for k in 0..self.k_max {
let zk = z[start + k];
let zjac = zk * (1.0 - zk) * inv_tau;
let rank1 = score_derivative[k] * zjac * s_per_col[k];
let c_diag = score[k] * zk * (1.0 - zk) * (1.0 - 2.0 * zk) * inv_tau2;
out[start + k] = self.weight * (rank1 + c_diag * v[start + k]);
}
}
out
}
fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
if !self.learnable_alpha {
return Array1::<f64>::zeros(0);
}
let alpha = self.resolved_alpha(rho);
let z = self.concrete_logits(target);
let pi = self.pi_map(z.view(), alpha);
let mut sum_log_pi = 0.0;
for &pk in pi.iter() {
sum_log_pi += pk
.clamp(IBP_PROBABILITY_CLAMP, 1.0 - IBP_PROBABILITY_CLAMP)
.ln();
}
Array1::from_vec(vec![
-self.weight * (alpha * sum_log_pi / self.k_max as f64 + self.k_max as f64),
])
}
fn rho_count(&self) -> usize {
usize::from(self.learnable_alpha)
}
fn name(&self) -> &str {
"ibp_assignment_map"
}
fn apply_schedule(&mut self, iter: usize) {
if let Some(schedule) = self.temperature_schedule.as_mut() {
self.tau = schedule.current_tau(iter);
schedule.iter_count = iter + 1;
}
advance_scalar_weight(&mut self.weight, &mut self.weight_schedule, iter);
}
}