1use super::*;
2
3#[derive(Debug, Clone, Copy)]
18pub enum SparsityKind {
19 SmoothedL1 { eps: f64 },
20 Hoyer,
21 Log { delta: f64 },
22}
23
24#[derive(Debug, Clone)]
38pub struct SparsityPenalty {
39 pub target_tier: PenaltyTier,
40 pub kind: SparsityKind,
41 pub weight: f64,
42 pub weight_schedule: Option<ScalarWeightSchedule>,
43 pub strength_rho_index: usize,
45 pub eps_rho_index: Option<usize>,
49}
50
51#[derive(Debug, Clone)]
67pub struct SoftmaxAssignmentSparsityPenalty {
68 pub k_atoms: usize,
69 pub temperature: f64,
70 pub weight: f64,
71 pub weight_schedule: Option<ScalarWeightSchedule>,
72}
73
74impl SoftmaxAssignmentSparsityPenalty {
75 #[must_use]
76 pub fn new(k_atoms: usize, temperature: f64) -> Self {
77 assert!(k_atoms > 0);
78 assert!(temperature > 0.0);
79 Self {
80 k_atoms,
81 temperature,
82 weight: 1.0,
83 weight_schedule: None,
84 }
85 }
86
87 impl_with_weight_schedule!(weight);
88
89 fn softmax_row(&self, row: &[f64]) -> Vec<f64> {
90 let inv_tau = 1.0 / self.temperature;
91 let mut max_logit = f64::NEG_INFINITY;
92 for (idx, &v) in row.iter().enumerate() {
93 assert!(
94 v.is_finite(),
95 "SoftmaxAssignmentSparsityPenalty: non-finite logit at atom {idx}: {v}"
96 );
97 max_logit = max_logit.max(v);
98 }
99 let mut out = vec![0.0; self.k_atoms];
100 let mut sum = 0.0;
101 for i in 0..self.k_atoms {
102 let v = ((row[i] - max_logit) * inv_tau).exp();
103 out[i] = v;
104 sum += v;
105 }
106 assert!(
107 sum.is_finite() && sum > 0.0,
108 "SoftmaxAssignmentSparsityPenalty: non-finite softmax normalizer"
109 );
110 for v in out.iter_mut() {
111 *v /= sum;
112 }
113 out
114 }
115
116 pub fn psd_majorizer_abs_row_sums(&self, row: &[f64], scale: f64) -> Vec<f64> {
136 let a = self.softmax_row(row);
137 let k = self.k_atoms;
138 let l: Vec<f64> = (0..k)
139 .map(|i| a[i].max(ENTROPY_LOG_PROBABILITY_FLOOR).ln() + 1.0)
140 .collect();
141 let m: f64 = (0..k).map(|i| a[i] * l[i]).sum();
142 let mut d = vec![0.0_f64; k];
143 for kk in 0..k {
144 let h_kk = scale * a[kk] * ((m - l[kk] - 1.0) + a[kk] * (2.0 * l[kk] + 1.0 - 2.0 * m));
146 let mut acc = h_kk.abs();
147 for jj in 0..k {
149 if jj == kk {
150 continue;
151 }
152 let h_kj = scale * a[kk] * a[jj] * (l[kk] + l[jj] + 1.0 - 2.0 * m);
153 acc += h_kj.abs();
154 }
155 d[kk] = acc;
156 }
157 d
158 }
159
160 #[must_use]
176 pub fn row_dense_hessian(&self, row_logits: &[f64], scale: f64) -> Array2<f64> {
177 let k = self.k_atoms;
178 let a = self.softmax_row(row_logits);
179 let l: Vec<f64> = (0..k)
180 .map(|i| a[i].max(ENTROPY_LOG_PROBABILITY_FLOOR).ln() + 1.0)
181 .collect();
182 let m: f64 = (0..k).map(|i| a[i] * l[i]).sum();
183 let mut h = Array2::<f64>::zeros((k, k));
184 for kk in 0..k {
185 for jj in 0..k {
186 let indicator = if kk == jj { 1.0 } else { 0.0 };
187 h[[kk, jj]] = scale
188 * a[kk]
189 * (indicator * (m - l[kk] - 1.0) + a[jj] * (l[kk] + l[jj] + 1.0 - 2.0 * m));
190 }
191 }
192 h
193 }
194
195 #[must_use]
203 pub fn row_dense_hessian_logit_derivative(
204 &self,
205 row_logits: &[f64],
206 scale: f64,
207 w: usize,
208 ) -> Array2<f64> {
209 let k = self.k_atoms;
210 let inv_tau = 1.0 / self.temperature;
211 let a = self.softmax_row(row_logits);
212 let l: Vec<f64> = (0..k)
213 .map(|i| a[i].max(ENTROPY_LOG_PROBABILITY_FLOOR).ln() + 1.0)
214 .collect();
215 let m: f64 = (0..k).map(|i| a[i] * l[i]).sum();
216 let da: Vec<f64> = (0..k)
218 .map(|r| a[r] * (if r == w { 1.0 } else { 0.0 } - a[w]) * inv_tau)
219 .collect();
220 let dl: Vec<f64> = (0..k)
221 .map(|r| da[r] / a[r].max(ENTROPY_LOG_PROBABILITY_FLOOR))
222 .collect();
223 let dm: f64 = (0..k).map(|r| da[r] * l[r] + a[r] * dl[r]).sum();
224 let mut dh = Array2::<f64>::zeros((k, k));
225 for kk in 0..k {
226 for jj in 0..k {
227 let indicator = if kk == jj { 1.0 } else { 0.0 };
228 let bracket =
230 indicator * (m - l[kk] - 1.0) + a[jj] * (l[kk] + l[jj] + 1.0 - 2.0 * m);
231 let dbracket = indicator * (dm - dl[kk])
232 + da[jj] * (l[kk] + l[jj] + 1.0 - 2.0 * m)
233 + a[jj] * (dl[kk] + dl[jj] - 2.0 * dm);
234 dh[[kk, jj]] = scale * (da[kk] * bracket + a[kk] * dbracket);
235 }
236 }
237 dh
238 }
239
240 #[must_use]
258 pub fn row_psd_majorizer(&self, row_logits: &[f64], scale: f64) -> Array2<f64> {
259 let k = self.k_atoms;
260 let d = self.psd_majorizer_abs_row_sums(row_logits, scale);
261 let mut out = Array2::<f64>::zeros((k, k));
262 for kk in 0..k {
263 out[[kk, kk]] = d[kk];
264 }
265 out
266 }
267
268 #[must_use]
278 pub fn row_psd_majorizer_logit_derivative(
279 &self,
280 row_logits: &[f64],
281 scale: f64,
282 w: usize,
283 ) -> Array2<f64> {
284 let k = self.k_atoms;
285 let h = self.row_dense_hessian(row_logits, scale);
286 let dh = self.row_dense_hessian_logit_derivative(row_logits, scale, w);
287 let mut out = Array2::<f64>::zeros((k, k));
288 for kk in 0..k {
289 let mut acc = 0.0_f64;
290 for jj in 0..k {
291 let s = h[[kk, jj]].signum();
292 if h[[kk, jj]] != 0.0 {
293 acc += s * dh[[kk, jj]];
294 }
295 }
296 out[[kk, kk]] = acc;
297 }
298 out
299 }
300
301 #[must_use]
320 pub fn row_fisher_metric(&self, row_logits: &[f64], scale: f64) -> Array2<f64> {
321 let k = self.k_atoms;
322 let a = self.softmax_row(row_logits);
323 let mut g = Array2::<f64>::zeros((k, k));
324 for kk in 0..k {
325 for jj in 0..k {
326 let indicator = if kk == jj { 1.0 } else { 0.0 };
327 g[[kk, jj]] = scale * a[kk] * (indicator - a[jj]);
328 }
329 }
330 g
331 }
332
333 #[must_use]
345 pub fn row_fisher_metric_logit_derivative(
346 &self,
347 row_logits: &[f64],
348 scale: f64,
349 w: usize,
350 ) -> Array2<f64> {
351 let k = self.k_atoms;
352 let inv_tau = 1.0 / self.temperature;
353 let a = self.softmax_row(row_logits);
354 let da: Vec<f64> = (0..k)
357 .map(|r| a[r] * (if r == w { 1.0 } else { 0.0 } - a[w]) * inv_tau)
358 .collect();
359 let mut dg = Array2::<f64>::zeros((k, k));
360 for kk in 0..k {
361 for jj in 0..k {
362 let indicator = if kk == jj { 1.0 } else { 0.0 };
363 dg[[kk, jj]] = scale * (da[kk] * (indicator - a[jj]) - a[kk] * da[jj]);
364 }
365 }
366 dg
367 }
368}
369
370impl AnalyticPenalty for SoftmaxAssignmentSparsityPenalty {
371 fn tier(&self) -> PenaltyTier {
372 PenaltyTier::Psi
373 }
374
375 fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
376 let lambda = resolve_learnable_weight(self.weight, rho[0]);
377 let n = target.len() / self.k_atoms;
378 let values: Vec<f64> = target.iter().copied().collect();
379 let mut acc = 0.0;
380 for row in 0..n {
381 let start = row * self.k_atoms;
382 let a = self.softmax_row(&values[start..start + self.k_atoms]);
383 for v in a {
384 if v > 0.0 {
385 acc += -v * v.ln();
386 }
387 }
388 }
389 lambda * acc
390 }
391
392 fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
393 let lambda = resolve_learnable_weight(self.weight, rho[0]);
394 let n = target.len() / self.k_atoms;
395 let values: Vec<f64> = target.iter().copied().collect();
396 let mut out = Array1::<f64>::zeros(target.len());
397 let inv_tau = 1.0 / self.temperature;
398 for row in 0..n {
399 let start = row * self.k_atoms;
400 let a = self.softmax_row(&values[start..start + self.k_atoms]);
401 let mut d_h_da = vec![0.0; self.k_atoms];
402 let mut mean = 0.0;
403 for k in 0..self.k_atoms {
404 let ak = a[k].max(ENTROPY_LOG_PROBABILITY_FLOOR);
405 d_h_da[k] = -lambda * (ak.ln() + 1.0);
406 mean += a[k] * d_h_da[k];
407 }
408 for k in 0..self.k_atoms {
409 out[start + k] = a[k] * (d_h_da[k] - mean) * inv_tau;
410 }
411 }
412 out
413 }
414
415 fn hessian_diag(
416 &self,
417 target: ArrayView1<'_, f64>,
418 rho: ArrayView1<'_, f64>,
419 ) -> Option<Array1<f64>> {
420 assert_eq!(rho.len(), 1, "softmax entropy expects one rho parameter");
421 assert!(
422 rho.iter().all(|value| value.is_finite()),
423 "softmax entropy rho must be finite"
424 );
425 assert_eq!(
426 target.len() % self.k_atoms,
427 0,
428 "softmax entropy target length must be divisible by k_atoms"
429 );
430 let lambda = resolve_learnable_weight(self.weight, rho[0]);
439 let inv_tau = 1.0 / self.temperature;
440 let scale = lambda * inv_tau * inv_tau;
441 let n = target.len() / self.k_atoms;
442 let values: Vec<f64> = target.iter().copied().collect();
443 let mut out = Array1::<f64>::zeros(target.len());
444 for row in 0..n {
445 let start = row * self.k_atoms;
446 let a = self.softmax_row(&values[start..start + self.k_atoms]);
447 let mut mean_log_plus_one = 0.0;
448 for k in 0..self.k_atoms {
449 mean_log_plus_one += a[k] * (a[k].max(ENTROPY_LOG_PROBABILITY_FLOOR).ln() + 1.0);
450 }
451 for k in 0..self.k_atoms {
452 let log_plus_one = a[k].max(ENTROPY_LOG_PROBABILITY_FLOOR).ln() + 1.0;
453 let term = (1.0 - 2.0 * a[k]) * (mean_log_plus_one - log_plus_one) + a[k] - 1.0;
454 out[start + k] = scale * a[k] * term;
455 }
456 }
457 Some(out)
458 }
459
460 fn hvp(
461 &self,
462 target: ArrayView1<'_, f64>,
463 rho: ArrayView1<'_, f64>,
464 v: ArrayView1<'_, f64>,
465 ) -> Array1<f64> {
466 let lambda = resolve_learnable_weight(self.weight, rho[0]);
477 assert_eq!(target.len(), v.len(), "hvp dimension mismatch");
478 let n = target.len() / self.k_atoms;
479 let values: Vec<f64> = target.iter().copied().collect();
480 let mut out = Array1::<f64>::zeros(target.len());
481 let inv_tau = 1.0 / self.temperature;
482 let scale = lambda * inv_tau * inv_tau;
483 for row in 0..n {
484 let start = row * self.k_atoms;
485 let a = self.softmax_row(&values[start..start + self.k_atoms]);
486 let mut mean_log_plus_one = 0.0;
487 let mut mean_v = 0.0;
488 for k in 0..self.k_atoms {
489 mean_log_plus_one += a[k] * (a[k].max(ENTROPY_LOG_PROBABILITY_FLOOR).ln() + 1.0);
490 mean_v += a[k] * v[start + k];
491 }
492 let mut mean_centered_v_log_plus_one = 0.0;
493 for k in 0..self.k_atoms {
494 let centered_v = v[start + k] - mean_v;
495 mean_centered_v_log_plus_one +=
496 a[k] * centered_v * (a[k].max(ENTROPY_LOG_PROBABILITY_FLOOR).ln() + 1.0);
497 }
498 for k in 0..self.k_atoms {
499 let log_plus_one = a[k].max(ENTROPY_LOG_PROBABILITY_FLOOR).ln() + 1.0;
500 let centered_v = v[start + k] - mean_v;
501 out[start + k] = scale
502 * a[k]
503 * (centered_v * (mean_log_plus_one - log_plus_one - 1.0)
504 + mean_centered_v_log_plus_one);
505 }
506 }
507 out
508 }
509
510 fn psd_majorizer_diag(
511 &self,
512 target: ArrayView1<'_, f64>,
513 rho: ArrayView1<'_, f64>,
514 ) -> Option<Array1<f64>> {
515 assert_eq!(rho.len(), 1, "softmax entropy expects one rho parameter");
516 assert_eq!(
517 target.len() % self.k_atoms,
518 0,
519 "softmax entropy target length must be divisible by k_atoms"
520 );
521 let lambda = resolve_learnable_weight(self.weight, rho[0]);
530 let inv_tau = 1.0 / self.temperature;
531 let scale = lambda * inv_tau * inv_tau;
532 let n = target.len() / self.k_atoms;
533 let values: Vec<f64> = target.iter().copied().collect();
534 let mut out = Array1::<f64>::zeros(target.len());
535 for row in 0..n {
536 let start = row * self.k_atoms;
537 let d = self.psd_majorizer_abs_row_sums(&values[start..start + self.k_atoms], scale);
538 for k in 0..self.k_atoms {
539 out[start + k] = d[k];
540 }
541 }
542 Some(out)
543 }
544
545 fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
546 Array1::from_vec(vec![self.value(target, rho)])
547 }
548
549 fn rho_count(&self) -> usize {
550 1
551 }
552
553 fn name(&self) -> &str {
554 "softmax_assignment_sparsity"
555 }
556
557 impl_scalar_apply_schedule!(weight);
558}
559
560impl SparsityPenalty {
561 #[must_use = "build error must be handled"]
562 pub fn smoothed_l1(target_tier: PenaltyTier, eps: f64) -> Result<Self, String> {
563 if !(eps.is_finite() && eps > 0.0) {
564 return Err(format!(
565 "SparsityPenalty::smoothed_l1 requires eps > 0 \
566 (Hessian / gradient have a `1/sqrt(x² + eps²)` factor that needs eps > 0 \
567 for differentiability at x = 0); got eps = {eps}"
568 ));
569 }
570 Ok(Self {
571 target_tier,
572 kind: SparsityKind::SmoothedL1 { eps },
573 weight: 1.0,
574 weight_schedule: None,
575 strength_rho_index: 0,
576 eps_rho_index: None,
577 })
578 }
579
580 #[must_use = "build error must be handled"]
581 pub fn log(target_tier: PenaltyTier, delta: f64) -> Result<Self, String> {
582 if !(delta.is_finite() && delta > 0.0) {
583 return Err(format!(
584 "SparsityPenalty::log requires delta > 0 \
585 (the log-sparsifier is log(1 + x²/δ²), undefined at δ = 0); \
586 got delta = {delta}"
587 ));
588 }
589 Ok(Self {
590 target_tier,
591 kind: SparsityKind::Log { delta },
592 weight: 1.0,
593 weight_schedule: None,
594 strength_rho_index: 0,
595 eps_rho_index: None,
596 })
597 }
598
599 #[must_use]
602 pub fn hoyer(target_tier: PenaltyTier) -> Self {
603 Self {
604 target_tier,
605 kind: SparsityKind::Hoyer,
606 weight: 1.0,
607 weight_schedule: None,
608 strength_rho_index: 0,
609 eps_rho_index: None,
610 }
611 }
612
613 impl_with_weight_schedule!(weight);
614
615 #[must_use]
616 pub fn with_eps_reml(mut self, eps_rho_index: usize) -> Self {
617 self.eps_rho_index = Some(eps_rho_index);
618 self
619 }
620
621 fn resolved(&self, rho: ArrayView1<'_, f64>) -> (f64, f64) {
623 let strength = resolve_learnable_weight(self.weight, rho[self.strength_rho_index]);
624 let smoothing = match (self.eps_rho_index, self.kind) {
625 (Some(idx), _) => rho[idx].exp().max(f64::MIN_POSITIVE),
631 (None, SparsityKind::SmoothedL1 { eps }) => eps,
632 (None, SparsityKind::Log { delta }) => delta,
633 (None, SparsityKind::Hoyer) => 0.0,
634 };
635 (strength, smoothing)
636 }
637}
638
639impl AnalyticPenalty for SparsityPenalty {
640 fn tier(&self) -> PenaltyTier {
641 self.target_tier
642 }
643
644 fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
645 let (lam, smooth) = self.resolved(rho);
646 match self.kind {
647 SparsityKind::SmoothedL1 { .. } => {
648 let mut acc = 0.0;
649 for &x in target.iter() {
650 acc += (x * x + smooth * smooth).sqrt();
651 }
652 lam * acc
653 }
654 SparsityKind::Hoyer => {
655 let n = target.len() as f64;
662 assert!(n > 1.0, "Hoyer requires n > 1");
663 let l1: f64 = target.iter().map(|x| x.abs()).sum();
664 let l2: f64 = target.iter().map(|x| x * x).sum::<f64>().sqrt();
665 if l2 == 0.0 {
666 return 0.0;
667 }
668 let h = (l1 / l2 - 1.0) / (n.sqrt() - 1.0);
669 lam * h
670 }
671 SparsityKind::Log { .. } => {
672 let mut acc = 0.0;
673 let d2 = smooth * smooth;
674 for &x in target.iter() {
675 acc += (1.0 + x * x / d2).ln();
676 }
677 lam * acc
678 }
679 }
680 }
681
682 fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
683 let (lam, smooth) = self.resolved(rho);
684 let mut g = Array1::<f64>::zeros(target.len());
685 match self.kind {
686 SparsityKind::SmoothedL1 { .. } => {
687 let eps2 = smooth * smooth;
688 for (i, &x) in target.iter().enumerate() {
689 g[i] = lam * x / (x * x + eps2).sqrt();
690 }
691 }
692 SparsityKind::Hoyer => {
693 let n = target.len() as f64;
696 assert!(n > 1.0, "Hoyer requires n > 1");
697 let l1: f64 = target.iter().map(|x| x.abs()).sum();
698 let l2: f64 = target.iter().map(|x| x * x).sum::<f64>().sqrt();
699 if l2 == 0.0 {
700 return g;
701 }
702 let denom = n.sqrt() - 1.0;
703 let a = lam / denom;
704 let inv_l2 = 1.0 / l2;
705 let inv_l2_cubed = inv_l2 * inv_l2 * inv_l2;
706 for (i, &x) in target.iter().enumerate() {
707 let sgn = if x > 0.0 {
708 1.0
709 } else if x < 0.0 {
710 -1.0
711 } else {
712 0.0
713 };
714 g[i] = a * (sgn * inv_l2 - l1 * x * inv_l2_cubed);
715 }
716 }
717 SparsityKind::Log { .. } => {
718 let d2 = smooth * smooth;
719 for (i, &x) in target.iter().enumerate() {
720 g[i] = lam * 2.0 * x / (d2 + x * x);
721 }
722 }
723 }
724 g
725 }
726
727 fn hessian_diag(
728 &self,
729 target: ArrayView1<'_, f64>,
730 rho: ArrayView1<'_, f64>,
731 ) -> Option<Array1<f64>> {
732 let (lam, smooth) = self.resolved(rho);
733 match self.kind {
734 SparsityKind::SmoothedL1 { .. } => {
735 let mut d = Array1::<f64>::zeros(target.len());
736 let eps2 = smooth * smooth;
737 for (i, &x) in target.iter().enumerate() {
738 let r = (x * x + eps2).sqrt();
739 d[i] = lam * eps2 / (r * r * r);
740 }
741 Some(d)
742 }
743 SparsityKind::Log { .. } => {
744 let mut d = Array1::<f64>::zeros(target.len());
745 let d2 = smooth * smooth;
754 for (i, &x) in target.iter().enumerate() {
755 let denom = d2 + x * x;
756 d[i] = lam * 2.0 * (d2 - x * x) / (denom * denom);
757 }
758 Some(d)
759 }
760 SparsityKind::Hoyer => None,
767 }
768 }
769
770 fn hvp(
771 &self,
772 target: ArrayView1<'_, f64>,
773 rho: ArrayView1<'_, f64>,
774 v: ArrayView1<'_, f64>,
775 ) -> Array1<f64> {
776 let (lam, smooth) = self.resolved(rho);
782 let n_target = target.len();
783 assert_eq!(v.len(), n_target, "hvp dimension mismatch");
784 match self.kind {
785 SparsityKind::SmoothedL1 { .. } => {
786 let mut out = Array1::<f64>::zeros(n_target);
787 let eps2 = smooth * smooth;
788 for (i, &x) in target.iter().enumerate() {
789 let r = (x * x + eps2).sqrt();
790 out[i] = lam * eps2 / (r * r * r) * v[i];
791 }
792 out
793 }
794 SparsityKind::Log { .. } => {
795 let mut out = Array1::<f64>::zeros(n_target);
801 let d2 = smooth * smooth;
802 for (i, &x) in target.iter().enumerate() {
803 let denom = d2 + x * x;
804 out[i] = lam * 2.0 * (d2 - x * x) / (denom * denom) * v[i];
805 }
806 out
807 }
808 SparsityKind::Hoyer => {
809 let n = n_target as f64;
815 assert!(n > 1.0, "Hoyer requires n > 1");
816 let l1: f64 = target.iter().map(|x| x.abs()).sum();
817 let l2: f64 = target.iter().map(|x| x * x).sum::<f64>().sqrt();
818 let mut out = Array1::<f64>::zeros(n_target);
819 if l2 == 0.0 {
820 return out;
821 }
822 let a = lam / (n.sqrt() - 1.0);
823 let inv_l2_cubed = 1.0 / (l2 * l2 * l2);
824 let inv_l2_5 = inv_l2_cubed / (l2 * l2);
825 let mut x_dot_v = 0.0;
826 let mut s_dot_v = 0.0;
827 for i in 0..n_target {
828 let xi = target[i];
829 let si = if xi > 0.0 {
830 1.0
831 } else if xi < 0.0 {
832 -1.0
833 } else {
834 0.0
835 };
836 x_dot_v += xi * v[i];
837 s_dot_v += si * v[i];
838 }
839 for i in 0..n_target {
840 let xi = target[i];
841 let si = if xi > 0.0 {
842 1.0
843 } else if xi < 0.0 {
844 -1.0
845 } else {
846 0.0
847 };
848 out[i] = a
849 * (-si * x_dot_v * inv_l2_cubed
850 - xi * s_dot_v * inv_l2_cubed
851 - l1 * v[i] * inv_l2_cubed
852 + 3.0 * l1 * xi * x_dot_v * inv_l2_5);
853 }
854 out
855 }
856 }
857 }
858
859 fn psd_majorizer_diag(
860 &self,
861 target: ArrayView1<'_, f64>,
862 rho: ArrayView1<'_, f64>,
863 ) -> Option<Array1<f64>> {
864 let (lam, smooth) = self.resolved(rho);
865 match self.kind {
866 SparsityKind::SmoothedL1 { .. } => self.hessian_diag(target, rho),
868 SparsityKind::Log { .. } => {
872 let mut d = Array1::<f64>::zeros(target.len());
873 let d2 = smooth * smooth;
874 for (i, &x) in target.iter().enumerate() {
875 d[i] = lam * 2.0 / (d2 + x * x);
876 }
877 Some(d)
878 }
879 SparsityKind::Hoyer => None,
882 }
883 }
884
885 fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
886 let n_rho = self.rho_count();
889 let mut out = Array1::<f64>::zeros(n_rho);
890 let p_val = self.value(target, rho);
891 out[self.strength_rho_index] = p_val;
892 if let Some(eps_idx) = self.eps_rho_index {
893 let (lam, smooth) = self.resolved(rho);
894 let mut dp_deps = 0.0;
895 match self.kind {
896 SparsityKind::SmoothedL1 { .. } => {
897 for &x in target.iter() {
898 dp_deps += smooth / (x * x + smooth * smooth).sqrt();
899 }
900 dp_deps *= lam;
901 }
902 SparsityKind::Log { .. } => {
903 let d2 = smooth * smooth;
905 for &x in target.iter() {
906 dp_deps += -2.0 * x * x / (smooth * (d2 + x * x));
907 }
908 dp_deps *= lam;
909 }
910 SparsityKind::Hoyer => {}
911 }
912 out[eps_idx] = smooth * dp_deps;
914 }
915 out
916 }
917
918 fn rho_count(&self) -> usize {
919 1 + if self.eps_rho_index.is_some() { 1 } else { 0 }
920 }
921
922 fn name(&self) -> &str {
923 "sparsity"
924 }
925
926 impl_scalar_apply_schedule!(weight);
927}
928
929#[derive(Debug, Clone)]
934pub struct TopKActivationPenalty {
935 pub target: PsiSlice,
936 pub k: usize,
937 pub latent_dim: usize,
938 pub weight: f64,
939 pub weight_schedule: Option<ScalarWeightSchedule>,
940}
941
942impl TopKActivationPenalty {
943 #[must_use = "build error must be handled"]
944 pub fn new(target: PsiSlice, k: usize, weight: f64) -> Result<Self, String> {
945 let latent_dim = target
946 .latent_dim
947 .ok_or_else(|| "TopKActivationPenalty::new requires target.latent_dim".to_string())?;
948 if latent_dim == 0 {
949 return Err("TopKActivationPenalty::new requires latent_dim > 0".to_string());
950 }
951 if k == 0 || k > latent_dim {
952 return Err(format!(
953 "TopKActivationPenalty::new requires 0 < k <= latent_dim; got k={k}, latent_dim={latent_dim}"
954 ));
955 }
956 if !(weight.is_finite() && weight > 0.0) {
957 return Err(format!(
958 "TopKActivationPenalty::new requires finite weight > 0, got {weight}"
959 ));
960 }
961 Ok(Self {
962 target,
963 k,
964 latent_dim,
965 weight,
966 weight_schedule: None,
967 })
968 }
969
970 impl_with_weight_schedule!(weight);
971
972 fn topk_mask_row(&self, target: ArrayView1<'_, f64>, row: usize, mask: &mut [bool]) {
973 mask.fill(false);
974 let d = self.latent_dim;
975 let base = row * d;
976 let mut order = (0..d).collect::<Vec<_>>();
977 order.sort_by(|&a, &b| {
978 target[base + b]
979 .abs()
980 .total_cmp(&target[base + a].abs())
981 .then_with(|| a.cmp(&b))
982 });
983 for &axis in order.iter().take(self.k) {
984 mask[axis] = true;
985 }
986 }
987}
988
989impl AnalyticPenalty for TopKActivationPenalty {
990 fn tier(&self) -> PenaltyTier {
991 PenaltyTier::Psi
992 }
993
994 fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
995 assert_eq!(rho.len(), 0, "TopKActivationPenalty has no rho parameters");
996 let d = self.latent_dim;
997 let n_obs = target.len() / d;
998 let mut mask = vec![false; d];
999 let mut acc = 0.0;
1000 for row in 0..n_obs {
1001 self.topk_mask_row(target, row, &mut mask);
1002 let base = row * d;
1003 for axis in 0..d {
1004 if mask[axis] {
1005 let v = target[base + axis];
1006 acc += 0.5 * self.weight * v * v;
1007 }
1008 }
1009 }
1010 acc
1011 }
1012
1013 fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
1014 assert_eq!(rho.len(), 0, "TopKActivationPenalty has no rho parameters");
1015 let d = self.latent_dim;
1016 let n_obs = target.len() / d;
1017 let mut mask = vec![false; d];
1018 let mut grad = Array1::<f64>::zeros(target.len());
1019 for row in 0..n_obs {
1020 self.topk_mask_row(target, row, &mut mask);
1021 let base = row * d;
1022 for axis in 0..d {
1023 if mask[axis] {
1024 grad[base + axis] = self.weight * target[base + axis];
1025 }
1026 }
1027 }
1028 grad
1029 }
1030
1031 fn hessian_diag(
1032 &self,
1033 target: ArrayView1<'_, f64>,
1034 rho: ArrayView1<'_, f64>,
1035 ) -> Option<Array1<f64>> {
1036 assert_eq!(rho.len(), 0, "TopKActivationPenalty has no rho parameters");
1037 let d = self.latent_dim;
1038 let n_obs = target.len() / d;
1039 let mut mask = vec![false; d];
1040 let mut diag = Array1::<f64>::zeros(target.len());
1041 for row in 0..n_obs {
1042 self.topk_mask_row(target, row, &mut mask);
1043 let base = row * d;
1044 for axis in 0..d {
1045 if mask[axis] {
1046 diag[base + axis] = self.weight;
1047 }
1048 }
1049 }
1050 Some(diag)
1051 }
1052
1053 fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
1054 assert_eq!(rho.len(), 0, "TopKActivationPenalty has no rho parameters");
1055 assert_eq!(
1056 target.len() % self.latent_dim,
1057 0,
1058 "TopKActivationPenalty target length must be a multiple of latent_dim"
1059 );
1060 Array1::<f64>::zeros(0)
1061 }
1062
1063 fn rho_count(&self) -> usize {
1064 0
1065 }
1066
1067 fn name(&self) -> &str {
1068 "topk_activation"
1069 }
1070
1071 impl_scalar_apply_schedule!(weight);
1072}
1073
1074#[derive(Debug, Clone)]
1079pub struct JumpReLUPenalty {
1080 pub target: PsiSlice,
1081 pub latent_dim: usize,
1082 pub thresholds: Array1<f64>,
1083 pub weight: f64,
1084 pub smoothing_eps: f64,
1085 pub weight_schedule: Option<ScalarWeightSchedule>,
1086}
1087
1088impl JumpReLUPenalty {
1089 #[must_use = "build error must be handled"]
1090 pub fn new(
1091 target: PsiSlice,
1092 thresholds: Array1<f64>,
1093 weight: f64,
1094 smoothing_eps: f64,
1095 ) -> Result<Self, String> {
1096 let latent_dim = target
1097 .latent_dim
1098 .ok_or_else(|| "JumpReLUPenalty::new requires target.latent_dim".to_string())?;
1099 if latent_dim == 0 {
1100 return Err("JumpReLUPenalty::new requires latent_dim > 0".to_string());
1101 }
1102 if thresholds.len() != latent_dim {
1103 return Err(format!(
1104 "JumpReLUPenalty::new thresholds length {} does not match latent_dim {latent_dim}",
1105 thresholds.len()
1106 ));
1107 }
1108 for (idx, &tau) in thresholds.iter().enumerate() {
1109 if !(tau.is_finite() && tau > 0.0) {
1110 return Err(format!(
1111 "JumpReLUPenalty::new thresholds[{idx}] must be finite and > 0, got {tau}"
1112 ));
1113 }
1114 }
1115 if !(weight.is_finite() && weight > 0.0) {
1116 return Err(format!(
1117 "JumpReLUPenalty::new requires finite weight > 0, got {weight}"
1118 ));
1119 }
1120 if !(smoothing_eps.is_finite() && smoothing_eps > 0.0) {
1121 return Err(format!(
1122 "JumpReLUPenalty::new requires finite smoothing_eps > 0, got {smoothing_eps}"
1123 ));
1124 }
1125 Ok(Self {
1126 target,
1127 latent_dim,
1128 thresholds,
1129 weight,
1130 smoothing_eps,
1131 weight_schedule: None,
1132 })
1133 }
1134
1135 impl_with_weight_schedule!(weight);
1136
1137 fn threshold(&self, axis: usize, rho: ArrayView1<'_, f64>) -> f64 {
1138 resolve_learnable_weight(self.thresholds[axis], rho[axis])
1142 }
1143
1144 pub(crate) fn sigmoid_gate(&self, x: f64) -> f64 {
1145 if x >= 0.0 {
1146 1.0 / (1.0 + (-x).exp())
1147 } else {
1148 let ex = x.exp();
1149 ex / (1.0 + ex)
1150 }
1151 }
1152
1153 fn true_hessian_diag_entry(&self, tau: f64, gate: f64) -> f64 {
1154 self.weight * tau * gate * (1.0 - gate) * (1.0 - 2.0 * gate)
1155 / (self.smoothing_eps * self.smoothing_eps)
1156 }
1157
1158 fn psd_hessian_diag_entry(&self, tau: f64, gate: f64) -> f64 {
1159 let slope = gate * (1.0 - gate);
1175 let reweighted_l2 = slope * slope;
1176 let abs_exact = slope * (1.0 - 2.0 * gate).abs();
1177 self.weight * tau * reweighted_l2.max(abs_exact) / (self.smoothing_eps * self.smoothing_eps)
1178 }
1179}
1180
1181#[must_use]
1196pub fn jumprelu_gate_value_grad(z: f64, tau: f64, smoothing_eps: f64) -> (f64, f64, f64) {
1197 let g = gam_linalg::utils::stable_logistic((z - tau) / smoothing_eps);
1198 let value = if z > tau { z } else { 0.0 };
1199 let slope = z * g * (1.0 - g) / smoothing_eps;
1200 let dphi_dz = g + slope;
1201 let dphi_dtau = -slope;
1202 (value, dphi_dz, dphi_dtau)
1203}
1204
1205impl AnalyticPenalty for JumpReLUPenalty {
1206 fn tier(&self) -> PenaltyTier {
1207 PenaltyTier::Psi
1208 }
1209
1210 fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
1211 let d = self.latent_dim;
1212 let n_obs = target.len() / d;
1213 let mut acc = 0.0;
1214 for row in 0..n_obs {
1215 let base = row * d;
1216 for axis in 0..d {
1217 let tau = self.threshold(axis, rho);
1218 let gate = self.sigmoid_gate((target[base + axis] - tau) / self.smoothing_eps);
1219 acc += self.weight * tau * gate;
1220 }
1221 }
1222 acc
1223 }
1224
1225 fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
1226 let d = self.latent_dim;
1227 let n_obs = target.len() / d;
1228 let mut grad = Array1::<f64>::zeros(target.len());
1229 for row in 0..n_obs {
1230 let base = row * d;
1231 for axis in 0..d {
1232 let tau = self.threshold(axis, rho);
1233 let gate = self.sigmoid_gate((target[base + axis] - tau) / self.smoothing_eps);
1234 grad[base + axis] = self.weight * tau * gate * (1.0 - gate) / self.smoothing_eps;
1235 }
1236 }
1237 grad
1238 }
1239
1240 fn hessian_diag(
1241 &self,
1242 target: ArrayView1<'_, f64>,
1243 rho: ArrayView1<'_, f64>,
1244 ) -> Option<Array1<f64>> {
1245 let d = self.latent_dim;
1246 let n_obs = target.len() / d;
1247 let mut diag = Array1::<f64>::zeros(target.len());
1248 for row in 0..n_obs {
1249 let base = row * d;
1250 for axis in 0..d {
1251 let tau = self.threshold(axis, rho);
1252 let gate = self.sigmoid_gate((target[base + axis] - tau) / self.smoothing_eps);
1253 diag[base + axis] = self.true_hessian_diag_entry(tau, gate);
1254 }
1255 }
1256 Some(diag)
1257 }
1258
1259 fn hvp(
1260 &self,
1261 target: ArrayView1<'_, f64>,
1262 rho: ArrayView1<'_, f64>,
1263 v: ArrayView1<'_, f64>,
1264 ) -> Array1<f64> {
1265 assert_eq!(target.len(), v.len(), "hvp dimension mismatch");
1266 let d = self.latent_dim;
1267 let n_obs = target.len() / d;
1268 let mut out = Array1::<f64>::zeros(target.len());
1269 for row in 0..n_obs {
1270 let base = row * d;
1271 for axis in 0..d {
1272 let tau = self.threshold(axis, rho);
1273 let gate = self.sigmoid_gate((target[base + axis] - tau) / self.smoothing_eps);
1274 out[base + axis] = self.true_hessian_diag_entry(tau, gate) * v[base + axis];
1275 }
1276 }
1277 out
1278 }
1279
1280 fn psd_majorizer_diag(
1281 &self,
1282 target: ArrayView1<'_, f64>,
1283 rho: ArrayView1<'_, f64>,
1284 ) -> Option<Array1<f64>> {
1285 let d = self.latent_dim;
1293 let n_obs = target.len() / d;
1294 let mut diag = Array1::<f64>::zeros(target.len());
1295 for row in 0..n_obs {
1296 let base = row * d;
1297 for axis in 0..d {
1298 let tau = self.threshold(axis, rho);
1299 let gate = self.sigmoid_gate((target[base + axis] - tau) / self.smoothing_eps);
1300 diag[base + axis] = self.psd_hessian_diag_entry(tau, gate);
1301 }
1302 }
1303 Some(diag)
1304 }
1305
1306 fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
1307 let d = self.latent_dim;
1308 let n_obs = target.len() / d;
1309 let mut out = Array1::<f64>::zeros(d);
1310 for axis in 0..d {
1311 let tau = self.threshold(axis, rho);
1312 let mut g_tau = 0.0;
1313 for row in 0..n_obs {
1314 let x = target[row * d + axis];
1315 let gate = self.sigmoid_gate((x - tau) / self.smoothing_eps);
1316 g_tau += gate - tau * gate * (1.0 - gate) / self.smoothing_eps;
1317 }
1318 out[axis] = self.weight * tau * g_tau;
1319 }
1320 out
1321 }
1322
1323 fn rho_count(&self) -> usize {
1324 self.latent_dim
1325 }
1326
1327 fn name(&self) -> &str {
1328 "jumprelu"
1329 }
1330
1331 impl_scalar_apply_schedule!(weight);
1332}
1333
1334#[cfg(test)]
1335mod fisher_majorizer_1419_tests {
1336 use super::*;
1337 use approx::assert_abs_diff_eq;
1338 use gam_linalg::faer_ndarray::FaerEigh;
1339 use ndarray::Array2;
1340
1341 #[test]
1355 fn gershgorin_majorizes_entropy_where_fisher_does_not_1419() {
1356 let temperature = 1.0_f64;
1359 let scale = 1.0_f64; let pen = SoftmaxAssignmentSparsityPenalty::new(2, temperature);
1361 let z1 = 0.0_f64;
1362 let z0 = z1 + (0.95_f64 / 0.05_f64).ln();
1363 let row = [z0, z1];
1364
1365 let a = pen.softmax_row(&row);
1367 assert_abs_diff_eq!(a[0], 0.95, epsilon = 1e-12);
1368 assert_abs_diff_eq!(a[1], 0.05, epsilon = 1e-12);
1369
1370 let h = pen.row_dense_hessian(&row, scale);
1372 let g = pen.row_fisher_metric(&row, scale);
1373 let m = pen.row_psd_majorizer(&row, scale);
1374
1375 assert_abs_diff_eq!(h[[0, 0]], 0.0783747664, epsilon = 1e-9);
1378 assert_abs_diff_eq!(g[[0, 0]], 0.95 * 0.05, epsilon = 1e-12);
1379
1380 for kk in 0..2 {
1382 let row_sum: f64 = (0..2).map(|jj| h[[kk, jj]].abs()).sum();
1383 assert_abs_diff_eq!(m[[kk, kk]], row_sum, epsilon = 1e-12);
1384 }
1385 assert_abs_diff_eq!(m[[0, 1]], 0.0, epsilon = 1e-15);
1387 assert_abs_diff_eq!(m[[1, 0]], 0.0, epsilon = 1e-15);
1388 assert!(m[[0, 0]] >= 0.0 && m[[1, 1]] >= 0.0);
1389
1390 let fisher_free = g[[0, 0]] - h[[0, 0]];
1394 let major_free = m[[0, 0]] - h[[0, 0]];
1395 assert!(
1396 fisher_free < -1e-3,
1397 "Fisher must FAIL the majorizer bound in the free direction (#1419); \
1398 G_11 − H_11 = {fisher_free}"
1399 );
1400 assert!(
1401 major_free >= -1e-12,
1402 "Gershgorin majorizer must SATISFY the bound in the free direction (#1419); \
1403 D_11 − H_11 = {major_free}"
1404 );
1405
1406 let mut m_minus_h = Array2::<f64>::zeros((2, 2));
1410 let mut g_minus_h = Array2::<f64>::zeros((2, 2));
1411 for i in 0..2 {
1412 for j in 0..2 {
1413 m_minus_h[[i, j]] = m[[i, j]] - h[[i, j]];
1414 g_minus_h[[i, j]] = g[[i, j]] - h[[i, j]];
1415 }
1416 }
1417 let (m_evals, _) = m_minus_h.eigh(faer::Side::Lower).expect("eigh(M−H)");
1418 let (g_evals, _) = g_minus_h.eigh(faer::Side::Lower).expect("eigh(G−H)");
1419 let m_min = m_evals.iter().cloned().fold(f64::INFINITY, f64::min);
1420 let g_min = g_evals.iter().cloned().fold(f64::INFINITY, f64::min);
1421 assert!(
1422 m_min >= -1e-12,
1423 "Gershgorin majorizer must be a Loewner majorizer (M − H ⪰ 0, #1419); \
1424 smallest eigenvalue of M−H = {m_min}"
1425 );
1426 assert!(
1427 g_min < -1e-9,
1428 "the OLD Fisher metric must FAIL the Loewner majorizer test (#1419); \
1429 smallest eigenvalue of G−H = {g_min} (expected strictly negative)"
1430 );
1431 }
1432
1433 #[test]
1440 fn gershgorin_majorizer_logit_derivative_matches_fd_1419() {
1441 let pen = SoftmaxAssignmentSparsityPenalty::new(4, 0.8);
1442 let row = [0.3_f64, -0.6, 0.9, 0.2];
1443 let scale = 1.1_f64 * (1.0 / 0.8_f64) * (1.0 / 0.8_f64);
1444 let eps = 1e-6;
1445 for w in 0..4 {
1446 let dd = pen.row_psd_majorizer_logit_derivative(&row, scale, w);
1447 let mut rp = row;
1448 let mut rm = row;
1449 rp[w] += eps;
1450 rm[w] -= eps;
1451 let mp = pen.row_psd_majorizer(&rp, scale);
1452 let mm = pen.row_psd_majorizer(&rm, scale);
1453 for k in 0..4 {
1454 let fd = (mp[[k, k]] - mm[[k, k]]) / (2.0 * eps);
1455 assert_abs_diff_eq!(dd[[k, k]], fd, epsilon = 1e-6);
1456 }
1457 for i in 0..4 {
1459 for j in 0..4 {
1460 if i != j {
1461 assert_abs_diff_eq!(dd[[i, j]], 0.0, epsilon = 1e-15);
1462 }
1463 }
1464 }
1465 }
1466 }
1467}