1use super::*;
6
7#[derive(Debug, Clone)]
8pub(crate) struct BetaEdge {
9 pub(crate) a: usize,
10 pub(crate) b: usize,
11}
12
13#[derive(Debug, Clone)]
14pub(crate) struct BetaCouplingGraph {
15 pub(crate) num_blocks: usize,
16 pub(crate) edges: Vec<BetaEdge>,
17 pub(crate) adj_start: Vec<usize>,
18 pub(crate) adj_targets: Vec<usize>,
19}
20
21impl BetaCouplingGraph {
22 pub(crate) fn build(block_offsets: &[Range<usize>], htbeta_rows: &[Array2<f64>]) -> Self {
23 let num_blocks = block_offsets.len();
24 if num_blocks == 0 {
25 return Self {
26 num_blocks: 0,
27 edges: Vec::new(),
28 adj_start: vec![0],
29 adj_targets: Vec::new(),
30 };
31 }
32
33 let mut edge_set = Vec::<(usize, usize)>::new();
34 for row in htbeta_rows {
35 let mut active = Vec::<usize>::new();
36 for (block, range) in block_offsets.iter().enumerate() {
37 if range
38 .clone()
39 .any(|col| (0..row.nrows()).any(|axis| row[[axis, col]] != 0.0))
40 {
41 active.push(block);
42 }
43 }
44 for i in 0..active.len() {
45 for j in (i + 1)..active.len() {
46 edge_set.push((active[i].min(active[j]), active[i].max(active[j])));
47 }
48 }
49 }
50 edge_set.sort_unstable();
51 edge_set.dedup();
52
53 let edges: Vec<_> = edge_set.iter().map(|&(a, b)| BetaEdge { a, b }).collect();
54 let mut degree = vec![0usize; num_blocks];
55 for &BetaEdge { a, b } in &edges {
56 degree[a] += 1;
57 degree[b] += 1;
58 }
59 let mut adj_start = vec![0usize; num_blocks + 1];
60 for block in 0..num_blocks {
61 adj_start[block + 1] = adj_start[block] + degree[block];
62 }
63 let mut adj_targets = vec![0usize; adj_start[num_blocks]];
64 let mut cursor = adj_start[..num_blocks].to_vec();
65 for &BetaEdge { a, b } in &edges {
66 adj_targets[cursor[a]] = b;
67 cursor[a] += 1;
68 adj_targets[cursor[b]] = a;
69 cursor[b] += 1;
70 }
71 Self {
72 num_blocks,
73 edges,
74 adj_start,
75 adj_targets,
76 }
77 }
78
79 pub(crate) fn neighbours(&self, node: usize) -> &[usize] {
80 &self.adj_targets[self.adj_start[node]..self.adj_start[node + 1]]
81 }
82
83 pub(crate) fn component_partition(&self) -> Vec<Vec<usize>> {
84 let mut parent: Vec<usize> = (0..self.num_blocks).collect();
85 let mut rank = vec![0u8; self.num_blocks];
86
87 fn find(parent: &mut [usize], mut x: usize) -> usize {
88 while parent[x] != x {
89 parent[x] = parent[parent[x]];
90 x = parent[x];
91 }
92 x
93 }
94
95 for &BetaEdge { a, b } in &self.edges {
96 let lhs = find(&mut parent, a);
97 let rhs = find(&mut parent, b);
98 if lhs != rhs {
99 if rank[lhs] < rank[rhs] {
100 parent[lhs] = rhs;
101 } else if rank[lhs] > rank[rhs] {
102 parent[rhs] = lhs;
103 } else {
104 parent[rhs] = lhs;
105 rank[lhs] += 1;
106 }
107 }
108 }
109
110 let mut label_map = vec![usize::MAX; self.num_blocks];
111 let mut parts = Vec::<Vec<usize>>::new();
112 for block in 0..self.num_blocks {
113 let root = find(&mut parent, block);
114 let label = if label_map[root] == usize::MAX {
115 label_map[root] = parts.len();
116 parts.push(Vec::new());
117 label_map[root]
118 } else {
119 label_map[root]
120 };
121 parts[label].push(block);
122 }
123 parts
124 }
125
126 pub(crate) fn expand_one_hop(&self, seed: &[usize]) -> Vec<usize> {
127 let mut expanded = seed.to_vec();
128 for &block in seed {
129 expanded.extend_from_slice(self.neighbours(block));
130 }
131 expanded.sort_unstable();
132 expanded.dedup();
133 expanded
134 }
135}
136#[derive(Debug, Clone, Copy, PartialEq, Eq)]
146pub struct BetaBlockId(pub usize);
147
148pub trait BetaPenaltyOp: Send + Sync {
155 fn dim(&self) -> usize;
157 fn matvec(&self, x: &[f64], y: &mut [f64]);
159 fn gradient(&self, beta: &[f64], out: &mut [f64]);
161 fn diagonal(&self, diag: &mut [f64]);
163 fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>);
167 fn to_dense(&self) -> Array2<f64>;
170 fn row_abs_sums(&self) -> Array1<f64> {
178 let dense = self.to_dense();
179 let k = dense.nrows();
180 let mut out = Array1::<f64>::zeros(k);
181 for r in 0..k {
182 let mut s = 0.0_f64;
183 for c in 0..dense.ncols() {
184 s += dense[[r, c]].abs();
185 }
186 out[r] = s;
187 }
188 out
189 }
190 fn fingerprint(&self, hasher: &mut Fingerprinter);
198
199 fn output_range(&self) -> Option<Range<usize>> {
208 None
209 }
210
211 fn matvec_local(&self, x: &[f64], y_local: &mut [f64]) {
220 panic!(
227 "matvec_local requires output_range() == Some; a None-range \
228 BetaPenaltyOp (input len {}, local output len {}) must be applied \
229 through matvec",
230 x.len(),
231 y_local.len()
232 );
233 }
234}
235
236pub struct DensePenaltyOp(pub Array2<f64>);
238
239impl BetaPenaltyOp for DensePenaltyOp {
240 fn dim(&self) -> usize {
241 self.0.nrows()
242 }
243
244 fn matvec(&self, x: &[f64], y: &mut [f64]) {
245 let k = self.0.nrows();
246 for a in 0..k {
247 let mut acc = 0.0_f64;
248 for b in 0..k {
249 acc += self.0[[a, b]] * x[b];
250 }
251 y[a] += acc;
252 }
253 }
254
255 fn gradient(&self, beta: &[f64], out: &mut [f64]) {
256 let k = self.0.nrows();
257 for a in 0..k {
258 let mut acc = 0.0_f64;
259 for b in 0..k {
260 acc += self.0[[a, b]] * beta[b];
261 }
262 out[a] += acc;
263 }
264 }
265
266 fn diagonal(&self, diag: &mut [f64]) {
267 let k = self.0.nrows().min(diag.len());
268 for j in 0..k {
269 diag[j] += self.0[[j, j]];
270 }
271 }
272
273 fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
274 let range = &offsets[id.0];
275 let b = range.end - range.start;
276 for bi in 0..b {
277 for bj in 0..b {
278 out[[bi, bj]] += self.0[[range.start + bi, range.start + bj]];
279 }
280 }
281 }
282
283 fn to_dense(&self) -> Array2<f64> {
284 self.0.clone()
285 }
286
287 fn fingerprint(&self, hasher: &mut Fingerprinter) {
288 hasher.write_str("dense-penalty-op-v1");
289 hasher.write_f64_array2(&self.0);
290 }
291}
292
293pub struct BlockPenaltyOp {
300 pub k: usize,
302 pub blocks: Vec<(usize, Array2<f64>)>,
304}
305
306impl BetaPenaltyOp for BlockPenaltyOp {
307 fn dim(&self) -> usize {
308 self.k
309 }
310
311 fn matvec(&self, x: &[f64], y: &mut [f64]) {
312 for (off, local) in &self.blocks {
313 let b = local.nrows();
314 for i in 0..b {
315 let gi = off + i;
316 let mut acc = 0.0_f64;
317 for j in 0..b {
318 acc += local[[i, j]] * x[off + j];
319 }
320 y[gi] += acc;
321 }
322 }
323 }
324
325 fn gradient(&self, beta: &[f64], out: &mut [f64]) {
326 for (off, local) in &self.blocks {
327 let b = local.nrows();
328 for i in 0..b {
329 let gi = off + i;
330 let mut acc = 0.0_f64;
331 for j in 0..b {
332 acc += local[[i, j]] * beta[off + j];
333 }
334 out[gi] += acc;
335 }
336 }
337 }
338
339 fn diagonal(&self, diag: &mut [f64]) {
340 for (off, local) in &self.blocks {
341 let b = local.nrows();
342 for j in 0..b {
343 diag[off + j] += local[[j, j]];
344 }
345 }
346 }
347
348 fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
349 let range = &offsets[id.0];
350 let b_out = range.end - range.start;
351 for (off, local) in &self.blocks {
352 let b = local.nrows();
353 let block_end = off + b;
354 if block_end <= range.start || *off >= range.end {
355 continue;
356 }
357 for bi in 0..b_out {
358 let gi = range.start + bi;
359 if gi < *off || gi >= block_end {
360 continue;
361 }
362 let li = gi - off;
363 for bj in 0..b_out {
364 let gj = range.start + bj;
365 if gj < *off || gj >= block_end {
366 continue;
367 }
368 let lj = gj - off;
369 out[[bi, bj]] += local[[li, lj]];
370 }
371 }
372 }
373 }
374
375 fn to_dense(&self) -> Array2<f64> {
376 let mut out = Array2::<f64>::zeros((self.k, self.k));
377 for (off, local) in &self.blocks {
378 let b = local.nrows();
379 for i in 0..b {
380 for j in 0..b {
381 out[[off + i, off + j]] += local[[i, j]];
382 }
383 }
384 }
385 out
386 }
387
388 fn fingerprint(&self, hasher: &mut Fingerprinter) {
389 hasher.write_str("block-penalty-op-v1");
390 hasher.write_usize(self.k);
391 hasher.write_usize(self.blocks.len());
392 for (off, local) in &self.blocks {
393 hasher.write_usize(*off);
394 hasher.write_f64_array2(local);
395 }
396 }
397}
398
399pub struct KroneckerPenaltyOp {
402 pub factor_a: Array2<f64>,
404 pub factor_b: Array2<f64>,
406 pub global_offset: usize,
408 pub k: usize,
410}
411
412impl BetaPenaltyOp for KroneckerPenaltyOp {
413 fn dim(&self) -> usize {
414 self.k
415 }
416
417 fn matvec(&self, x: &[f64], y: &mut [f64]) {
418 let p_a = self.factor_a.nrows();
419 let p_b = self.factor_b.nrows();
420 let off = self.global_offset;
421 for i_a in 0..p_a {
423 for i_b in 0..p_b {
424 let gi = off + i_a * p_b + i_b;
425 let mut acc = 0.0_f64;
426 for j_a in 0..p_a {
427 let a_ij = self.factor_a[[i_a, j_a]];
428 if a_ij == 0.0 {
429 continue;
430 }
431 for j_b in 0..p_b {
432 acc += a_ij * self.factor_b[[i_b, j_b]] * x[off + j_a * p_b + j_b];
433 }
434 }
435 y[gi] += acc;
436 }
437 }
438 }
439
440 fn output_range(&self) -> Option<Range<usize>> {
441 let off = self.global_offset;
442 Some(off..off + self.factor_a.nrows() * self.factor_b.nrows())
443 }
444
445 fn matvec_local(&self, x: &[f64], y_local: &mut [f64]) {
446 let p_a = self.factor_a.nrows();
451 let p_b = self.factor_b.nrows();
452 let off = self.global_offset;
453 for i_a in 0..p_a {
454 for i_b in 0..p_b {
455 let li = i_a * p_b + i_b;
456 let mut acc = 0.0_f64;
457 for j_a in 0..p_a {
458 let a_ij = self.factor_a[[i_a, j_a]];
459 if a_ij == 0.0 {
460 continue;
461 }
462 for j_b in 0..p_b {
463 acc += a_ij * self.factor_b[[i_b, j_b]] * x[off + j_a * p_b + j_b];
464 }
465 }
466 y_local[li] += acc;
467 }
468 }
469 }
470
471 fn gradient(&self, beta: &[f64], out: &mut [f64]) {
472 let p_a = self.factor_a.nrows();
473 let p_b = self.factor_b.nrows();
474 let off = self.global_offset;
475 for i_a in 0..p_a {
476 for i_b in 0..p_b {
477 let gi = off + i_a * p_b + i_b;
478 let mut acc = 0.0_f64;
479 for j_a in 0..p_a {
480 let a_ij = self.factor_a[[i_a, j_a]];
481 if a_ij == 0.0 {
482 continue;
483 }
484 for j_b in 0..p_b {
485 acc += a_ij * self.factor_b[[i_b, j_b]] * beta[off + j_a * p_b + j_b];
486 }
487 }
488 out[gi] += acc;
489 }
490 }
491 }
492
493 fn diagonal(&self, diag: &mut [f64]) {
494 let p_a = self.factor_a.nrows();
495 let p_b = self.factor_b.nrows();
496 let off = self.global_offset;
497 for i_a in 0..p_a {
498 for i_b in 0..p_b {
499 diag[off + i_a * p_b + i_b] +=
500 self.factor_a[[i_a, i_a]] * self.factor_b[[i_b, i_b]];
501 }
502 }
503 }
504
505 fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
506 let range = &offsets[id.0];
507 let b = range.end - range.start;
508 let p_a = self.factor_a.nrows();
509 let p_b = self.factor_b.nrows();
510 let off = self.global_offset;
511 let block_end = off + p_a * p_b;
512 if block_end <= range.start || off >= range.end {
513 return;
514 }
515 for bi in 0..b {
516 let gi = range.start + bi;
517 if gi < off || gi >= block_end {
518 continue;
519 }
520 let li = gi - off;
521 let i_a = li / p_b;
522 let i_b = li % p_b;
523 for bj in 0..b {
524 let gj = range.start + bj;
525 if gj < off || gj >= block_end {
526 continue;
527 }
528 let lj = gj - off;
529 let j_a = lj / p_b;
530 let j_b = lj % p_b;
531 out[[bi, bj]] += self.factor_a[[i_a, j_a]] * self.factor_b[[i_b, j_b]];
532 }
533 }
534 }
535
536 fn to_dense(&self) -> Array2<f64> {
537 let p_a = self.factor_a.nrows();
538 let p_b = self.factor_b.nrows();
539 let off = self.global_offset;
540 let mut out = Array2::<f64>::zeros((self.k, self.k));
541 for i_a in 0..p_a {
542 for i_b in 0..p_b {
543 let gi = off + i_a * p_b + i_b;
544 for j_a in 0..p_a {
545 let a_ij = self.factor_a[[i_a, j_a]];
546 if a_ij == 0.0 {
547 continue;
548 }
549 for j_b in 0..p_b {
550 let gj = off + j_a * p_b + j_b;
551 out[[gi, gj]] += a_ij * self.factor_b[[i_b, j_b]];
552 }
553 }
554 }
555 }
556 out
557 }
558
559 fn fingerprint(&self, hasher: &mut Fingerprinter) {
560 hasher.write_str("kronecker-penalty-op-v1");
561 hasher.write_usize(self.global_offset);
562 hasher.write_usize(self.k);
563 hasher.write_f64_array2(&self.factor_a);
564 hasher.write_f64_array2(&self.factor_b);
565 }
566}
567
568pub struct IdentityRightKroneckerPenaltyOp {
576 pub factor_a: Array2<f64>,
578 pub p: usize,
580 pub global_offset: usize,
582 pub k: usize,
584}
585
586impl BetaPenaltyOp for IdentityRightKroneckerPenaltyOp {
587 fn dim(&self) -> usize {
588 self.k
589 }
590
591 fn matvec(&self, x: &[f64], y: &mut [f64]) {
592 let p_a = self.factor_a.nrows();
593 let p = self.p;
594 let off = self.global_offset;
595 for i_a in 0..p_a {
596 for i_b in 0..p {
597 let gi = off + i_a * p + i_b;
598 let mut acc = 0.0_f64;
599 for j_a in 0..p_a {
600 let a_ij = self.factor_a[[i_a, j_a]];
601 if a_ij == 0.0 {
602 continue;
603 }
604 acc += a_ij * x[off + j_a * p + i_b];
605 }
606 y[gi] += acc;
607 }
608 }
609 }
610
611 fn output_range(&self) -> Option<Range<usize>> {
612 let off = self.global_offset;
613 Some(off..off + self.factor_a.nrows() * self.p)
614 }
615
616 fn matvec_local(&self, x: &[f64], y_local: &mut [f64]) {
617 let p_a = self.factor_a.nrows();
623 let p = self.p;
624 let off = self.global_offset;
625 for i_a in 0..p_a {
626 for i_b in 0..p {
627 let li = i_a * p + i_b;
628 let mut acc = 0.0_f64;
629 for j_a in 0..p_a {
630 let a_ij = self.factor_a[[i_a, j_a]];
631 if a_ij == 0.0 {
632 continue;
633 }
634 acc += a_ij * x[off + j_a * p + i_b];
635 }
636 y_local[li] += acc;
637 }
638 }
639 }
640
641 fn gradient(&self, beta: &[f64], out: &mut [f64]) {
642 self.matvec(beta, out);
643 }
644
645 fn diagonal(&self, diag: &mut [f64]) {
646 let p_a = self.factor_a.nrows();
647 let p = self.p;
648 let off = self.global_offset;
649 for i_a in 0..p_a {
650 let a_ii = self.factor_a[[i_a, i_a]];
651 for i_b in 0..p {
652 diag[off + i_a * p + i_b] += a_ii;
653 }
654 }
655 }
656
657 fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
658 let range = &offsets[id.0];
659 let b = range.end - range.start;
660 let p_a = self.factor_a.nrows();
661 let p = self.p;
662 let off = self.global_offset;
663 let block_end = off + p_a * p;
664 if block_end <= range.start || off >= range.end {
665 return;
666 }
667 for bi in 0..b {
668 let gi = range.start + bi;
669 if gi < off || gi >= block_end {
670 continue;
671 }
672 let li = gi - off;
673 let i_a = li / p;
674 let i_b = li % p;
675 for bj in 0..b {
676 let gj = range.start + bj;
677 if gj < off || gj >= block_end {
678 continue;
679 }
680 let lj = gj - off;
681 let j_a = lj / p;
682 let j_b = lj % p;
683 if i_b == j_b {
684 out[[bi, bj]] += self.factor_a[[i_a, j_a]];
685 }
686 }
687 }
688 }
689
690 fn to_dense(&self) -> Array2<f64> {
691 let p_a = self.factor_a.nrows();
692 let p = self.p;
693 let off = self.global_offset;
694 let mut out = Array2::<f64>::zeros((self.k, self.k));
695 for i_a in 0..p_a {
696 for j_a in 0..p_a {
697 let a_ij = self.factor_a[[i_a, j_a]];
698 if a_ij == 0.0 {
699 continue;
700 }
701 for i_b in 0..p {
702 let gi = off + i_a * p + i_b;
703 let gj = off + j_a * p + i_b;
704 out[[gi, gj]] += a_ij;
705 }
706 }
707 }
708 out
709 }
710
711 fn fingerprint(&self, hasher: &mut Fingerprinter) {
712 hasher.write_str("identity-right-kronecker-penalty-op-v1");
713 hasher.write_usize(self.global_offset);
714 hasher.write_usize(self.k);
715 hasher.write_usize(self.p);
716 hasher.write_f64_array2(&self.factor_a);
717 }
718}
719
720#[derive(Debug, Clone)]
727pub struct SparseGBlock {
728 pub row_off: usize,
730 pub col_off: usize,
732 pub data: Array2<f64>,
734}
735
736pub struct SparseBlockKroneckerPenaltyOp {
754 pub p: usize,
756 pub dim_a: usize,
758 pub k: usize,
760 pub blocks: Vec<SparseGBlock>,
762}
763
764#[derive(Debug, Clone)]
765pub struct DeviceSaeSmoothBlock {
766 pub global_offset: usize,
767 pub factor_a: Array2<f64>,
768}
769
770#[derive(Debug, Clone)]
783pub struct DeviceSaeFrameData {
784 pub ranks: Vec<usize>,
787 pub basis_sizes: Vec<usize>,
789 pub border_offsets: Vec<usize>,
792 pub frame_blocks: Vec<FactoredFrameGBlock>,
794 pub smooth_ranks: Vec<usize>,
800 pub row_htbeta: Vec<Vec<f64>>,
803}
804
805#[derive(Debug, Clone)]
806pub struct DeviceSaePcgData {
807 pub p: usize,
808 pub beta_dim: usize,
809 pub a_phi: Arc<[Vec<(usize, f64)>]>,
817 pub local_jac: Arc<[Vec<f64>]>,
818 pub smooth_blocks: Vec<DeviceSaeSmoothBlock>,
819 pub sparse_g_blocks: Vec<SparseGBlock>,
820 pub frame: Option<DeviceSaeFrameData>,
825}
826
827impl DeviceSaePcgData {
828 pub(crate) fn a_phi_shared(&self) -> Arc<[Vec<(usize, f64)>]> {
834 Arc::clone(&self.a_phi)
837 }
838
839 pub(crate) fn local_jac_shared(&self) -> Arc<[Vec<f64>]> {
845 Arc::clone(&self.local_jac)
846 }
847}
848
849impl BetaPenaltyOp for SparseBlockKroneckerPenaltyOp {
850 fn dim(&self) -> usize {
851 self.k
852 }
853
854 fn matvec(&self, x: &[f64], y: &mut [f64]) {
855 let p = self.p;
856 for blk in &self.blocks {
857 let (m_i, m_j) = blk.data.dim();
858 for li in 0..m_i {
859 let gi_base = (blk.row_off + li) * p;
860 for lj in 0..m_j {
861 let a_ij = blk.data[[li, lj]];
862 if a_ij == 0.0 {
863 continue;
864 }
865 let gj_base = (blk.col_off + lj) * p;
866 for oc in 0..p {
867 y[gi_base + oc] += a_ij * x[gj_base + oc];
868 }
869 }
870 }
871 }
872 }
873
874 fn gradient(&self, beta: &[f64], out: &mut [f64]) {
875 self.matvec(beta, out);
876 }
877
878 fn diagonal(&self, diag: &mut [f64]) {
879 let p = self.p;
880 for blk in &self.blocks {
881 if blk.row_off != blk.col_off {
884 continue;
885 }
886 let (m_i, m_j) = blk.data.dim();
887 let m = m_i.min(m_j);
888 for li in 0..m {
889 let a_ii = blk.data[[li, li]];
890 let gi_base = (blk.row_off + li) * p;
891 for oc in 0..p {
892 diag[gi_base + oc] += a_ii;
893 }
894 }
895 }
896 }
897
898 fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
899 let range = &offsets[id.0];
900 let b = range.end - range.start;
901 let p = self.p;
902 for blk in &self.blocks {
903 let (m_i, m_j) = blk.data.dim();
904 let row_start = blk.row_off * p;
905 let row_end = (blk.row_off + m_i) * p;
906 let col_start = blk.col_off * p;
907 let col_end = (blk.col_off + m_j) * p;
908 if row_end <= range.start
909 || row_start >= range.end
910 || col_end <= range.start
911 || col_start >= range.end
912 {
913 continue;
914 }
915 for bi in 0..b {
916 let gi = range.start + bi;
917 if gi < row_start || gi >= row_end {
918 continue;
919 }
920 let li = (gi - row_start) / p;
921 let oc_i = (gi - row_start) % p;
922 for bj in 0..b {
923 let gj = range.start + bj;
924 if gj < col_start || gj >= col_end {
925 continue;
926 }
927 let oc_j = (gj - col_start) % p;
928 if oc_i != oc_j {
929 continue;
930 }
931 let lj = (gj - col_start) / p;
932 out[[bi, bj]] += blk.data[[li, lj]];
933 }
934 }
935 }
936 }
937
938 fn to_dense(&self) -> Array2<f64> {
939 let p = self.p;
940 let mut out = Array2::<f64>::zeros((self.k, self.k));
941 for blk in &self.blocks {
942 let (m_i, m_j) = blk.data.dim();
943 for li in 0..m_i {
944 let gi_base = (blk.row_off + li) * p;
945 for lj in 0..m_j {
946 let a_ij = blk.data[[li, lj]];
947 if a_ij == 0.0 {
948 continue;
949 }
950 let gj_base = (blk.col_off + lj) * p;
951 for oc in 0..p {
952 out[[gi_base + oc, gj_base + oc]] += a_ij;
953 }
954 }
955 }
956 }
957 out
958 }
959
960 fn row_abs_sums(&self) -> Array1<f64> {
961 let p = self.p;
967 let mut out = Array1::<f64>::zeros(self.k);
968 for blk in &self.blocks {
969 let (m_i, m_j) = blk.data.dim();
970 for li in 0..m_i {
971 let gi_base = (blk.row_off + li) * p;
972 let mut row_abs = 0.0_f64;
973 for lj in 0..m_j {
974 row_abs += blk.data[[li, lj]].abs();
975 }
976 for oc in 0..p {
977 out[gi_base + oc] += row_abs;
978 }
979 }
980 }
981 out
982 }
983
984 fn fingerprint(&self, hasher: &mut Fingerprinter) {
985 hasher.write_str("sparse-block-kronecker-penalty-op-v1");
986 hasher.write_usize(self.p);
987 hasher.write_usize(self.dim_a);
988 hasher.write_usize(self.k);
989 hasher.write_usize(self.blocks.len());
990 for blk in &self.blocks {
991 hasher.write_usize(blk.row_off);
992 hasher.write_usize(blk.col_off);
993 hasher.write_f64_array2(&blk.data);
994 }
995 }
996}
997
998#[derive(Debug, Clone)]
1004pub struct FactoredFrameGBlock {
1005 pub atom_i: usize,
1007 pub atom_j: usize,
1009 pub g: Array2<f64>,
1011 pub w: Array2<f64>,
1016}
1017
1018pub struct FactoredFrameKroneckerOp {
1037 pub ranks: Vec<usize>,
1039 pub basis_sizes: Vec<usize>,
1041 pub offsets: Vec<usize>,
1044 pub dim: usize,
1046 pub blocks: Vec<FactoredFrameGBlock>,
1048}
1049
1050pub fn frame_output_gram(u_i: ArrayView2<f64>, u_j: ArrayView2<f64>) -> Array2<f64> {
1057 let (p_i, r_i) = u_i.dim();
1058 let (p_j, r_j) = u_j.dim();
1059 assert_eq!(
1060 p_i, p_j,
1061 "frame_output_gram: frames live in different ambient dims ({p_i} vs {p_j})"
1062 );
1063 let mut w = Array2::<f64>::zeros((r_i, r_j));
1064 for a in 0..r_i {
1065 for b in 0..r_j {
1066 let mut acc = 0.0;
1067 for c in 0..p_i {
1068 acc += u_i[[c, a]] * u_j[[c, b]];
1069 }
1070 w[[a, b]] = acc;
1071 }
1072 }
1073 w
1074}
1075
1076impl FactoredFrameKroneckerOp {
1077 pub fn new(
1081 ranks: Vec<usize>,
1082 basis_sizes: Vec<usize>,
1083 blocks: Vec<FactoredFrameGBlock>,
1084 ) -> Result<Self, String> {
1085 if ranks.len() != basis_sizes.len() {
1086 return Err(format!(
1087 "FactoredFrameKroneckerOp: {} ranks but {} basis sizes",
1088 ranks.len(),
1089 basis_sizes.len()
1090 ));
1091 }
1092 let n_atoms = ranks.len();
1093 let mut offsets = Vec::with_capacity(n_atoms + 1);
1094 let mut acc = 0usize;
1095 for k in 0..n_atoms {
1096 offsets.push(acc);
1097 acc += basis_sizes[k] * ranks[k];
1098 }
1099 offsets.push(acc);
1100 let dim = acc;
1101 for blk in &blocks {
1102 if blk.atom_i >= n_atoms || blk.atom_j >= n_atoms {
1103 return Err(format!(
1104 "FactoredFrameKroneckerOp: block atom indices ({}, {}) out of range (n_atoms = {n_atoms})",
1105 blk.atom_i, blk.atom_j
1106 ));
1107 }
1108 if blk.g.dim() != (basis_sizes[blk.atom_i], basis_sizes[blk.atom_j]) {
1109 return Err(format!(
1110 "FactoredFrameKroneckerOp: block ({}, {}) g has shape {:?} but expected ({}, {})",
1111 blk.atom_i,
1112 blk.atom_j,
1113 blk.g.dim(),
1114 basis_sizes[blk.atom_i],
1115 basis_sizes[blk.atom_j]
1116 ));
1117 }
1118 if blk.w.dim() != (ranks[blk.atom_i], ranks[blk.atom_j]) {
1119 return Err(format!(
1120 "FactoredFrameKroneckerOp: block ({}, {}) w has shape {:?} but expected ({}, {})",
1121 blk.atom_i,
1122 blk.atom_j,
1123 blk.w.dim(),
1124 ranks[blk.atom_i],
1125 ranks[blk.atom_j]
1126 ));
1127 }
1128 }
1129 Ok(Self {
1130 ranks,
1131 basis_sizes,
1132 offsets,
1133 dim,
1134 blocks,
1135 })
1136 }
1137
1138 pub fn from_frames_and_blocks(
1153 frames: &[Option<Array2<f64>>],
1154 basis_sizes: &[usize],
1155 p: usize,
1156 g_blocks: &std::collections::BTreeMap<(usize, usize), Array2<f64>>,
1157 ) -> Result<Self, String> {
1158 if frames.len() != basis_sizes.len() {
1159 return Err(format!(
1160 "FactoredFrameKroneckerOp::from_frames_and_blocks: {} frames but {} basis sizes",
1161 frames.len(),
1162 basis_sizes.len()
1163 ));
1164 }
1165 let n_atoms = frames.len();
1166 let mut ranks = Vec::with_capacity(n_atoms);
1168 for (k, frame) in frames.iter().enumerate() {
1169 match frame {
1170 Some(u) => {
1171 let (pr, r) = u.dim();
1172 if pr != p {
1173 return Err(format!(
1174 "FactoredFrameKroneckerOp::from_frames_and_blocks: frame {k} has {pr} rows but ambient dim is {p}"
1175 ));
1176 }
1177 if r > p {
1178 return Err(format!(
1179 "FactoredFrameKroneckerOp::from_frames_and_blocks: frame {k} has rank {r} > ambient dim {p}"
1180 ));
1181 }
1182 ranks.push(r);
1183 }
1184 None => ranks.push(p),
1185 }
1186 }
1187 let identity = Array2::<f64>::eye(p);
1190 let frame_or_ident = |k: usize| -> ArrayView2<f64> {
1191 match &frames[k] {
1192 Some(u) => u.view(),
1193 None => identity.view(),
1194 }
1195 };
1196 let mut blocks = Vec::with_capacity(g_blocks.len());
1197 for (&(atom_i, atom_j), g) in g_blocks {
1198 if atom_i >= n_atoms || atom_j >= n_atoms {
1199 return Err(format!(
1200 "FactoredFrameKroneckerOp::from_frames_and_blocks: block atom indices ({atom_i}, {atom_j}) out of range (n_atoms = {n_atoms})"
1201 ));
1202 }
1203 let w = frame_output_gram(frame_or_ident(atom_i), frame_or_ident(atom_j));
1204 blocks.push(FactoredFrameGBlock {
1205 atom_i,
1206 atom_j,
1207 g: g.clone(),
1208 w,
1209 });
1210 }
1211 Self::new(ranks, basis_sizes.to_vec(), blocks)
1212 }
1213}
1214
1215impl BetaPenaltyOp for FactoredFrameKroneckerOp {
1216 fn dim(&self) -> usize {
1217 self.dim
1218 }
1219
1220 fn matvec(&self, x: &[f64], y: &mut [f64]) {
1221 for blk in &self.blocks {
1222 let r_i = self.ranks[blk.atom_i];
1223 let r_j = self.ranks[blk.atom_j];
1224 let off_i = self.offsets[blk.atom_i];
1225 let off_j = self.offsets[blk.atom_j];
1226 let (m_i, m_j) = blk.g.dim();
1227 for li in 0..m_i {
1228 let yi_base = off_i + li * r_i;
1229 for lj in 0..m_j {
1230 let g = blk.g[[li, lj]];
1231 if g == 0.0 {
1232 continue;
1233 }
1234 let xj_base = off_j + lj * r_j;
1235 for a in 0..r_i {
1237 let mut acc = 0.0;
1238 for b in 0..r_j {
1239 acc += blk.w[[a, b]] * x[xj_base + b];
1240 }
1241 y[yi_base + a] += g * acc;
1242 }
1243 }
1244 }
1245 }
1246 }
1247
1248 fn gradient(&self, beta: &[f64], out: &mut [f64]) {
1249 self.matvec(beta, out);
1250 }
1251
1252 fn diagonal(&self, diag: &mut [f64]) {
1253 for blk in &self.blocks {
1254 if blk.atom_i != blk.atom_j {
1257 continue;
1258 }
1259 let r = self.ranks[blk.atom_i];
1260 let off = self.offsets[blk.atom_i];
1261 let (m_i, m_j) = blk.g.dim();
1262 let m = m_i.min(m_j);
1263 for li in 0..m {
1264 let gii = blk.g[[li, li]];
1265 let base = off + li * r;
1266 for a in 0..r {
1267 diag[base + a] += gii * blk.w[[a, a]];
1268 }
1269 }
1270 }
1271 }
1272
1273 fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
1274 let range = &offsets[id.0];
1277 let b_dim = range.end - range.start;
1278 for blk in &self.blocks {
1279 let r_i = self.ranks[blk.atom_i];
1280 let r_j = self.ranks[blk.atom_j];
1281 let off_i = self.offsets[blk.atom_i];
1282 let off_j = self.offsets[blk.atom_j];
1283 let (m_i, m_j) = blk.g.dim();
1284 for li in 0..m_i {
1285 for a in 0..r_i {
1286 let gi = off_i + li * r_i + a;
1287 if gi < range.start || gi >= range.end {
1288 continue;
1289 }
1290 let bi = gi - range.start;
1291 for lj in 0..m_j {
1292 let g = blk.g[[li, lj]];
1293 if g == 0.0 {
1294 continue;
1295 }
1296 for b in 0..r_j {
1297 let gj = off_j + lj * r_j + b;
1298 if gj < range.start || gj >= range.end {
1299 continue;
1300 }
1301 let bj = gj - range.start;
1302 if bi < b_dim && bj < b_dim {
1303 out[[bi, bj]] += g * blk.w[[a, b]];
1304 }
1305 }
1306 }
1307 }
1308 }
1309 }
1310 }
1311
1312 fn to_dense(&self) -> Array2<f64> {
1313 let mut out = Array2::<f64>::zeros((self.dim, self.dim));
1314 for blk in &self.blocks {
1315 let r_i = self.ranks[blk.atom_i];
1316 let r_j = self.ranks[blk.atom_j];
1317 let off_i = self.offsets[blk.atom_i];
1318 let off_j = self.offsets[blk.atom_j];
1319 let (m_i, m_j) = blk.g.dim();
1320 for li in 0..m_i {
1321 for lj in 0..m_j {
1322 let g = blk.g[[li, lj]];
1323 if g == 0.0 {
1324 continue;
1325 }
1326 for a in 0..r_i {
1327 let gi = off_i + li * r_i + a;
1328 for b in 0..r_j {
1329 let gj = off_j + lj * r_j + b;
1330 out[[gi, gj]] += g * blk.w[[a, b]];
1331 }
1332 }
1333 }
1334 }
1335 }
1336 out
1337 }
1338
1339 fn fingerprint(&self, hasher: &mut Fingerprinter) {
1340 hasher.write_str("factored-frame-kronecker-op-v1");
1341 hasher.write_usize(self.dim);
1342 for &r in &self.ranks {
1343 hasher.write_usize(r);
1344 }
1345 for &m in &self.basis_sizes {
1346 hasher.write_usize(m);
1347 }
1348 hasher.write_usize(self.blocks.len());
1349 for blk in &self.blocks {
1350 hasher.write_usize(blk.atom_i);
1351 hasher.write_usize(blk.atom_j);
1352 hasher.write_f64_array2(&blk.g);
1353 hasher.write_f64_array2(&blk.w);
1354 }
1355 }
1356}
1357
1358pub struct CompositePenaltyOp {
1360 pub k: usize,
1362 pub ops: Vec<Arc<dyn BetaPenaltyOp>>,
1364}
1365
1366impl BetaPenaltyOp for CompositePenaltyOp {
1367 fn dim(&self) -> usize {
1368 self.k
1369 }
1370
1371 fn matvec(&self, x: &[f64], y: &mut [f64]) {
1372 let mut prefix_len = 0usize;
1392 let mut prev_end = 0usize;
1393 if rayon::current_thread_index().is_none() {
1394 for op in &self.ops {
1395 match op.output_range() {
1396 Some(r) if r.start >= prev_end && r.end > r.start && r.end <= y.len() => {
1397 prev_end = r.end;
1398 prefix_len += 1;
1399 }
1400 _ => break,
1401 }
1402 }
1403 }
1404 if prefix_len >= 2 && prev_end >= SCHUR_PROLOGUE_PARALLEL_K_MIN {
1408 use rayon::prelude::*;
1409 let mut subslices: Vec<&mut [f64]> = Vec::with_capacity(prefix_len);
1413 {
1414 let mut consumed = 0usize;
1415 let mut rest: &mut [f64] = y;
1416 for op in &self.ops[..prefix_len] {
1417 let r = op.output_range().expect("prefix op has an output range");
1418 let (_, after_gap) = rest.split_at_mut(r.start - consumed);
1419 let (block, tail) = after_gap.split_at_mut(r.end - r.start);
1420 subslices.push(block);
1421 rest = tail;
1422 consumed = r.end;
1423 }
1424 }
1425 self.ops[..prefix_len]
1426 .par_iter()
1427 .zip(subslices.par_iter_mut())
1428 .for_each(|(op, y_local)| op.matvec_local(x, y_local));
1429 for op in &self.ops[prefix_len..] {
1430 op.matvec(x, y);
1431 }
1432 } else {
1433 for op in &self.ops {
1434 op.matvec(x, y);
1435 }
1436 }
1437 }
1438
1439 fn gradient(&self, beta: &[f64], out: &mut [f64]) {
1440 for op in &self.ops {
1441 op.gradient(beta, out);
1442 }
1443 }
1444
1445 fn diagonal(&self, diag: &mut [f64]) {
1446 for op in &self.ops {
1447 op.diagonal(diag);
1448 }
1449 }
1450
1451 fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
1452 for op in &self.ops {
1453 op.block(id, offsets, out);
1454 }
1455 }
1456
1457 fn to_dense(&self) -> Array2<f64> {
1458 let mut out = Array2::<f64>::zeros((self.k, self.k));
1459 for op in &self.ops {
1460 let dense = op.to_dense();
1461 out += &dense;
1462 }
1463 out
1464 }
1465
1466 fn fingerprint(&self, hasher: &mut Fingerprinter) {
1467 hasher.write_str("composite-penalty-op-v1");
1468 hasher.write_usize(self.k);
1469 hasher.write_usize(self.ops.len());
1470 for op in &self.ops {
1471 op.fingerprint(hasher);
1472 }
1473 }
1474}
1475
1476pub struct MatvecDiagPenaltyOp {
1482 pub(crate) k: usize,
1483 pub(crate) matvec: SharedBetaMatvec,
1484 pub(crate) diagonal_vec: Array1<f64>,
1485}
1486
1487impl MatvecDiagPenaltyOp {
1488 pub fn new(k: usize, matvec: SharedBetaMatvec, diagonal_vec: Array1<f64>) -> Self {
1489 assert_eq!(diagonal_vec.len(), k);
1490 Self {
1491 k,
1492 matvec,
1493 diagonal_vec,
1494 }
1495 }
1496}
1497
1498impl BetaPenaltyOp for MatvecDiagPenaltyOp {
1499 fn dim(&self) -> usize {
1500 self.k
1501 }
1502
1503 fn matvec(&self, x: &[f64], y: &mut [f64]) {
1504 let x_arr = Array1::from_iter(x.iter().copied());
1505 let mut out = Array1::<f64>::zeros(self.k);
1506 (self.matvec)(x_arr.view(), &mut out);
1507 for a in 0..self.k {
1508 y[a] += out[a];
1509 }
1510 }
1511
1512 fn gradient(&self, beta: &[f64], out: &mut [f64]) {
1513 let beta_arr = Array1::from_iter(beta.iter().copied());
1514 let mut hb = Array1::<f64>::zeros(self.k);
1515 (self.matvec)(beta_arr.view(), &mut hb);
1516 for a in 0..self.k {
1517 out[a] += hb[a];
1518 }
1519 }
1520
1521 fn diagonal(&self, diag: &mut [f64]) {
1522 for j in 0..self.k.min(diag.len()) {
1523 diag[j] += self.diagonal_vec[j];
1524 }
1525 }
1526
1527 fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
1528 let range = &offsets[id.0];
1530 let b = range.end - range.start;
1531 let mut probe = Array1::<f64>::zeros(self.k);
1532 for bj in 0..b {
1533 probe.fill(0.0);
1534 probe[range.start + bj] = 1.0;
1535 let mut col = Array1::<f64>::zeros(self.k);
1536 (self.matvec)(probe.view(), &mut col);
1537 for bi in 0..b {
1538 out[[bi, bj]] += col[range.start + bi];
1539 }
1540 }
1541 }
1542
1543 fn to_dense(&self) -> Array2<f64> {
1544 let k = self.k;
1545 let mut out = Array2::<f64>::zeros((k, k));
1546 let mut probe = Array1::<f64>::zeros(k);
1547 for j in 0..k {
1548 probe.fill(0.0);
1549 probe[j] = 1.0;
1550 let mut col = Array1::<f64>::zeros(k);
1551 (self.matvec)(probe.view(), &mut col);
1552 for i in 0..k {
1553 out[[i, j]] = col[i];
1554 }
1555 }
1556 out
1557 }
1558
1559 fn fingerprint(&self, hasher: &mut Fingerprinter) {
1560 hasher.write_str("matvec-diag-penalty-op-v1");
1564 hasher.write_usize(self.k);
1565 for &value in self.diagonal_vec.iter() {
1566 hasher.write_f64(value);
1567 }
1568 }
1569}