Skip to main content

gam_solve/arrow_schur/
penalty_ops.rs

1//! Matrix-free penalty-side `H_ββ` operators: the [`BetaPenaltyOp`] trait and
2//! every concrete operator (dense, block, Kronecker, factored-frame, composite,
3//! matvec-diagonal) plus the β-coupling graph used for block-Jacobi clustering.
4
5use super::*;
6
7#[derive(Debug, Clone)]
8pub(crate) struct BetaEdge {
9    pub(crate) a: usize,
10    pub(crate) b: usize,
11}
12
13#[derive(Debug, Clone)]
14pub(crate) struct BetaCouplingGraph {
15    pub(crate) num_blocks: usize,
16    pub(crate) edges: Vec<BetaEdge>,
17    pub(crate) adj_start: Vec<usize>,
18    pub(crate) adj_targets: Vec<usize>,
19}
20
21impl BetaCouplingGraph {
22    pub(crate) fn build(block_offsets: &[Range<usize>], htbeta_rows: &[Array2<f64>]) -> Self {
23        let num_blocks = block_offsets.len();
24        if num_blocks == 0 {
25            return Self {
26                num_blocks: 0,
27                edges: Vec::new(),
28                adj_start: vec![0],
29                adj_targets: Vec::new(),
30            };
31        }
32
33        let mut edge_set = Vec::<(usize, usize)>::new();
34        for row in htbeta_rows {
35            let mut active = Vec::<usize>::new();
36            for (block, range) in block_offsets.iter().enumerate() {
37                if range
38                    .clone()
39                    .any(|col| (0..row.nrows()).any(|axis| row[[axis, col]] != 0.0))
40                {
41                    active.push(block);
42                }
43            }
44            for i in 0..active.len() {
45                for j in (i + 1)..active.len() {
46                    edge_set.push((active[i].min(active[j]), active[i].max(active[j])));
47                }
48            }
49        }
50        edge_set.sort_unstable();
51        edge_set.dedup();
52
53        let edges: Vec<_> = edge_set.iter().map(|&(a, b)| BetaEdge { a, b }).collect();
54        let mut degree = vec![0usize; num_blocks];
55        for &BetaEdge { a, b } in &edges {
56            degree[a] += 1;
57            degree[b] += 1;
58        }
59        let mut adj_start = vec![0usize; num_blocks + 1];
60        for block in 0..num_blocks {
61            adj_start[block + 1] = adj_start[block] + degree[block];
62        }
63        let mut adj_targets = vec![0usize; adj_start[num_blocks]];
64        let mut cursor = adj_start[..num_blocks].to_vec();
65        for &BetaEdge { a, b } in &edges {
66            adj_targets[cursor[a]] = b;
67            cursor[a] += 1;
68            adj_targets[cursor[b]] = a;
69            cursor[b] += 1;
70        }
71        Self {
72            num_blocks,
73            edges,
74            adj_start,
75            adj_targets,
76        }
77    }
78
79    pub(crate) fn neighbours(&self, node: usize) -> &[usize] {
80        &self.adj_targets[self.adj_start[node]..self.adj_start[node + 1]]
81    }
82
83    pub(crate) fn component_partition(&self) -> Vec<Vec<usize>> {
84        let mut parent: Vec<usize> = (0..self.num_blocks).collect();
85        let mut rank = vec![0u8; self.num_blocks];
86
87        fn find(parent: &mut [usize], mut x: usize) -> usize {
88            while parent[x] != x {
89                parent[x] = parent[parent[x]];
90                x = parent[x];
91            }
92            x
93        }
94
95        for &BetaEdge { a, b } in &self.edges {
96            let lhs = find(&mut parent, a);
97            let rhs = find(&mut parent, b);
98            if lhs != rhs {
99                if rank[lhs] < rank[rhs] {
100                    parent[lhs] = rhs;
101                } else if rank[lhs] > rank[rhs] {
102                    parent[rhs] = lhs;
103                } else {
104                    parent[rhs] = lhs;
105                    rank[lhs] += 1;
106                }
107            }
108        }
109
110        let mut label_map = vec![usize::MAX; self.num_blocks];
111        let mut parts = Vec::<Vec<usize>>::new();
112        for block in 0..self.num_blocks {
113            let root = find(&mut parent, block);
114            let label = if label_map[root] == usize::MAX {
115                label_map[root] = parts.len();
116                parts.push(Vec::new());
117                label_map[root]
118            } else {
119                label_map[root]
120            };
121            parts[label].push(block);
122        }
123        parts
124    }
125
126    pub(crate) fn expand_one_hop(&self, seed: &[usize]) -> Vec<usize> {
127        let mut expanded = seed.to_vec();
128        for &block in seed {
129            expanded.extend_from_slice(self.neighbours(block));
130        }
131        expanded.sort_unstable();
132        expanded.dedup();
133        expanded
134    }
135}
136// ---------------------------------------------------------------------------
137// BetaPenaltyOp — matrix-free penalty-side H_ββ abstraction (#296)
138// ---------------------------------------------------------------------------
139
140/// Identifies one contiguous column block in the shared β vector for
141/// block-Jacobi Schur pre-conditioning (#287).
142///
143/// A `BetaBlockId(i)` refers to the `i`-th range in
144/// [`ArrowSchurSystem::block_offsets`].
145#[derive(Debug, Clone, Copy, PartialEq, Eq)]
146pub struct BetaBlockId(pub usize);
147
148/// Matrix-free operator for the penalty side of `H_ββ`.
149///
150/// Callers must satisfy the additive convention: every method **adds** its
151/// contribution to the output buffer (i.e. `y += P x`, not `y = P x`).
152/// This matches the assembly pattern where multiple penalty terms are
153/// accumulated into the same gradient / Hessian buffers.
154pub trait BetaPenaltyOp: Send + Sync {
155    /// Full dimension `K` of the β vector.
156    fn dim(&self) -> usize;
157    /// `y += P x` — penalty Hessian-vector product (length `K`).
158    fn matvec(&self, x: &[f64], y: &mut [f64]);
159    /// Penalty gradient: `out += P β`.
160    fn gradient(&self, beta: &[f64], out: &mut [f64]);
161    /// `diag += diag(P)` — diagonal entries used by Jacobi preconditioner.
162    fn diagonal(&self, diag: &mut [f64]);
163    /// Add the `b×b` dense penalty sub-block for block `id` into `out`
164    /// (row-major, block size `b = offsets[id.0].len()`).
165    /// Used by the block-Jacobi Schur preconditioner (#287).
166    fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>);
167    /// Materialize the full `K×K` dense penalty matrix (needed by
168    /// Direct / SqrtBA modes that form the Schur complement explicitly).
169    fn to_dense(&self) -> Array2<f64>;
170    /// Per-row absolute-value sums `out[r] = Σ_c |P[r,c]|`, the row contribution
171    /// to the operator's `∞`-norm. The default folds `to_dense()`, which costs an
172    /// `O(K²)` materialization; structured operators override this to fold their
173    /// compact factors directly so the backward-error certificate's
174    /// `arrow_operator_infinity_norm` never builds a dense `K×K` matrix on the
175    /// SAE LLM-border critical path (#1017). Overrides MUST agree bit-for-bit with
176    /// the `to_dense()` row sums (verified by the cross-check tests).
177    fn row_abs_sums(&self) -> Array1<f64> {
178        let dense = self.to_dense();
179        let k = dense.nrows();
180        let mut out = Array1::<f64>::zeros(k);
181        for r in 0..k {
182            let mut s = 0.0_f64;
183            for c in 0..dense.ncols() {
184                s += dense[[r, c]].abs();
185            }
186            out[r] = s;
187        }
188        out
189    }
190    /// Mix the operator's defining state into `hasher` for cache-validity
191    /// fingerprinting. Must change whenever `matvec` / `to_dense` would change,
192    /// so the factorization / evidence cache (`cache_matches_system`) is
193    /// invalidated when the β-block content changes. Implementations hash their
194    /// own compact defining data (e.g. Kronecker factors, block matrices)
195    /// rather than the full `K×K` dense form, which would defeat the structured
196    /// operator's storage savings.
197    fn fingerprint(&self, hasher: &mut Fingerprinter);
198}
199
200/// Dense fallback: wraps the existing `K×K` `H_ββ` accumulator.
201pub struct DensePenaltyOp(pub Array2<f64>);
202
203impl BetaPenaltyOp for DensePenaltyOp {
204    fn dim(&self) -> usize {
205        self.0.nrows()
206    }
207
208    fn matvec(&self, x: &[f64], y: &mut [f64]) {
209        let k = self.0.nrows();
210        for a in 0..k {
211            let mut acc = 0.0_f64;
212            for b in 0..k {
213                acc += self.0[[a, b]] * x[b];
214            }
215            y[a] += acc;
216        }
217    }
218
219    fn gradient(&self, beta: &[f64], out: &mut [f64]) {
220        let k = self.0.nrows();
221        for a in 0..k {
222            let mut acc = 0.0_f64;
223            for b in 0..k {
224                acc += self.0[[a, b]] * beta[b];
225            }
226            out[a] += acc;
227        }
228    }
229
230    fn diagonal(&self, diag: &mut [f64]) {
231        let k = self.0.nrows().min(diag.len());
232        for j in 0..k {
233            diag[j] += self.0[[j, j]];
234        }
235    }
236
237    fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
238        let range = &offsets[id.0];
239        let b = range.end - range.start;
240        for bi in 0..b {
241            for bj in 0..b {
242                out[[bi, bj]] += self.0[[range.start + bi, range.start + bj]];
243            }
244        }
245    }
246
247    fn to_dense(&self) -> Array2<f64> {
248        self.0.clone()
249    }
250
251    fn fingerprint(&self, hasher: &mut Fingerprinter) {
252        hasher.write_str("dense-penalty-op-v1");
253        hasher.write_f64_array2(&self.0);
254    }
255}
256
257/// Block-local penalty operator: applies per-block penalty matrices
258/// (matching `ParameterBlockSpec` boundaries) without materialising a
259/// full `K×K` dense matrix.
260///
261/// Each entry is `(global_offset, local_matrix)` where `global_offset`
262/// is the start of that block in the full β vector.
263pub struct BlockPenaltyOp {
264    /// Full β dimension `K`.
265    pub k: usize,
266    /// `(global_start, local_matrix)` for each atom/block.
267    pub blocks: Vec<(usize, Array2<f64>)>,
268}
269
270impl BetaPenaltyOp for BlockPenaltyOp {
271    fn dim(&self) -> usize {
272        self.k
273    }
274
275    fn matvec(&self, x: &[f64], y: &mut [f64]) {
276        for (off, local) in &self.blocks {
277            let b = local.nrows();
278            for i in 0..b {
279                let gi = off + i;
280                let mut acc = 0.0_f64;
281                for j in 0..b {
282                    acc += local[[i, j]] * x[off + j];
283                }
284                y[gi] += acc;
285            }
286        }
287    }
288
289    fn gradient(&self, beta: &[f64], out: &mut [f64]) {
290        for (off, local) in &self.blocks {
291            let b = local.nrows();
292            for i in 0..b {
293                let gi = off + i;
294                let mut acc = 0.0_f64;
295                for j in 0..b {
296                    acc += local[[i, j]] * beta[off + j];
297                }
298                out[gi] += acc;
299            }
300        }
301    }
302
303    fn diagonal(&self, diag: &mut [f64]) {
304        for (off, local) in &self.blocks {
305            let b = local.nrows();
306            for j in 0..b {
307                diag[off + j] += local[[j, j]];
308            }
309        }
310    }
311
312    fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
313        let range = &offsets[id.0];
314        let b_out = range.end - range.start;
315        for (off, local) in &self.blocks {
316            let b = local.nrows();
317            let block_end = off + b;
318            if block_end <= range.start || *off >= range.end {
319                continue;
320            }
321            for bi in 0..b_out {
322                let gi = range.start + bi;
323                if gi < *off || gi >= block_end {
324                    continue;
325                }
326                let li = gi - off;
327                for bj in 0..b_out {
328                    let gj = range.start + bj;
329                    if gj < *off || gj >= block_end {
330                        continue;
331                    }
332                    let lj = gj - off;
333                    out[[bi, bj]] += local[[li, lj]];
334                }
335            }
336        }
337    }
338
339    fn to_dense(&self) -> Array2<f64> {
340        let mut out = Array2::<f64>::zeros((self.k, self.k));
341        for (off, local) in &self.blocks {
342            let b = local.nrows();
343            for i in 0..b {
344                for j in 0..b {
345                    out[[off + i, off + j]] += local[[i, j]];
346                }
347            }
348        }
349        out
350    }
351
352    fn fingerprint(&self, hasher: &mut Fingerprinter) {
353        hasher.write_str("block-penalty-op-v1");
354        hasher.write_usize(self.k);
355        hasher.write_usize(self.blocks.len());
356        for (off, local) in &self.blocks {
357            hasher.write_usize(*off);
358            hasher.write_f64_array2(local);
359        }
360    }
361}
362
363/// Kronecker-product penalty: `P = A ⊗ B` applied without materialising
364/// the full `(p_a·p_b)×(p_a·p_b)` matrix.
365pub struct KroneckerPenaltyOp {
366    /// Left factor `A`, shape `(p_a, p_a)`.
367    pub factor_a: Array2<f64>,
368    /// Right factor `B`, shape `(p_b, p_b)`.
369    pub factor_b: Array2<f64>,
370    /// Global offset into the β vector where this block starts.
371    pub global_offset: usize,
372    /// Full β dimension `K`.
373    pub k: usize,
374}
375
376impl BetaPenaltyOp for KroneckerPenaltyOp {
377    fn dim(&self) -> usize {
378        self.k
379    }
380
381    fn matvec(&self, x: &[f64], y: &mut [f64]) {
382        let p_a = self.factor_a.nrows();
383        let p_b = self.factor_b.nrows();
384        let off = self.global_offset;
385        // (A ⊗ B) vec(V) where V is (p_b, p_a) with Fortran/vec ordering.
386        for i_a in 0..p_a {
387            for i_b in 0..p_b {
388                let gi = off + i_a * p_b + i_b;
389                let mut acc = 0.0_f64;
390                for j_a in 0..p_a {
391                    let a_ij = self.factor_a[[i_a, j_a]];
392                    if a_ij == 0.0 {
393                        continue;
394                    }
395                    for j_b in 0..p_b {
396                        acc += a_ij * self.factor_b[[i_b, j_b]] * x[off + j_a * p_b + j_b];
397                    }
398                }
399                y[gi] += acc;
400            }
401        }
402    }
403
404    fn gradient(&self, beta: &[f64], out: &mut [f64]) {
405        let p_a = self.factor_a.nrows();
406        let p_b = self.factor_b.nrows();
407        let off = self.global_offset;
408        for i_a in 0..p_a {
409            for i_b in 0..p_b {
410                let gi = off + i_a * p_b + i_b;
411                let mut acc = 0.0_f64;
412                for j_a in 0..p_a {
413                    let a_ij = self.factor_a[[i_a, j_a]];
414                    if a_ij == 0.0 {
415                        continue;
416                    }
417                    for j_b in 0..p_b {
418                        acc += a_ij * self.factor_b[[i_b, j_b]] * beta[off + j_a * p_b + j_b];
419                    }
420                }
421                out[gi] += acc;
422            }
423        }
424    }
425
426    fn diagonal(&self, diag: &mut [f64]) {
427        let p_a = self.factor_a.nrows();
428        let p_b = self.factor_b.nrows();
429        let off = self.global_offset;
430        for i_a in 0..p_a {
431            for i_b in 0..p_b {
432                diag[off + i_a * p_b + i_b] +=
433                    self.factor_a[[i_a, i_a]] * self.factor_b[[i_b, i_b]];
434            }
435        }
436    }
437
438    fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
439        let range = &offsets[id.0];
440        let b = range.end - range.start;
441        let p_a = self.factor_a.nrows();
442        let p_b = self.factor_b.nrows();
443        let off = self.global_offset;
444        let block_end = off + p_a * p_b;
445        if block_end <= range.start || off >= range.end {
446            return;
447        }
448        for bi in 0..b {
449            let gi = range.start + bi;
450            if gi < off || gi >= block_end {
451                continue;
452            }
453            let li = gi - off;
454            let i_a = li / p_b;
455            let i_b = li % p_b;
456            for bj in 0..b {
457                let gj = range.start + bj;
458                if gj < off || gj >= block_end {
459                    continue;
460                }
461                let lj = gj - off;
462                let j_a = lj / p_b;
463                let j_b = lj % p_b;
464                out[[bi, bj]] += self.factor_a[[i_a, j_a]] * self.factor_b[[i_b, j_b]];
465            }
466        }
467    }
468
469    fn to_dense(&self) -> Array2<f64> {
470        let p_a = self.factor_a.nrows();
471        let p_b = self.factor_b.nrows();
472        let off = self.global_offset;
473        let mut out = Array2::<f64>::zeros((self.k, self.k));
474        for i_a in 0..p_a {
475            for i_b in 0..p_b {
476                let gi = off + i_a * p_b + i_b;
477                for j_a in 0..p_a {
478                    let a_ij = self.factor_a[[i_a, j_a]];
479                    if a_ij == 0.0 {
480                        continue;
481                    }
482                    for j_b in 0..p_b {
483                        let gj = off + j_a * p_b + j_b;
484                        out[[gi, gj]] += a_ij * self.factor_b[[i_b, j_b]];
485                    }
486                }
487            }
488        }
489        out
490    }
491
492    fn fingerprint(&self, hasher: &mut Fingerprinter) {
493        hasher.write_str("kronecker-penalty-op-v1");
494        hasher.write_usize(self.global_offset);
495        hasher.write_usize(self.k);
496        hasher.write_f64_array2(&self.factor_a);
497        hasher.write_f64_array2(&self.factor_b);
498    }
499}
500
501/// Kronecker-product penalty with an identity right factor:
502/// `P = A ⊗ I_p`.
503///
504/// This is the hot SAE smoothness case. Storing `I_p` as a dense matrix costs
505/// `O(p²)` memory per atom and makes every matvec pay an unnecessary right-factor
506/// loop. This operator stores only the identity dimension and keeps the same
507/// layout as [`KroneckerPenaltyOp`]: local index `i_a * p + i_b`.
508pub struct IdentityRightKroneckerPenaltyOp {
509    /// Left factor `A`, shape `(p_a, p_a)`.
510    pub factor_a: Array2<f64>,
511    /// Identity right-factor dimension `p`.
512    pub p: usize,
513    /// Global offset into the β vector where this block starts.
514    pub global_offset: usize,
515    /// Full β dimension `K`.
516    pub k: usize,
517}
518
519impl BetaPenaltyOp for IdentityRightKroneckerPenaltyOp {
520    fn dim(&self) -> usize {
521        self.k
522    }
523
524    fn matvec(&self, x: &[f64], y: &mut [f64]) {
525        let p_a = self.factor_a.nrows();
526        let p = self.p;
527        let off = self.global_offset;
528        for i_a in 0..p_a {
529            for i_b in 0..p {
530                let gi = off + i_a * p + i_b;
531                let mut acc = 0.0_f64;
532                for j_a in 0..p_a {
533                    let a_ij = self.factor_a[[i_a, j_a]];
534                    if a_ij == 0.0 {
535                        continue;
536                    }
537                    acc += a_ij * x[off + j_a * p + i_b];
538                }
539                y[gi] += acc;
540            }
541        }
542    }
543
544    fn gradient(&self, beta: &[f64], out: &mut [f64]) {
545        self.matvec(beta, out);
546    }
547
548    fn diagonal(&self, diag: &mut [f64]) {
549        let p_a = self.factor_a.nrows();
550        let p = self.p;
551        let off = self.global_offset;
552        for i_a in 0..p_a {
553            let a_ii = self.factor_a[[i_a, i_a]];
554            for i_b in 0..p {
555                diag[off + i_a * p + i_b] += a_ii;
556            }
557        }
558    }
559
560    fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
561        let range = &offsets[id.0];
562        let b = range.end - range.start;
563        let p_a = self.factor_a.nrows();
564        let p = self.p;
565        let off = self.global_offset;
566        let block_end = off + p_a * p;
567        if block_end <= range.start || off >= range.end {
568            return;
569        }
570        for bi in 0..b {
571            let gi = range.start + bi;
572            if gi < off || gi >= block_end {
573                continue;
574            }
575            let li = gi - off;
576            let i_a = li / p;
577            let i_b = li % p;
578            for bj in 0..b {
579                let gj = range.start + bj;
580                if gj < off || gj >= block_end {
581                    continue;
582                }
583                let lj = gj - off;
584                let j_a = lj / p;
585                let j_b = lj % p;
586                if i_b == j_b {
587                    out[[bi, bj]] += self.factor_a[[i_a, j_a]];
588                }
589            }
590        }
591    }
592
593    fn to_dense(&self) -> Array2<f64> {
594        let p_a = self.factor_a.nrows();
595        let p = self.p;
596        let off = self.global_offset;
597        let mut out = Array2::<f64>::zeros((self.k, self.k));
598        for i_a in 0..p_a {
599            for j_a in 0..p_a {
600                let a_ij = self.factor_a[[i_a, j_a]];
601                if a_ij == 0.0 {
602                    continue;
603                }
604                for i_b in 0..p {
605                    let gi = off + i_a * p + i_b;
606                    let gj = off + j_a * p + i_b;
607                    out[[gi, gj]] += a_ij;
608                }
609            }
610        }
611        out
612    }
613
614    fn fingerprint(&self, hasher: &mut Fingerprinter) {
615        hasher.write_str("identity-right-kronecker-penalty-op-v1");
616        hasher.write_usize(self.global_offset);
617        hasher.write_usize(self.k);
618        hasher.write_usize(self.p);
619        hasher.write_f64_array2(&self.factor_a);
620    }
621}
622
623/// One co-occurring atom-pair block of a block-sparse left factor `A`.
624///
625/// `data` is the dense `(m_i × m_j)` coupling between the basis columns of
626/// atom `i` (rows, starting at left-factor offset `row_off`) and atom `j`
627/// (columns, starting at `col_off`). Both offsets are in *left-factor* (`A`)
628/// coordinates, i.e. `μ`-space, not β-space.
629#[derive(Debug, Clone)]
630pub struct SparseGBlock {
631    /// Left-factor (`μ`-space) row offset = `beta_offset[atom_i] / p`.
632    pub row_off: usize,
633    /// Left-factor (`μ`-space) column offset = `beta_offset[atom_j] / p`.
634    pub col_off: usize,
635    /// Dense `(m_i × m_j)` coupling block.
636    pub data: Array2<f64>,
637}
638
639/// Block-sparse Kronecker penalty `P = A ⊗ I_p` where the left factor `A`
640/// (dimension `dim_a × dim_a` in `μ`-space) is stored only on its non-empty
641/// co-occurring atom-pair blocks rather than as a dense `(dim_a × dim_a)`
642/// matrix.
643///
644/// This is the sparse-atom (`K = 100K`) replacement for wrapping the dense
645/// data-fit Gauss-Newton Gram `G` (`m_total × m_total`) in a
646/// [`KroneckerPenaltyOp`]: with per-row active sets of size `k_active ≪ K`,
647/// only the `(atom, atom')` pairs that co-occur in some row contribute a
648/// non-zero `(m_i × m_j)` block, so the storage and every matvec/diagonal
649/// pass cost `O(Σ_pairs m_i m_j · p)` instead of `O((m_total · p)²)`.
650///
651/// The β index of left-factor coordinate `μ` and output channel `oc` is
652/// `μ · p + oc` (the same `μ`-major / `oc`-minor layout the dense
653/// `KroneckerPenaltyOp { factor_b: I_p }` uses), so this op is a drop-in
654/// structured replacement: with the full dense pair set it reproduces the
655/// dense operator exactly.
656pub struct SparseBlockKroneckerPenaltyOp {
657    /// Right-factor identity dimension `p` (number of decoder output channels).
658    pub p: usize,
659    /// Left-factor dimension `dim_a` in `μ`-space (= `m_total`).
660    pub dim_a: usize,
661    /// Full β dimension `K = dim_a · p`.
662    pub k: usize,
663    /// Non-empty `(atom_i, atom_j)` coupling blocks of `A`.
664    pub blocks: Vec<SparseGBlock>,
665}
666
667#[derive(Debug, Clone)]
668pub struct DeviceSaeSmoothBlock {
669    pub global_offset: usize,
670    pub factor_a: Array2<f64>,
671}
672
673/// Frame-factored extension of [`DeviceSaePcgData`] (issue #1017/#1026,
674/// frames-engaged device PCG). Present only when at least one atom is genuinely
675/// frame-reduced (`ranks[k] < p`); absent (`None`) on the full-`B` path, where
676/// the legacy `G ⊗ I_p` channel-identical kernel applies byte-for-byte.
677///
678/// On the frames path the β border is the FACTORED coordinate space `C` of width
679/// `Σ_k M_k·r_k`, the data-fit β-Hessian is `G_{ij} ⊗ W_{ij}` (`W_{ij}=U_iᵀU_j`,
680/// carried on `frame_blocks`), the smooth penalty is `λ S_k ⊗ I_{r_k}`
681/// (`smooth_blocks`, reused — width `r_k` instead of `p`), and the per-row
682/// reduced-Schur cross-block `H_tβ^(i)` is the DENSE `(q_i × border_dim)` slab
683/// `row_htbeta[i]` (row-major) rather than the full-`B` factored `L_i · J_β`
684/// gather (so `a_phi`/`local_jac` are unused on this path).
685#[derive(Debug, Clone)]
686pub struct DeviceSaeFrameData {
687    /// Per-atom frame rank `r_k` (factored output width); `r_k == p` for an
688    /// un-framed atom riding the identity special case.
689    pub ranks: Vec<usize>,
690    /// Per-atom basis size `M_k`.
691    pub basis_sizes: Vec<usize>,
692    /// Per-atom factored-border offset `off_C[k]` (prefix sum of `M_k·r_k`),
693    /// length `n_atoms`. Atom `k`'s `C_k` block is `[off_C[k] .. +M_k·r_k)`.
694    pub border_offsets: Vec<usize>,
695    /// Co-occurring `(atom_i, atom_j)` data-fit blocks `g ⊗ w` (`w = U_iᵀU_j`).
696    pub frame_blocks: Vec<FactoredFrameGBlock>,
697    /// Right-factor width (`r_k`) of each entry of the top-level
698    /// `DeviceSaePcgData::smooth_blocks`, in the SAME order. On the frames path
699    /// the smooth penalty is `λ S_k ⊗ I_{r_k}` so the block at
700    /// `smooth_blocks[i].global_offset` has identity width `smooth_ranks[i]`
701    /// (which equals `ranks[atom]`), NOT the ambient `p`.
702    pub smooth_ranks: Vec<usize>,
703    /// Per-row dense cross-block `H_tβ^(i)` as a row-major `q_i × border_dim`
704    /// buffer (`q_i = row_dims[i]`). Empty inner `Vec` for a 0-dim row.
705    pub row_htbeta: Vec<Vec<f64>>,
706}
707
708#[derive(Debug, Clone)]
709pub struct DeviceSaePcgData {
710    pub p: usize,
711    pub beta_dim: usize,
712    // #1033 large-n: the per-row support `a_phi` and local Jacobians `local_jac`
713    // are ALSO held by the host matrix-free row operator (`SaeKroneckerRows`) for
714    // the lifetime of the inner solve. Storing them as `Arc<[…]>` lets the
715    // assembler hand BOTH consumers the SAME backing allocation instead of a
716    // second full `O(n·q·p)` clone (`device_rows = (a_phi.clone(), kron_jac.clone())`
717    // was the dominant always-resident duplication on the CPU non-frames path at
718    // the LLM shape p≈5120). Indexing/`.len()`/iteration are identical to `Vec`.
719    pub a_phi: Arc<[Vec<(usize, f64)>]>,
720    pub local_jac: Arc<[Vec<f64>]>,
721    pub smooth_blocks: Vec<DeviceSaeSmoothBlock>,
722    pub sparse_g_blocks: Vec<SparseGBlock>,
723    /// Frame-factored metadata. `None` ⇒ legacy full-`B` `G ⊗ I_p` path
724    /// (byte-identical to before this field existed). `Some` ⇒ frames-engaged
725    /// path: the kernel consumes `frame.frame_blocks`/`smooth_blocks` (now
726    /// rank-`r_k` wide) and `frame.row_htbeta` instead of the `⊗ I_p` gather.
727    pub frame: Option<DeviceSaeFrameData>,
728}
729
730impl DeviceSaePcgData {
731    /// Snapshot the per-row active-atom support `a_phi` into a shared `Arc<[…]>`
732    /// for the CPU residency operator ([`SaeResidentReducedSchur`]). Cloned once
733    /// per CG-solve build (cost `O(Σ_i m_i)`, dwarfed by the per-row factor solves
734    /// in the same build), so the resident matvec borrows the index lists without
735    /// re-cloning them on every CG iteration.
736    pub(crate) fn a_phi_shared(&self) -> Arc<[Vec<(usize, f64)>]> {
737        // #1033: `a_phi` is already an `Arc<[…]>`; hand back a refcount bump
738        // (`O(1)`) rather than re-cloning every `(idx, weight)` pair per CG build.
739        Arc::clone(&self.a_phi)
740    }
741
742    /// Share the per-row local Jacobians `local_jac` with the CPU residency
743    /// operator ([`SaeResidentReducedSchur`]) as an `O(1)` refcount bump. The
744    /// staged row factor used to hold a verbatim row-major copy of each
745    /// `local_jac[row]`; sharing the slab removes that second full `O(n·di·p)`
746    /// copy with byte-for-byte identical reads (#1033).
747    pub(crate) fn local_jac_shared(&self) -> Arc<[Vec<f64>]> {
748        Arc::clone(&self.local_jac)
749    }
750}
751
752impl BetaPenaltyOp for SparseBlockKroneckerPenaltyOp {
753    fn dim(&self) -> usize {
754        self.k
755    }
756
757    fn matvec(&self, x: &[f64], y: &mut [f64]) {
758        let p = self.p;
759        for blk in &self.blocks {
760            let (m_i, m_j) = blk.data.dim();
761            for li in 0..m_i {
762                let gi_base = (blk.row_off + li) * p;
763                for lj in 0..m_j {
764                    let a_ij = blk.data[[li, lj]];
765                    if a_ij == 0.0 {
766                        continue;
767                    }
768                    let gj_base = (blk.col_off + lj) * p;
769                    for oc in 0..p {
770                        y[gi_base + oc] += a_ij * x[gj_base + oc];
771                    }
772                }
773            }
774        }
775    }
776
777    fn gradient(&self, beta: &[f64], out: &mut [f64]) {
778        self.matvec(beta, out);
779    }
780
781    fn diagonal(&self, diag: &mut [f64]) {
782        let p = self.p;
783        for blk in &self.blocks {
784            // Only on-diagonal `A` blocks (row_off == col_off) carry diagonal
785            // mass; their `(li, li)` entries map to `(row_off+li)·p + oc`.
786            if blk.row_off != blk.col_off {
787                continue;
788            }
789            let (m_i, m_j) = blk.data.dim();
790            let m = m_i.min(m_j);
791            for li in 0..m {
792                let a_ii = blk.data[[li, li]];
793                let gi_base = (blk.row_off + li) * p;
794                for oc in 0..p {
795                    diag[gi_base + oc] += a_ii;
796                }
797            }
798        }
799    }
800
801    fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
802        let range = &offsets[id.0];
803        let b = range.end - range.start;
804        let p = self.p;
805        for blk in &self.blocks {
806            let (m_i, m_j) = blk.data.dim();
807            let row_start = blk.row_off * p;
808            let row_end = (blk.row_off + m_i) * p;
809            let col_start = blk.col_off * p;
810            let col_end = (blk.col_off + m_j) * p;
811            if row_end <= range.start
812                || row_start >= range.end
813                || col_end <= range.start
814                || col_start >= range.end
815            {
816                continue;
817            }
818            for bi in 0..b {
819                let gi = range.start + bi;
820                if gi < row_start || gi >= row_end {
821                    continue;
822                }
823                let li = (gi - row_start) / p;
824                let oc_i = (gi - row_start) % p;
825                for bj in 0..b {
826                    let gj = range.start + bj;
827                    if gj < col_start || gj >= col_end {
828                        continue;
829                    }
830                    let oc_j = (gj - col_start) % p;
831                    if oc_i != oc_j {
832                        continue;
833                    }
834                    let lj = (gj - col_start) / p;
835                    out[[bi, bj]] += blk.data[[li, lj]];
836                }
837            }
838        }
839    }
840
841    fn to_dense(&self) -> Array2<f64> {
842        let p = self.p;
843        let mut out = Array2::<f64>::zeros((self.k, self.k));
844        for blk in &self.blocks {
845            let (m_i, m_j) = blk.data.dim();
846            for li in 0..m_i {
847                let gi_base = (blk.row_off + li) * p;
848                for lj in 0..m_j {
849                    let a_ij = blk.data[[li, lj]];
850                    if a_ij == 0.0 {
851                        continue;
852                    }
853                    let gj_base = (blk.col_off + lj) * p;
854                    for oc in 0..p {
855                        out[[gi_base + oc, gj_base + oc]] += a_ij;
856                    }
857                }
858            }
859        }
860        out
861    }
862
863    fn row_abs_sums(&self) -> Array1<f64> {
864        // Mirror `to_dense`: entry `(gi_base+oc, gj_base+oc) += a_ij`. Each
865        // `(li, lj, oc)` lands in a DISTINCT column (`gj_base+oc` is injective in
866        // `(lj, oc)` for a fixed block, and blocks with the same `row_off` have
867        // disjoint `col_off`), so the row's `Σ_c|P[r,c]|` is just the sum of
868        // `|a_ij|` over the contributing `(block, lj)` pairs — no dense matrix.
869        let p = self.p;
870        let mut out = Array1::<f64>::zeros(self.k);
871        for blk in &self.blocks {
872            let (m_i, m_j) = blk.data.dim();
873            for li in 0..m_i {
874                let gi_base = (blk.row_off + li) * p;
875                let mut row_abs = 0.0_f64;
876                for lj in 0..m_j {
877                    row_abs += blk.data[[li, lj]].abs();
878                }
879                for oc in 0..p {
880                    out[gi_base + oc] += row_abs;
881                }
882            }
883        }
884        out
885    }
886
887    fn fingerprint(&self, hasher: &mut Fingerprinter) {
888        hasher.write_str("sparse-block-kronecker-penalty-op-v1");
889        hasher.write_usize(self.p);
890        hasher.write_usize(self.dim_a);
891        hasher.write_usize(self.k);
892        hasher.write_usize(self.blocks.len());
893        for blk in &self.blocks {
894            hasher.write_usize(blk.row_off);
895            hasher.write_usize(blk.col_off);
896            hasher.write_f64_array2(&blk.data);
897        }
898    }
899}
900
901/// One co-occurring `(atom_i, atom_j)` block of the **frame-factored** data-fit
902/// Gauss–Newton β-Hessian (issue #972 / #977 T1). Carries the basis-space Gram
903/// `g` (`m_i × m_j`) AND the per-pair frame output factor `w = U_iᵀ U_j`
904/// (`r_i × r_j`); the contributed Hessian sub-block is the Kronecker product
905/// `g ⊗ w`.
906#[derive(Debug, Clone)]
907pub struct FactoredFrameGBlock {
908    /// Atom index of the row factor (selects rank `r_i` and β offset).
909    pub atom_i: usize,
910    /// Atom index of the column factor (selects rank `r_j` and β offset).
911    pub atom_j: usize,
912    /// Basis-space coupling `G_{ij}` (`m_i × m_j`).
913    pub g: Array2<f64>,
914    /// Frame output factor `U_iᵀ U_j` (`r_i × r_j`). For `i == j` with an
915    /// orthonormal frame this is `I_{r_i}` (the clean within-atom `g ⊗ I_r`
916    /// collapse); across atoms it is the dense principal-angle cosine matrix
917    /// between the two frames.
918    pub w: Array2<f64>,
919}
920
921/// Frame-factored data-fit Gauss–Newton β-Hessian operator (#972 / #977 T1):
922/// the `Σ_k M_k·r_k` reduced-border analogue of [`SparseBlockKroneckerPenaltyOp`].
923///
924/// When every atom's decoder `B_k = C_k U_kᵀ` is profiled onto a Grassmann
925/// frame `U_k ∈ St(p, r_k)`, the border carries only the shape coefficients
926/// `C_k` (`M_k · r_k` entries) instead of the full `B_k` (`M_k · p`). The data
927/// Gram in this reduced space is, for the isotropic likelihood,
928/// `H[(i,li,a),(j,lj,b)] = G_{ij}[li,lj] · (U_iᵀ U_j)[a,b]` — within an atom the
929/// orthonormal frame gives `U_iᵀU_i = I_{r_i}` and the block is the clean
930/// `G ⊗ I_r` collapse; across co-active atoms the frames do not share a basis
931/// so the output factor is the dense `U_iᵀU_j`.
932///
933/// The β layout is `μ`-major / frame-minor with a **variable** per-atom width
934/// `r_k`: the index of (atom `k`, basis `li`, frame coord `a`) is
935/// `offset[k] + li·r_k + a`, where `offset` is the prefix sum of `M_k · r_k`.
936/// With every `r_k = p` and `U_k = I_p` this reproduces
937/// [`SparseBlockKroneckerPenaltyOp`] exactly (a unit test pins the reduction),
938/// so it is a strict generalization, not a separate code path.
939pub struct FactoredFrameKroneckerOp {
940    /// Per-atom frame rank `r_k` (the factored output width).
941    pub ranks: Vec<usize>,
942    /// Per-atom basis size `M_k`.
943    pub basis_sizes: Vec<usize>,
944    /// Per-atom β offset (prefix sum of `M_k · r_k`); `offsets[k]` is the start
945    /// of atom `k`'s `C_k` block, `offsets[n_atoms]` the total dim.
946    pub offsets: Vec<usize>,
947    /// Total reduced β dimension `Σ_k M_k · r_k`.
948    pub dim: usize,
949    /// Non-empty co-occurring `(atom_i, atom_j)` blocks.
950    pub blocks: Vec<FactoredFrameGBlock>,
951}
952
953/// Frame output Gram `U_iᵀ U_j` (`r_i × r_j`) between two per-atom output
954/// frames (each `p × r`). This is the dense principal-angle cosine matrix that
955/// becomes the `w` factor of a [`FactoredFrameGBlock`]; for `i == j` with an
956/// orthonormal frame it is `I_{r_i}`. Shared with
957/// [`gam_terms::sae::manifold`], which builds the same factors when
958/// profiling decoders onto Grassmann frames.
959pub fn frame_output_gram(u_i: ArrayView2<f64>, u_j: ArrayView2<f64>) -> Array2<f64> {
960    let (p_i, r_i) = u_i.dim();
961    let (p_j, r_j) = u_j.dim();
962    assert_eq!(
963        p_i, p_j,
964        "frame_output_gram: frames live in different ambient dims ({p_i} vs {p_j})"
965    );
966    let mut w = Array2::<f64>::zeros((r_i, r_j));
967    for a in 0..r_i {
968        for b in 0..r_j {
969            let mut acc = 0.0;
970            for c in 0..p_i {
971                acc += u_i[[c, a]] * u_j[[c, b]];
972            }
973            w[[a, b]] = acc;
974        }
975    }
976    w
977}
978
979impl FactoredFrameKroneckerOp {
980    /// Build from per-atom ranks + basis sizes and the co-occurring blocks.
981    /// Computes the β offsets (prefix sum of `M_k·r_k`) and validates that each
982    /// block's `g`/`w` shapes match the atoms' `(M, r)`.
983    pub fn new(
984        ranks: Vec<usize>,
985        basis_sizes: Vec<usize>,
986        blocks: Vec<FactoredFrameGBlock>,
987    ) -> Result<Self, String> {
988        if ranks.len() != basis_sizes.len() {
989            return Err(format!(
990                "FactoredFrameKroneckerOp: {} ranks but {} basis sizes",
991                ranks.len(),
992                basis_sizes.len()
993            ));
994        }
995        let n_atoms = ranks.len();
996        let mut offsets = Vec::with_capacity(n_atoms + 1);
997        let mut acc = 0usize;
998        for k in 0..n_atoms {
999            offsets.push(acc);
1000            acc += basis_sizes[k] * ranks[k];
1001        }
1002        offsets.push(acc);
1003        let dim = acc;
1004        for blk in &blocks {
1005            if blk.atom_i >= n_atoms || blk.atom_j >= n_atoms {
1006                return Err(format!(
1007                    "FactoredFrameKroneckerOp: block atom indices ({}, {}) out of range (n_atoms = {n_atoms})",
1008                    blk.atom_i, blk.atom_j
1009                ));
1010            }
1011            if blk.g.dim() != (basis_sizes[blk.atom_i], basis_sizes[blk.atom_j]) {
1012                return Err(format!(
1013                    "FactoredFrameKroneckerOp: block ({}, {}) g has shape {:?} but expected ({}, {})",
1014                    blk.atom_i,
1015                    blk.atom_j,
1016                    blk.g.dim(),
1017                    basis_sizes[blk.atom_i],
1018                    basis_sizes[blk.atom_j]
1019                ));
1020            }
1021            if blk.w.dim() != (ranks[blk.atom_i], ranks[blk.atom_j]) {
1022                return Err(format!(
1023                    "FactoredFrameKroneckerOp: block ({}, {}) w has shape {:?} but expected ({}, {})",
1024                    blk.atom_i,
1025                    blk.atom_j,
1026                    blk.w.dim(),
1027                    ranks[blk.atom_i],
1028                    ranks[blk.atom_j]
1029                ));
1030            }
1031        }
1032        Ok(Self {
1033            ranks,
1034            basis_sizes,
1035            offsets,
1036            dim,
1037            blocks,
1038        })
1039    }
1040
1041    /// Convenience constructor that builds the operator directly from per-atom
1042    /// output frames + the basis-space Gram block map, computing the per-pair
1043    /// frame factors `W_ij = U_iᵀ U_j` itself.
1044    ///
1045    /// `frames[k]` is either `Some(U_k)` — a `p × r_k` (`r_k ≤ p`) output frame
1046    /// (a Grassmann representative `St(p, r_k)` need not be orthonormal here; the
1047    /// `W` factor carries whatever frame is supplied) — or `None`, meaning atom
1048    /// `k` keeps the full ambient output (`U_k = I_p`, so `r_k = p`). For each
1049    /// non-empty Gram block `(atom_i, atom_j)` the factor `W` is
1050    /// `U_iᵀ U_j` (`r_i × r_j`), with the `None` frame standing in for `I_p`:
1051    /// a framed×unframed cross gives `W = U_iᵀ` (`r_i × p`) and an unframed
1052    /// diagonal gives `W = I_p` — exactly reproducing the `g ⊗ I_p` full-`B`
1053    /// block. The resulting blocks are handed to [`Self::new`], which validates
1054    /// the `(M, r)` shapes and computes the β offsets.
1055    pub fn from_frames_and_blocks(
1056        frames: &[Option<Array2<f64>>],
1057        basis_sizes: &[usize],
1058        p: usize,
1059        g_blocks: &std::collections::BTreeMap<(usize, usize), Array2<f64>>,
1060    ) -> Result<Self, String> {
1061        if frames.len() != basis_sizes.len() {
1062            return Err(format!(
1063                "FactoredFrameKroneckerOp::from_frames_and_blocks: {} frames but {} basis sizes",
1064                frames.len(),
1065                basis_sizes.len()
1066            ));
1067        }
1068        let n_atoms = frames.len();
1069        // Per-atom rank: ncols of a supplied frame, else the ambient dim p.
1070        let mut ranks = Vec::with_capacity(n_atoms);
1071        for (k, frame) in frames.iter().enumerate() {
1072            match frame {
1073                Some(u) => {
1074                    let (pr, r) = u.dim();
1075                    if pr != p {
1076                        return Err(format!(
1077                            "FactoredFrameKroneckerOp::from_frames_and_blocks: frame {k} has {pr} rows but ambient dim is {p}"
1078                        ));
1079                    }
1080                    if r > p {
1081                        return Err(format!(
1082                            "FactoredFrameKroneckerOp::from_frames_and_blocks: frame {k} has rank {r} > ambient dim {p}"
1083                        ));
1084                    }
1085                    ranks.push(r);
1086                }
1087                None => ranks.push(p),
1088            }
1089        }
1090        // Materialize each atom's frame as a `p × r_k` view source: the supplied
1091        // `U_k`, or `I_p` for the unframed atoms.
1092        let identity = Array2::<f64>::eye(p);
1093        let frame_or_ident = |k: usize| -> ArrayView2<f64> {
1094            match &frames[k] {
1095                Some(u) => u.view(),
1096                None => identity.view(),
1097            }
1098        };
1099        let mut blocks = Vec::with_capacity(g_blocks.len());
1100        for (&(atom_i, atom_j), g) in g_blocks {
1101            if atom_i >= n_atoms || atom_j >= n_atoms {
1102                return Err(format!(
1103                    "FactoredFrameKroneckerOp::from_frames_and_blocks: block atom indices ({atom_i}, {atom_j}) out of range (n_atoms = {n_atoms})"
1104                ));
1105            }
1106            let w = frame_output_gram(frame_or_ident(atom_i), frame_or_ident(atom_j));
1107            blocks.push(FactoredFrameGBlock {
1108                atom_i,
1109                atom_j,
1110                g: g.clone(),
1111                w,
1112            });
1113        }
1114        Self::new(ranks, basis_sizes.to_vec(), blocks)
1115    }
1116}
1117
1118impl BetaPenaltyOp for FactoredFrameKroneckerOp {
1119    fn dim(&self) -> usize {
1120        self.dim
1121    }
1122
1123    fn matvec(&self, x: &[f64], y: &mut [f64]) {
1124        for blk in &self.blocks {
1125            let r_i = self.ranks[blk.atom_i];
1126            let r_j = self.ranks[blk.atom_j];
1127            let off_i = self.offsets[blk.atom_i];
1128            let off_j = self.offsets[blk.atom_j];
1129            let (m_i, m_j) = blk.g.dim();
1130            for li in 0..m_i {
1131                let yi_base = off_i + li * r_i;
1132                for lj in 0..m_j {
1133                    let g = blk.g[[li, lj]];
1134                    if g == 0.0 {
1135                        continue;
1136                    }
1137                    let xj_base = off_j + lj * r_j;
1138                    // y_block[li, a] += g · Σ_b w[a, b] · x_block[lj, b]
1139                    for a in 0..r_i {
1140                        let mut acc = 0.0;
1141                        for b in 0..r_j {
1142                            acc += blk.w[[a, b]] * x[xj_base + b];
1143                        }
1144                        y[yi_base + a] += g * acc;
1145                    }
1146                }
1147            }
1148        }
1149    }
1150
1151    fn gradient(&self, beta: &[f64], out: &mut [f64]) {
1152        self.matvec(beta, out);
1153    }
1154
1155    fn diagonal(&self, diag: &mut [f64]) {
1156        for blk in &self.blocks {
1157            // Only on-diagonal atom blocks carry diagonal mass; the entry at
1158            // (atom k, basis li, coord a) is g[li,li]·w[a,a].
1159            if blk.atom_i != blk.atom_j {
1160                continue;
1161            }
1162            let r = self.ranks[blk.atom_i];
1163            let off = self.offsets[blk.atom_i];
1164            let (m_i, m_j) = blk.g.dim();
1165            let m = m_i.min(m_j);
1166            for li in 0..m {
1167                let gii = blk.g[[li, li]];
1168                let base = off + li * r;
1169                for a in 0..r {
1170                    diag[base + a] += gii * blk.w[[a, a]];
1171                }
1172            }
1173        }
1174    }
1175
1176    fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
1177        // Dense sub-block over the β index range `offsets[id.0]`. Mirror the
1178        // global (i,a) ↔ (j,b) coupling, keeping only indices inside the range.
1179        let range = &offsets[id.0];
1180        let b_dim = range.end - range.start;
1181        for blk in &self.blocks {
1182            let r_i = self.ranks[blk.atom_i];
1183            let r_j = self.ranks[blk.atom_j];
1184            let off_i = self.offsets[blk.atom_i];
1185            let off_j = self.offsets[blk.atom_j];
1186            let (m_i, m_j) = blk.g.dim();
1187            for li in 0..m_i {
1188                for a in 0..r_i {
1189                    let gi = off_i + li * r_i + a;
1190                    if gi < range.start || gi >= range.end {
1191                        continue;
1192                    }
1193                    let bi = gi - range.start;
1194                    for lj in 0..m_j {
1195                        let g = blk.g[[li, lj]];
1196                        if g == 0.0 {
1197                            continue;
1198                        }
1199                        for b in 0..r_j {
1200                            let gj = off_j + lj * r_j + b;
1201                            if gj < range.start || gj >= range.end {
1202                                continue;
1203                            }
1204                            let bj = gj - range.start;
1205                            if bi < b_dim && bj < b_dim {
1206                                out[[bi, bj]] += g * blk.w[[a, b]];
1207                            }
1208                        }
1209                    }
1210                }
1211            }
1212        }
1213    }
1214
1215    fn to_dense(&self) -> Array2<f64> {
1216        let mut out = Array2::<f64>::zeros((self.dim, self.dim));
1217        for blk in &self.blocks {
1218            let r_i = self.ranks[blk.atom_i];
1219            let r_j = self.ranks[blk.atom_j];
1220            let off_i = self.offsets[blk.atom_i];
1221            let off_j = self.offsets[blk.atom_j];
1222            let (m_i, m_j) = blk.g.dim();
1223            for li in 0..m_i {
1224                for lj in 0..m_j {
1225                    let g = blk.g[[li, lj]];
1226                    if g == 0.0 {
1227                        continue;
1228                    }
1229                    for a in 0..r_i {
1230                        let gi = off_i + li * r_i + a;
1231                        for b in 0..r_j {
1232                            let gj = off_j + lj * r_j + b;
1233                            out[[gi, gj]] += g * blk.w[[a, b]];
1234                        }
1235                    }
1236                }
1237            }
1238        }
1239        out
1240    }
1241
1242    fn fingerprint(&self, hasher: &mut Fingerprinter) {
1243        hasher.write_str("factored-frame-kronecker-op-v1");
1244        hasher.write_usize(self.dim);
1245        for &r in &self.ranks {
1246            hasher.write_usize(r);
1247        }
1248        for &m in &self.basis_sizes {
1249            hasher.write_usize(m);
1250        }
1251        hasher.write_usize(self.blocks.len());
1252        for blk in &self.blocks {
1253            hasher.write_usize(blk.atom_i);
1254            hasher.write_usize(blk.atom_j);
1255            hasher.write_f64_array2(&blk.g);
1256            hasher.write_f64_array2(&blk.w);
1257        }
1258    }
1259}
1260
1261/// Composite penalty: sum of multiple `BetaPenaltyOp` operators.
1262pub struct CompositePenaltyOp {
1263    /// Full β dimension `K`.
1264    pub k: usize,
1265    /// Component operators, each contributing additively.
1266    pub ops: Vec<Arc<dyn BetaPenaltyOp>>,
1267}
1268
1269impl BetaPenaltyOp for CompositePenaltyOp {
1270    fn dim(&self) -> usize {
1271        self.k
1272    }
1273
1274    fn matvec(&self, x: &[f64], y: &mut [f64]) {
1275        for op in &self.ops {
1276            op.matvec(x, y);
1277        }
1278    }
1279
1280    fn gradient(&self, beta: &[f64], out: &mut [f64]) {
1281        for op in &self.ops {
1282            op.gradient(beta, out);
1283        }
1284    }
1285
1286    fn diagonal(&self, diag: &mut [f64]) {
1287        for op in &self.ops {
1288            op.diagonal(diag);
1289        }
1290    }
1291
1292    fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
1293        for op in &self.ops {
1294            op.block(id, offsets, out);
1295        }
1296    }
1297
1298    fn to_dense(&self) -> Array2<f64> {
1299        let mut out = Array2::<f64>::zeros((self.k, self.k));
1300        for op in &self.ops {
1301            let dense = op.to_dense();
1302            out += &dense;
1303        }
1304        out
1305    }
1306
1307    fn fingerprint(&self, hasher: &mut Fingerprinter) {
1308        hasher.write_str("composite-penalty-op-v1");
1309        hasher.write_usize(self.k);
1310        hasher.write_usize(self.ops.len());
1311        for op in &self.ops {
1312            op.fingerprint(hasher);
1313        }
1314    }
1315}
1316
1317/// Adapts a closure-based matrix-free `H_ββ` operator (from
1318/// [`ArrowSchurSystem::set_shared_beta_operator`]) to the `BetaPenaltyOp` trait.
1319///
1320/// `diagonal` holds the precomputed `diag(H_ββ)` supplied alongside the matvec;
1321/// `to_dense` falls back to probing all `K` canonical basis vectors.
1322pub struct MatvecDiagPenaltyOp {
1323    pub(crate) k: usize,
1324    pub(crate) matvec: SharedBetaMatvec,
1325    pub(crate) diagonal_vec: Array1<f64>,
1326}
1327
1328impl MatvecDiagPenaltyOp {
1329    pub fn new(k: usize, matvec: SharedBetaMatvec, diagonal_vec: Array1<f64>) -> Self {
1330        assert_eq!(diagonal_vec.len(), k);
1331        Self {
1332            k,
1333            matvec,
1334            diagonal_vec,
1335        }
1336    }
1337}
1338
1339impl BetaPenaltyOp for MatvecDiagPenaltyOp {
1340    fn dim(&self) -> usize {
1341        self.k
1342    }
1343
1344    fn matvec(&self, x: &[f64], y: &mut [f64]) {
1345        let x_arr = Array1::from_iter(x.iter().copied());
1346        let mut out = Array1::<f64>::zeros(self.k);
1347        (self.matvec)(x_arr.view(), &mut out);
1348        for a in 0..self.k {
1349            y[a] += out[a];
1350        }
1351    }
1352
1353    fn gradient(&self, beta: &[f64], out: &mut [f64]) {
1354        let beta_arr = Array1::from_iter(beta.iter().copied());
1355        let mut hb = Array1::<f64>::zeros(self.k);
1356        (self.matvec)(beta_arr.view(), &mut hb);
1357        for a in 0..self.k {
1358            out[a] += hb[a];
1359        }
1360    }
1361
1362    fn diagonal(&self, diag: &mut [f64]) {
1363        for j in 0..self.k.min(diag.len()) {
1364            diag[j] += self.diagonal_vec[j];
1365        }
1366    }
1367
1368    fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
1369        // Probe each basis vector in the block range to extract the sub-block.
1370        let range = &offsets[id.0];
1371        let b = range.end - range.start;
1372        let mut probe = Array1::<f64>::zeros(self.k);
1373        for bj in 0..b {
1374            probe.fill(0.0);
1375            probe[range.start + bj] = 1.0;
1376            let mut col = Array1::<f64>::zeros(self.k);
1377            (self.matvec)(probe.view(), &mut col);
1378            for bi in 0..b {
1379                out[[bi, bj]] += col[range.start + bi];
1380            }
1381        }
1382    }
1383
1384    fn to_dense(&self) -> Array2<f64> {
1385        let k = self.k;
1386        let mut out = Array2::<f64>::zeros((k, k));
1387        let mut probe = Array1::<f64>::zeros(k);
1388        for j in 0..k {
1389            probe.fill(0.0);
1390            probe[j] = 1.0;
1391            let mut col = Array1::<f64>::zeros(k);
1392            (self.matvec)(probe.view(), &mut col);
1393            for i in 0..k {
1394                out[[i, j]] = col[i];
1395            }
1396        }
1397        out
1398    }
1399
1400    fn fingerprint(&self, hasher: &mut Fingerprinter) {
1401        // The matvec closure cannot be hashed by content; the precomputed
1402        // diagonal is the operator's stable defining proxy (it is recomputed
1403        // alongside the matvec each time the operator is installed).
1404        hasher.write_str("matvec-diag-penalty-op-v1");
1405        hasher.write_usize(self.k);
1406        for &value in self.diagonal_vec.iter() {
1407            hasher.write_f64(value);
1408        }
1409    }
1410}