impl DecoderIncoherencePenalty {
#[must_use = "build error must be handled"]
pub fn new(
target: PsiSlice,
block_sizes: Vec<usize>,
p_out: usize,
coactivation: Array2<f64>,
weight: f64,
learnable_weight: bool,
) -> Result<Self, String> {
if target.is_empty() {
return Err("DecoderIncoherencePenalty::new requires a non-empty target".to_string());
}
if !(weight.is_finite() && weight > 0.0) {
return Err(format!(
"DecoderIncoherencePenalty::new requires finite weight > 0, got {weight}"
));
}
if p_out == 0 {
return Err("DecoderIncoherencePenalty::new requires p_out > 0".to_string());
}
if block_sizes.len() < 2 {
return Err(
"DecoderIncoherencePenalty::new requires at least two atom blocks".to_string(),
);
}
let k = block_sizes.len();
if coactivation.dim() != (k, k) {
return Err(format!(
"DecoderIncoherencePenalty::new requires (K, K)=({k}, {k}) coactivation; got {:?}",
coactivation.dim()
));
}
if !coactivation
.iter()
.all(|value| value.is_finite() && *value >= 0.0)
{
return Err(
"DecoderIncoherencePenalty::new requires finite non-negative coactivation entries"
.to_string(),
);
}
let mut total = 0usize;
for (atom_idx, &m) in block_sizes.iter().enumerate() {
if m == 0 {
return Err(format!(
"DecoderIncoherencePenalty::new block_sizes[{atom_idx}] must be > 0"
));
}
let span = m.checked_mul(p_out).ok_or_else(|| {
"DecoderIncoherencePenalty::new block span overflows usize".to_string()
})?;
total = total.checked_add(span).ok_or_else(|| {
"DecoderIncoherencePenalty::new total span overflows usize".to_string()
})?;
}
if total != target.len() {
return Err(format!(
"DecoderIncoherencePenalty::new Σ_k M_k·p_out = {total} does not match target length {}",
target.len()
));
}
Ok(Self {
target,
block_sizes,
p_out,
coactivation,
weight,
learnable_weight,
rho_index: 0,
weight_schedule: None,
})
}
impl_with_weight_schedule!(weight);
fn resolved_weight(&self, rho: ArrayView1<'_, f64>) -> f64 {
if self.learnable_weight {
resolve_learnable_weight(self.weight, rho[self.rho_index])
} else {
self.weight
}
}
fn block_offsets(&self) -> Vec<usize> {
let mut out = Vec::with_capacity(self.block_sizes.len());
let mut cursor = self.target.range.start;
for &m in &self.block_sizes {
out.push(cursor);
cursor += m * self.p_out;
}
out
}
fn pair_weight(&self, j: usize, k: usize) -> f64 {
0.5 * (self.coactivation[[j, k]] + self.coactivation[[k, j]])
}
fn cross_gram(
target: ArrayView1<'_, f64>,
off_j: usize,
m_j: usize,
off_k: usize,
m_k: usize,
p_out: usize,
) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((m_j, m_k));
for a in 0..m_j {
for b in 0..m_k {
let mut s = 0.0;
for o in 0..p_out {
s += target[off_j + a * p_out + o] * target[off_k + b * p_out + o];
}
out[[a, b]] = s;
}
}
out
}
fn hvp_impl(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
v: ArrayView1<'_, f64>,
include_residual: bool,
) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(target.len());
if target.len() != self.target.len() {
return out;
}
let offsets = self.block_offsets();
let k_atoms = self.block_sizes.len();
let weight = self.resolved_weight(rho);
let p_out = self.p_out;
for j in 0..k_atoms {
for k in (j + 1)..k_atoms {
let w_pair = self.pair_weight(j, k) * weight;
if w_pair == 0.0 {
continue;
}
let off_j = offsets[j];
let off_k = offsets[k];
let m_j = self.block_sizes[j];
let m_k = self.block_sizes[k];
let mut d_c = Array2::<f64>::zeros((m_j, m_k));
for a in 0..m_j {
for b in 0..m_k {
let mut s = 0.0;
for o in 0..p_out {
s += v[off_j + a * p_out + o] * target[off_k + b * p_out + o]
+ target[off_j + a * p_out + o] * v[off_k + b * p_out + o];
}
d_c[[a, b]] = s;
}
}
let c = if include_residual {
Some(Self::cross_gram(target, off_j, m_j, off_k, m_k, p_out))
} else {
None
};
for a in 0..m_j {
for o in 0..p_out {
let mut s = 0.0;
for b in 0..m_k {
s += d_c[[a, b]] * target[off_k + b * p_out + o];
if let Some(c) = &c {
s += c[[a, b]] * v[off_k + b * p_out + o];
}
}
out[off_j + a * p_out + o] += w_pair * s;
}
}
for b in 0..m_k {
for o in 0..p_out {
let mut s = 0.0;
for a in 0..m_j {
s += d_c[[a, b]] * target[off_j + a * p_out + o];
if let Some(c) = &c {
s += c[[a, b]] * v[off_j + a * p_out + o];
}
}
out[off_k + b * p_out + o] += w_pair * s;
}
}
}
}
out
}
}
impl AnalyticPenalty for DecoderIncoherencePenalty {
fn tier(&self) -> PenaltyTier {
PenaltyTier::Beta
}
fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
if target.len() != self.target.len() {
return 0.0;
}
let offsets = self.block_offsets();
let k_atoms = self.block_sizes.len();
let mut acc = 0.0;
for j in 0..k_atoms {
for k in (j + 1)..k_atoms {
let w_pair = self.pair_weight(j, k);
if w_pair == 0.0 {
continue;
}
let c = Self::cross_gram(
target,
offsets[j],
self.block_sizes[j],
offsets[k],
self.block_sizes[k],
self.p_out,
);
let mut frob_sq = 0.0;
for &value in c.iter() {
frob_sq += value * value;
}
acc += w_pair * frob_sq;
}
}
0.5 * self.resolved_weight(rho) * acc
}
fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
let mut grad = Array1::<f64>::zeros(target.len());
if target.len() != self.target.len() {
return grad;
}
let offsets = self.block_offsets();
let k_atoms = self.block_sizes.len();
let weight = self.resolved_weight(rho);
for j in 0..k_atoms {
for k in (j + 1)..k_atoms {
let w_pair = self.pair_weight(j, k) * weight;
if w_pair == 0.0 {
continue;
}
let off_j = offsets[j];
let off_k = offsets[k];
let m_j = self.block_sizes[j];
let m_k = self.block_sizes[k];
let c = Self::cross_gram(target, off_j, m_j, off_k, m_k, self.p_out);
for a in 0..m_j {
for o in 0..self.p_out {
let mut s = 0.0;
for b in 0..m_k {
s += c[[a, b]] * target[off_k + b * self.p_out + o];
}
grad[off_j + a * self.p_out + o] += w_pair * s;
}
}
for b in 0..m_k {
for o in 0..self.p_out {
let mut s = 0.0;
for a in 0..m_j {
s += c[[a, b]] * target[off_j + a * self.p_out + o];
}
grad[off_k + b * self.p_out + o] += w_pair * s;
}
}
}
}
grad
}
fn hvp(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
v: ArrayView1<'_, f64>,
) -> Array1<f64> {
assert_eq!(target.len(), v.len(), "hvp dimension mismatch");
self.hvp_impl(target, rho, v, true)
}
fn psd_majorizer_hvp(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
v: ArrayView1<'_, f64>,
) -> Array1<f64> {
assert_eq!(
target.len(),
v.len(),
"psd_majorizer_hvp dimension mismatch"
);
self.hvp_impl(target, rho, v, false)
}
impl_learnable_weight_grad_rho!();
impl_learnable_weight_rho_count!();
fn name(&self) -> &str {
"decoder_incoherence"
}
impl_scalar_apply_schedule!(weight);
}
#[derive(Debug, Clone)]
pub struct OrthogonalityPenalty {
pub target: PsiSlice,
pub latent_dim: usize,
pub weight: f64,
pub n_eff: usize,
pub learnable_weight: bool,
pub rho_index: usize,
pub weight_schedule: Option<ScalarWeightSchedule>,
}
impl OrthogonalityPenalty {
#[must_use = "build error must be handled"]
pub fn new(
target: PsiSlice,
latent_dim: usize,
weight: f64,
n_eff: usize,
learnable_weight: bool,
) -> Result<Self, String> {
if latent_dim == 0 {
return Err("OrthogonalityPenalty::new requires latent_dim > 0".to_string());
}
if !target.len().is_multiple_of(latent_dim) {
return Err(format!(
"OrthogonalityPenalty::new target length {} is not divisible by latent_dim {}",
target.len(),
latent_dim
));
}
let n_obs = target.len() / latent_dim;
if n_obs < latent_dim {
return Err(format!(
"OrthogonalityPenalty::new requires n_obs >= latent_dim for a feasible \
Stiefel target, got n_obs {n_obs} and latent_dim {latent_dim}"
));
}
if !(weight.is_finite() && weight > 0.0) {
return Err(format!(
"OrthogonalityPenalty::new requires finite weight > 0, got {weight}"
));
}
if n_eff == 0 {
return Err("OrthogonalityPenalty::new requires n_eff > 0".to_string());
}
if n_eff != n_obs {
return Err(format!(
"OrthogonalityPenalty::new requires n_eff to match target rows, got \
n_eff {n_eff} and target rows {n_obs}"
));
}
Ok(Self {
target,
latent_dim,
weight,
n_eff,
learnable_weight,
rho_index: 0,
weight_schedule: None,
})
}
impl_with_weight_schedule!(weight);
fn resolved_weight(&self, rho: ArrayView1<'_, f64>) -> f64 {
if self.learnable_weight {
resolve_learnable_weight(self.weight, rho[self.rho_index])
} else {
self.weight
}
}
fn scale(&self, rho: ArrayView1<'_, f64>) -> f64 {
self.resolved_weight(rho) / self.n_eff as f64
}
fn target_matrix<'a>(&self, target: ArrayView1<'a, f64>) -> Option<ArrayView2<'a, f64>> {
let d = self.latent_dim;
if !target.len().is_multiple_of(d) {
assert_eq!(
target.len() % d,
0,
"target length must be divisible by latent_dim"
);
return None;
}
let n_obs = target.len() / d;
target.into_shape_with_order((n_obs, d)).ok()
}
fn gram_minus_identity(t: ArrayView2<'_, f64>) -> Array2<f64> {
let n_obs = t.nrows();
let d = t.ncols();
let mut gram = Array2::<f64>::zeros((d, d));
for a in 0..d {
for b in 0..d {
let mut s = 0.0;
for n in 0..n_obs {
s += t[[n, a]] * t[[n, b]];
}
gram[[a, b]] = s;
}
gram[[a, a]] -= 1.0;
}
gram
}
fn flatten_matrix(m: &Array2<f64>) -> Array1<f64> {
let n_obs = m.nrows();
let d = m.ncols();
let mut out = Array1::<f64>::zeros(n_obs * d);
for n in 0..n_obs {
for a in 0..d {
out[n * d + a] = m[[n, a]];
}
}
out
}
fn hvp_with_precomputed_m(
&self,
t: ArrayView2<'_, f64>,
m: ArrayView2<'_, f64>,
v: ArrayView2<'_, f64>,
scale: f64,
) -> Array2<f64> {
let n_obs = t.nrows();
let d = t.ncols();
assert_eq!(v.dim(), t.dim(), "hvp matrix dimension mismatch");
assert_eq!(m.dim(), (d, d), "precomputed gram dimension mismatch");
if v.dim() != t.dim() {
return Array2::<f64>::zeros((n_obs, d));
}
let mut vt_t_plus_tt_v = Array2::<f64>::zeros((d, d));
for c in 0..d {
for b in 0..d {
let mut s = 0.0;
for n in 0..n_obs {
s += v[[n, c]] * t[[n, b]] + t[[n, c]] * v[[n, b]];
}
vt_t_plus_tt_v[[c, b]] = s;
}
}
let mut out = Array2::<f64>::zeros((n_obs, d));
for n in 0..n_obs {
for b in 0..d {
let mut va = 0.0;
let mut tb = 0.0;
for c in 0..d {
va += v[[n, c]] * m[[c, b]];
tb += t[[n, c]] * vt_t_plus_tt_v[[c, b]];
}
out[[n, b]] = 2.0 * scale * (va + tb);
}
}
out
}
fn as_dense_with_precomputed_m(
&self,
t: ArrayView2<'_, f64>,
m: ArrayView2<'_, f64>,
scale: f64,
) -> Array2<f64> {
let n_obs = t.nrows();
let d = t.ncols();
assert_eq!(m.dim(), (d, d), "precomputed gram dimension mismatch");
if m.dim() != (d, d) {
return Array2::<f64>::zeros((n_obs * d, n_obs * d));
}
let mut dense = Array2::<f64>::zeros((n_obs * d, n_obs * d));
let factor = 2.0 * scale;
for row1 in 0..n_obs {
for row2 in 0..n_obs {
let mut row_dot = 0.0;
for axis in 0..d {
row_dot += t[[row1, axis]] * t[[row2, axis]];
}
for col1 in 0..d {
let i = row1 * d + col1;
for col2 in 0..d {
let j = row2 * d + col2;
let mut entry = t[[row1, col2]] * t[[row2, col1]];
if row1 == row2 {
entry += m[[col2, col1]];
}
if col1 == col2 {
entry += row_dot;
}
dense[[i, j]] = factor * entry;
}
}
}
}
dense
}
}
impl AnalyticPenalty for OrthogonalityPenalty {
fn tier(&self) -> PenaltyTier {
PenaltyTier::Psi
}
fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
let Some(t) = self.target_matrix(target) else {
return 0.0;
};
let gram = Self::gram_minus_identity(t.view());
let mut acc = 0.0;
for &v in gram.iter() {
acc += v * v;
}
0.5 * self.scale(rho) * acc
}
fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
let Some(t) = self.target_matrix(target) else {
return Array1::<f64>::zeros(target.len());
};
let gram = Self::gram_minus_identity(t.view());
let n_obs = t.nrows();
let d = t.ncols();
let factor = 2.0 * self.scale(rho);
let mut grad = Array2::<f64>::zeros((n_obs, d));
for n in 0..n_obs {
for a in 0..d {
let mut s = 0.0;
for b in 0..d {
s += t[[n, b]] * gram[[b, a]];
}
grad[[n, a]] = factor * s;
}
}
Self::flatten_matrix(&grad)
}
fn hvp(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
v: ArrayView1<'_, f64>,
) -> Array1<f64> {
assert_eq!(target.len(), v.len(), "hvp dimension mismatch");
if target.len() != v.len() {
return Array1::<f64>::zeros(target.len());
}
let Some(t) = self.target_matrix(target) else {
return Array1::<f64>::zeros(target.len());
};
let Some(v_mat) = self.target_matrix(v) else {
return Array1::<f64>::zeros(target.len());
};
let m = Self::gram_minus_identity(t.view());
let hv = self.hvp_with_precomputed_m(t.view(), m.view(), v_mat.view(), self.scale(rho));
Self::flatten_matrix(&hv)
}
impl_learnable_weight_grad_rho!();
impl_learnable_weight_rho_count!();
fn name(&self) -> &str {
"orthogonality"
}
impl_scalar_apply_schedule!(weight);
}
pub struct AnalyticPenaltyOp {
pub penalty: Arc<dyn AnalyticPenalty>,
}
impl AnalyticPenaltyOp {
#[must_use]
pub fn new(penalty: Arc<dyn AnalyticPenalty>) -> Self {
Self { penalty }
}
}
macro_rules! define_analytic_penalty_kind {
($(register!($variant:ident, $ty:ty);)*) => {
#[derive(Clone, Debug)]
pub enum AnalyticPenaltyKind {
$($variant(Arc<$ty>),)*
}
impl AnalyticPenaltyKind {
pub fn apply_schedule(&mut self, iter: usize) {
match self {
$(AnalyticPenaltyKind::$variant(p) => Arc::make_mut(p).apply_schedule(iter),)*
}
}
pub fn tier(&self) -> PenaltyTier {
match self {
$(AnalyticPenaltyKind::$variant(p) => p.dispatch_tier(),)*
}
}
pub fn rho_count(&self) -> usize {
match self {
$(AnalyticPenaltyKind::$variant(p) => p.rho_count(),)*
}
}
pub fn name(&self) -> &str {
match self {
$(AnalyticPenaltyKind::$variant(p) => p.name(),)*
}
}
pub fn kind_tag(&self) -> &'static str {
match self {
$(AnalyticPenaltyKind::$variant(_) => <$ty as PenaltyManifest>::KIND_TAG,)*
}
}
pub fn python_wrapper_name(&self) -> &'static str {
match self {
$(AnalyticPenaltyKind::$variant(_) => <$ty as PenaltyManifest>::PYTHON_WRAPPER,)*
}
}
pub fn is_row_block_diagonal(&self) -> bool {
match self {
$(AnalyticPenaltyKind::$variant(_) => <$ty as PenaltyManifest>::ROW_BLOCK_DIAGONAL,)*
}
}
pub fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
match self {
$(AnalyticPenaltyKind::$variant(p) => <$ty as AnalyticPenalty>::value(p, target, rho),)*
}
}
pub fn grad_target(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Array1<f64> {
match self {
$(AnalyticPenaltyKind::$variant(p) => <$ty as AnalyticPenalty>::grad_target(p, target, rho),)*
}
}
pub fn grad_rho(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Array1<f64> {
match self {
$(AnalyticPenaltyKind::$variant(p) => <$ty as AnalyticPenalty>::grad_rho(p, target, rho),)*
}
}
pub fn hessian_diag(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Option<Array1<f64>> {
match self {
$(AnalyticPenaltyKind::$variant(p) => <$ty as AnalyticPenalty>::hessian_diag(p, target, rho),)*
}
}
pub fn hvp(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
v: ArrayView1<'_, f64>,
) -> Array1<f64> {
match self {
$(AnalyticPenaltyKind::$variant(p) => <$ty as AnalyticPenalty>::hvp(p, target, rho, v),)*
}
}
pub fn psd_majorizer_diag(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Option<Array1<f64>> {
match self {
$(AnalyticPenaltyKind::$variant(p) => <$ty as AnalyticPenalty>::psd_majorizer_diag(p, target, rho),)*
}
}
pub fn psd_majorizer_hvp(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
v: ArrayView1<'_, f64>,
) -> Array1<f64> {
match self {
$(AnalyticPenaltyKind::$variant(p) => <$ty as AnalyticPenalty>::psd_majorizer_hvp(p, target, rho, v),)*
}
}
}
};
}
crate::analytic_penalty_registry!(define_analytic_penalty_kind);
impl AnalyticPenaltyKind {
pub(crate) fn isometry_scalar_weight(&self) -> Option<f64> {
match self {
AnalyticPenaltyKind::Isometry(p) => Some(p.scalar_weight),
_ => None,
}
}
pub(crate) fn set_isometry_scalar_weight(&mut self, weight: f64) {
if let AnalyticPenaltyKind::Isometry(p) = self {
Arc::make_mut(p).scalar_weight = weight;
}
}
}
#[derive(Clone, Default)]
pub struct AnalyticPenaltyRegistry {
pub penalties: Vec<AnalyticPenaltyKind>,
}
impl std::fmt::Debug for AnalyticPenaltyRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AnalyticPenaltyRegistry")
.field("penalty_count", &self.penalties.len())
.finish()
}
}
impl AnalyticPenaltyRegistry {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn push(&mut self, p: AnalyticPenaltyKind) {
self.penalties.push(p);
}
pub fn total_rho_count(&self) -> usize {
self.penalties.iter().map(|p| p.rho_count()).sum()
}
pub fn apply_weight_schedules(&mut self, iter: usize) {
for penalty in &mut self.penalties {
penalty.apply_schedule(iter);
}
}
pub fn isometry_scalar_weights(&self) -> Vec<f64> {
self.penalties
.iter()
.filter_map(AnalyticPenaltyKind::isometry_scalar_weight)
.collect()
}
pub fn set_isometry_scalar_weights(&mut self, weights: &[f64]) {
let mut idx = 0usize;
for penalty in &mut self.penalties {
if penalty.isometry_scalar_weight().is_some() {
assert!(
idx < weights.len(),
"set_isometry_scalar_weights received fewer weights than registered isometry penalties"
);
penalty.set_isometry_scalar_weight(weights[idx]);
idx += 1;
}
}
assert_eq!(
idx,
weights.len(),
"set_isometry_scalar_weights received extra weights"
);
}
pub fn rho_layout(&self) -> Vec<(std::ops::Range<usize>, PenaltyTier, &str)> {
let mut out = Vec::with_capacity(self.penalties.len());
let mut offset = 0usize;
for p in &self.penalties {
let n = p.rho_count();
out.push((offset..offset + n, p.tier(), p.name()));
offset += n;
}
out
}
}
pub struct FrozenAnalyticPenaltyOp {
penalty: AnalyticPenaltyKind,
target: Array1<f64>,
rho: Array1<f64>,
}
const ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD: usize = 1024;
const HUTCHINSON_DIAG_SAMPLES: usize = 32;
const ORTHOGONALITY_LOGDET_SLQ_PROBES: usize = 16;
const ORTHOGONALITY_LOGDET_LANCZOS_STEPS: usize = 32;
impl FrozenAnalyticPenaltyOp {
#[must_use]
pub fn new(penalty: AnalyticPenaltyKind, target: Array1<f64>, rho: Array1<f64>) -> Self {
Self {
penalty,
target,
rho,
}
}
pub fn penalty(&self) -> &AnalyticPenaltyKind {
&self.penalty
}
}
impl PenaltyOp for FrozenAnalyticPenaltyOp {
fn dim(&self) -> usize {
self.target.len()
}
fn matvec(&self, w: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
let h = self
.penalty
.psd_majorizer_hvp(self.target.view(), self.rho.view(), w);
for i in 0..h.len() {
out[i] = h[i];
}
}
fn diag(&self) -> Array1<f64> {
match &self.penalty {
AnalyticPenaltyKind::Ard(p) => p
.psd_majorizer_diag(self.target.view(), self.rho.view())
.expect("ARD diag"),
AnalyticPenaltyKind::TopKActivation(p) => p
.psd_majorizer_diag(self.target.view(), self.rho.view())
.expect("TopK activation diag"),
AnalyticPenaltyKind::JumpReLU(p) => p
.psd_majorizer_diag(self.target.view(), self.rho.view())
.expect("JumpReLU majorizer diag"),
AnalyticPenaltyKind::TotalVariation(p) => {
p.diag_target(self.target.view(), self.rho.view())
}
AnalyticPenaltyKind::BlockOrthogonality(_)
if self.dim() > ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD =>
{
self.stochastic_diag_via_matvec()
}
AnalyticPenaltyKind::BlockOrthogonality(_) => self.diag_via_matvec(),
AnalyticPenaltyKind::DecoderIncoherence(_)
if self.dim() > ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD =>
{
self.stochastic_diag_via_matvec()
}
AnalyticPenaltyKind::DecoderIncoherence(_) => self.diag_via_matvec(),
AnalyticPenaltyKind::Orthogonality(_) => self.diag_via_matvec(),
AnalyticPenaltyKind::NuclearNorm(_) => self.diag_via_matvec(),
AnalyticPenaltyKind::BlockSparsity(_)
if self.dim() > ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD =>
{
self.stochastic_diag_via_matvec()
}
AnalyticPenaltyKind::BlockSparsity(p) => {
p.diag_target(self.target.view(), self.rho.view())
}
AnalyticPenaltyKind::MechanismSparsity(_)
if self.dim() > ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD =>
{
self.stochastic_diag_via_matvec()
}
AnalyticPenaltyKind::MechanismSparsity(p) => {
p.diag_target(self.target.view(), self.rho.view())
}
AnalyticPenaltyKind::RowPrecisionPrior(p) => {
p.diag_target(self.target.view(), self.rho.view())
}
AnalyticPenaltyKind::IvaeRidgeMeanGauge(p) => {
p.diag_target(self.target.view(), self.rho.view())
}
AnalyticPenaltyKind::ParametricRowPrecisionPrior(p) => {
p.diag_target(self.target.view(), self.rho.view())
}
AnalyticPenaltyKind::ScadMcp(p) => p.diag_target(self.target.view(), self.rho.view()),
AnalyticPenaltyKind::IBPAssignment(p) => p
.psd_majorizer_diag(self.target.view(), self.rho.view())
.expect("IBP assignment diag"),
AnalyticPenaltyKind::SoftmaxAssignmentSparsity(_) => self.diag_via_matvec(),
AnalyticPenaltyKind::Sparsity(p) => {
if let Some(d) = p.psd_majorizer_diag(self.target.view(), self.rho.view()) {
d
} else {
self.diag_via_matvec()
}
}
AnalyticPenaltyKind::Isometry(_)
if self.dim() > ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD =>
{
self.stochastic_diag_via_matvec()
}
AnalyticPenaltyKind::Isometry(_) => self.diag_via_matvec(),
AnalyticPenaltyKind::NestedPrefix(p) => p
.psd_majorizer_diag(self.target.view(), self.rho.view())
.expect("NestedPrefix diag"),
AnalyticPenaltyKind::SheafConsistency(_)
if self.dim() > ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD =>
{
self.stochastic_diag_via_matvec()
}
AnalyticPenaltyKind::SheafConsistency(_) => self.diag_via_matvec(),
AnalyticPenaltyKind::Monotonicity(_)
if self.dim() > ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD =>
{
self.stochastic_diag_via_matvec()
}
AnalyticPenaltyKind::Monotonicity(_) => self.diag_via_matvec(),
}
}
fn log_det_plus_lambda_i(&self, lambda: f64) -> Result<f64, String> {
if !(lambda.is_finite() && lambda > 0.0) {
return Err(format!(
"FrozenAnalyticPenaltyOp::log_det_plus_lambda_i requires finite λ > 0; got {lambda}"
));
}
match &self.penalty {
AnalyticPenaltyKind::Ard(_)
| AnalyticPenaltyKind::TopKActivation(_)
| AnalyticPenaltyKind::JumpReLU(_)
| AnalyticPenaltyKind::Sparsity(_)
| AnalyticPenaltyKind::IBPAssignment(_)
| AnalyticPenaltyKind::NestedPrefix(_) => {
let d = self.diag();
let mut s = 0.0;
for &v in d.iter() {
let r = v + lambda;
if !r.is_finite() || r <= 0.0 {
return Err(format!(
"FrozenAnalyticPenaltyOp::log_det_plus_lambda_i: \
non-positive entry {r:.3e} after λ shift"
));
}
s += r.ln();
}
Ok(s)
}
AnalyticPenaltyKind::TotalVariation(p) => match &p.difference_op {
DifferenceOpKind::ForwardDiff1D => {
p.log_det_plus_lambda_i_forward_1d(self.target.view(), self.rho.view(), lambda)
}
DifferenceOpKind::GraphEdges(_)
if self.dim() > ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD =>
{
self.stochastic_log_det_plus_lambda_i(lambda)
}
DifferenceOpKind::GraphEdges(_) => {
let dense = p.as_dense(self.target.view(), self.rho.view());
<Array2<f64> as PenaltyOp>::log_det_plus_lambda_i(&dense, lambda)
}
},
AnalyticPenaltyKind::Orthogonality(_) => Err(
"FrozenAnalyticPenaltyOp::log_det_plus_lambda_i cannot treat \
OrthogonalityPenalty as PSD; its exact Hessian is indefinite"
.to_string(),
),
AnalyticPenaltyKind::NuclearNorm(_)
if self.dim() > ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD =>
{
self.stochastic_log_det_plus_lambda_i(lambda)
}
AnalyticPenaltyKind::IvaeRidgeMeanGauge(_)
if self.dim() > ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD =>
{
self.stochastic_log_det_plus_lambda_i(lambda)
}
AnalyticPenaltyKind::RowPrecisionPrior(p) => {
p.log_det_plus_lambda_i(self.rho.view(), lambda)
}
AnalyticPenaltyKind::ParametricRowPrecisionPrior(p) => {
p.log_det_plus_lambda_i(self.rho.view(), lambda)
}
AnalyticPenaltyKind::ScadMcp(p) => {
p.log_det_plus_lambda_i(self.target.view(), self.rho.view(), lambda)
}
AnalyticPenaltyKind::BlockSparsity(_)
if self.dim() > ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD =>
{
self.stochastic_log_det_plus_lambda_i(lambda)
}
AnalyticPenaltyKind::MechanismSparsity(_)
if self.dim() > ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD =>
{
self.stochastic_log_det_plus_lambda_i(lambda)
}
AnalyticPenaltyKind::BlockOrthogonality(_)
if self.dim() > ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD =>
{
self.stochastic_log_det_plus_lambda_i(lambda)
}
AnalyticPenaltyKind::DecoderIncoherence(_)
if self.dim() > ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD =>
{
self.stochastic_log_det_plus_lambda_i(lambda)
}
AnalyticPenaltyKind::SoftmaxAssignmentSparsity(_)
if self.dim() > ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD =>
{
self.stochastic_log_det_plus_lambda_i(lambda)
}
AnalyticPenaltyKind::Isometry(_) => {
let dense = self.as_dense();
<Array2<f64> as PenaltyOp>::log_det_plus_lambda_i(&dense, lambda)
}
AnalyticPenaltyKind::SheafConsistency(_)
if self.dim() > ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD =>
{
self.stochastic_log_det_plus_lambda_i(lambda)
}
AnalyticPenaltyKind::Monotonicity(_)
if self.dim() > ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD =>
{
self.stochastic_log_det_plus_lambda_i(lambda)
}
AnalyticPenaltyKind::NuclearNorm(_)
| AnalyticPenaltyKind::BlockSparsity(_)
| AnalyticPenaltyKind::MechanismSparsity(_)
| AnalyticPenaltyKind::IvaeRidgeMeanGauge(_)
| AnalyticPenaltyKind::BlockOrthogonality(_)
| AnalyticPenaltyKind::DecoderIncoherence(_)
| AnalyticPenaltyKind::SoftmaxAssignmentSparsity(_)
| AnalyticPenaltyKind::SheafConsistency(_)
| AnalyticPenaltyKind::Monotonicity(_) => {
let dense = self.as_dense();
<Array2<f64> as PenaltyOp>::log_det_plus_lambda_i(&dense, lambda)
}
}
}
fn as_dense(&self) -> Array2<f64> {
match &self.penalty {
AnalyticPenaltyKind::TotalVariation(p) => {
return p.as_dense(self.target.view(), self.rho.view());
}
AnalyticPenaltyKind::BlockSparsity(p) => {
return p.as_dense(self.target.view(), self.rho.view());
}
AnalyticPenaltyKind::MechanismSparsity(p) => {
return p.as_dense(self.target.view(), self.rho.view());
}
AnalyticPenaltyKind::BlockOrthogonality(p) => {
return p.as_dense(self.target.view(), self.rho.view());
}
AnalyticPenaltyKind::RowPrecisionPrior(p) => {
return p.as_dense(self.target.view(), self.rho.view());
}
AnalyticPenaltyKind::IvaeRidgeMeanGauge(p) => {
return p.as_dense(self.target.view(), self.rho.view());
}
AnalyticPenaltyKind::ParametricRowPrecisionPrior(p) => {
return p.as_dense(self.target.view(), self.rho.view());
}
AnalyticPenaltyKind::Orthogonality(p) => {
let n = self.target.len();
let Some(t) = p.target_matrix(self.target.view()) else {
return Array2::<f64>::zeros((n, n));
};
let gram = OrthogonalityPenalty::gram_minus_identity(t.view());
return p.as_dense_with_precomputed_m(
t.view(),
gram.view(),
p.scale(self.rho.view()),
);
}
AnalyticPenaltyKind::Isometry(p) => {
let n = self.target.len();
let Some(state) = p.hvp_state(self.target.view()) else {
return Array2::<f64>::zeros((n, n));
};
let mut dense = Array2::<f64>::zeros((n, n));
let mut e = Array1::<f64>::zeros(n);
for j in 0..n {
e[j] = 1.0;
let col = p.hvp_with_precomputed_state(&state, self.rho.view(), e.view());
for i in 0..n {
dense[[i, j]] = col[i];
}
e[j] = 0.0;
}
return dense;
}
_ => {}
}
let n = self.target.len();
let mut m = Array2::<f64>::zeros((n, n));
let mut e = Array1::<f64>::zeros(n);
for j in 0..n {
e[j] = 1.0;
let col = self
.penalty
.psd_majorizer_hvp(self.target.view(), self.rho.view(), e.view());
for i in 0..n {
m[[i, j]] = col[i];
}
e[j] = 0.0;
}
m
}
}
impl FrozenAnalyticPenaltyOp {
fn diag_via_matvec(&self) -> Array1<f64> {
match &self.penalty {
AnalyticPenaltyKind::Orthogonality(p) => {
let n = self.target.len();
let Some(t) = p.target_matrix(self.target.view()) else {
return Array1::<f64>::zeros(n);
};
let latent_dim = t.ncols();
let gram = OrthogonalityPenalty::gram_minus_identity(t.view());
let scale = p.scale(self.rho.view());
let factor = 2.0 * scale;
let mut diag = Array1::<f64>::zeros(n);
for row in 0..t.nrows() {
let mut row_norm_sq = 0.0;
for col in 0..latent_dim {
row_norm_sq += t[[row, col]] * t[[row, col]];
}
for col in 0..latent_dim {
let i = row * latent_dim + col;
diag[i] = factor
* (gram[[col, col]] + t[[row, col]] * t[[row, col]] + row_norm_sq);
}
}
return diag;
}
AnalyticPenaltyKind::Isometry(p) => {
let n = self.target.len();
let Some(state) = p.hvp_state(self.target.view()) else {
return Array1::<f64>::zeros(n);
};
let mut d = Array1::<f64>::zeros(n);
let mut e = Array1::<f64>::zeros(n);
for i in 0..n {
e[i] = 1.0;
let h = p.hvp_with_precomputed_state(&state, self.rho.view(), e.view());
d[i] = h[i];
e[i] = 0.0;
}
return d;
}
_ => {}
}
let n = self.target.len();
let mut d = Array1::<f64>::zeros(n);
let mut e = Array1::<f64>::zeros(n);
for i in 0..n {
e[i] = 1.0;
let h = self
.penalty
.psd_majorizer_hvp(self.target.view(), self.rho.view(), e.view());
d[i] = h[i];
e[i] = 0.0;
}
d
}
fn stochastic_diag_via_matvec(&self) -> Array1<f64> {
match &self.penalty {
AnalyticPenaltyKind::Orthogonality(p) => {
let n = self.target.len();
let Some(t) = p.target_matrix(self.target.view()) else {
return Array1::<f64>::zeros(n);
};
let gram = OrthogonalityPenalty::gram_minus_identity(t.view());
let scale = p.scale(self.rho.view());
let samples = HUTCHINSON_DIAG_SAMPLES.max(1);
let mut diag = Array1::<f64>::zeros(n);
let mut z = Array1::<f64>::zeros(n);
for probe in 0..samples {
rademacher_unit_probe_into(z.view_mut(), probe as u64, 1.0);
let Some(z_mat) = p.target_matrix(z.view()) else {
return diag;
};
let hz = p.hvp_with_precomputed_m(t.view(), gram.view(), z_mat, scale);
for i in 0..n {
diag[i] += z[i] * hz[[i / t.ncols(), i % t.ncols()]];
}
}
let inv_samples = 1.0 / samples as f64;
for i in 0..n {
diag[i] *= inv_samples;
}
return diag;
}
AnalyticPenaltyKind::Isometry(p) => {
let n = self.target.len();
let Some(state) = p.hvp_state(self.target.view()) else {
return Array1::<f64>::zeros(n);
};
let samples = HUTCHINSON_DIAG_SAMPLES.max(1);
let mut diag = Array1::<f64>::zeros(n);
let mut z = Array1::<f64>::zeros(n);
for probe in 0..samples {
rademacher_unit_probe_into(z.view_mut(), probe as u64, 1.0);
let hz = p.hvp_with_precomputed_state(&state, self.rho.view(), z.view());
for i in 0..n {
diag[i] += z[i] * hz[i];
}
}
let inv_samples = 1.0 / samples as f64;
for i in 0..n {
diag[i] *= inv_samples;
}
return diag;
}
_ => {}
}
let n = self.target.len();
let samples = HUTCHINSON_DIAG_SAMPLES.max(1);
let mut diag = Array1::<f64>::zeros(n);
let mut z = Array1::<f64>::zeros(n);
let mut hz = Array1::<f64>::zeros(n);
for probe in 0..samples {
rademacher_unit_probe_into(z.view_mut(), probe as u64, 1.0);
self.matvec(z.view(), hz.view_mut());
for i in 0..n {
diag[i] += z[i] * hz[i];
}
}
let inv_samples = 1.0 / samples as f64;
for i in 0..n {
diag[i] *= inv_samples;
}
diag
}
fn stochastic_log_det_plus_lambda_i(&self, lambda: f64) -> Result<f64, String> {
let n = self.dim();
if n == 0 {
return Ok(0.0);
}
let probes = ORTHOGONALITY_LOGDET_SLQ_PROBES.max(1);
let steps = ORTHOGONALITY_LOGDET_LANCZOS_STEPS.min(n).max(1);
let inv_norm = 1.0 / (n as f64).sqrt();
let mut estimate = 0.0;
for probe in 0..probes {
let mut q0 = Array1::<f64>::zeros(n);
rademacher_unit_probe_into(q0.view_mut(), probe as u64, inv_norm);
let quad = self.lanczos_log_quadrature(lambda, q0, steps)?;
estimate += n as f64 * quad;
}
Ok(estimate / probes as f64)
}
fn lanczos_log_quadrature(
&self,
lambda: f64,
q: Array1<f64>,
max_steps: usize,
) -> Result<f64, String> {
let n = self.dim();
let eigen = symmetric_lanczos_eigenpairs(
n,
q.as_slice().ok_or_else(|| {
"FrozenAnalyticPenaltyOp::log_det_plus_lambda_i SLQ start vector is not contiguous"
.to_string()
})?,
SymmetricLanczosOptions {
max_steps,
residual_tol: 1e-12,
local_reorthogonalize: false,
full_reorthogonalize: false,
},
|q, out| {
self.matvec(ArrayView1::from(q), ArrayViewMut1::from(&mut *out));
for i in 0..n {
out[i] += lambda * q[i];
}
Ok(())
},
)
.map_err(|e| {
format!("FrozenAnalyticPenaltyOp::log_det_plus_lambda_i SLQ Lanczos failed: {e}")
})?;
symmetric_lanczos_log_quadrature(
&eigen,
"FrozenAnalyticPenaltyOp::log_det_plus_lambda_i expected SPD S+λI",
)
}
}
fn rademacher_unit_probe_into(mut z: ArrayViewMut1<'_, f64>, probe: u64, scale: f64) {
let mut state = 0x6A09E667F3BCC909_u64 ^ probe.wrapping_mul(0xD1B54A32D192ED03);
let mut bits = 0_u64;
let mut remaining_bits = 0_u32;
for i in 0..z.len() {
if remaining_bits == 0 {
bits = splitmix64(&mut state);
remaining_bits = 64;
}
z[i] = if bits & 1 == 0 { scale } else { -scale };
bits >>= 1;
remaining_bits -= 1;
}
}
#[inline]
const fn splitmix64(state: &mut u64) -> u64 {
crate::linalg::utils::splitmix64(state)
}
impl AnalyticPenaltyKind {
#[must_use]
pub fn freeze(&self, target: Array1<f64>, rho: Array1<f64>) -> Arc<dyn PenaltyOp> {
Arc::new(FrozenAnalyticPenaltyOp::new(self.clone(), target, rho))
}
}
#[derive(Debug, Clone)]
pub struct NestedPrefixPenalty {
pub target: PsiSlice,
pub target_tier: PenaltyTier,
pub prefix_sizes: Vec<usize>,
pub shell_weights: Vec<f64>,
pub eps: f64,
pub rho_indices: Vec<usize>,
pub weight_schedule: Option<ScalarWeightSchedule>,
}
impl NestedPrefixPenalty {
#[must_use = "build error must be handled"]
pub fn new(
target: PsiSlice,
target_tier: PenaltyTier,
prefix_sizes: Vec<usize>,
shell_weights: Vec<f64>,
eps: f64,
) -> Result<Self, String> {
if prefix_sizes.is_empty() {
return Err("NestedPrefixPenalty requires at least one prefix".into());
}
if shell_weights.len() != prefix_sizes.len() {
return Err(format!(
"NestedPrefixPenalty requires shell_weights.len() == prefix_sizes.len(); \
got {} weights for {} prefixes",
shell_weights.len(),
prefix_sizes.len()
));
}
for w in &shell_weights {
if !w.is_finite() || *w < 0.0 {
return Err(format!(
"NestedPrefixPenalty shell weights must be finite and ≥ 0; got {w}"
));
}
}
for i in 0..prefix_sizes.len() {
if prefix_sizes[i] == 0 {
return Err("NestedPrefixPenalty prefixes must be > 0".into());
}
if i > 0 && prefix_sizes[i] <= prefix_sizes[i - 1] {
return Err(format!(
"NestedPrefixPenalty prefixes must be strictly increasing; got {:?}",
prefix_sizes
));
}
}
if let Some(d) = target.latent_dim {
let max_prefix = *prefix_sizes.last().expect("non-empty");
if max_prefix > d {
return Err(format!(
"NestedPrefixPenalty largest prefix {max_prefix} exceeds latent_dim {d}"
));
}
}
if !(eps.is_finite() && eps > 0.0) {
return Err(format!(
"NestedPrefixPenalty requires eps > 0 (1/sqrt(x²+ε²) singularity at 0); got {eps}"
));
}
let rho_indices = (0..prefix_sizes.len()).collect();
Ok(Self {
target,
target_tier,
prefix_sizes,
shell_weights,
eps,
rho_indices,
weight_schedule: None,
})
}
#[must_use]
pub fn with_weight_schedule(mut self, schedule: ScalarWeightSchedule) -> Self {
self.weight_schedule = Some(schedule);
self
}
fn latent_dim(&self) -> usize {
self.target
.latent_dim
.unwrap_or_else(|| *self.prefix_sizes.last().expect("non-empty"))
}
fn lambdas(&self, rho: ArrayView1<'_, f64>) -> Vec<f64> {
self.prefix_sizes
.iter()
.enumerate()
.map(|(k, _)| resolve_learnable_weight(self.shell_weights[k], rho[self.rho_indices[k]]))
.collect()
}
fn per_axis_weights(&self, lambdas: &[f64]) -> Vec<f64> {
let f = self.latent_dim();
let mut w = vec![0.0_f64; f];
for (k, &m_k) in self.prefix_sizes.iter().enumerate() {
let lam = lambdas[k];
if lam == 0.0 {
continue;
}
let end = m_k.min(f);
for entry in w.iter_mut().take(end) {
*entry += lam;
}
}
w
}
}
impl AnalyticPenalty for NestedPrefixPenalty {
fn tier(&self) -> PenaltyTier {
self.target_tier
}
fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
let f = self.latent_dim();
assert!(
target.len().is_multiple_of(f),
"target length must be n_rows · F"
);
let n_rows = target.len() / f;
let lambdas = self.lambdas(rho);
let eps2 = self.eps * self.eps;
let mut s_axis = vec![0.0_f64; f];
for n in 0..n_rows {
let row = &target.as_slice().expect("contiguous")[n * f..(n + 1) * f];
for (i, &x) in row.iter().enumerate() {
s_axis[i] += (x * x + eps2).sqrt();
}
}
let mut total = 0.0;
for (k, &m_k) in self.prefix_sizes.iter().enumerate() {
let end = m_k.min(f);
let mut acc = 0.0;
for &v in s_axis.iter().take(end) {
acc += v;
}
total += lambdas[k] * acc;
}
total
}
fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
let f = self.latent_dim();
let n_rows = target.len() / f;
let lambdas = self.lambdas(rho);
let w_per_axis = self.per_axis_weights(&lambdas);
let eps2 = self.eps * self.eps;
let src = target.as_slice().expect("contiguous");
let mut g = Array1::<f64>::zeros(target.len());
let g_slice = g.as_slice_mut().expect("contiguous");
for n in 0..n_rows {
for i in 0..f {
let x = src[n * f + i];
let w = w_per_axis[i];
if w == 0.0 {
continue;
}
g_slice[n * f + i] = w * x / (x * x + eps2).sqrt();
}
}
g
}
fn hessian_diag(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Option<Array1<f64>> {
let f = self.latent_dim();
let n_rows = target.len() / f;
let lambdas = self.lambdas(rho);
let w_per_axis = self.per_axis_weights(&lambdas);
let eps2 = self.eps * self.eps;
let src = target.as_slice().expect("contiguous");
let mut d = Array1::<f64>::zeros(target.len());
let d_slice = d.as_slice_mut().expect("contiguous");
for n in 0..n_rows {
for i in 0..f {
let w = w_per_axis[i];
if w == 0.0 {
continue;
}
let x = src[n * f + i];
let r = (x * x + eps2).sqrt();
d_slice[n * f + i] = w * eps2 / (r * r * r);
}
}
Some(d)
}
fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
let f = self.latent_dim();
let n_rows = target.len() / f;
let lambdas = self.lambdas(rho);
let eps2 = self.eps * self.eps;
let mut s_axis = vec![0.0_f64; f];
let src = target.as_slice().expect("contiguous");
for n in 0..n_rows {
for i in 0..f {
let x = src[n * f + i];
s_axis[i] += (x * x + eps2).sqrt();
}
}
let n_rho = self.rho_count();
let mut out = Array1::<f64>::zeros(n_rho);
for (k, &m_k) in self.prefix_sizes.iter().enumerate() {
let end = m_k.min(f);
let mut shell_sum = 0.0;
for &v in s_axis.iter().take(end) {
shell_sum += v;
}
out[self.rho_indices[k]] = lambdas[k] * shell_sum;
}
out
}
fn rho_count(&self) -> usize {
self.prefix_sizes.len()
}
fn name(&self) -> &str {
"nested_prefix"
}
fn apply_schedule(&mut self, iter: usize) {
if let Some(schedule) = self.weight_schedule.as_mut() {
let prev = schedule.current_weight(schedule.iter_count);
let next = schedule.current_weight(iter);
if prev > 0.0 {
let ratio = next / prev;
for w in &mut self.shell_weights {
*w *= ratio;
}
}
schedule.iter_count = iter + 1;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use ndarray::{array, s};
#[test]
fn isometry_value_is_decoder_scale_invariant() {
let n_obs = 2;
let d = 2;
let p = 3;
let target = PsiSlice::full(n_obs * d, Some(d));
let pen = IsometryPenalty::new_euclidean(target, p);
let mut j = Array2::<f64>::zeros((n_obs, p * d));
for n in 0..n_obs {
for i in 0..p {
for a in 0..d {
j[[n, i * d + a]] = 0.4 + 0.2 * n as f64 - 0.1 * i as f64 + 0.3 * a as f64;
}
}
}
let mut j_scaled = j.clone();
for value in j_scaled.iter_mut() {
*value *= 17.0;
}
let rho = array![0.0_f64];
let t = Array1::<f64>::zeros(n_obs * d);
pen.set_jacobian_cache(Some(Arc::new(j)));
let value = pen.value(t.view(), rho.view());
pen.set_jacobian_cache(Some(Arc::new(j_scaled)));
let scaled_value = pen.value(t.view(), rho.view());
assert_abs_diff_eq!(value, scaled_value, epsilon = 1e-10);
}
#[test]
fn ard_value_matches_quadratic_form() {
let d = 2;
let t = array![0.5_f64, 1.0, 2.0, -1.0, 0.0, 3.0];
let target = PsiSlice::full(t.len(), Some(d));
let ard = ARDPenalty::new(target, d);
let rho = array![0.0_f64, 0.0]; let v = ard.value(t.view(), rho.view());
assert!((v - 0.5 * (4.25 + 11.0)).abs() < 1e-12);
}
#[test]
fn smoothed_l1_grad_smoothes_signum_at_zero() {
let p = SparsityPenalty::smoothed_l1(PenaltyTier::Beta, 1e-3)
.expect("positive eps builds smoothed L1 penalty");
let t = array![0.0_f64, 1.0, -2.0];
let rho = array![0.0_f64];
let g = p.grad_target(t.view(), rho.view());
assert!(g[0].abs() < 1e-9);
assert!((g[1] - 1.0).abs() < 1e-3);
assert!((g[2] - (-1.0)).abs() < 1e-3);
}
#[test]
fn softmax_assignment_hvp_matches_gradient_directional_derivative() {
let pen = SoftmaxAssignmentSparsityPenalty::new(3, 0.7);
let t = array![0.4_f64, -0.8, 1.3, -0.2, 0.9, 0.1];
let rho = array![1.4_f64.ln()];
let v = array![0.2_f64, -0.5, 0.7, -0.3, 0.4, 0.6];
let h_diag = pen
.hessian_diag(t.view(), rho.view())
.expect("softmax entropy diagonal is analytic via row-dense HVP at e_k");
for i in 0..t.len() {
let mut e_i = Array1::<f64>::zeros(t.len());
e_i[i] = 1.0;
let hv_i = pen.hvp(t.view(), rho.view(), e_i.view());
assert_abs_diff_eq!(h_diag[i], hv_i[i], epsilon = 1e-10);
}
let hv = pen.hvp(t.view(), rho.view(), v.view());
let eps = 1e-6;
let mut tp = t.clone();
let mut tm = t.clone();
for i in 0..t.len() {
tp[i] += eps * v[i];
tm[i] -= eps * v[i];
}
let gp = pen.grad_target(tp.view(), rho.view());
let gm = pen.grad_target(tm.view(), rho.view());
for i in 0..t.len() {
let fd = (gp[i] - gm[i]) / (2.0 * eps);
assert_abs_diff_eq!(hv[i], fd, epsilon = 1e-6);
}
}
#[test]
fn softmax_row_dense_hessian_matches_hvp_and_diagonal() {
let pen = SoftmaxAssignmentSparsityPenalty::new(4, 0.7);
let row = [0.4_f64, -0.8, 1.3, -0.2];
let lambda = 1.4_f64;
let rho = array![lambda.ln()];
let inv_tau2 = (1.0 / 0.7_f64) * (1.0 / 0.7_f64);
let scale = lambda * inv_tau2;
let h = pen.row_dense_hessian(&row, scale);
let full: Vec<f64> = row.to_vec();
let diag = pen
.hessian_diag(Array1::from_vec(full.clone()).view(), rho.view())
.expect("diag");
for k in 0..4 {
assert_abs_diff_eq!(h[[k, k]], diag[k], epsilon = 1e-10);
}
for i in 0..4 {
for j in 0..4 {
assert_abs_diff_eq!(h[[i, j]], h[[j, i]], epsilon = 1e-12);
}
}
let v = array![0.2_f64, -0.5, 0.7, -0.3];
let hv = pen.hvp(Array1::from_vec(full.clone()).view(), rho.view(), v.view());
for i in 0..4 {
let acc: f64 = (0..4).map(|j| h[[i, j]] * v[j]).sum();
assert_abs_diff_eq!(acc, hv[i], epsilon = 1e-9);
}
for i in 0..4 {
let row_sum: f64 = (0..4).map(|j| h[[i, j]]).sum();
assert_abs_diff_eq!(row_sum, 0.0, epsilon = 1e-10);
}
}
#[test]
fn softmax_row_dense_hessian_logit_derivative_matches_finite_difference() {
let pen = SoftmaxAssignmentSparsityPenalty::new(4, 0.8);
let row = [0.3_f64, -0.6, 0.9, 0.2];
let scale = 1.1_f64 * (1.0 / 0.8_f64) * (1.0 / 0.8_f64);
let eps = 1e-6;
for w in 0..4 {
let dh = pen.row_dense_hessian_logit_derivative(&row, scale, w);
let mut rp = row;
let mut rm = row;
rp[w] += eps;
rm[w] -= eps;
let hp = pen.row_dense_hessian(&rp, scale);
let hm = pen.row_dense_hessian(&rm, scale);
for i in 0..4 {
for j in 0..4 {
let fd = (hp[[i, j]] - hm[[i, j]]) / (2.0 * eps);
assert_abs_diff_eq!(dh[[i, j]], fd, epsilon = 1e-6);
}
}
}
}
#[test]
fn ibp_assignment_grad_target_matches_value_finite_difference() {
let pen = IBPAssignmentPenalty::new(4, 6.0, 0.8, false);
let t = array![
0.2_f64, -0.3, 0.7, -0.5, 0.9, 0.4, -0.2, 0.1, -0.4, 0.8, 0.3, -0.1
];
let rho = Array1::<f64>::zeros(0);
let g = pen.grad_target(t.view(), rho.view());
let eps = 1.0e-6;
let mut max_err = 0.0_f64;
for i in 0..t.len() {
let mut tp = t.clone();
let mut tm = t.clone();
tp[i] += eps;
tm[i] -= eps;
let fd =
(pen.value(tp.view(), rho.view()) - pen.value(tm.view(), rho.view())) / (2.0 * eps);
let err = (g[i] - fd).abs();
if err > max_err {
max_err = err;
}
assert_abs_diff_eq!(g[i], fd, epsilon = 1.0e-7);
}
assert!(
max_err < 1.0e-7,
"IBP grad-FD max abs error = {max_err:.3e}"
);
}
#[test]
fn ibp_cross_row_woodbury_d_matches_full_off_diagonal_hessian() {
let pen = IBPAssignmentPenalty::new(3, 5.0, 0.85, false);
let t = array![
0.3_f64, -0.2, 0.6, 0.5, 0.1, -0.4, -0.1, 0.7, 0.2, 0.4, -0.3, 0.8
];
let rho = Array1::<f64>::zeros(0);
let k = pen.k_max;
let n = t.len() / k;
let ch = pen.hessian_diag_logit_third_channels(t.view(), rho.view());
let eps = 1.0e-5;
let mut max_err = 0.0_f64;
let mixed_fd = |a: usize, b: usize| -> f64 {
let bump = |sa: f64, sb: f64| -> Array1<f64> {
let mut tt = t.clone();
tt[a] += sa * eps;
tt[b] += sb * eps;
tt
};
(pen.value(bump(1.0, 1.0).view(), rho.view())
- pen.value(bump(1.0, -1.0).view(), rho.view())
- pen.value(bump(-1.0, 1.0).view(), rho.view())
+ pen.value(bump(-1.0, -1.0).view(), rho.view()))
/ (4.0 * eps * eps)
};
for col in 0..k {
for i in 0..n {
for j in 0..n {
if i == j {
continue;
}
let analytic =
ch.cross_row_d[col] * ch.z_jac[i * k + col] * ch.z_jac[j * k + col];
let fd = mixed_fd(i * k + col, j * k + col);
let err = (analytic - fd).abs();
if err > max_err {
max_err = err;
}
assert_abs_diff_eq!(analytic, fd, epsilon = 5.0e-5);
}
}
}
let mixed_distinct = mixed_fd(0, k + 1);
assert!(
mixed_distinct.abs() < 5.0e-5,
"distinct-column cross-row coupling must vanish; got {mixed_distinct:.3e}"
);
assert!(
max_err < 5.0e-5,
"IBP cross-row Woodbury d·z'·z' vs FD max abs error = {max_err:.3e}"
);
}
#[test]
fn ibp_assignment_learnable_alpha_grad_rho_matches_value_finite_difference() {
let pen = IBPAssignmentPenalty::new(3, 6.0, 0.8, true);
let t = array![
0.2_f64, -0.3, 0.7, -0.1, 0.4, 0.5, 0.6, -0.2, 0.3, 0.1, 0.8, -0.4
];
let rho = array![0.2_f64];
let grad = pen.grad_rho(t.view(), rho.view());
let step = 1.0e-6_f64;
let rho_plus = array![rho[0] + step];
let rho_minus = array![rho[0] - step];
let fd = (pen.value(t.view(), rho_plus.view()) - pen.value(t.view(), rho_minus.view()))
/ (2.0 * step);
assert_abs_diff_eq!(grad[0], fd, epsilon = 2.0e-7);
}
#[test]
fn ibp_assignment_learnable_alpha_mixed_log_alpha_target_matches_fd() {
let pen = IBPAssignmentPenalty::new(2, 2.0, 0.9, true);
let t = array![0.2_f64, -0.3, 0.7, -0.1, 0.4, 0.5];
let rho = array![0.15_f64];
let analytic = pen.log_alpha_target_mixed_derivative(t.view(), rho.view());
let step = 1.0e-6_f64;
for i in 0..t.len() {
let mut tp = t.clone();
let mut tm = t.clone();
tp[i] += step;
tm[i] -= step;
let gp = pen.grad_rho(tp.view(), rho.view())[0];
let gm = pen.grad_rho(tm.view(), rho.view())[0];
let fd = (gp - gm) / (2.0 * step);
assert_abs_diff_eq!(analytic[i], fd, epsilon = 2.0e-7);
}
}
#[test]
fn ibp_assignment_learnable_alpha_hdiag_log_alpha_derivative_matches_fd() {
let pen = IBPAssignmentPenalty::new(2, 2.0, 0.9, true);
let t = array![0.2_f64, -0.3, 0.7, -0.1, 0.4, 0.5];
let rho = array![0.15_f64];
let analytic = pen.hessian_diag_log_alpha_derivative(t.view(), rho.view());
let step = 1.0e-6_f64;
let rho_plus = array![rho[0] + step];
let rho_minus = array![rho[0] - step];
let hp = pen
.hessian_diag(t.view(), rho_plus.view())
.expect("IBP hessian diag exists");
let hm = pen
.hessian_diag(t.view(), rho_minus.view())
.expect("IBP hessian diag exists");
for i in 0..t.len() {
let fd = (hp[i] - hm[i]) / (2.0 * step);
assert_abs_diff_eq!(analytic[i], fd, epsilon = 2.0e-7);
}
}
#[test]
fn ibp_assignment_extreme_logits_remain_finite() {
let pen = IBPAssignmentPenalty::new(3, 1.5, 1.0e-3, false);
let t = array![
1000.0_f64, -1000.0, 500.0, -500.0, 750.0, -750.0, 250.0, -250.0, 0.0
];
let rho = Array1::<f64>::zeros(0);
let value = pen.value(t.view(), rho.view());
assert!(
value.is_finite(),
"IBP value must remain finite for saturated concrete logits"
);
let grad = pen.grad_target(t.view(), rho.view());
assert!(
grad.iter().all(|entry| entry.is_finite()),
"IBP gradient must remain finite for saturated concrete logits: {grad:?}"
);
let diag = pen
.hessian_diag(t.view(), rho.view())
.expect("IBP assignment exposes a diagonal Hessian");
assert!(
diag.iter().all(|entry| entry.is_finite()),
"IBP Hessian diagonal must remain finite for saturated concrete logits: {diag:?}"
);
}
#[test]
fn ibp_assignment_high_k_prior_keeps_positive_gradient_path() {
let k = 400usize;
let pen = IBPAssignmentPenalty::new(k, 0.1, 1.0, false);
let t = Array1::<f64>::zeros(k);
let rho = Array1::<f64>::zeros(0);
let value = pen.value(t.view(), rho.view());
assert!(value.is_finite(), "high-K IBP value must stay finite");
let grad = pen.grad_target(t.view(), rho.view());
assert_eq!(grad.len(), k);
assert!(
grad.iter().all(|entry| entry.is_finite()),
"high-K IBP gradient must stay finite: {grad:?}"
);
assert!(
grad.slice(s![320..]).iter().any(|entry| entry.abs() > 0.0),
"late high-K atoms must retain a strictly positive gradient path"
);
}
#[test]
fn learnable_weights_stay_finite_at_extreme_rho() {
for rho in [1000.0_f64, -1000.0] {
let resolved = resolve_learnable_weight(0.7, rho);
assert!(
resolved.is_finite() && resolved > 0.0,
"resolved learnable weight must be finite-positive at rho={rho}: {resolved}"
);
}
let softmax = SoftmaxAssignmentSparsityPenalty::new(3, 0.8);
let logits = array![0.2_f64, -0.1, 0.4];
for rho in [array![1000.0_f64], array![-1000.0_f64]] {
let value = softmax.value(logits.view(), rho.view());
let grad = softmax.grad_target(logits.view(), rho.view());
let diag = softmax
.hessian_diag(logits.view(), rho.view())
.expect("softmax entropy exposes a diagonal Hessian");
assert!(value.is_finite(), "softmax value non-finite at rho={rho:?}");
assert!(grad.iter().all(|entry| entry.is_finite()));
assert!(diag.iter().all(|entry| entry.is_finite()));
}
let jump =
JumpReLUPenalty::new(PsiSlice::full(2, Some(1)), array![1.0_f64], 0.5, 0.1).unwrap();
let jump_target = array![0.0_f64, 0.2];
for rho in [array![1000.0_f64], array![-1000.0_f64]] {
let value = jump.value(jump_target.view(), rho.view());
let grad = jump.grad_target(jump_target.view(), rho.view());
let diag = jump
.hessian_diag(jump_target.view(), rho.view())
.expect("JumpReLU exposes a diagonal Hessian");
assert!(
value.is_finite(),
"JumpReLU value non-finite at rho={rho:?}"
);
assert!(grad.iter().all(|entry| entry.is_finite()));
assert!(diag.iter().all(|entry| entry.is_finite()));
}
let target = PsiSlice {
range: 0..4,
latent_dim: Some(2),
};
let block_sizes = vec![1usize, 1usize];
let p = 2usize;
let coact = Array2::<f64>::ones((2, 2));
let decoder =
DecoderIncoherencePenalty::new(target, block_sizes, p, coact, 0.7, true).unwrap();
let beta = Array1::<f64>::zeros(4);
for rho in [array![1000.0_f64], array![-1000.0_f64]] {
let value = decoder.value(beta.view(), rho.view());
let grad = decoder.grad_target(beta.view(), rho.view());
let hv = decoder.hvp(beta.view(), rho.view(), beta.view());
assert!(
value.is_finite(),
"DecoderIncoherence value non-finite at rho={rho:?}"
);
assert!(grad.iter().all(|entry| entry.is_finite()));
assert!(hv.iter().all(|entry| entry.is_finite()));
}
}
#[test]
fn ard_grad_target_matches_lambda_t() {
let d = 2;
let t = array![0.5_f64, 1.0, 2.0, -1.0];
let target = PsiSlice::full(t.len(), Some(d));
let ard = ARDPenalty::new(target, d);
let rho = array![2.0_f64.ln(), 3.0_f64.ln()];
let g = ard.grad_target(t.view(), rho.view());
assert!((g[0] - 2.0 * 0.5).abs() < 1e-12);
assert!((g[2] - 2.0 * 2.0).abs() < 1e-12);
assert!((g[1] - 3.0 * 1.0).abs() < 1e-12);
assert!((g[3] - -3.0).abs() < 1e-12);
}
#[test]
fn ard_hessian_diag_matches_lambda() {
let d = 2;
let t = array![0.5_f64, 1.0, 2.0, -1.0];
let target = PsiSlice::full(t.len(), Some(d));
let ard = ARDPenalty::new(target, d);
let rho = array![2.0_f64.ln(), 3.0_f64.ln()];
let h = ard
.hessian_diag(t.view(), rho.view())
.expect("ARD has a diagonal Hessian");
assert!((h[0] - 2.0).abs() < 1e-12);
assert!((h[2] - 2.0).abs() < 1e-12);
assert!((h[1] - 3.0).abs() < 1e-12);
assert!((h[3] - 3.0).abs() < 1e-12);
}
fn isometry_gn_fixture() -> (usize, usize, usize, Arc<Array2<f64>>, Arc<Array2<f64>>) {
let (n_obs, p, d) = (3usize, 4usize, 2usize);
let mut j = Array2::<f64>::zeros((n_obs, p * d));
for n in 0..n_obs {
for i in 0..p {
for a in 0..d {
j[[n, i * d + a]] = 0.7 + 0.31 * (n as f64) - 0.23 * (i as f64)
+ 0.17 * (a as f64)
+ 0.05 * ((n * p + i) as f64);
}
}
}
let mut h = Array2::<f64>::zeros((n_obs, p * d * d));
for n in 0..n_obs {
for i in 0..p {
for a in 0..d {
for c in 0..d {
let s = (a + c) as f64;
let pr = (a * c) as f64;
h[[n, (i * d + a) * d + c]] =
0.13 * (n as f64 + 1.0) + 0.09 * (i as f64) + 0.21 * s - 0.04 * pr;
}
}
}
}
(n_obs, p, d, Arc::new(j), Arc::new(h))
}
#[test]
fn isometry_gn_majorizer_is_psd_and_symmetric() {
let (n_obs, p, d, j, h) = isometry_gn_fixture();
let n = n_obs * d;
let target = PsiSlice::full(n, Some(d));
let pen = IsometryPenalty::new_euclidean(target, p);
pen.refresh_caches(Some(j), Some(h));
let t = Array1::<f64>::zeros(n);
let rho = array![0.0_f64];
let mut bmat = Array2::<f64>::zeros((n, n));
for k in 0..n {
let mut e = Array1::<f64>::zeros(n);
e[k] = 1.0;
let col = pen.psd_majorizer_hvp(t.view(), rho.view(), e.view());
for r in 0..n {
bmat[[r, k]] = col[r];
}
}
for r in 0..n {
for c in 0..n {
assert_abs_diff_eq!(bmat[[r, c]], bmat[[c, r]], epsilon = 1e-12);
}
}
let probes = [
array![0.4_f64, -1.1, 0.7, 0.3, -0.5, 0.9],
array![1.0_f64, 1.0, 1.0, 1.0, 1.0, 1.0],
array![-2.3_f64, 0.6, -0.1, 1.4, 0.8, -1.7],
array![0.0_f64, 0.0, 3.2, -0.4, 0.0, 0.5],
];
for v in &probes {
let bv = pen.psd_majorizer_hvp(t.view(), rho.view(), v.view());
let quad = v.dot(&bv);
assert!(
quad >= -1e-9,
"isometry GN majorizer must be PSD; got vᵀBv = {quad:.3e}"
);
}
}
#[test]
fn isometry_gn_majorizer_matches_exact_hvp_at_zero_residual() {
let (n_obs, p, d, j, h) = isometry_gn_fixture();
let n = n_obs * d;
let target = PsiSlice::full(n, Some(d));
let scratch = IsometryPenalty::new_euclidean(target.clone(), p);
scratch.refresh_caches(Some(j.clone()), Some(h.clone()));
let mut g = scratch
.pullback_metric(d)
.expect("pullback metric available once J is cached");
let mut trace_sum = 0.0_f64;
for row in 0..g.nrows() {
for axis in 0..d {
trace_sum += g[[row, axis * d + axis]];
}
}
let normalizer = trace_sum / (g.nrows() * d) as f64;
for value in g.iter_mut() {
*value /= normalizer;
}
let k_zero = Arc::new(ndarray::Array3::<f64>::zeros((n_obs, p, d * d * d)));
let pen = IsometryPenalty::new_euclidean(target, p)
.with_reference(IsometryReference::UserSupplied(Arc::new(g)))
.with_third_decoder_derivative(k_zero);
pen.refresh_caches(Some(j), Some(h));
let t = Array1::<f64>::zeros(n);
let rho = array![0.0_f64];
let probes = [
array![0.4_f64, -1.1, 0.7, 0.3, -0.5, 0.9],
array![-2.3_f64, 0.6, -0.1, 1.4, 0.8, -1.7],
];
for v in &probes {
let exact = pen.hvp(t.view(), rho.view(), v.view());
let gn = pen.psd_majorizer_hvp(t.view(), rho.view(), v.view());
for i in 0..n {
assert_abs_diff_eq!(exact[i], gn[i], epsilon = 1e-12);
}
assert!(gn.iter().any(|x| x.abs() > 1e-9));
}
}
fn jumprelu_sweep_fixture() -> (
JumpReLUPenalty,
Array1<f64>,
Array1<f64>,
[f64; 2],
f64,
f64,
) {
let thresholds = array![0.25_f64, 0.8];
let rho = array![0.0_f64, 1.5_f64.ln()];
let eps = 0.04_f64;
let weight = 1.3_f64;
let scaled_thresholds = [thresholds[0] * rho[0].exp(), thresholds[1] * rho[1].exp()];
let latent_dim = thresholds.len();
let offsets = [-5.0_f64, -2.0, -0.5, -0.05, 0.0, 0.05, 0.5, 2.0, 5.0];
let mut values = Vec::with_capacity(offsets.len() * latent_dim);
for &offset in &offsets {
values.push(scaled_thresholds[0] + offset);
values.push(scaled_thresholds[1] + offset);
}
let target_values = Array1::from_vec(values);
let slice = PsiSlice::full(target_values.len(), Some(latent_dim));
let pen =
JumpReLUPenalty::new(slice, thresholds, weight, eps).expect("valid JumpReLU penalty");
(pen, target_values, rho, scaled_thresholds, eps, weight)
}
#[test]
fn jumprelu_hessian_diag_is_exact_true_second_derivative() {
let (pen, target_values, rho, scaled_thresholds, eps, weight) = jumprelu_sweep_fixture();
let latent_dim = scaled_thresholds.len();
let diag = pen
.hessian_diag(target_values.view(), rho.view())
.expect("JumpReLU exposes an analytic diagonal Hessian");
let mut saw_negative = false;
for (idx, &entry) in diag.iter().enumerate() {
let axis = idx % latent_dim;
let gate = pen.sigmoid_gate((target_values[idx] - scaled_thresholds[axis]) / eps);
let expected =
weight * scaled_thresholds[axis] * gate * (1.0 - gate) * (1.0 - 2.0 * gate)
/ (eps * eps);
assert!(
entry.is_finite(),
"JumpReLU hessian_diag must be finite at index {idx}; entry={entry}"
);
assert_abs_diff_eq!(entry, expected, epsilon = 1e-12);
if entry < 0.0 {
saw_negative = true;
}
}
assert!(
saw_negative,
"true JumpReLU hessian_diag must go negative once the gate passes g = ½"
);
}
#[test]
fn jumprelu_hvp_diagonal_matches_hessian_diag() {
let (pen, target_values, rho, _scaled_thresholds, _eps, _weight) = jumprelu_sweep_fixture();
let diag = pen
.hessian_diag(target_values.view(), rho.view())
.expect("JumpReLU exposes an analytic diagonal Hessian");
for i in 0..target_values.len() {
let mut e_i = Array1::<f64>::zeros(target_values.len());
e_i[i] = 1.0;
let hv_i = pen.hvp(target_values.view(), rho.view(), e_i.view());
assert_abs_diff_eq!(diag[i], hv_i[i], epsilon = 1e-12);
}
}
#[test]
fn jumprelu_psd_majorizer_diag_is_psd_over_logit_sweep() {
let (pen, target_values, rho, scaled_thresholds, eps, weight) = jumprelu_sweep_fixture();
let latent_dim = scaled_thresholds.len();
let diag = pen
.psd_majorizer_diag(target_values.view(), rho.view())
.expect("JumpReLU exposes a PSD diagonal majorizer");
let exact = pen
.hessian_diag(target_values.view(), rho.view())
.expect("JumpReLU exposes a closed-form diagonal Hessian");
for (idx, &entry) in diag.iter().enumerate() {
let axis = idx % latent_dim;
let gate = pen.sigmoid_gate((target_values[idx] - scaled_thresholds[axis]) / eps);
let slope = gate * (1.0 - gate);
let reweighted_l2 = slope * slope;
let abs_exact = slope * (1.0 - 2.0 * gate).abs();
let expected =
weight * scaled_thresholds[axis] * reweighted_l2.max(abs_exact) / (eps * eps);
assert!(
entry.is_finite() && entry >= 0.0,
"JumpReLU psd_majorizer_diag must be finite and PSD at index {idx}; entry={entry}"
);
assert_abs_diff_eq!(entry, expected, epsilon = 1e-12);
assert!(
entry + 1e-12 >= exact[idx],
"majorizer {entry} must dominate exact Hessian {} at index {idx}",
exact[idx]
);
}
}
#[test]
fn log_sparsity_hessian_is_exact_true_second_derivative() {
let delta = 0.5_f64;
let weight = 1.3_f64;
let log_lambda = 0.2_f64;
let lambda = weight * log_lambda.exp();
let d2 = delta * delta;
let pen = {
let mut p = SparsityPenalty::log(PenaltyTier::Psi, delta).expect("valid log sparsity");
p.weight = weight;
p
};
let target = array![0.0_f64, 0.25, 0.5, 1.0, 2.0, -2.0, -0.1];
let rho = array![log_lambda];
let diag = pen
.hessian_diag(target.view(), rho.view())
.expect("log sparsity exposes an analytic diagonal Hessian");
let mut saw_negative = false;
for (i, &x) in target.iter().enumerate() {
let denom = d2 + x * x;
let expected = 2.0 * lambda * (d2 - x * x) / (denom * denom);
assert_abs_diff_eq!(diag[i], expected, epsilon = 1e-12);
let mut e_i = Array1::<f64>::zeros(target.len());
e_i[i] = 1.0;
let hv_i = pen.hvp(target.view(), rho.view(), e_i.view());
assert_abs_diff_eq!(hv_i[i], expected, epsilon = 1e-12);
if diag[i] < 0.0 {
saw_negative = true;
}
}
assert!(
saw_negative,
"true log-sparsity Hessian must go negative once |x| > δ"
);
}
#[test]
fn log_sparsity_hessian_diag_matches_central_difference_of_gradient() {
let delta = 0.7_f64;
let weight = 0.9_f64;
let log_lambda = -0.3_f64;
let pen = {
let mut p = SparsityPenalty::log(PenaltyTier::Psi, delta).expect("valid log sparsity");
p.weight = weight;
p
};
let target = array![0.0_f64, 0.3, 0.7, 1.5, -1.8];
let rho = array![log_lambda];
let diag = pen
.hessian_diag(target.view(), rho.view())
.expect("log sparsity exposes an analytic diagonal Hessian");
let h = 1e-6_f64;
for i in 0..target.len() {
let mut tp = target.clone();
let mut tm = target.clone();
tp[i] += h;
tm[i] -= h;
let gp = pen.grad_target(tp.view(), rho.view());
let gm = pen.grad_target(tm.view(), rho.view());
let fd = (gp[i] - gm[i]) / (2.0 * h);
assert_abs_diff_eq!(diag[i], fd, epsilon = 1e-5);
}
}
#[test]
fn log_sparsity_psd_majorizer_diag_is_distinct_positive_operator() {
let delta = 0.5_f64;
let weight = 1.3_f64;
let log_lambda = 0.2_f64;
let lambda = weight * log_lambda.exp();
let d2 = delta * delta;
let pen = {
let mut p = SparsityPenalty::log(PenaltyTier::Psi, delta).expect("valid log sparsity");
p.weight = weight;
p
};
let target = array![0.0_f64, 0.25, 0.5, 1.0, 2.0, -2.0, -0.1];
let rho = array![log_lambda];
let maj = pen
.psd_majorizer_diag(target.view(), rho.view())
.expect("log sparsity exposes a PSD diagonal majorizer");
let exact = pen
.hessian_diag(target.view(), rho.view())
.expect("log sparsity exposes an analytic diagonal Hessian");
for (i, &x) in target.iter().enumerate() {
let expected = 2.0 * lambda / (d2 + x * x);
assert_abs_diff_eq!(maj[i], expected, epsilon = 1e-12);
assert!(
maj[i] >= 0.0,
"log-sparsity majorizer must be PSD at index {i}; entry={}",
maj[i]
);
assert!(maj[i] + 1e-12 >= exact[i]);
if x == 0.0 {
assert_abs_diff_eq!(maj[i], exact[i], epsilon = 1e-12);
}
}
}
#[test]
fn ard_rho_grad_includes_occam_log_det_term() {
let d = 2;
let t = array![1.0_f64, 0.0, 0.0, 2.0];
let n_obs = t.len() / d; let target = PsiSlice::full(t.len(), Some(d));
let ard = ARDPenalty::new(target, d);
assert!((ard.n_eff - n_obs as f64).abs() < 1e-12);
let rho = array![0.0_f64, 0.0];
let dr = ard.grad_rho(t.view(), rho.view());
assert!((dr[0] - (-0.5)).abs() < 1e-12);
assert!((dr[1] - 1.0).abs() < 1e-12);
}
fn block_ortho_test_target() -> Array1<f64> {
array![
1.0_f64, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, -1.0, -1.0, 0.0, 1.0, 0.0
]
}
#[test]
fn block_orthogonality_value_matches_offdiag_gram_frobenius() {
let t = block_ortho_test_target();
let target = PsiSlice::full(t.len(), Some(4));
let pen = BlockOrthogonalityPenalty::new(
target,
vec![vec![0_usize, 1], vec![2, 3]],
2.5,
4,
false,
)
.expect("valid block orthogonality penalty");
let rho = array![0.0_f64];
let v = pen.value(t.view(), rho.view());
assert!(v.is_finite(), "block-orthogonality value must be finite");
assert_abs_diff_eq!(v, 3.75, epsilon = 1e-12);
}
#[test]
fn block_orthogonality_grad_matches_finite_difference() {
let t = block_ortho_test_target();
let n = t.len();
let target = PsiSlice::full(n, Some(4));
let pen = BlockOrthogonalityPenalty::new(
target,
vec![vec![0_usize, 1], vec![2, 3]],
1.25,
4,
false,
)
.expect("valid block orthogonality penalty");
let rho = array![0.0_f64];
let g = pen.grad_target(t.view(), rho.view());
let eps = 1e-6;
let mut max_err = 0.0_f64;
for i in 0..n {
let mut tp = t.clone();
let mut tm = t.clone();
tp[i] += eps;
tm[i] -= eps;
let fd =
(pen.value(tp.view(), rho.view()) - pen.value(tm.view(), rho.view())) / (2.0 * eps);
let err = (g[i] - fd).abs();
if err > max_err {
max_err = err;
}
assert_abs_diff_eq!(g[i], fd, epsilon = 1e-6);
}
assert!(max_err < 1e-6, "grad-FD max abs error = {max_err:.3e}");
}
#[test]
fn block_orthogonality_hvp_matches_gradient_directional_derivative() {
let t = block_ortho_test_target();
let n = t.len();
let target = PsiSlice::full(n, Some(4));
let pen = BlockOrthogonalityPenalty::new(
target,
vec![vec![0_usize, 1], vec![2, 3]],
0.75,
4,
false,
)
.expect("valid block orthogonality penalty");
let rho = array![0.0_f64];
let v: Array1<f64> =
Array1::from_vec((0..n).map(|i| 0.3 * ((i as f64) + 1.0).sin()).collect());
let hv = pen.hvp(t.view(), rho.view(), v.view());
let eps = 1e-5;
let mut tp = t.clone();
let mut tm = t.clone();
for i in 0..n {
tp[i] += eps * v[i];
tm[i] -= eps * v[i];
}
let gp = pen.grad_target(tp.view(), rho.view());
let gm = pen.grad_target(tm.view(), rho.view());
let mut max_err = 0.0_f64;
for i in 0..n {
let fd = (gp[i] - gm[i]) / (2.0 * eps);
let err = (hv[i] - fd).abs();
if err > max_err {
max_err = err;
}
assert_abs_diff_eq!(hv[i], fd, epsilon = 1e-5);
}
assert!(max_err < 1e-5, "hvp-FD max abs error = {max_err:.3e}");
}
#[test]
fn block_orthogonality_hessian_diag_matches_finite_difference() {
let t = block_ortho_test_target();
let n = t.len();
let target = PsiSlice::full(n, Some(4));
let pen = BlockOrthogonalityPenalty::new(
target,
vec![vec![0_usize, 1], vec![2, 3]],
0.9,
4,
false,
)
.expect("valid block orthogonality penalty");
let rho = array![0.0_f64];
let diag = pen
.hessian_diag(t.view(), rho.view())
.expect("hessian_diag must be available");
assert_eq!(diag.len(), n);
let eps = 1e-5;
for i in 0..n {
let mut tp = t.clone();
let mut tm = t.clone();
tp[i] += eps;
tm[i] -= eps;
let gp = pen.grad_target(tp.view(), rho.view())[i];
let gm = pen.grad_target(tm.view(), rho.view())[i];
let fd = (gp - gm) / (2.0 * eps);
assert_abs_diff_eq!(diag[i], fd, epsilon = 1e-5);
assert!(
diag[i] >= 0.0,
"hessian_diag entry must be PSD; got {}",
diag[i]
);
}
}
#[test]
fn block_orthogonality_rejects_groups_missing_an_axis() {
let t = block_ortho_test_target();
let target = PsiSlice::full(t.len(), Some(4));
let err =
BlockOrthogonalityPenalty::new(target, vec![vec![0_usize, 1], vec![2]], 1.0, 4, false)
.expect_err("groups missing axis 3 must error");
assert!(
err.contains("must partition latent axes") && err.contains("missing axis 3"),
"unexpected error message: {err}"
);
}
fn mech_sparsity_test_target() -> Array1<f64> {
array![0.4_f64, -0.3, 0.2, -0.1, 0.6, 0.5]
}
fn build_mech_sparsity(weight: f64) -> MechanismSparsityPenalty {
let t = mech_sparsity_test_target();
let target = PsiSlice::full(t.len(), Some(2));
MechanismSparsityPenalty::new(
target,
vec![vec![0_usize, 1], vec![2]],
weight,
1e-2,
4.0,
false,
)
.expect("valid mechanism sparsity penalty")
}
#[test]
fn mechanism_sparsity_value_matches_group_norm_sum() {
let pen = build_mech_sparsity(1.5);
let t = mech_sparsity_test_target();
let rho = array![0.0_f64];
let v = pen.value(t.view(), rho.view());
let eps2 = 1e-2_f64 * 1e-2_f64;
let sqrt2 = 2.0_f64.sqrt();
let l0 = sqrt2 * (0.16_f64 + 0.09 + eps2).sqrt() + (0.04_f64 + eps2).sqrt();
let l1 = sqrt2 * (0.01_f64 + 0.36 + eps2).sqrt() + (0.25_f64 + eps2).sqrt();
let expected = 1.5 * (l0 + l1);
assert!(v.is_finite(), "mechanism-sparsity value must be finite");
assert_abs_diff_eq!(v, expected, epsilon = 1e-12);
}
#[test]
fn mechanism_sparsity_grad_matches_finite_difference() {
let pen = build_mech_sparsity(0.8);
let t = mech_sparsity_test_target();
let n = t.len();
let rho = array![0.0_f64];
let g = pen.grad_target(t.view(), rho.view());
let eps = 1e-6;
let mut max_err = 0.0_f64;
for i in 0..n {
let mut tp = t.clone();
let mut tm = t.clone();
tp[i] += eps;
tm[i] -= eps;
let fd =
(pen.value(tp.view(), rho.view()) - pen.value(tm.view(), rho.view())) / (2.0 * eps);
let err = (g[i] - fd).abs();
if err > max_err {
max_err = err;
}
assert_abs_diff_eq!(g[i], fd, epsilon = 1e-6);
}
assert!(max_err < 1e-6, "grad-FD max abs error = {max_err:.3e}");
}
#[test]
fn mechanism_sparsity_hvp_matches_gradient_directional_derivative() {
let pen = build_mech_sparsity(0.5);
let t = mech_sparsity_test_target();
let n = t.len();
let rho = array![0.0_f64];
let v: Array1<f64> =
Array1::from_vec((0..n).map(|i| 0.2 * ((i as f64) + 1.3).cos()).collect());
let hv = pen.hvp(t.view(), rho.view(), v.view());
let eps = 1e-5;
let mut tp = t.clone();
let mut tm = t.clone();
for i in 0..n {
tp[i] += eps * v[i];
tm[i] -= eps * v[i];
}
let gp = pen.grad_target(tp.view(), rho.view());
let gm = pen.grad_target(tm.view(), rho.view());
let mut max_err = 0.0_f64;
for i in 0..n {
let fd = (gp[i] - gm[i]) / (2.0 * eps);
let err = (hv[i] - fd).abs();
if err > max_err {
max_err = err;
}
assert_abs_diff_eq!(hv[i], fd, epsilon = 1e-5);
}
assert!(max_err < 1e-5, "hvp-FD max abs error = {max_err:.3e}");
}
#[test]
fn mechanism_sparsity_rejects_groups_missing_a_feature() {
let t = mech_sparsity_test_target();
let target = PsiSlice::full(t.len(), Some(2));
let err = MechanismSparsityPenalty::new(
target,
vec![vec![0_usize], vec![2]],
1.0,
1e-2,
4.0,
false,
)
.expect_err("groups missing feature 1 must error");
assert!(
err.contains("must partition features") && err.contains("missing feature 1"),
"unexpected error message: {err}"
);
}
#[test]
fn mechanism_sparsity_rejects_overlapping_groups() {
let t = mech_sparsity_test_target();
let target = PsiSlice::full(t.len(), Some(2));
let err = MechanismSparsityPenalty::new(
target,
vec![vec![0_usize, 1], vec![1, 2]],
1.0,
1e-2,
4.0,
false,
)
.expect_err("overlapping feature must error");
assert!(
err.contains("feature 1 appears in more than one group"),
"unexpected error message: {err}"
);
}
fn nested_prefix_test_target() -> (Array1<f64>, usize, usize) {
let t = array![
1.0_f64, 2.0, 3.0, 4.0, -1.0, 0.5, 0.0, 2.0, 0.1, -0.2, 0.3, -0.4, ];
(t, 3, 4)
}
#[test]
fn nested_prefix_grad_matches_finite_difference() {
let (t, _n, f) = nested_prefix_test_target();
let target = PsiSlice::full(t.len(), Some(f));
let pen = NestedPrefixPenalty::new(
target,
PenaltyTier::Psi,
vec![1_usize, 2, 4],
vec![0.7, 0.5, 0.3],
1e-3,
)
.expect("valid nested-prefix penalty");
let rho = array![0.0_f64, 0.0, 0.0];
let g = pen.grad_target(t.view(), rho.view());
let eps = 1e-6;
let mut max_err = 0.0_f64;
for i in 0..t.len() {
let mut tp = t.clone();
let mut tm = t.clone();
tp[i] += eps;
tm[i] -= eps;
let fd =
(pen.value(tp.view(), rho.view()) - pen.value(tm.view(), rho.view())) / (2.0 * eps);
let err = (g[i] - fd).abs();
if err > max_err {
max_err = err;
}
assert_abs_diff_eq!(g[i], fd, epsilon = 1e-5);
}
assert!(max_err < 1e-5, "grad-FD max abs error = {max_err:.3e}");
}
#[test]
fn nested_prefix_hessian_diag_is_psd() {
let (t, _n, f) = nested_prefix_test_target();
let target = PsiSlice::full(t.len(), Some(f));
let pen = NestedPrefixPenalty::new(
target,
PenaltyTier::Psi,
vec![2_usize, 3, 4],
vec![1.0, 0.5, 0.25],
1e-3,
)
.expect("valid nested-prefix penalty");
let rho = array![0.0_f64, 0.0, 0.0];
let h = pen
.hessian_diag(t.view(), rho.view())
.expect("nested-prefix Hessian is diagonal");
for &v in h.iter() {
assert!(
v >= 0.0 && v.is_finite(),
"Hessian diag must be finite and PSD; got {v}"
);
}
assert!(h[0] > 0.0);
}
#[test]
fn nested_prefix_mask_is_correct() {
assert!(file!().ends_with(".rs"));
let (t, n_rows, f) = nested_prefix_test_target();
let target = PsiSlice::full(t.len(), Some(f));
let prefixes = vec![1_usize, 3, 4];
let weights = vec![2.0_f64, 1.0, 0.5];
let eps = 0.5;
let pen = NestedPrefixPenalty::new(target, PenaltyTier::Psi, prefixes, weights, eps)
.expect("valid");
let rho = Array1::<f64>::zeros(3);
let v = pen.value(t.view(), rho.view());
let w_axis = [3.5_f64, 1.5, 1.5, 0.5];
let mut expected = 0.0;
let eps2 = eps * eps;
let src = t.as_slice().unwrap();
for n in 0..n_rows {
for i in 0..f {
let x = src[n * f + i];
expected += w_axis[i] * (x * x + eps2).sqrt();
}
}
assert_abs_diff_eq!(v, expected, epsilon = 1e-10);
}
#[test]
fn nested_prefix_grad_rho_matches_finite_difference() {
assert!(file!().ends_with(".rs"));
let (t, _n, f) = nested_prefix_test_target();
let target = PsiSlice::full(t.len(), Some(f));
let pen = NestedPrefixPenalty::new(
target,
PenaltyTier::Psi,
vec![1_usize, 2, 4],
vec![0.7, 0.5, 0.3],
1e-3,
)
.expect("valid");
let rho = array![0.1_f64, -0.2, 0.3];
let dr = pen.grad_rho(t.view(), rho.view());
let eps = 1e-6;
for k in 0..3 {
let mut rp = rho.clone();
let mut rm = rho.clone();
rp[k] += eps;
rm[k] -= eps;
let fd =
(pen.value(t.view(), rp.view()) - pen.value(t.view(), rm.view())) / (2.0 * eps);
assert_abs_diff_eq!(dr[k], fd, epsilon = 1e-5);
}
}
fn value_grad_fd_max_abs_error(
pen: &dyn AnalyticPenalty,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
epsilon: f64,
) -> f64 {
let grad = pen.grad_target(target, rho);
let mut worst = 0.0_f64;
let mut tp = target.to_owned();
let mut tm = target.to_owned();
for i in 0..target.len() {
let base = target[i];
tp[i] = base + epsilon;
tm[i] = base - epsilon;
let fd = (pen.value(tp.view(), rho) - pen.value(tm.view(), rho)) / (2.0 * epsilon);
tp[i] = base;
tm[i] = base;
let err = (grad[i] - fd).abs();
if err > worst {
worst = err;
}
}
worst
}
#[test]
fn ard_value_grad_self_consistent_fd() {
let d = 2;
let target = PsiSlice::full(8, Some(d));
let ard = ARDPenalty::new(target, d);
let t = array![0.5_f64, 1.0, 2.0, -1.0, 0.3, -0.7, 1.4, -0.2];
let rho = array![0.4_f64, -0.6];
let worst = value_grad_fd_max_abs_error(&ard, t.view(), rho.view(), 1.0e-6);
assert!(
worst <= 1.0e-7,
"ARD value↔grad FD max abs error = {worst:.3e}"
);
}
#[test]
fn scadmcp_value_grad_self_consistent_fd() {
let n_eff = 6usize;
let target = PsiSlice::full(n_eff, Some(1));
let pen = ScadMcpPenalty::new(
target,
0.5,
n_eff,
3.0,
1.0e-4,
PenaltyConcavity::Mcp,
false,
)
.unwrap();
let t = array![0.02_f64, 0.4, 0.9, 1.6, -1.1, -0.05];
let rho = Array1::<f64>::zeros(0);
let worst = value_grad_fd_max_abs_error(&pen, t.view(), rho.view(), 1.0e-3);
assert!(
worst <= 1.0e-5,
"ScadMcp value↔grad FD max abs error = {worst:.3e}"
);
}
#[test]
fn nuclear_norm_value_grad_self_consistent_fd() {
let n_eff = 4usize;
let p = 3usize;
let target = PsiSlice {
range: 0..n_eff * p,
latent_dim: Some(p),
};
let pen = NuclearNormPenalty::new(target, 0.8, n_eff, 1.0e-4, None, false).unwrap();
let t = array![
1.2_f64, -0.4, 0.3, 0.1, 0.9, -0.7, -0.5, 0.2, 1.1, 0.6, -0.3, 0.8
];
let rho = Array1::<f64>::zeros(0);
let worst = value_grad_fd_max_abs_error(&pen, t.view(), rho.view(), 1.0e-6);
assert!(
worst <= 1.0e-5,
"NuclearNorm value↔grad FD max abs error = {worst:.3e}"
);
}
#[test]
fn nuclear_norm_hvp_wide_matrix_max_rank_above_thin_rank_is_uncapped() {
let n_eff = 2usize;
let p = 4usize;
let target = PsiSlice {
range: 0..n_eff * p,
latent_dim: Some(p),
};
let capped =
NuclearNormPenalty::new(target.clone(), 0.7, n_eff, 1.0e-3, Some(3), false).unwrap();
let uncapped = NuclearNormPenalty::new(target, 0.7, n_eff, 1.0e-3, None, false).unwrap();
let t = array![2.0_f64, 0.0, 0.0, 0.0, 0.0, 1.5, 0.0, 0.0];
let v = array![0.2_f64, -0.4, 0.6, -0.8, 0.3, -0.5, 0.7, -0.9];
let rho = Array1::<f64>::zeros(0);
let hv_capped = capped.hvp(t.view(), rho.view(), v.view());
let hv_uncapped = uncapped.hvp(t.view(), rho.view(), v.view());
for i in 0..t.len() {
assert!(
hv_capped[i].is_finite(),
"wide NuclearNorm HVP must stay finite at index {i}"
);
assert_abs_diff_eq!(hv_capped[i], hv_uncapped[i], epsilon = 1.0e-10);
}
}
#[test]
fn nuclear_norm_wide_block_max_rank_above_true_rank_value_grad_hvp_are_finite() {
let n_eff = 3usize;
let p = 10usize;
let target = PsiSlice {
range: 0..n_eff * p,
latent_dim: Some(p),
};
let pen = NuclearNormPenalty::new(target, 0.9, n_eff, 1.0e-3, Some(4), false).unwrap();
let t = array![
2.0_f64, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
];
let v = Array1::from_vec(
(0..n_eff * p)
.map(|i| 0.2 * ((i as f64) + 0.3).cos())
.collect(),
);
let rho = Array1::<f64>::zeros(0);
let value = pen.value(t.view(), rho.view());
let grad = pen.grad_target(t.view(), rho.view());
let hv = pen.hvp(t.view(), rho.view(), v.view());
assert!(value.is_finite(), "wide NuclearNorm value must be finite");
for i in 0..t.len() {
assert!(
grad[i].is_finite(),
"wide NuclearNorm gradient must be finite at index {i}"
);
assert!(
hv[i].is_finite(),
"wide NuclearNorm HVP must be finite at index {i}"
);
}
}
#[test]
fn nuclear_norm_right_gram_divided_difference_uses_eigen_floor() {
let n_eff = 2usize;
let p = 2usize;
let target = PsiSlice {
range: 0..n_eff * p,
latent_dim: Some(p),
};
let smoothing_eps = 1.0e-10_f64;
let pen = NuclearNormPenalty::new(target, 0.9, n_eff, smoothing_eps, None, false).unwrap();
let a = 1.0e-10_f64;
let b = 2.0e-7_f64;
let t = array![[a, 0.0_f64], [0.0, b]];
let v = array![[0.0_f64, 1.0], [0.0, 0.0]];
let (_right_filter, right_filter_derivative) = pen
.right_spectral_inverse_sqrt_derivative(t.view(), v.view())
.expect("right-Gram derivative");
let eps2 = smoothing_eps * smoothing_eps;
let eig_floor = eps2.max(1.0e-15);
let lambda0 = (a * a + eps2).max(eig_floor);
let lambda1 = (b * b + eps2).max(eig_floor);
let f0 = lambda0.powf(-0.5);
let f1 = lambda1.powf(-0.5);
let expected = ((f0 - f1) / (lambda0 - lambda1)) * a;
assert_abs_diff_eq!(
right_filter_derivative[[0, 1]],
expected,
epsilon = expected.abs() * 1.0e-12
);
}
#[test]
fn nuclear_norm_wide_block_fast_path_matches_dense_oracle() {
let n_eff = 3usize;
let p = 40usize; for max_rank in [None, Some(2)] {
let target = PsiSlice {
range: 0..n_eff * p,
latent_dim: Some(p),
};
let pen = NuclearNormPenalty::new(target, 0.8, n_eff, 1.0e-3, max_rank, false).unwrap();
let t_flat = Array1::from_vec(
(0..n_eff * p)
.map(|i| (0.3 * (i as f64) + 0.11).sin() + 0.05 * (i as f64 % 7.0))
.collect(),
);
let v_flat = Array1::from_vec(
(0..n_eff * p)
.map(|i| (0.17 * (i as f64) - 0.4).cos())
.collect(),
);
let t = t_flat.view().into_shape_with_order((n_eff, p)).unwrap();
let v = v_flat.view().into_shape_with_order((n_eff, p)).unwrap();
let (fast_vr, fast_tdr) = pen
.right_spectral_filters_applied(t.view(), v.view())
.expect("fast path");
let (rf, rfd) = pen
.right_spectral_inverse_sqrt_derivative(t.view(), v.view())
.expect("dense oracle");
let dense_vr = v.dot(&rf);
let dense_tdr = t.dot(&rfd);
let scale = dense_vr
.iter()
.chain(dense_tdr.iter())
.fold(0.0_f64, |a, &x| a.max(x.abs()))
.max(1.0);
for n in 0..n_eff {
for a in 0..p {
assert!(
(fast_vr[[n, a]] - dense_vr[[n, a]]).abs() <= 1.0e-9 * scale,
"V·R mismatch at ({n},{a}) max_rank={max_rank:?}: \
fast={} dense={}",
fast_vr[[n, a]],
dense_vr[[n, a]]
);
assert!(
(fast_tdr[[n, a]] - dense_tdr[[n, a]]).abs() <= 1.0e-9 * scale,
"T·dR mismatch at ({n},{a}) max_rank={max_rank:?}: \
fast={} dense={}",
fast_tdr[[n, a]],
dense_tdr[[n, a]]
);
}
}
}
}
#[test]
fn nuclear_norm_wide_zero_joint_rowspace_rejects_biting_zero_tie() {
let n_eff = 3usize;
let p = 40usize; let target = PsiSlice {
range: 0..n_eff * p,
latent_dim: Some(p),
};
let pen = NuclearNormPenalty::new(target, 0.8, n_eff, 1.0e-3, Some(2), false).unwrap();
let t = Array2::<f64>::zeros((n_eff, p));
let v = Array2::<f64>::zeros((n_eff, p));
let fast_err = pen
.right_spectral_filters_applied(t.view(), v.view())
.expect_err("fast path must reject a biting all-zero tied spectrum");
let dense_err = pen
.right_spectral_inverse_sqrt_derivative(t.view(), v.view())
.expect_err("dense oracle rejects the same tied cutoff");
assert!(
fast_err.contains("splits a tied") && dense_err.contains("splits a tied"),
"fast path error must preserve dense tie-guard semantics; \
fast={fast_err}, dense={dense_err}"
);
}
#[test]
fn nuclear_norm_hvp_truncated_rank_matches_gradient_directional_derivative() {
let n_eff = 4usize;
let p = 3usize;
let target = PsiSlice {
range: 0..n_eff * p,
latent_dim: Some(p),
};
let pen = NuclearNormPenalty::new(target, 1.1, n_eff, 0.2, Some(2), false).unwrap();
let t = array![
2.0_f64, 0.1, -0.2, 0.3, 1.5, 0.4, -0.1, 0.2, 0.9, 0.5, -0.4, 0.7
];
let v = Array1::from_vec(
(0..t.len())
.map(|i| 0.25 * ((i as f64) + 0.7).sin())
.collect(),
);
let rho = Array1::<f64>::zeros(0);
let hv = pen.hvp(t.view(), rho.view(), v.view());
let eps = 1.0e-6;
let mut tp = t.clone();
let mut tm = t.clone();
for i in 0..t.len() {
tp[i] += eps * v[i];
tm[i] -= eps * v[i];
}
let gp = pen.grad_target(tp.view(), rho.view());
let gm = pen.grad_target(tm.view(), rho.view());
let mut max_err = 0.0_f64;
for i in 0..t.len() {
let fd = (gp[i] - gm[i]) / (2.0 * eps);
let err = (hv[i] - fd).abs();
max_err = max_err.max(err);
assert_abs_diff_eq!(hv[i], fd, epsilon = 1.0e-5);
}
assert!(
max_err <= 1.0e-5,
"truncated NuclearNorm HVP-FD max abs error = {max_err:.3e}"
);
}
#[test]
fn decoder_incoherence_value_grad_self_consistent_fd() {
let p = 3usize;
let block_sizes = vec![2usize, 2usize];
let total: usize = block_sizes.iter().map(|m| m * p).sum();
let target = PsiSlice {
range: 0..total,
latent_dim: Some(total / p),
};
let mut coact = Array2::<f64>::from_elem((2, 2), 0.0);
coact[[0, 1]] = 0.6;
coact[[1, 0]] = 0.6;
coact[[0, 0]] = 1.0;
coact[[1, 1]] = 1.0;
let pen =
DecoderIncoherencePenalty::new(target, block_sizes, p, coact, 0.7, false).unwrap();
let t = array![
0.5_f64, -0.3, 0.2, 0.8, -0.1, 0.4, -0.6, 0.7, 0.1, -0.2, 0.9, 0.3
];
let rho = Array1::<f64>::zeros(0);
let worst = value_grad_fd_max_abs_error(&pen, t.view(), rho.view(), 1.0e-6);
assert!(
worst <= 1.0e-5,
"DecoderIncoherence value↔grad FD max abs error = {worst:.3e}"
);
}
#[test]
fn decoder_incoherence_heterogeneous_blocks_use_output_space_cross_gram() {
let p = 3usize;
let block_sizes = vec![2usize, 1usize];
let total: usize = block_sizes.iter().map(|m| m * p).sum();
let target = PsiSlice {
range: 0..total,
latent_dim: Some(total / p),
};
let mut coact = Array2::<f64>::zeros((2, 2));
coact[[0, 1]] = 0.2;
coact[[1, 0]] = 0.6;
let pen =
DecoderIncoherencePenalty::new(target, block_sizes, p, coact, 2.0, false).unwrap();
let beta = array![1.0_f64, 0.0, 0.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0];
let rho = Array1::<f64>::zeros(0);
let value = pen.value(beta.view(), rho.view());
assert_abs_diff_eq!(value, 3.6, epsilon = 1.0e-12);
}
#[test]
fn decoder_incoherence_rejects_negative_coactivation() {
let p = 2usize;
let block_sizes = vec![1usize, 1usize];
let target = PsiSlice {
range: 0..4,
latent_dim: Some(2),
};
let mut coact = Array2::<f64>::zeros((2, 2));
coact[[0, 1]] = -0.1;
let err = DecoderIncoherencePenalty::new(target, block_sizes, p, coact, 1.0, false)
.expect_err("negative coactivation must be rejected");
assert_eq!(
err,
"DecoderIncoherencePenalty::new requires finite non-negative coactivation entries"
);
}
#[test]
fn decoder_incoherence_separability_semantics() {
let p = 2usize;
let block_sizes = vec![1usize, 1usize];
let total: usize = block_sizes.iter().map(|m| m * p).sum();
let target = PsiSlice {
range: 0..total,
latent_dim: Some(total / p),
};
let full_coact = || {
let mut c = Array2::<f64>::zeros((2, 2));
c[[0, 1]] = 1.0;
c[[1, 0]] = 1.0;
c
};
let rho = Array1::<f64>::zeros(0);
let pen_ortho = DecoderIncoherencePenalty::new(
target.clone(),
block_sizes.clone(),
p,
full_coact(),
1.0,
false,
)
.unwrap();
let t_ortho = array![1.0_f64, 0.0, 0.0, 1.0];
let p_ortho = pen_ortho.value(t_ortho.view(), rho.view());
assert!(
p_ortho.abs() <= 1.0e-12,
"orthogonal decoder blocks must give P≈0, got {p_ortho:.3e}"
);
let pen_coinc = DecoderIncoherencePenalty::new(
target.clone(),
block_sizes.clone(),
p,
full_coact(),
1.0,
false,
)
.unwrap();
let t_coinc = array![1.0_f64, 0.0, 1.0, 0.0];
let p_coinc = pen_coinc.value(t_coinc.view(), rho.view());
assert!(
(p_coinc - 0.5).abs() <= 1.0e-12,
"coincident decoder blocks must give large P (=0.5 here), got {p_coinc:.3e}"
);
assert!(
p_coinc > p_ortho + 1.0e-3,
"coincident P must exceed orthogonal P"
);
let pen_zero = DecoderIncoherencePenalty::new(
target,
block_sizes,
p,
Array2::<f64>::zeros((2, 2)),
1.0,
false,
)
.unwrap();
let p_zero = pen_zero.value(t_coinc.view(), rho.view());
assert!(
p_zero.abs() <= 1.0e-12,
"zero co-activation must zero the pair contribution, got {p_zero:.3e}"
);
let g_zero = pen_zero.grad_target(t_coinc.view(), rho.view());
assert!(
g_zero.iter().all(|v| v.abs() <= 1.0e-12),
"zero co-activation must zero the pair gradient"
);
}
#[test]
fn nested_prefix_rejects_non_monotone_prefixes() {
let target = PsiSlice::full(12, Some(4));
let err = NestedPrefixPenalty::new(
target,
PenaltyTier::Psi,
vec![2_usize, 2, 4],
vec![1.0, 1.0, 1.0],
1e-3,
)
.expect_err("non-strictly-increasing prefixes must error");
assert!(err.contains("strictly increasing"), "got: {err}");
}
}