1use super::*;
2
3#[derive(Debug, Clone)]
37pub struct BlockOrthogonalityPenalty {
38 pub target: PsiSlice,
39 pub groups: Vec<Vec<usize>>,
40 pub weight: f64,
43 pub n_eff: usize,
45 pub learnable_weight: bool,
46 pub rho_index: usize,
47 pub weight_schedule: Option<ScalarWeightSchedule>,
48}
49
50impl BlockOrthogonalityPenalty {
51 #[must_use = "build error must be handled"]
52 pub fn new(
53 target: PsiSlice,
54 groups: Vec<Vec<usize>>,
55 weight: f64,
56 n_eff: usize,
57 learnable_weight: bool,
58 ) -> Result<Self, String> {
59 if target.is_empty() {
60 return Err("BlockOrthogonalityPenalty::new requires a non-empty target".to_string());
61 }
62 if !(weight.is_finite() && weight > 0.0) {
63 return Err(format!(
64 "BlockOrthogonalityPenalty::new requires finite weight > 0, got {weight}"
65 ));
66 }
67 if n_eff == 0 {
68 return Err("BlockOrthogonalityPenalty::new requires n_eff > 0".to_string());
69 }
70 if !target.len().is_multiple_of(n_eff) {
71 return Err(format!(
72 "BlockOrthogonalityPenalty::new target length {} is not divisible by n_eff {}",
73 target.len(),
74 n_eff
75 ));
76 }
77 let latent_dim = target.len() / n_eff;
78 if let Some(expected_dim) = target.latent_dim {
79 let expected = n_eff.checked_mul(expected_dim).ok_or_else(|| {
80 "BlockOrthogonalityPenalty::new target shape overflows usize".to_string()
81 })?;
82 if expected != target.len() {
83 return Err(format!(
84 "BlockOrthogonalityPenalty::new target length {} does not match n_eff {} × latent_dim {}",
85 target.len(),
86 n_eff,
87 expected_dim
88 ));
89 }
90 }
91 if groups.len() < 2 {
92 return Err("BlockOrthogonalityPenalty::new requires at least two groups".to_string());
93 }
94 let mut seen = vec![false; latent_dim];
95 for (group_idx, group) in groups.iter().enumerate() {
96 if group.is_empty() {
97 return Err(format!(
98 "BlockOrthogonalityPenalty::new groups[{group_idx}] must not be empty"
99 ));
100 }
101 for &axis in group {
102 if axis >= latent_dim {
103 return Err(format!(
104 "BlockOrthogonalityPenalty::new groups[{group_idx}] axis {axis} exceeds latent_dim {latent_dim}"
105 ));
106 }
107 if seen[axis] {
108 return Err(format!(
109 "BlockOrthogonalityPenalty::new axis {axis} appears in more than one group"
110 ));
111 }
112 seen[axis] = true;
113 }
114 }
115 for (axis, present) in seen.iter().copied().enumerate() {
116 if !present {
117 return Err(format!(
118 "BlockOrthogonalityPenalty::new groups must partition latent axes; missing axis {axis}"
119 ));
120 }
121 }
122 Ok(Self {
123 target,
124 groups,
125 weight,
126 n_eff,
127 learnable_weight,
128 rho_index: 0,
129 weight_schedule: None,
130 })
131 }
132
133 impl_with_weight_schedule!(weight);
134
135 fn resolved_weight(&self, rho: ArrayView1<'_, f64>) -> f64 {
136 if self.learnable_weight {
137 resolve_learnable_weight(self.weight, rho[self.rho_index])
138 } else {
139 self.weight
140 }
141 }
142
143 fn latent_dim(&self, target_len: usize) -> Option<usize> {
144 if self.n_eff == 0 || !target_len.is_multiple_of(self.n_eff) {
145 assert_eq!(
146 target_len % self.n_eff.max(1),
147 0,
148 "target length must be divisible by n_eff"
149 );
150 return None;
151 }
152 Some(target_len / self.n_eff)
153 }
154
155 fn target_matrix<'a>(&self, target: ArrayView1<'a, f64>) -> Option<ArrayView2<'a, f64>> {
156 let d = self.latent_dim(target.len())?;
157 target.into_shape_with_order((self.n_eff, d)).ok()
158 }
159
160 fn flatten_matrix(m: &Array2<f64>) -> Array1<f64> {
161 let n_obs = m.nrows();
162 let d = m.ncols();
163 let mut out = Array1::<f64>::zeros(n_obs * d);
164 for n in 0..n_obs {
165 for a in 0..d {
166 out[n * d + a] = m[[n, a]];
167 }
168 }
169 out
170 }
171
172 fn cross_gram(t: ArrayView2<'_, f64>, left: &[usize], right: &[usize]) -> Array2<f64> {
173 let mut out = Array2::<f64>::zeros((left.len(), right.len()));
174 for (li, &a) in left.iter().enumerate() {
175 for (ri, &b) in right.iter().enumerate() {
176 let mut s = 0.0;
177 for n in 0..t.nrows() {
178 s += t[[n, a]] * t[[n, b]];
179 }
180 out[[li, ri]] = s;
181 }
182 }
183 out
184 }
185
186 fn mixed_cross_gram(
193 a: ArrayView2<'_, f64>,
194 b: ArrayView2<'_, f64>,
195 left: &[usize],
196 right: &[usize],
197 ) -> Array2<f64> {
198 assert_eq!(a.nrows(), b.nrows(), "mixed_cross_gram row mismatch");
199 let mut out = Array2::<f64>::zeros((left.len(), right.len()));
200 for (li, &al) in left.iter().enumerate() {
201 for (ri, &br) in right.iter().enumerate() {
202 let mut s = 0.0;
203 for n in 0..a.nrows() {
204 s += a[[n, al]] * b[[n, br]];
205 }
206 out[[li, ri]] = s;
207 }
208 }
209 out
210 }
211
212 fn add_right_times_cross(
213 out: &mut Array2<f64>,
214 right: ArrayView2<'_, f64>,
215 left_axes: &[usize],
216 right_axes: &[usize],
217 cross_right_left: ArrayView2<'_, f64>,
218 factor: f64,
219 ) {
220 assert_eq!(cross_right_left.dim(), (right_axes.len(), left_axes.len()));
221 for n in 0..out.nrows() {
222 for (li, &left_axis) in left_axes.iter().enumerate() {
223 let mut s = 0.0;
224 for (ri, &right_axis) in right_axes.iter().enumerate() {
225 s += right[[n, right_axis]] * cross_right_left[[ri, li]];
226 }
227 out[[n, left_axis]] += factor * s;
228 }
229 }
230 }
231
232 fn hvp_with_precomputed_cross(
233 &self,
234 t: ArrayView2<'_, f64>,
235 cross: &[Vec<Option<Array2<f64>>>],
236 v: ArrayView2<'_, f64>,
237 weight: f64,
238 ) -> Array2<f64> {
239 assert_eq!(v.dim(), t.dim(), "hvp matrix dimension mismatch");
240 if v.dim() != t.dim() {
241 return Array2::<f64>::zeros(t.dim());
242 }
243 let mut out = Array2::<f64>::zeros(t.dim());
244 for g in 0..self.groups.len() {
245 let group_g = &self.groups[g];
246 for h in 0..self.groups.len() {
247 if g == h {
248 continue;
249 }
250 let group_h = &self.groups[h];
251 let c_hg = cross[h][g]
252 .as_ref()
253 .expect("between-block cross Gram must be precomputed");
254 Self::add_right_times_cross(&mut out, v, group_g, group_h, c_hg.view(), weight);
257
258 let dv_h_g = Self::mixed_cross_gram(v, t, group_h, group_g);
271 let tv_h_g = Self::mixed_cross_gram(t, v, group_h, group_g);
272 let mut d_c_hg = dv_h_g;
273 d_c_hg += &tv_h_g;
274 Self::add_right_times_cross(&mut out, t, group_g, group_h, d_c_hg.view(), weight);
275 }
276 }
277 out
278 }
279
280 fn precompute_cross(&self, t: ArrayView2<'_, f64>) -> Vec<Vec<Option<Array2<f64>>>> {
281 let mut cross = vec![vec![None; self.groups.len()]; self.groups.len()];
282 for g in 0..self.groups.len() {
283 for h in 0..self.groups.len() {
284 if g != h {
285 cross[g][h] = Some(Self::cross_gram(t, &self.groups[g], &self.groups[h]));
286 }
287 }
288 }
289 cross
290 }
291
292 pub fn as_dense(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array2<f64> {
295 let n = target.len();
296 let Some(t) = self.target_matrix(target) else {
297 return Array2::<f64>::zeros((n, n));
298 };
299 let cross = self.precompute_cross(t.view());
300 let weight = self.resolved_weight(rho);
301 let mut dense = Array2::<f64>::zeros((n, n));
302 let mut e = Array1::<f64>::zeros(n);
303 for j in 0..n {
304 e[j] = 1.0;
305 let Some(e_mat) = self.target_matrix(e.view()) else {
306 return Array2::<f64>::zeros((n, n));
307 };
308 let col = self.hvp_with_precomputed_cross(t.view(), &cross, e_mat, weight);
309 for i in 0..n {
310 dense[[i, j]] = col[[i / t.ncols(), i % t.ncols()]];
311 }
312 e[j] = 0.0;
313 }
314 dense
315 }
316}
317
318impl AnalyticPenalty for BlockOrthogonalityPenalty {
319 fn tier(&self) -> PenaltyTier {
320 PenaltyTier::Psi
321 }
322
323 fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
324 let Some(t) = self.target_matrix(target) else {
325 return 0.0;
326 };
327 let mut acc = 0.0;
328 for g in 0..self.groups.len() {
329 for h in (g + 1)..self.groups.len() {
330 let c = Self::cross_gram(t.view(), &self.groups[g], &self.groups[h]);
331 for &v in c.iter() {
332 acc += v * v;
333 }
334 }
335 }
336 0.5 * self.resolved_weight(rho) * acc
337 }
338
339 fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
340 let Some(t) = self.target_matrix(target) else {
341 return Array1::<f64>::zeros(target.len());
342 };
343 let cross = self.precompute_cross(t.view());
344 let weight = self.resolved_weight(rho);
345 let mut grad = Array2::<f64>::zeros(t.dim());
346 for g in 0..self.groups.len() {
347 for h in 0..self.groups.len() {
348 if g == h {
349 continue;
350 }
351 let c_hg = cross[h][g]
352 .as_ref()
353 .expect("between-block cross Gram must be precomputed");
354 Self::add_right_times_cross(
355 &mut grad,
356 t.view(),
357 &self.groups[g],
358 &self.groups[h],
359 c_hg.view(),
360 weight,
361 );
362 }
363 }
364 Self::flatten_matrix(&grad)
365 }
366
367 fn hvp(
368 &self,
369 target: ArrayView1<'_, f64>,
370 rho: ArrayView1<'_, f64>,
371 v: ArrayView1<'_, f64>,
372 ) -> Array1<f64> {
373 assert_eq!(target.len(), v.len(), "hvp dimension mismatch");
374 if target.len() != v.len() {
375 return Array1::<f64>::zeros(target.len());
376 }
377 let Some(t) = self.target_matrix(target) else {
378 return Array1::<f64>::zeros(target.len());
379 };
380 let Some(v_mat) = self.target_matrix(v) else {
381 return Array1::<f64>::zeros(target.len());
382 };
383 let cross = self.precompute_cross(t.view());
384 let hv = self.hvp_with_precomputed_cross(
385 t.view(),
386 &cross,
387 v_mat.view(),
388 self.resolved_weight(rho),
389 );
390 Self::flatten_matrix(&hv)
391 }
392
393 fn hessian_diag(
394 &self,
395 target: ArrayView1<'_, f64>,
396 rho: ArrayView1<'_, f64>,
397 ) -> Option<Array1<f64>> {
398 let t = self.target_matrix(target)?;
399 let n_obs = t.nrows();
400 let d = t.ncols();
401 let weight = self.resolved_weight(rho);
402 let mut group_of = vec![usize::MAX; d];
403 for (gi, group) in self.groups.iter().enumerate() {
404 for &axis in group {
405 group_of[axis] = gi;
406 }
407 }
408 let mut out = Array1::<f64>::zeros(n_obs * d);
409 for n in 0..n_obs {
410 let mut row_sq = 0.0_f64;
411 let mut group_sq = vec![0.0_f64; self.groups.len()];
412 for b in 0..d {
413 let v = t[[n, b]];
414 let v2 = v * v;
415 row_sq += v2;
416 group_sq[group_of[b]] += v2;
417 }
418 for a in 0..d {
419 let g = group_of[a];
420 out[n * d + a] = weight * (row_sq - group_sq[g]);
421 }
422 }
423 Some(out)
424 }
425
426 impl_learnable_weight_grad_rho!();
427
428 impl_learnable_weight_rho_count!();
429
430 fn name(&self) -> &str {
431 "block_orthogonality"
432 }
433
434 impl_scalar_apply_schedule!(weight);
435}
436
437#[derive(Debug, Clone)]
482pub struct DecoderIncoherencePenalty {
483 pub target: PsiSlice,
484 pub block_sizes: Vec<usize>,
487 pub p_out: usize,
490 pub k_atoms: usize,
497 pub pairs: Vec<(usize, usize, f64)>,
502 pub weight: f64,
505 pub learnable_weight: bool,
506 pub rho_index: usize,
507 pub weight_schedule: Option<ScalarWeightSchedule>,
508}
509
510impl DecoderIncoherencePenalty {
511 #[must_use = "build error must be handled"]
512 pub fn new(
513 target: PsiSlice,
514 block_sizes: Vec<usize>,
515 p_out: usize,
516 coactivation: Array2<f64>,
517 weight: f64,
518 learnable_weight: bool,
519 ) -> Result<Self, String> {
520 if target.is_empty() {
521 return Err("DecoderIncoherencePenalty::new requires a non-empty target".to_string());
522 }
523 if !(weight.is_finite() && weight > 0.0) {
524 return Err(format!(
525 "DecoderIncoherencePenalty::new requires finite weight > 0, got {weight}"
526 ));
527 }
528 if p_out == 0 {
529 return Err("DecoderIncoherencePenalty::new requires p_out > 0".to_string());
530 }
531 if block_sizes.len() < 2 {
532 return Err(
533 "DecoderIncoherencePenalty::new requires at least two atom blocks".to_string(),
534 );
535 }
536 let k = block_sizes.len();
537 if coactivation.dim() != (k, k) {
538 return Err(format!(
539 "DecoderIncoherencePenalty::new requires (K, K)=({k}, {k}) coactivation; got {:?}",
540 coactivation.dim()
541 ));
542 }
543 if !coactivation
544 .iter()
545 .all(|value| value.is_finite() && *value >= 0.0)
546 {
547 return Err(
548 "DecoderIncoherencePenalty::new requires finite non-negative coactivation entries"
549 .to_string(),
550 );
551 }
552 let mut total = 0usize;
553 for (atom_idx, &m) in block_sizes.iter().enumerate() {
554 if m == 0 {
555 return Err(format!(
556 "DecoderIncoherencePenalty::new block_sizes[{atom_idx}] must be > 0"
557 ));
558 }
559 let span = m.checked_mul(p_out).ok_or_else(|| {
560 "DecoderIncoherencePenalty::new block span overflows usize".to_string()
561 })?;
562 total = total.checked_add(span).ok_or_else(|| {
563 "DecoderIncoherencePenalty::new total span overflows usize".to_string()
564 })?;
565 }
566 if total != target.len() {
567 return Err(format!(
568 "DecoderIncoherencePenalty::new Σ_k M_k·p_out = {total} does not match target length {}",
569 target.len()
570 ));
571 }
572 let mut pairs = Vec::new();
577 for j in 0..k {
578 for kk in (j + 1)..k {
579 let w = 0.5 * (coactivation[[j, kk]] + coactivation[[kk, j]]);
580 if w != 0.0 {
581 pairs.push((j, kk, w));
582 }
583 }
584 }
585 Ok(Self {
586 target,
587 block_sizes,
588 p_out,
589 k_atoms: k,
590 pairs,
591 weight,
592 learnable_weight,
593 rho_index: 0,
594 weight_schedule: None,
595 })
596 }
597
598 #[must_use = "build error must be handled"]
606 pub fn new_sparse(
607 target: PsiSlice,
608 block_sizes: Vec<usize>,
609 p_out: usize,
610 pairs: Vec<(usize, usize, f64)>,
611 weight: f64,
612 learnable_weight: bool,
613 ) -> Result<Self, String> {
614 if target.is_empty() {
615 return Err(
616 "DecoderIncoherencePenalty::new_sparse requires a non-empty target".to_string(),
617 );
618 }
619 if !(weight.is_finite() && weight > 0.0) {
620 return Err(format!(
621 "DecoderIncoherencePenalty::new_sparse requires finite weight > 0, got {weight}"
622 ));
623 }
624 if p_out == 0 {
625 return Err("DecoderIncoherencePenalty::new_sparse requires p_out > 0".to_string());
626 }
627 if block_sizes.len() < 2 {
628 return Err(
629 "DecoderIncoherencePenalty::new_sparse requires at least two atom blocks"
630 .to_string(),
631 );
632 }
633 let k = block_sizes.len();
634 let mut total = 0usize;
635 for (atom_idx, &m) in block_sizes.iter().enumerate() {
636 if m == 0 {
637 return Err(format!(
638 "DecoderIncoherencePenalty::new_sparse block_sizes[{atom_idx}] must be > 0"
639 ));
640 }
641 let span = m.checked_mul(p_out).ok_or_else(|| {
642 "DecoderIncoherencePenalty::new_sparse block span overflows usize".to_string()
643 })?;
644 total = total.checked_add(span).ok_or_else(|| {
645 "DecoderIncoherencePenalty::new_sparse total span overflows usize".to_string()
646 })?;
647 }
648 if total != target.len() {
649 return Err(format!(
650 "DecoderIncoherencePenalty::new_sparse Σ_k M_k·p_out = {total} does not match target length {}",
651 target.len()
652 ));
653 }
654 let mut clean = Vec::with_capacity(pairs.len());
655 for (j, kk, w) in pairs {
656 if j >= k || kk >= k {
657 return Err(format!(
658 "DecoderIncoherencePenalty::new_sparse pair ({j}, {kk}) out of range K={k}"
659 ));
660 }
661 if j >= kk {
662 return Err(format!(
663 "DecoderIncoherencePenalty::new_sparse requires j < k for each pair, got ({j}, {kk})"
664 ));
665 }
666 if !(w.is_finite() && w >= 0.0) {
667 return Err(format!(
668 "DecoderIncoherencePenalty::new_sparse requires finite non-negative pair weight, got {w}"
669 ));
670 }
671 if w != 0.0 {
672 clean.push((j, kk, w));
673 }
674 }
675 Ok(Self {
676 target,
677 block_sizes,
678 p_out,
679 k_atoms: k,
680 pairs: clean,
681 weight,
682 learnable_weight,
683 rho_index: 0,
684 weight_schedule: None,
685 })
686 }
687
688 impl_with_weight_schedule!(weight);
689
690 fn resolved_weight(&self, rho: ArrayView1<'_, f64>) -> f64 {
691 if self.learnable_weight {
692 resolve_learnable_weight(self.weight, rho[self.rho_index])
693 } else {
694 self.weight
695 }
696 }
697
698 fn block_offsets(&self) -> Vec<usize> {
702 let mut out = Vec::with_capacity(self.block_sizes.len());
703 let mut cursor = self.target.range.start;
704 for &m in &self.block_sizes {
705 out.push(cursor);
706 cursor += m * self.p_out;
707 }
708 out
709 }
710
711 fn cross_gram(
713 target: ArrayView1<'_, f64>,
714 off_j: usize,
715 m_j: usize,
716 off_k: usize,
717 m_k: usize,
718 p_out: usize,
719 ) -> Array2<f64> {
720 let mut out = Array2::<f64>::zeros((m_j, m_k));
721 for a in 0..m_j {
722 for b in 0..m_k {
723 let mut s = 0.0;
724 for o in 0..p_out {
725 s += target[off_j + a * p_out + o] * target[off_k + b * p_out + o];
726 }
727 out[[a, b]] = s;
728 }
729 }
730 out
731 }
732
733 fn hvp_impl(
741 &self,
742 target: ArrayView1<'_, f64>,
743 rho: ArrayView1<'_, f64>,
744 v: ArrayView1<'_, f64>,
745 include_residual: bool,
746 ) -> Array1<f64> {
747 let mut out = Array1::<f64>::zeros(target.len());
748 if target.len() != self.target.len() {
749 return out;
750 }
751 let offsets = self.block_offsets();
752 let weight = self.resolved_weight(rho);
753 let p_out = self.p_out;
754 for &(j, k, w_sym) in &self.pairs {
755 {
756 let w_pair = w_sym * weight;
757 if w_pair == 0.0 {
758 continue;
759 }
760 let off_j = offsets[j];
761 let off_k = offsets[k];
762 let m_j = self.block_sizes[j];
763 let m_k = self.block_sizes[k];
764 let mut d_c = Array2::<f64>::zeros((m_j, m_k));
767 for a in 0..m_j {
768 for b in 0..m_k {
769 let mut s = 0.0;
770 for o in 0..p_out {
771 s += v[off_j + a * p_out + o] * target[off_k + b * p_out + o]
772 + target[off_j + a * p_out + o] * v[off_k + b * p_out + o];
773 }
774 d_c[[a, b]] = s;
775 }
776 }
777 let c = if include_residual {
780 Some(Self::cross_gram(target, off_j, m_j, off_k, m_k, p_out))
781 } else {
782 None
783 };
784 for a in 0..m_j {
786 for o in 0..p_out {
787 let mut s = 0.0;
788 for b in 0..m_k {
789 s += d_c[[a, b]] * target[off_k + b * p_out + o];
790 if let Some(c) = &c {
791 s += c[[a, b]] * v[off_k + b * p_out + o];
792 }
793 }
794 out[off_j + a * p_out + o] += w_pair * s;
795 }
796 }
797 for b in 0..m_k {
799 for o in 0..p_out {
800 let mut s = 0.0;
801 for a in 0..m_j {
802 s += d_c[[a, b]] * target[off_j + a * p_out + o];
803 if let Some(c) = &c {
804 s += c[[a, b]] * v[off_j + a * p_out + o];
805 }
806 }
807 out[off_k + b * p_out + o] += w_pair * s;
808 }
809 }
810 }
811 }
812 out
813 }
814
815 pub fn accumulate_psd_majorizer_dense(
837 &self,
838 target: ArrayView1<'_, f64>,
839 rho: ArrayView1<'_, f64>,
840 scale: f64,
841 hbb: &mut Array2<f64>,
842 ) {
843 if target.len() != self.target.len() {
844 return;
845 }
846 let offsets = self.block_offsets();
847 let weight = self.resolved_weight(rho);
848 let p = self.p_out;
849 for &(j, k, w_sym) in &self.pairs {
850 let w = w_sym * weight * scale;
851 if w == 0.0 {
852 continue;
853 }
854 let off_j = offsets[j];
855 let off_k = offsets[k];
856 let m_j = self.block_sizes[j];
857 let m_k = self.block_sizes[k];
858 let mut g_j = vec![0.0_f64; p * p];
861 let mut g_k = vec![0.0_f64; p * p];
862 for o in 0..p {
863 for o2 in 0..p {
864 let mut sj = 0.0;
865 for a in 0..m_j {
866 sj += target[off_j + a * p + o] * target[off_j + a * p + o2];
867 }
868 g_j[o * p + o2] = sj;
869 let mut sk = 0.0;
870 for b in 0..m_k {
871 sk += target[off_k + b * p + o] * target[off_k + b * p + o2];
872 }
873 g_k[o * p + o2] = sk;
874 }
875 }
876 for a in 0..m_j {
878 let base = off_j + a * p;
879 for o in 0..p {
880 for o2 in 0..p {
881 hbb[[base + o, base + o2]] += w * g_k[o * p + o2];
882 }
883 }
884 }
885 for b in 0..m_k {
887 let base = off_k + b * p;
888 for o in 0..p {
889 for o2 in 0..p {
890 hbb[[base + o, base + o2]] += w * g_j[o * p + o2];
891 }
892 }
893 }
894 for a in 0..m_j {
897 for b in 0..m_k {
898 for o1 in 0..p {
899 let row_j = off_j + a * p + o1;
900 let bk_b_o1 = target[off_k + b * p + o1];
901 for o2 in 0..p {
902 let col_k = off_k + b * p + o2;
903 let contrib = w * target[off_j + a * p + o2] * bk_b_o1;
904 hbb[[row_j, col_k]] += contrib;
905 hbb[[col_k, row_j]] += contrib;
906 }
907 }
908 }
909 }
910 }
911 }
912}
913
914impl AnalyticPenalty for DecoderIncoherencePenalty {
915 fn tier(&self) -> PenaltyTier {
916 PenaltyTier::Beta
917 }
918
919 fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
920 if target.len() != self.target.len() {
921 return 0.0;
922 }
923 let offsets = self.block_offsets();
924 let mut acc = 0.0;
925 for &(j, k, w_pair) in &self.pairs {
926 {
927 if w_pair == 0.0 {
928 continue;
929 }
930 let c = Self::cross_gram(
931 target,
932 offsets[j],
933 self.block_sizes[j],
934 offsets[k],
935 self.block_sizes[k],
936 self.p_out,
937 );
938 let mut frob_sq = 0.0;
939 for &value in c.iter() {
940 frob_sq += value * value;
941 }
942 acc += w_pair * frob_sq;
943 }
944 }
945 0.5 * self.resolved_weight(rho) * acc
946 }
947
948 fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
949 let mut grad = Array1::<f64>::zeros(target.len());
950 if target.len() != self.target.len() {
951 return grad;
952 }
953 let offsets = self.block_offsets();
954 let weight = self.resolved_weight(rho);
955 for &(j, k, w_sym) in &self.pairs {
956 {
957 let w_pair = w_sym * weight;
958 if w_pair == 0.0 {
959 continue;
960 }
961 let off_j = offsets[j];
962 let off_k = offsets[k];
963 let m_j = self.block_sizes[j];
964 let m_k = self.block_sizes[k];
965 let c = Self::cross_gram(target, off_j, m_j, off_k, m_k, self.p_out);
966 for a in 0..m_j {
968 for o in 0..self.p_out {
969 let mut s = 0.0;
970 for b in 0..m_k {
971 s += c[[a, b]] * target[off_k + b * self.p_out + o];
972 }
973 grad[off_j + a * self.p_out + o] += w_pair * s;
974 }
975 }
976 for b in 0..m_k {
978 for o in 0..self.p_out {
979 let mut s = 0.0;
980 for a in 0..m_j {
981 s += c[[a, b]] * target[off_j + a * self.p_out + o];
982 }
983 grad[off_k + b * self.p_out + o] += w_pair * s;
984 }
985 }
986 }
987 }
988 grad
989 }
990
991 fn hvp(
1007 &self,
1008 target: ArrayView1<'_, f64>,
1009 rho: ArrayView1<'_, f64>,
1010 v: ArrayView1<'_, f64>,
1011 ) -> Array1<f64> {
1012 assert_eq!(target.len(), v.len(), "hvp dimension mismatch");
1013 self.hvp_impl(target, rho, v, true)
1014 }
1015
1016 fn psd_majorizer_hvp(
1029 &self,
1030 target: ArrayView1<'_, f64>,
1031 rho: ArrayView1<'_, f64>,
1032 v: ArrayView1<'_, f64>,
1033 ) -> Array1<f64> {
1034 assert_eq!(
1035 target.len(),
1036 v.len(),
1037 "psd_majorizer_hvp dimension mismatch"
1038 );
1039 self.hvp_impl(target, rho, v, false)
1040 }
1041
1042 impl_learnable_weight_grad_rho!();
1048
1049 impl_learnable_weight_rho_count!();
1050
1051 fn name(&self) -> &str {
1052 "decoder_incoherence"
1053 }
1054
1055 impl_scalar_apply_schedule!(weight);
1056}
1057
1058#[derive(Debug, Clone)]
1068pub struct OrthogonalityPenalty {
1069 pub target: PsiSlice,
1070 pub latent_dim: usize,
1071 pub weight: f64,
1074 pub n_eff: usize,
1077 pub learnable_weight: bool,
1078 pub rho_index: usize,
1079 pub weight_schedule: Option<ScalarWeightSchedule>,
1080}
1081
1082impl OrthogonalityPenalty {
1083 #[must_use = "build error must be handled"]
1084 pub fn new(
1085 target: PsiSlice,
1086 latent_dim: usize,
1087 weight: f64,
1088 n_eff: usize,
1089 learnable_weight: bool,
1090 ) -> Result<Self, String> {
1091 if latent_dim == 0 {
1092 return Err("OrthogonalityPenalty::new requires latent_dim > 0".to_string());
1093 }
1094 if !target.len().is_multiple_of(latent_dim) {
1095 return Err(format!(
1096 "OrthogonalityPenalty::new target length {} is not divisible by latent_dim {}",
1097 target.len(),
1098 latent_dim
1099 ));
1100 }
1101 let n_obs = target.len() / latent_dim;
1102 if n_obs < latent_dim {
1103 return Err(format!(
1104 "OrthogonalityPenalty::new requires n_obs >= latent_dim for a feasible \
1105 Stiefel target, got n_obs {n_obs} and latent_dim {latent_dim}"
1106 ));
1107 }
1108 if !(weight.is_finite() && weight > 0.0) {
1109 return Err(format!(
1110 "OrthogonalityPenalty::new requires finite weight > 0, got {weight}"
1111 ));
1112 }
1113 if n_eff == 0 {
1114 return Err("OrthogonalityPenalty::new requires n_eff > 0".to_string());
1115 }
1116 if n_eff != n_obs {
1117 return Err(format!(
1118 "OrthogonalityPenalty::new requires n_eff to match target rows, got \
1119 n_eff {n_eff} and target rows {n_obs}"
1120 ));
1121 }
1122 Ok(Self {
1123 target,
1124 latent_dim,
1125 weight,
1126 n_eff,
1127 learnable_weight,
1128 rho_index: 0,
1129 weight_schedule: None,
1130 })
1131 }
1132
1133 impl_with_weight_schedule!(weight);
1134
1135 fn resolved_weight(&self, rho: ArrayView1<'_, f64>) -> f64 {
1136 if self.learnable_weight {
1137 resolve_learnable_weight(self.weight, rho[self.rho_index])
1138 } else {
1139 self.weight
1140 }
1141 }
1142
1143 pub(crate) fn scale(&self, rho: ArrayView1<'_, f64>) -> f64 {
1144 self.resolved_weight(rho) / self.n_eff as f64
1145 }
1146
1147 pub(crate) fn target_matrix<'a>(
1148 &self,
1149 target: ArrayView1<'a, f64>,
1150 ) -> Option<ArrayView2<'a, f64>> {
1151 let d = self.latent_dim;
1152 if !target.len().is_multiple_of(d) {
1153 assert_eq!(
1154 target.len() % d,
1155 0,
1156 "target length must be divisible by latent_dim"
1157 );
1158 return None;
1159 }
1160 let n_obs = target.len() / d;
1161 target.into_shape_with_order((n_obs, d)).ok()
1162 }
1163
1164 pub(crate) fn gram_minus_identity(t: ArrayView2<'_, f64>) -> Array2<f64> {
1165 let n_obs = t.nrows();
1166 let d = t.ncols();
1167 let mut gram = Array2::<f64>::zeros((d, d));
1168 for a in 0..d {
1169 for b in 0..d {
1170 let mut s = 0.0;
1171 for n in 0..n_obs {
1172 s += t[[n, a]] * t[[n, b]];
1173 }
1174 gram[[a, b]] = s;
1175 }
1176 gram[[a, a]] -= 1.0;
1177 }
1178 gram
1179 }
1180
1181 fn flatten_matrix(m: &Array2<f64>) -> Array1<f64> {
1182 let n_obs = m.nrows();
1183 let d = m.ncols();
1184 let mut out = Array1::<f64>::zeros(n_obs * d);
1185 for n in 0..n_obs {
1186 for a in 0..d {
1187 out[n * d + a] = m[[n, a]];
1188 }
1189 }
1190 out
1191 }
1192
1193 pub(crate) fn hvp_with_precomputed_m(
1194 &self,
1195 t: ArrayView2<'_, f64>,
1196 m: ArrayView2<'_, f64>,
1197 v: ArrayView2<'_, f64>,
1198 scale: f64,
1199 ) -> Array2<f64> {
1200 let n_obs = t.nrows();
1201 let d = t.ncols();
1202 assert_eq!(v.dim(), t.dim(), "hvp matrix dimension mismatch");
1203 assert_eq!(m.dim(), (d, d), "precomputed gram dimension mismatch");
1204 if v.dim() != t.dim() {
1205 return Array2::<f64>::zeros((n_obs, d));
1206 }
1207
1208 let mut vt_t_plus_tt_v = Array2::<f64>::zeros((d, d));
1209 for c in 0..d {
1210 for b in 0..d {
1211 let mut s = 0.0;
1212 for n in 0..n_obs {
1213 s += v[[n, c]] * t[[n, b]] + t[[n, c]] * v[[n, b]];
1214 }
1215 vt_t_plus_tt_v[[c, b]] = s;
1216 }
1217 }
1218
1219 let mut out = Array2::<f64>::zeros((n_obs, d));
1220 for n in 0..n_obs {
1221 for b in 0..d {
1222 let mut va = 0.0;
1223 let mut tb = 0.0;
1224 for c in 0..d {
1225 va += v[[n, c]] * m[[c, b]];
1226 tb += t[[n, c]] * vt_t_plus_tt_v[[c, b]];
1227 }
1228 out[[n, b]] = 2.0 * scale * (va + tb);
1229 }
1230 }
1231 out
1232 }
1233
1234 pub(crate) fn as_dense_with_precomputed_m(
1235 &self,
1236 t: ArrayView2<'_, f64>,
1237 m: ArrayView2<'_, f64>,
1238 scale: f64,
1239 ) -> Array2<f64> {
1240 let n_obs = t.nrows();
1241 let d = t.ncols();
1242 assert_eq!(m.dim(), (d, d), "precomputed gram dimension mismatch");
1243 if m.dim() != (d, d) {
1244 return Array2::<f64>::zeros((n_obs * d, n_obs * d));
1245 }
1246
1247 let mut dense = Array2::<f64>::zeros((n_obs * d, n_obs * d));
1248 let factor = 2.0 * scale;
1249 for row1 in 0..n_obs {
1250 for row2 in 0..n_obs {
1251 let mut row_dot = 0.0;
1252 for axis in 0..d {
1253 row_dot += t[[row1, axis]] * t[[row2, axis]];
1254 }
1255 for col1 in 0..d {
1256 let i = row1 * d + col1;
1257 for col2 in 0..d {
1258 let j = row2 * d + col2;
1259 let mut entry = t[[row1, col2]] * t[[row2, col1]];
1260 if row1 == row2 {
1261 entry += m[[col2, col1]];
1262 }
1263 if col1 == col2 {
1264 entry += row_dot;
1265 }
1266 dense[[i, j]] = factor * entry;
1267 }
1268 }
1269 }
1270 }
1271 dense
1272 }
1273}
1274
1275impl AnalyticPenalty for OrthogonalityPenalty {
1276 fn tier(&self) -> PenaltyTier {
1277 PenaltyTier::Psi
1278 }
1279
1280 fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
1281 let Some(t) = self.target_matrix(target) else {
1282 return 0.0;
1283 };
1284 let gram = Self::gram_minus_identity(t.view());
1285 let mut acc = 0.0;
1286 for &v in gram.iter() {
1287 acc += v * v;
1288 }
1289 0.5 * self.scale(rho) * acc
1290 }
1291
1292 fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
1293 let Some(t) = self.target_matrix(target) else {
1297 return Array1::<f64>::zeros(target.len());
1298 };
1299 let gram = Self::gram_minus_identity(t.view());
1300 let n_obs = t.nrows();
1301 let d = t.ncols();
1302 let factor = 2.0 * self.scale(rho);
1303 let mut grad = Array2::<f64>::zeros((n_obs, d));
1304 for n in 0..n_obs {
1305 for a in 0..d {
1306 let mut s = 0.0;
1307 for b in 0..d {
1308 s += t[[n, b]] * gram[[b, a]];
1309 }
1310 grad[[n, a]] = factor * s;
1311 }
1312 }
1313 Self::flatten_matrix(&grad)
1314 }
1315
1316 fn hvp(
1317 &self,
1318 target: ArrayView1<'_, f64>,
1319 rho: ArrayView1<'_, f64>,
1320 v: ArrayView1<'_, f64>,
1321 ) -> Array1<f64> {
1322 assert_eq!(target.len(), v.len(), "hvp dimension mismatch");
1323 if target.len() != v.len() {
1324 return Array1::<f64>::zeros(target.len());
1325 }
1326 let Some(t) = self.target_matrix(target) else {
1327 return Array1::<f64>::zeros(target.len());
1328 };
1329 let Some(v_mat) = self.target_matrix(v) else {
1330 return Array1::<f64>::zeros(target.len());
1331 };
1332 let m = Self::gram_minus_identity(t.view());
1333 let hv = self.hvp_with_precomputed_m(t.view(), m.view(), v_mat.view(), self.scale(rho));
1334 Self::flatten_matrix(&hv)
1335 }
1336
1337 impl_learnable_weight_grad_rho!();
1338
1339 impl_learnable_weight_rho_count!();
1340
1341 fn name(&self) -> &str {
1342 "orthogonality"
1343 }
1344
1345 impl_scalar_apply_schedule!(weight);
1346}