1use super::*;
6
7#[derive(Debug, Clone)]
8pub(crate) struct BetaEdge {
9 pub(crate) a: usize,
10 pub(crate) b: usize,
11}
12
13#[derive(Debug, Clone)]
14pub(crate) struct BetaCouplingGraph {
15 pub(crate) num_blocks: usize,
16 pub(crate) edges: Vec<BetaEdge>,
17 pub(crate) adj_start: Vec<usize>,
18 pub(crate) adj_targets: Vec<usize>,
19}
20
21impl BetaCouplingGraph {
22 pub(crate) fn build(block_offsets: &[Range<usize>], htbeta_rows: &[Array2<f64>]) -> Self {
23 let num_blocks = block_offsets.len();
24 if num_blocks == 0 {
25 return Self {
26 num_blocks: 0,
27 edges: Vec::new(),
28 adj_start: vec![0],
29 adj_targets: Vec::new(),
30 };
31 }
32
33 let mut edge_set = Vec::<(usize, usize)>::new();
34 for row in htbeta_rows {
35 let mut active = Vec::<usize>::new();
36 for (block, range) in block_offsets.iter().enumerate() {
37 if range
38 .clone()
39 .any(|col| (0..row.nrows()).any(|axis| row[[axis, col]] != 0.0))
40 {
41 active.push(block);
42 }
43 }
44 for i in 0..active.len() {
45 for j in (i + 1)..active.len() {
46 edge_set.push((active[i].min(active[j]), active[i].max(active[j])));
47 }
48 }
49 }
50 edge_set.sort_unstable();
51 edge_set.dedup();
52
53 let edges: Vec<_> = edge_set.iter().map(|&(a, b)| BetaEdge { a, b }).collect();
54 let mut degree = vec![0usize; num_blocks];
55 for &BetaEdge { a, b } in &edges {
56 degree[a] += 1;
57 degree[b] += 1;
58 }
59 let mut adj_start = vec![0usize; num_blocks + 1];
60 for block in 0..num_blocks {
61 adj_start[block + 1] = adj_start[block] + degree[block];
62 }
63 let mut adj_targets = vec![0usize; adj_start[num_blocks]];
64 let mut cursor = adj_start[..num_blocks].to_vec();
65 for &BetaEdge { a, b } in &edges {
66 adj_targets[cursor[a]] = b;
67 cursor[a] += 1;
68 adj_targets[cursor[b]] = a;
69 cursor[b] += 1;
70 }
71 Self {
72 num_blocks,
73 edges,
74 adj_start,
75 adj_targets,
76 }
77 }
78
79 pub(crate) fn neighbours(&self, node: usize) -> &[usize] {
80 &self.adj_targets[self.adj_start[node]..self.adj_start[node + 1]]
81 }
82
83 pub(crate) fn component_partition(&self) -> Vec<Vec<usize>> {
84 let mut parent: Vec<usize> = (0..self.num_blocks).collect();
85 let mut rank = vec![0u8; self.num_blocks];
86
87 fn find(parent: &mut [usize], mut x: usize) -> usize {
88 while parent[x] != x {
89 parent[x] = parent[parent[x]];
90 x = parent[x];
91 }
92 x
93 }
94
95 for &BetaEdge { a, b } in &self.edges {
96 let lhs = find(&mut parent, a);
97 let rhs = find(&mut parent, b);
98 if lhs != rhs {
99 if rank[lhs] < rank[rhs] {
100 parent[lhs] = rhs;
101 } else if rank[lhs] > rank[rhs] {
102 parent[rhs] = lhs;
103 } else {
104 parent[rhs] = lhs;
105 rank[lhs] += 1;
106 }
107 }
108 }
109
110 let mut label_map = vec![usize::MAX; self.num_blocks];
111 let mut parts = Vec::<Vec<usize>>::new();
112 for block in 0..self.num_blocks {
113 let root = find(&mut parent, block);
114 let label = if label_map[root] == usize::MAX {
115 label_map[root] = parts.len();
116 parts.push(Vec::new());
117 label_map[root]
118 } else {
119 label_map[root]
120 };
121 parts[label].push(block);
122 }
123 parts
124 }
125
126 pub(crate) fn expand_one_hop(&self, seed: &[usize]) -> Vec<usize> {
127 let mut expanded = seed.to_vec();
128 for &block in seed {
129 expanded.extend_from_slice(self.neighbours(block));
130 }
131 expanded.sort_unstable();
132 expanded.dedup();
133 expanded
134 }
135}
136#[derive(Debug, Clone, Copy, PartialEq, Eq)]
146pub struct BetaBlockId(pub usize);
147
148pub trait BetaPenaltyOp: Send + Sync {
155 fn dim(&self) -> usize;
157 fn matvec(&self, x: &[f64], y: &mut [f64]);
159 fn gradient(&self, beta: &[f64], out: &mut [f64]);
161 fn diagonal(&self, diag: &mut [f64]);
163 fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>);
167 fn to_dense(&self) -> Array2<f64>;
170 fn row_abs_sums(&self) -> Array1<f64> {
178 let dense = self.to_dense();
179 let k = dense.nrows();
180 let mut out = Array1::<f64>::zeros(k);
181 for r in 0..k {
182 let mut s = 0.0_f64;
183 for c in 0..dense.ncols() {
184 s += dense[[r, c]].abs();
185 }
186 out[r] = s;
187 }
188 out
189 }
190 fn fingerprint(&self, hasher: &mut Fingerprinter);
198}
199
200pub struct DensePenaltyOp(pub Array2<f64>);
202
203impl BetaPenaltyOp for DensePenaltyOp {
204 fn dim(&self) -> usize {
205 self.0.nrows()
206 }
207
208 fn matvec(&self, x: &[f64], y: &mut [f64]) {
209 let k = self.0.nrows();
210 for a in 0..k {
211 let mut acc = 0.0_f64;
212 for b in 0..k {
213 acc += self.0[[a, b]] * x[b];
214 }
215 y[a] += acc;
216 }
217 }
218
219 fn gradient(&self, beta: &[f64], out: &mut [f64]) {
220 let k = self.0.nrows();
221 for a in 0..k {
222 let mut acc = 0.0_f64;
223 for b in 0..k {
224 acc += self.0[[a, b]] * beta[b];
225 }
226 out[a] += acc;
227 }
228 }
229
230 fn diagonal(&self, diag: &mut [f64]) {
231 let k = self.0.nrows().min(diag.len());
232 for j in 0..k {
233 diag[j] += self.0[[j, j]];
234 }
235 }
236
237 fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
238 let range = &offsets[id.0];
239 let b = range.end - range.start;
240 for bi in 0..b {
241 for bj in 0..b {
242 out[[bi, bj]] += self.0[[range.start + bi, range.start + bj]];
243 }
244 }
245 }
246
247 fn to_dense(&self) -> Array2<f64> {
248 self.0.clone()
249 }
250
251 fn fingerprint(&self, hasher: &mut Fingerprinter) {
252 hasher.write_str("dense-penalty-op-v1");
253 hasher.write_f64_array2(&self.0);
254 }
255}
256
257pub struct BlockPenaltyOp {
264 pub k: usize,
266 pub blocks: Vec<(usize, Array2<f64>)>,
268}
269
270impl BetaPenaltyOp for BlockPenaltyOp {
271 fn dim(&self) -> usize {
272 self.k
273 }
274
275 fn matvec(&self, x: &[f64], y: &mut [f64]) {
276 for (off, local) in &self.blocks {
277 let b = local.nrows();
278 for i in 0..b {
279 let gi = off + i;
280 let mut acc = 0.0_f64;
281 for j in 0..b {
282 acc += local[[i, j]] * x[off + j];
283 }
284 y[gi] += acc;
285 }
286 }
287 }
288
289 fn gradient(&self, beta: &[f64], out: &mut [f64]) {
290 for (off, local) in &self.blocks {
291 let b = local.nrows();
292 for i in 0..b {
293 let gi = off + i;
294 let mut acc = 0.0_f64;
295 for j in 0..b {
296 acc += local[[i, j]] * beta[off + j];
297 }
298 out[gi] += acc;
299 }
300 }
301 }
302
303 fn diagonal(&self, diag: &mut [f64]) {
304 for (off, local) in &self.blocks {
305 let b = local.nrows();
306 for j in 0..b {
307 diag[off + j] += local[[j, j]];
308 }
309 }
310 }
311
312 fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
313 let range = &offsets[id.0];
314 let b_out = range.end - range.start;
315 for (off, local) in &self.blocks {
316 let b = local.nrows();
317 let block_end = off + b;
318 if block_end <= range.start || *off >= range.end {
319 continue;
320 }
321 for bi in 0..b_out {
322 let gi = range.start + bi;
323 if gi < *off || gi >= block_end {
324 continue;
325 }
326 let li = gi - off;
327 for bj in 0..b_out {
328 let gj = range.start + bj;
329 if gj < *off || gj >= block_end {
330 continue;
331 }
332 let lj = gj - off;
333 out[[bi, bj]] += local[[li, lj]];
334 }
335 }
336 }
337 }
338
339 fn to_dense(&self) -> Array2<f64> {
340 let mut out = Array2::<f64>::zeros((self.k, self.k));
341 for (off, local) in &self.blocks {
342 let b = local.nrows();
343 for i in 0..b {
344 for j in 0..b {
345 out[[off + i, off + j]] += local[[i, j]];
346 }
347 }
348 }
349 out
350 }
351
352 fn fingerprint(&self, hasher: &mut Fingerprinter) {
353 hasher.write_str("block-penalty-op-v1");
354 hasher.write_usize(self.k);
355 hasher.write_usize(self.blocks.len());
356 for (off, local) in &self.blocks {
357 hasher.write_usize(*off);
358 hasher.write_f64_array2(local);
359 }
360 }
361}
362
363pub struct KroneckerPenaltyOp {
366 pub factor_a: Array2<f64>,
368 pub factor_b: Array2<f64>,
370 pub global_offset: usize,
372 pub k: usize,
374}
375
376impl BetaPenaltyOp for KroneckerPenaltyOp {
377 fn dim(&self) -> usize {
378 self.k
379 }
380
381 fn matvec(&self, x: &[f64], y: &mut [f64]) {
382 let p_a = self.factor_a.nrows();
383 let p_b = self.factor_b.nrows();
384 let off = self.global_offset;
385 for i_a in 0..p_a {
387 for i_b in 0..p_b {
388 let gi = off + i_a * p_b + i_b;
389 let mut acc = 0.0_f64;
390 for j_a in 0..p_a {
391 let a_ij = self.factor_a[[i_a, j_a]];
392 if a_ij == 0.0 {
393 continue;
394 }
395 for j_b in 0..p_b {
396 acc += a_ij * self.factor_b[[i_b, j_b]] * x[off + j_a * p_b + j_b];
397 }
398 }
399 y[gi] += acc;
400 }
401 }
402 }
403
404 fn gradient(&self, beta: &[f64], out: &mut [f64]) {
405 let p_a = self.factor_a.nrows();
406 let p_b = self.factor_b.nrows();
407 let off = self.global_offset;
408 for i_a in 0..p_a {
409 for i_b in 0..p_b {
410 let gi = off + i_a * p_b + i_b;
411 let mut acc = 0.0_f64;
412 for j_a in 0..p_a {
413 let a_ij = self.factor_a[[i_a, j_a]];
414 if a_ij == 0.0 {
415 continue;
416 }
417 for j_b in 0..p_b {
418 acc += a_ij * self.factor_b[[i_b, j_b]] * beta[off + j_a * p_b + j_b];
419 }
420 }
421 out[gi] += acc;
422 }
423 }
424 }
425
426 fn diagonal(&self, diag: &mut [f64]) {
427 let p_a = self.factor_a.nrows();
428 let p_b = self.factor_b.nrows();
429 let off = self.global_offset;
430 for i_a in 0..p_a {
431 for i_b in 0..p_b {
432 diag[off + i_a * p_b + i_b] +=
433 self.factor_a[[i_a, i_a]] * self.factor_b[[i_b, i_b]];
434 }
435 }
436 }
437
438 fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
439 let range = &offsets[id.0];
440 let b = range.end - range.start;
441 let p_a = self.factor_a.nrows();
442 let p_b = self.factor_b.nrows();
443 let off = self.global_offset;
444 let block_end = off + p_a * p_b;
445 if block_end <= range.start || off >= range.end {
446 return;
447 }
448 for bi in 0..b {
449 let gi = range.start + bi;
450 if gi < off || gi >= block_end {
451 continue;
452 }
453 let li = gi - off;
454 let i_a = li / p_b;
455 let i_b = li % p_b;
456 for bj in 0..b {
457 let gj = range.start + bj;
458 if gj < off || gj >= block_end {
459 continue;
460 }
461 let lj = gj - off;
462 let j_a = lj / p_b;
463 let j_b = lj % p_b;
464 out[[bi, bj]] += self.factor_a[[i_a, j_a]] * self.factor_b[[i_b, j_b]];
465 }
466 }
467 }
468
469 fn to_dense(&self) -> Array2<f64> {
470 let p_a = self.factor_a.nrows();
471 let p_b = self.factor_b.nrows();
472 let off = self.global_offset;
473 let mut out = Array2::<f64>::zeros((self.k, self.k));
474 for i_a in 0..p_a {
475 for i_b in 0..p_b {
476 let gi = off + i_a * p_b + i_b;
477 for j_a in 0..p_a {
478 let a_ij = self.factor_a[[i_a, j_a]];
479 if a_ij == 0.0 {
480 continue;
481 }
482 for j_b in 0..p_b {
483 let gj = off + j_a * p_b + j_b;
484 out[[gi, gj]] += a_ij * self.factor_b[[i_b, j_b]];
485 }
486 }
487 }
488 }
489 out
490 }
491
492 fn fingerprint(&self, hasher: &mut Fingerprinter) {
493 hasher.write_str("kronecker-penalty-op-v1");
494 hasher.write_usize(self.global_offset);
495 hasher.write_usize(self.k);
496 hasher.write_f64_array2(&self.factor_a);
497 hasher.write_f64_array2(&self.factor_b);
498 }
499}
500
501pub struct IdentityRightKroneckerPenaltyOp {
509 pub factor_a: Array2<f64>,
511 pub p: usize,
513 pub global_offset: usize,
515 pub k: usize,
517}
518
519impl BetaPenaltyOp for IdentityRightKroneckerPenaltyOp {
520 fn dim(&self) -> usize {
521 self.k
522 }
523
524 fn matvec(&self, x: &[f64], y: &mut [f64]) {
525 let p_a = self.factor_a.nrows();
526 let p = self.p;
527 let off = self.global_offset;
528 for i_a in 0..p_a {
529 for i_b in 0..p {
530 let gi = off + i_a * p + i_b;
531 let mut acc = 0.0_f64;
532 for j_a in 0..p_a {
533 let a_ij = self.factor_a[[i_a, j_a]];
534 if a_ij == 0.0 {
535 continue;
536 }
537 acc += a_ij * x[off + j_a * p + i_b];
538 }
539 y[gi] += acc;
540 }
541 }
542 }
543
544 fn gradient(&self, beta: &[f64], out: &mut [f64]) {
545 self.matvec(beta, out);
546 }
547
548 fn diagonal(&self, diag: &mut [f64]) {
549 let p_a = self.factor_a.nrows();
550 let p = self.p;
551 let off = self.global_offset;
552 for i_a in 0..p_a {
553 let a_ii = self.factor_a[[i_a, i_a]];
554 for i_b in 0..p {
555 diag[off + i_a * p + i_b] += a_ii;
556 }
557 }
558 }
559
560 fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
561 let range = &offsets[id.0];
562 let b = range.end - range.start;
563 let p_a = self.factor_a.nrows();
564 let p = self.p;
565 let off = self.global_offset;
566 let block_end = off + p_a * p;
567 if block_end <= range.start || off >= range.end {
568 return;
569 }
570 for bi in 0..b {
571 let gi = range.start + bi;
572 if gi < off || gi >= block_end {
573 continue;
574 }
575 let li = gi - off;
576 let i_a = li / p;
577 let i_b = li % p;
578 for bj in 0..b {
579 let gj = range.start + bj;
580 if gj < off || gj >= block_end {
581 continue;
582 }
583 let lj = gj - off;
584 let j_a = lj / p;
585 let j_b = lj % p;
586 if i_b == j_b {
587 out[[bi, bj]] += self.factor_a[[i_a, j_a]];
588 }
589 }
590 }
591 }
592
593 fn to_dense(&self) -> Array2<f64> {
594 let p_a = self.factor_a.nrows();
595 let p = self.p;
596 let off = self.global_offset;
597 let mut out = Array2::<f64>::zeros((self.k, self.k));
598 for i_a in 0..p_a {
599 for j_a in 0..p_a {
600 let a_ij = self.factor_a[[i_a, j_a]];
601 if a_ij == 0.0 {
602 continue;
603 }
604 for i_b in 0..p {
605 let gi = off + i_a * p + i_b;
606 let gj = off + j_a * p + i_b;
607 out[[gi, gj]] += a_ij;
608 }
609 }
610 }
611 out
612 }
613
614 fn fingerprint(&self, hasher: &mut Fingerprinter) {
615 hasher.write_str("identity-right-kronecker-penalty-op-v1");
616 hasher.write_usize(self.global_offset);
617 hasher.write_usize(self.k);
618 hasher.write_usize(self.p);
619 hasher.write_f64_array2(&self.factor_a);
620 }
621}
622
623#[derive(Debug, Clone)]
630pub struct SparseGBlock {
631 pub row_off: usize,
633 pub col_off: usize,
635 pub data: Array2<f64>,
637}
638
639pub struct SparseBlockKroneckerPenaltyOp {
657 pub p: usize,
659 pub dim_a: usize,
661 pub k: usize,
663 pub blocks: Vec<SparseGBlock>,
665}
666
667#[derive(Debug, Clone)]
668pub struct DeviceSaeSmoothBlock {
669 pub global_offset: usize,
670 pub factor_a: Array2<f64>,
671}
672
673#[derive(Debug, Clone)]
686pub struct DeviceSaeFrameData {
687 pub ranks: Vec<usize>,
690 pub basis_sizes: Vec<usize>,
692 pub border_offsets: Vec<usize>,
695 pub frame_blocks: Vec<FactoredFrameGBlock>,
697 pub smooth_ranks: Vec<usize>,
703 pub row_htbeta: Vec<Vec<f64>>,
706}
707
708#[derive(Debug, Clone)]
709pub struct DeviceSaePcgData {
710 pub p: usize,
711 pub beta_dim: usize,
712 pub a_phi: Arc<[Vec<(usize, f64)>]>,
720 pub local_jac: Arc<[Vec<f64>]>,
721 pub smooth_blocks: Vec<DeviceSaeSmoothBlock>,
722 pub sparse_g_blocks: Vec<SparseGBlock>,
723 pub frame: Option<DeviceSaeFrameData>,
728}
729
730impl DeviceSaePcgData {
731 pub(crate) fn a_phi_shared(&self) -> Arc<[Vec<(usize, f64)>]> {
737 Arc::clone(&self.a_phi)
740 }
741
742 pub(crate) fn local_jac_shared(&self) -> Arc<[Vec<f64>]> {
748 Arc::clone(&self.local_jac)
749 }
750}
751
752impl BetaPenaltyOp for SparseBlockKroneckerPenaltyOp {
753 fn dim(&self) -> usize {
754 self.k
755 }
756
757 fn matvec(&self, x: &[f64], y: &mut [f64]) {
758 let p = self.p;
759 for blk in &self.blocks {
760 let (m_i, m_j) = blk.data.dim();
761 for li in 0..m_i {
762 let gi_base = (blk.row_off + li) * p;
763 for lj in 0..m_j {
764 let a_ij = blk.data[[li, lj]];
765 if a_ij == 0.0 {
766 continue;
767 }
768 let gj_base = (blk.col_off + lj) * p;
769 for oc in 0..p {
770 y[gi_base + oc] += a_ij * x[gj_base + oc];
771 }
772 }
773 }
774 }
775 }
776
777 fn gradient(&self, beta: &[f64], out: &mut [f64]) {
778 self.matvec(beta, out);
779 }
780
781 fn diagonal(&self, diag: &mut [f64]) {
782 let p = self.p;
783 for blk in &self.blocks {
784 if blk.row_off != blk.col_off {
787 continue;
788 }
789 let (m_i, m_j) = blk.data.dim();
790 let m = m_i.min(m_j);
791 for li in 0..m {
792 let a_ii = blk.data[[li, li]];
793 let gi_base = (blk.row_off + li) * p;
794 for oc in 0..p {
795 diag[gi_base + oc] += a_ii;
796 }
797 }
798 }
799 }
800
801 fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
802 let range = &offsets[id.0];
803 let b = range.end - range.start;
804 let p = self.p;
805 for blk in &self.blocks {
806 let (m_i, m_j) = blk.data.dim();
807 let row_start = blk.row_off * p;
808 let row_end = (blk.row_off + m_i) * p;
809 let col_start = blk.col_off * p;
810 let col_end = (blk.col_off + m_j) * p;
811 if row_end <= range.start
812 || row_start >= range.end
813 || col_end <= range.start
814 || col_start >= range.end
815 {
816 continue;
817 }
818 for bi in 0..b {
819 let gi = range.start + bi;
820 if gi < row_start || gi >= row_end {
821 continue;
822 }
823 let li = (gi - row_start) / p;
824 let oc_i = (gi - row_start) % p;
825 for bj in 0..b {
826 let gj = range.start + bj;
827 if gj < col_start || gj >= col_end {
828 continue;
829 }
830 let oc_j = (gj - col_start) % p;
831 if oc_i != oc_j {
832 continue;
833 }
834 let lj = (gj - col_start) / p;
835 out[[bi, bj]] += blk.data[[li, lj]];
836 }
837 }
838 }
839 }
840
841 fn to_dense(&self) -> Array2<f64> {
842 let p = self.p;
843 let mut out = Array2::<f64>::zeros((self.k, self.k));
844 for blk in &self.blocks {
845 let (m_i, m_j) = blk.data.dim();
846 for li in 0..m_i {
847 let gi_base = (blk.row_off + li) * p;
848 for lj in 0..m_j {
849 let a_ij = blk.data[[li, lj]];
850 if a_ij == 0.0 {
851 continue;
852 }
853 let gj_base = (blk.col_off + lj) * p;
854 for oc in 0..p {
855 out[[gi_base + oc, gj_base + oc]] += a_ij;
856 }
857 }
858 }
859 }
860 out
861 }
862
863 fn row_abs_sums(&self) -> Array1<f64> {
864 let p = self.p;
870 let mut out = Array1::<f64>::zeros(self.k);
871 for blk in &self.blocks {
872 let (m_i, m_j) = blk.data.dim();
873 for li in 0..m_i {
874 let gi_base = (blk.row_off + li) * p;
875 let mut row_abs = 0.0_f64;
876 for lj in 0..m_j {
877 row_abs += blk.data[[li, lj]].abs();
878 }
879 for oc in 0..p {
880 out[gi_base + oc] += row_abs;
881 }
882 }
883 }
884 out
885 }
886
887 fn fingerprint(&self, hasher: &mut Fingerprinter) {
888 hasher.write_str("sparse-block-kronecker-penalty-op-v1");
889 hasher.write_usize(self.p);
890 hasher.write_usize(self.dim_a);
891 hasher.write_usize(self.k);
892 hasher.write_usize(self.blocks.len());
893 for blk in &self.blocks {
894 hasher.write_usize(blk.row_off);
895 hasher.write_usize(blk.col_off);
896 hasher.write_f64_array2(&blk.data);
897 }
898 }
899}
900
901#[derive(Debug, Clone)]
907pub struct FactoredFrameGBlock {
908 pub atom_i: usize,
910 pub atom_j: usize,
912 pub g: Array2<f64>,
914 pub w: Array2<f64>,
919}
920
921pub struct FactoredFrameKroneckerOp {
940 pub ranks: Vec<usize>,
942 pub basis_sizes: Vec<usize>,
944 pub offsets: Vec<usize>,
947 pub dim: usize,
949 pub blocks: Vec<FactoredFrameGBlock>,
951}
952
953pub fn frame_output_gram(u_i: ArrayView2<f64>, u_j: ArrayView2<f64>) -> Array2<f64> {
960 let (p_i, r_i) = u_i.dim();
961 let (p_j, r_j) = u_j.dim();
962 assert_eq!(
963 p_i, p_j,
964 "frame_output_gram: frames live in different ambient dims ({p_i} vs {p_j})"
965 );
966 let mut w = Array2::<f64>::zeros((r_i, r_j));
967 for a in 0..r_i {
968 for b in 0..r_j {
969 let mut acc = 0.0;
970 for c in 0..p_i {
971 acc += u_i[[c, a]] * u_j[[c, b]];
972 }
973 w[[a, b]] = acc;
974 }
975 }
976 w
977}
978
979impl FactoredFrameKroneckerOp {
980 pub fn new(
984 ranks: Vec<usize>,
985 basis_sizes: Vec<usize>,
986 blocks: Vec<FactoredFrameGBlock>,
987 ) -> Result<Self, String> {
988 if ranks.len() != basis_sizes.len() {
989 return Err(format!(
990 "FactoredFrameKroneckerOp: {} ranks but {} basis sizes",
991 ranks.len(),
992 basis_sizes.len()
993 ));
994 }
995 let n_atoms = ranks.len();
996 let mut offsets = Vec::with_capacity(n_atoms + 1);
997 let mut acc = 0usize;
998 for k in 0..n_atoms {
999 offsets.push(acc);
1000 acc += basis_sizes[k] * ranks[k];
1001 }
1002 offsets.push(acc);
1003 let dim = acc;
1004 for blk in &blocks {
1005 if blk.atom_i >= n_atoms || blk.atom_j >= n_atoms {
1006 return Err(format!(
1007 "FactoredFrameKroneckerOp: block atom indices ({}, {}) out of range (n_atoms = {n_atoms})",
1008 blk.atom_i, blk.atom_j
1009 ));
1010 }
1011 if blk.g.dim() != (basis_sizes[blk.atom_i], basis_sizes[blk.atom_j]) {
1012 return Err(format!(
1013 "FactoredFrameKroneckerOp: block ({}, {}) g has shape {:?} but expected ({}, {})",
1014 blk.atom_i,
1015 blk.atom_j,
1016 blk.g.dim(),
1017 basis_sizes[blk.atom_i],
1018 basis_sizes[blk.atom_j]
1019 ));
1020 }
1021 if blk.w.dim() != (ranks[blk.atom_i], ranks[blk.atom_j]) {
1022 return Err(format!(
1023 "FactoredFrameKroneckerOp: block ({}, {}) w has shape {:?} but expected ({}, {})",
1024 blk.atom_i,
1025 blk.atom_j,
1026 blk.w.dim(),
1027 ranks[blk.atom_i],
1028 ranks[blk.atom_j]
1029 ));
1030 }
1031 }
1032 Ok(Self {
1033 ranks,
1034 basis_sizes,
1035 offsets,
1036 dim,
1037 blocks,
1038 })
1039 }
1040
1041 pub fn from_frames_and_blocks(
1056 frames: &[Option<Array2<f64>>],
1057 basis_sizes: &[usize],
1058 p: usize,
1059 g_blocks: &std::collections::BTreeMap<(usize, usize), Array2<f64>>,
1060 ) -> Result<Self, String> {
1061 if frames.len() != basis_sizes.len() {
1062 return Err(format!(
1063 "FactoredFrameKroneckerOp::from_frames_and_blocks: {} frames but {} basis sizes",
1064 frames.len(),
1065 basis_sizes.len()
1066 ));
1067 }
1068 let n_atoms = frames.len();
1069 let mut ranks = Vec::with_capacity(n_atoms);
1071 for (k, frame) in frames.iter().enumerate() {
1072 match frame {
1073 Some(u) => {
1074 let (pr, r) = u.dim();
1075 if pr != p {
1076 return Err(format!(
1077 "FactoredFrameKroneckerOp::from_frames_and_blocks: frame {k} has {pr} rows but ambient dim is {p}"
1078 ));
1079 }
1080 if r > p {
1081 return Err(format!(
1082 "FactoredFrameKroneckerOp::from_frames_and_blocks: frame {k} has rank {r} > ambient dim {p}"
1083 ));
1084 }
1085 ranks.push(r);
1086 }
1087 None => ranks.push(p),
1088 }
1089 }
1090 let identity = Array2::<f64>::eye(p);
1093 let frame_or_ident = |k: usize| -> ArrayView2<f64> {
1094 match &frames[k] {
1095 Some(u) => u.view(),
1096 None => identity.view(),
1097 }
1098 };
1099 let mut blocks = Vec::with_capacity(g_blocks.len());
1100 for (&(atom_i, atom_j), g) in g_blocks {
1101 if atom_i >= n_atoms || atom_j >= n_atoms {
1102 return Err(format!(
1103 "FactoredFrameKroneckerOp::from_frames_and_blocks: block atom indices ({atom_i}, {atom_j}) out of range (n_atoms = {n_atoms})"
1104 ));
1105 }
1106 let w = frame_output_gram(frame_or_ident(atom_i), frame_or_ident(atom_j));
1107 blocks.push(FactoredFrameGBlock {
1108 atom_i,
1109 atom_j,
1110 g: g.clone(),
1111 w,
1112 });
1113 }
1114 Self::new(ranks, basis_sizes.to_vec(), blocks)
1115 }
1116}
1117
1118impl BetaPenaltyOp for FactoredFrameKroneckerOp {
1119 fn dim(&self) -> usize {
1120 self.dim
1121 }
1122
1123 fn matvec(&self, x: &[f64], y: &mut [f64]) {
1124 for blk in &self.blocks {
1125 let r_i = self.ranks[blk.atom_i];
1126 let r_j = self.ranks[blk.atom_j];
1127 let off_i = self.offsets[blk.atom_i];
1128 let off_j = self.offsets[blk.atom_j];
1129 let (m_i, m_j) = blk.g.dim();
1130 for li in 0..m_i {
1131 let yi_base = off_i + li * r_i;
1132 for lj in 0..m_j {
1133 let g = blk.g[[li, lj]];
1134 if g == 0.0 {
1135 continue;
1136 }
1137 let xj_base = off_j + lj * r_j;
1138 for a in 0..r_i {
1140 let mut acc = 0.0;
1141 for b in 0..r_j {
1142 acc += blk.w[[a, b]] * x[xj_base + b];
1143 }
1144 y[yi_base + a] += g * acc;
1145 }
1146 }
1147 }
1148 }
1149 }
1150
1151 fn gradient(&self, beta: &[f64], out: &mut [f64]) {
1152 self.matvec(beta, out);
1153 }
1154
1155 fn diagonal(&self, diag: &mut [f64]) {
1156 for blk in &self.blocks {
1157 if blk.atom_i != blk.atom_j {
1160 continue;
1161 }
1162 let r = self.ranks[blk.atom_i];
1163 let off = self.offsets[blk.atom_i];
1164 let (m_i, m_j) = blk.g.dim();
1165 let m = m_i.min(m_j);
1166 for li in 0..m {
1167 let gii = blk.g[[li, li]];
1168 let base = off + li * r;
1169 for a in 0..r {
1170 diag[base + a] += gii * blk.w[[a, a]];
1171 }
1172 }
1173 }
1174 }
1175
1176 fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
1177 let range = &offsets[id.0];
1180 let b_dim = range.end - range.start;
1181 for blk in &self.blocks {
1182 let r_i = self.ranks[blk.atom_i];
1183 let r_j = self.ranks[blk.atom_j];
1184 let off_i = self.offsets[blk.atom_i];
1185 let off_j = self.offsets[blk.atom_j];
1186 let (m_i, m_j) = blk.g.dim();
1187 for li in 0..m_i {
1188 for a in 0..r_i {
1189 let gi = off_i + li * r_i + a;
1190 if gi < range.start || gi >= range.end {
1191 continue;
1192 }
1193 let bi = gi - range.start;
1194 for lj in 0..m_j {
1195 let g = blk.g[[li, lj]];
1196 if g == 0.0 {
1197 continue;
1198 }
1199 for b in 0..r_j {
1200 let gj = off_j + lj * r_j + b;
1201 if gj < range.start || gj >= range.end {
1202 continue;
1203 }
1204 let bj = gj - range.start;
1205 if bi < b_dim && bj < b_dim {
1206 out[[bi, bj]] += g * blk.w[[a, b]];
1207 }
1208 }
1209 }
1210 }
1211 }
1212 }
1213 }
1214
1215 fn to_dense(&self) -> Array2<f64> {
1216 let mut out = Array2::<f64>::zeros((self.dim, self.dim));
1217 for blk in &self.blocks {
1218 let r_i = self.ranks[blk.atom_i];
1219 let r_j = self.ranks[blk.atom_j];
1220 let off_i = self.offsets[blk.atom_i];
1221 let off_j = self.offsets[blk.atom_j];
1222 let (m_i, m_j) = blk.g.dim();
1223 for li in 0..m_i {
1224 for lj in 0..m_j {
1225 let g = blk.g[[li, lj]];
1226 if g == 0.0 {
1227 continue;
1228 }
1229 for a in 0..r_i {
1230 let gi = off_i + li * r_i + a;
1231 for b in 0..r_j {
1232 let gj = off_j + lj * r_j + b;
1233 out[[gi, gj]] += g * blk.w[[a, b]];
1234 }
1235 }
1236 }
1237 }
1238 }
1239 out
1240 }
1241
1242 fn fingerprint(&self, hasher: &mut Fingerprinter) {
1243 hasher.write_str("factored-frame-kronecker-op-v1");
1244 hasher.write_usize(self.dim);
1245 for &r in &self.ranks {
1246 hasher.write_usize(r);
1247 }
1248 for &m in &self.basis_sizes {
1249 hasher.write_usize(m);
1250 }
1251 hasher.write_usize(self.blocks.len());
1252 for blk in &self.blocks {
1253 hasher.write_usize(blk.atom_i);
1254 hasher.write_usize(blk.atom_j);
1255 hasher.write_f64_array2(&blk.g);
1256 hasher.write_f64_array2(&blk.w);
1257 }
1258 }
1259}
1260
1261pub struct CompositePenaltyOp {
1263 pub k: usize,
1265 pub ops: Vec<Arc<dyn BetaPenaltyOp>>,
1267}
1268
1269impl BetaPenaltyOp for CompositePenaltyOp {
1270 fn dim(&self) -> usize {
1271 self.k
1272 }
1273
1274 fn matvec(&self, x: &[f64], y: &mut [f64]) {
1275 for op in &self.ops {
1276 op.matvec(x, y);
1277 }
1278 }
1279
1280 fn gradient(&self, beta: &[f64], out: &mut [f64]) {
1281 for op in &self.ops {
1282 op.gradient(beta, out);
1283 }
1284 }
1285
1286 fn diagonal(&self, diag: &mut [f64]) {
1287 for op in &self.ops {
1288 op.diagonal(diag);
1289 }
1290 }
1291
1292 fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
1293 for op in &self.ops {
1294 op.block(id, offsets, out);
1295 }
1296 }
1297
1298 fn to_dense(&self) -> Array2<f64> {
1299 let mut out = Array2::<f64>::zeros((self.k, self.k));
1300 for op in &self.ops {
1301 let dense = op.to_dense();
1302 out += &dense;
1303 }
1304 out
1305 }
1306
1307 fn fingerprint(&self, hasher: &mut Fingerprinter) {
1308 hasher.write_str("composite-penalty-op-v1");
1309 hasher.write_usize(self.k);
1310 hasher.write_usize(self.ops.len());
1311 for op in &self.ops {
1312 op.fingerprint(hasher);
1313 }
1314 }
1315}
1316
1317pub struct MatvecDiagPenaltyOp {
1323 pub(crate) k: usize,
1324 pub(crate) matvec: SharedBetaMatvec,
1325 pub(crate) diagonal_vec: Array1<f64>,
1326}
1327
1328impl MatvecDiagPenaltyOp {
1329 pub fn new(k: usize, matvec: SharedBetaMatvec, diagonal_vec: Array1<f64>) -> Self {
1330 assert_eq!(diagonal_vec.len(), k);
1331 Self {
1332 k,
1333 matvec,
1334 diagonal_vec,
1335 }
1336 }
1337}
1338
1339impl BetaPenaltyOp for MatvecDiagPenaltyOp {
1340 fn dim(&self) -> usize {
1341 self.k
1342 }
1343
1344 fn matvec(&self, x: &[f64], y: &mut [f64]) {
1345 let x_arr = Array1::from_iter(x.iter().copied());
1346 let mut out = Array1::<f64>::zeros(self.k);
1347 (self.matvec)(x_arr.view(), &mut out);
1348 for a in 0..self.k {
1349 y[a] += out[a];
1350 }
1351 }
1352
1353 fn gradient(&self, beta: &[f64], out: &mut [f64]) {
1354 let beta_arr = Array1::from_iter(beta.iter().copied());
1355 let mut hb = Array1::<f64>::zeros(self.k);
1356 (self.matvec)(beta_arr.view(), &mut hb);
1357 for a in 0..self.k {
1358 out[a] += hb[a];
1359 }
1360 }
1361
1362 fn diagonal(&self, diag: &mut [f64]) {
1363 for j in 0..self.k.min(diag.len()) {
1364 diag[j] += self.diagonal_vec[j];
1365 }
1366 }
1367
1368 fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
1369 let range = &offsets[id.0];
1371 let b = range.end - range.start;
1372 let mut probe = Array1::<f64>::zeros(self.k);
1373 for bj in 0..b {
1374 probe.fill(0.0);
1375 probe[range.start + bj] = 1.0;
1376 let mut col = Array1::<f64>::zeros(self.k);
1377 (self.matvec)(probe.view(), &mut col);
1378 for bi in 0..b {
1379 out[[bi, bj]] += col[range.start + bi];
1380 }
1381 }
1382 }
1383
1384 fn to_dense(&self) -> Array2<f64> {
1385 let k = self.k;
1386 let mut out = Array2::<f64>::zeros((k, k));
1387 let mut probe = Array1::<f64>::zeros(k);
1388 for j in 0..k {
1389 probe.fill(0.0);
1390 probe[j] = 1.0;
1391 let mut col = Array1::<f64>::zeros(k);
1392 (self.matvec)(probe.view(), &mut col);
1393 for i in 0..k {
1394 out[[i, j]] = col[i];
1395 }
1396 }
1397 out
1398 }
1399
1400 fn fingerprint(&self, hasher: &mut Fingerprinter) {
1401 hasher.write_str("matvec-diag-penalty-op-v1");
1405 hasher.write_usize(self.k);
1406 for &value in self.diagonal_vec.iter() {
1407 hasher.write_f64(value);
1408 }
1409 }
1410}