Skip to main content

gam_solve/arrow_schur/
penalty_ops.rs

1//! Matrix-free penalty-side `H_ββ` operators: the [`BetaPenaltyOp`] trait and
2//! every concrete operator (dense, block, Kronecker, factored-frame, composite,
3//! matvec-diagonal) plus the β-coupling graph used for block-Jacobi clustering.
4
5use super::*;
6
7#[derive(Debug, Clone)]
8pub(crate) struct BetaEdge {
9    pub(crate) a: usize,
10    pub(crate) b: usize,
11}
12
13#[derive(Debug, Clone)]
14pub(crate) struct BetaCouplingGraph {
15    pub(crate) num_blocks: usize,
16    pub(crate) edges: Vec<BetaEdge>,
17    pub(crate) adj_start: Vec<usize>,
18    pub(crate) adj_targets: Vec<usize>,
19}
20
21impl BetaCouplingGraph {
22    pub(crate) fn build(block_offsets: &[Range<usize>], htbeta_rows: &[Array2<f64>]) -> Self {
23        let num_blocks = block_offsets.len();
24        if num_blocks == 0 {
25            return Self {
26                num_blocks: 0,
27                edges: Vec::new(),
28                adj_start: vec![0],
29                adj_targets: Vec::new(),
30            };
31        }
32
33        let mut edge_set = Vec::<(usize, usize)>::new();
34        for row in htbeta_rows {
35            let mut active = Vec::<usize>::new();
36            for (block, range) in block_offsets.iter().enumerate() {
37                if range
38                    .clone()
39                    .any(|col| (0..row.nrows()).any(|axis| row[[axis, col]] != 0.0))
40                {
41                    active.push(block);
42                }
43            }
44            for i in 0..active.len() {
45                for j in (i + 1)..active.len() {
46                    edge_set.push((active[i].min(active[j]), active[i].max(active[j])));
47                }
48            }
49        }
50        edge_set.sort_unstable();
51        edge_set.dedup();
52
53        let edges: Vec<_> = edge_set.iter().map(|&(a, b)| BetaEdge { a, b }).collect();
54        let mut degree = vec![0usize; num_blocks];
55        for &BetaEdge { a, b } in &edges {
56            degree[a] += 1;
57            degree[b] += 1;
58        }
59        let mut adj_start = vec![0usize; num_blocks + 1];
60        for block in 0..num_blocks {
61            adj_start[block + 1] = adj_start[block] + degree[block];
62        }
63        let mut adj_targets = vec![0usize; adj_start[num_blocks]];
64        let mut cursor = adj_start[..num_blocks].to_vec();
65        for &BetaEdge { a, b } in &edges {
66            adj_targets[cursor[a]] = b;
67            cursor[a] += 1;
68            adj_targets[cursor[b]] = a;
69            cursor[b] += 1;
70        }
71        Self {
72            num_blocks,
73            edges,
74            adj_start,
75            adj_targets,
76        }
77    }
78
79    pub(crate) fn neighbours(&self, node: usize) -> &[usize] {
80        &self.adj_targets[self.adj_start[node]..self.adj_start[node + 1]]
81    }
82
83    pub(crate) fn component_partition(&self) -> Vec<Vec<usize>> {
84        let mut parent: Vec<usize> = (0..self.num_blocks).collect();
85        let mut rank = vec![0u8; self.num_blocks];
86
87        fn find(parent: &mut [usize], mut x: usize) -> usize {
88            while parent[x] != x {
89                parent[x] = parent[parent[x]];
90                x = parent[x];
91            }
92            x
93        }
94
95        for &BetaEdge { a, b } in &self.edges {
96            let lhs = find(&mut parent, a);
97            let rhs = find(&mut parent, b);
98            if lhs != rhs {
99                if rank[lhs] < rank[rhs] {
100                    parent[lhs] = rhs;
101                } else if rank[lhs] > rank[rhs] {
102                    parent[rhs] = lhs;
103                } else {
104                    parent[rhs] = lhs;
105                    rank[lhs] += 1;
106                }
107            }
108        }
109
110        let mut label_map = vec![usize::MAX; self.num_blocks];
111        let mut parts = Vec::<Vec<usize>>::new();
112        for block in 0..self.num_blocks {
113            let root = find(&mut parent, block);
114            let label = if label_map[root] == usize::MAX {
115                label_map[root] = parts.len();
116                parts.push(Vec::new());
117                label_map[root]
118            } else {
119                label_map[root]
120            };
121            parts[label].push(block);
122        }
123        parts
124    }
125
126    pub(crate) fn expand_one_hop(&self, seed: &[usize]) -> Vec<usize> {
127        let mut expanded = seed.to_vec();
128        for &block in seed {
129            expanded.extend_from_slice(self.neighbours(block));
130        }
131        expanded.sort_unstable();
132        expanded.dedup();
133        expanded
134    }
135}
136// ---------------------------------------------------------------------------
137// BetaPenaltyOp — matrix-free penalty-side H_ββ abstraction (#296)
138// ---------------------------------------------------------------------------
139
140/// Identifies one contiguous column block in the shared β vector for
141/// block-Jacobi Schur pre-conditioning (#287).
142///
143/// A `BetaBlockId(i)` refers to the `i`-th range in
144/// [`ArrowSchurSystem::block_offsets`].
145#[derive(Debug, Clone, Copy, PartialEq, Eq)]
146pub struct BetaBlockId(pub usize);
147
148/// Matrix-free operator for the penalty side of `H_ββ`.
149///
150/// Callers must satisfy the additive convention: every method **adds** its
151/// contribution to the output buffer (i.e. `y += P x`, not `y = P x`).
152/// This matches the assembly pattern where multiple penalty terms are
153/// accumulated into the same gradient / Hessian buffers.
154pub trait BetaPenaltyOp: Send + Sync {
155    /// Full dimension `K` of the β vector.
156    fn dim(&self) -> usize;
157    /// `y += P x` — penalty Hessian-vector product (length `K`).
158    fn matvec(&self, x: &[f64], y: &mut [f64]);
159    /// Penalty gradient: `out += P β`.
160    fn gradient(&self, beta: &[f64], out: &mut [f64]);
161    /// `diag += diag(P)` — diagonal entries used by Jacobi preconditioner.
162    fn diagonal(&self, diag: &mut [f64]);
163    /// Add the `b×b` dense penalty sub-block for block `id` into `out`
164    /// (row-major, block size `b = offsets[id.0].len()`).
165    /// Used by the block-Jacobi Schur preconditioner (#287).
166    fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>);
167    /// Materialize the full `K×K` dense penalty matrix (needed by
168    /// Direct / SqrtBA modes that form the Schur complement explicitly).
169    fn to_dense(&self) -> Array2<f64>;
170    /// Per-row absolute-value sums `out[r] = Σ_c |P[r,c]|`, the row contribution
171    /// to the operator's `∞`-norm. The default folds `to_dense()`, which costs an
172    /// `O(K²)` materialization; structured operators override this to fold their
173    /// compact factors directly so the backward-error certificate's
174    /// `arrow_operator_infinity_norm` never builds a dense `K×K` matrix on the
175    /// SAE LLM-border critical path (#1017). Overrides MUST agree bit-for-bit with
176    /// the `to_dense()` row sums (verified by the cross-check tests).
177    fn row_abs_sums(&self) -> Array1<f64> {
178        let dense = self.to_dense();
179        let k = dense.nrows();
180        let mut out = Array1::<f64>::zeros(k);
181        for r in 0..k {
182            let mut s = 0.0_f64;
183            for c in 0..dense.ncols() {
184                s += dense[[r, c]].abs();
185            }
186            out[r] = s;
187        }
188        out
189    }
190    /// Mix the operator's defining state into `hasher` for cache-validity
191    /// fingerprinting. Must change whenever `matvec` / `to_dense` would change,
192    /// so the factorization / evidence cache (`cache_matches_system`) is
193    /// invalidated when the β-block content changes. Implementations hash their
194    /// own compact defining data (e.g. Kronecker factors, block matrices)
195    /// rather than the full `K×K` dense form, which would defeat the structured
196    /// operator's storage savings.
197    fn fingerprint(&self, hasher: &mut Fingerprinter);
198
199    /// If this operator writes its `matvec` contribution into EXACTLY one
200    /// contiguous output range `[start, end)` and touches no other index,
201    /// return that range; otherwise `None` (the default — opaque / scattered
202    /// output). This lets [`CompositePenaltyOp::matvec`] fan a leading run of
203    /// mutually-disjoint operators across rayon workers, each writing its own
204    /// output sub-slice with no cross-thread aliasing. The per-atom Kronecker
205    /// smooth penalties (the SAE prologue's serial Amdahl ceiling at the K=32k
206    /// manifold border) each cover one atom's β block, so they qualify.
207    fn output_range(&self) -> Option<Range<usize>> {
208        None
209    }
210
211    /// Accumulate `matvec`'s contribution into `y_local`, where `y_local[i]`
212    /// aliases global output index `output_range().start + i` (so
213    /// `y_local.len() == output_range().len()`), reading the FULL-length input
214    /// `x`. ONLY valid when [`Self::output_range`] returns `Some`; the default
215    /// panics because a `None`-range operator has no single contiguous slice to
216    /// write into. Must be BIT-IDENTICAL to the corresponding indices of
217    /// `matvec` (same per-index accumulation order) so the composite's parallel
218    /// prefix reproduces the serial result exactly.
219    fn matvec_local(&self, x: &[f64], y_local: &mut [f64]) {
220        // SAFETY: hard contract guard, not a silent sentinel. A `None`-range
221        // `BetaPenaltyOp` exposes no single contiguous output slice, so the
222        // local matvec is undefined and MUST be routed through `matvec`. This
223        // default exists only to turn that misuse into a loud, immediate
224        // failure at the call site; the input/local extents are surfaced for
225        // triage.
226        panic!(
227            "matvec_local requires output_range() == Some; a None-range \
228             BetaPenaltyOp (input len {}, local output len {}) must be applied \
229             through matvec",
230            x.len(),
231            y_local.len()
232        );
233    }
234}
235
236/// Dense fallback: wraps the existing `K×K` `H_ββ` accumulator.
237pub struct DensePenaltyOp(pub Array2<f64>);
238
239impl BetaPenaltyOp for DensePenaltyOp {
240    fn dim(&self) -> usize {
241        self.0.nrows()
242    }
243
244    fn matvec(&self, x: &[f64], y: &mut [f64]) {
245        let k = self.0.nrows();
246        for a in 0..k {
247            let mut acc = 0.0_f64;
248            for b in 0..k {
249                acc += self.0[[a, b]] * x[b];
250            }
251            y[a] += acc;
252        }
253    }
254
255    fn gradient(&self, beta: &[f64], out: &mut [f64]) {
256        let k = self.0.nrows();
257        for a in 0..k {
258            let mut acc = 0.0_f64;
259            for b in 0..k {
260                acc += self.0[[a, b]] * beta[b];
261            }
262            out[a] += acc;
263        }
264    }
265
266    fn diagonal(&self, diag: &mut [f64]) {
267        let k = self.0.nrows().min(diag.len());
268        for j in 0..k {
269            diag[j] += self.0[[j, j]];
270        }
271    }
272
273    fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
274        let range = &offsets[id.0];
275        let b = range.end - range.start;
276        for bi in 0..b {
277            for bj in 0..b {
278                out[[bi, bj]] += self.0[[range.start + bi, range.start + bj]];
279            }
280        }
281    }
282
283    fn to_dense(&self) -> Array2<f64> {
284        self.0.clone()
285    }
286
287    fn fingerprint(&self, hasher: &mut Fingerprinter) {
288        hasher.write_str("dense-penalty-op-v1");
289        hasher.write_f64_array2(&self.0);
290    }
291}
292
293/// Block-local penalty operator: applies per-block penalty matrices
294/// (matching `ParameterBlockSpec` boundaries) without materialising a
295/// full `K×K` dense matrix.
296///
297/// Each entry is `(global_offset, local_matrix)` where `global_offset`
298/// is the start of that block in the full β vector.
299pub struct BlockPenaltyOp {
300    /// Full β dimension `K`.
301    pub k: usize,
302    /// `(global_start, local_matrix)` for each atom/block.
303    pub blocks: Vec<(usize, Array2<f64>)>,
304}
305
306impl BetaPenaltyOp for BlockPenaltyOp {
307    fn dim(&self) -> usize {
308        self.k
309    }
310
311    fn matvec(&self, x: &[f64], y: &mut [f64]) {
312        for (off, local) in &self.blocks {
313            let b = local.nrows();
314            for i in 0..b {
315                let gi = off + i;
316                let mut acc = 0.0_f64;
317                for j in 0..b {
318                    acc += local[[i, j]] * x[off + j];
319                }
320                y[gi] += acc;
321            }
322        }
323    }
324
325    fn gradient(&self, beta: &[f64], out: &mut [f64]) {
326        for (off, local) in &self.blocks {
327            let b = local.nrows();
328            for i in 0..b {
329                let gi = off + i;
330                let mut acc = 0.0_f64;
331                for j in 0..b {
332                    acc += local[[i, j]] * beta[off + j];
333                }
334                out[gi] += acc;
335            }
336        }
337    }
338
339    fn diagonal(&self, diag: &mut [f64]) {
340        for (off, local) in &self.blocks {
341            let b = local.nrows();
342            for j in 0..b {
343                diag[off + j] += local[[j, j]];
344            }
345        }
346    }
347
348    fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
349        let range = &offsets[id.0];
350        let b_out = range.end - range.start;
351        for (off, local) in &self.blocks {
352            let b = local.nrows();
353            let block_end = off + b;
354            if block_end <= range.start || *off >= range.end {
355                continue;
356            }
357            for bi in 0..b_out {
358                let gi = range.start + bi;
359                if gi < *off || gi >= block_end {
360                    continue;
361                }
362                let li = gi - off;
363                for bj in 0..b_out {
364                    let gj = range.start + bj;
365                    if gj < *off || gj >= block_end {
366                        continue;
367                    }
368                    let lj = gj - off;
369                    out[[bi, bj]] += local[[li, lj]];
370                }
371            }
372        }
373    }
374
375    fn to_dense(&self) -> Array2<f64> {
376        let mut out = Array2::<f64>::zeros((self.k, self.k));
377        for (off, local) in &self.blocks {
378            let b = local.nrows();
379            for i in 0..b {
380                for j in 0..b {
381                    out[[off + i, off + j]] += local[[i, j]];
382                }
383            }
384        }
385        out
386    }
387
388    fn fingerprint(&self, hasher: &mut Fingerprinter) {
389        hasher.write_str("block-penalty-op-v1");
390        hasher.write_usize(self.k);
391        hasher.write_usize(self.blocks.len());
392        for (off, local) in &self.blocks {
393            hasher.write_usize(*off);
394            hasher.write_f64_array2(local);
395        }
396    }
397}
398
399/// Kronecker-product penalty: `P = A ⊗ B` applied without materialising
400/// the full `(p_a·p_b)×(p_a·p_b)` matrix.
401pub struct KroneckerPenaltyOp {
402    /// Left factor `A`, shape `(p_a, p_a)`.
403    pub factor_a: Array2<f64>,
404    /// Right factor `B`, shape `(p_b, p_b)`.
405    pub factor_b: Array2<f64>,
406    /// Global offset into the β vector where this block starts.
407    pub global_offset: usize,
408    /// Full β dimension `K`.
409    pub k: usize,
410}
411
412impl BetaPenaltyOp for KroneckerPenaltyOp {
413    fn dim(&self) -> usize {
414        self.k
415    }
416
417    fn matvec(&self, x: &[f64], y: &mut [f64]) {
418        let p_a = self.factor_a.nrows();
419        let p_b = self.factor_b.nrows();
420        let off = self.global_offset;
421        // (A ⊗ B) vec(V) where V is (p_b, p_a) with Fortran/vec ordering.
422        for i_a in 0..p_a {
423            for i_b in 0..p_b {
424                let gi = off + i_a * p_b + i_b;
425                let mut acc = 0.0_f64;
426                for j_a in 0..p_a {
427                    let a_ij = self.factor_a[[i_a, j_a]];
428                    if a_ij == 0.0 {
429                        continue;
430                    }
431                    for j_b in 0..p_b {
432                        acc += a_ij * self.factor_b[[i_b, j_b]] * x[off + j_a * p_b + j_b];
433                    }
434                }
435                y[gi] += acc;
436            }
437        }
438    }
439
440    fn output_range(&self) -> Option<Range<usize>> {
441        let off = self.global_offset;
442        Some(off..off + self.factor_a.nrows() * self.factor_b.nrows())
443    }
444
445    fn matvec_local(&self, x: &[f64], y_local: &mut [f64]) {
446        // Byte-for-byte the `matvec` arithmetic with the output written at the
447        // LOCAL index `i_a·p_b + i_b` (== global `gi - off`), so the composite
448        // can apply this block into its own `y[off..off+p_a·p_b]` sub-slice in
449        // parallel. Per-index accumulation order is unchanged ⇒ bit-identical.
450        let p_a = self.factor_a.nrows();
451        let p_b = self.factor_b.nrows();
452        let off = self.global_offset;
453        for i_a in 0..p_a {
454            for i_b in 0..p_b {
455                let li = i_a * p_b + i_b;
456                let mut acc = 0.0_f64;
457                for j_a in 0..p_a {
458                    let a_ij = self.factor_a[[i_a, j_a]];
459                    if a_ij == 0.0 {
460                        continue;
461                    }
462                    for j_b in 0..p_b {
463                        acc += a_ij * self.factor_b[[i_b, j_b]] * x[off + j_a * p_b + j_b];
464                    }
465                }
466                y_local[li] += acc;
467            }
468        }
469    }
470
471    fn gradient(&self, beta: &[f64], out: &mut [f64]) {
472        let p_a = self.factor_a.nrows();
473        let p_b = self.factor_b.nrows();
474        let off = self.global_offset;
475        for i_a in 0..p_a {
476            for i_b in 0..p_b {
477                let gi = off + i_a * p_b + i_b;
478                let mut acc = 0.0_f64;
479                for j_a in 0..p_a {
480                    let a_ij = self.factor_a[[i_a, j_a]];
481                    if a_ij == 0.0 {
482                        continue;
483                    }
484                    for j_b in 0..p_b {
485                        acc += a_ij * self.factor_b[[i_b, j_b]] * beta[off + j_a * p_b + j_b];
486                    }
487                }
488                out[gi] += acc;
489            }
490        }
491    }
492
493    fn diagonal(&self, diag: &mut [f64]) {
494        let p_a = self.factor_a.nrows();
495        let p_b = self.factor_b.nrows();
496        let off = self.global_offset;
497        for i_a in 0..p_a {
498            for i_b in 0..p_b {
499                diag[off + i_a * p_b + i_b] +=
500                    self.factor_a[[i_a, i_a]] * self.factor_b[[i_b, i_b]];
501            }
502        }
503    }
504
505    fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
506        let range = &offsets[id.0];
507        let b = range.end - range.start;
508        let p_a = self.factor_a.nrows();
509        let p_b = self.factor_b.nrows();
510        let off = self.global_offset;
511        let block_end = off + p_a * p_b;
512        if block_end <= range.start || off >= range.end {
513            return;
514        }
515        for bi in 0..b {
516            let gi = range.start + bi;
517            if gi < off || gi >= block_end {
518                continue;
519            }
520            let li = gi - off;
521            let i_a = li / p_b;
522            let i_b = li % p_b;
523            for bj in 0..b {
524                let gj = range.start + bj;
525                if gj < off || gj >= block_end {
526                    continue;
527                }
528                let lj = gj - off;
529                let j_a = lj / p_b;
530                let j_b = lj % p_b;
531                out[[bi, bj]] += self.factor_a[[i_a, j_a]] * self.factor_b[[i_b, j_b]];
532            }
533        }
534    }
535
536    fn to_dense(&self) -> Array2<f64> {
537        let p_a = self.factor_a.nrows();
538        let p_b = self.factor_b.nrows();
539        let off = self.global_offset;
540        let mut out = Array2::<f64>::zeros((self.k, self.k));
541        for i_a in 0..p_a {
542            for i_b in 0..p_b {
543                let gi = off + i_a * p_b + i_b;
544                for j_a in 0..p_a {
545                    let a_ij = self.factor_a[[i_a, j_a]];
546                    if a_ij == 0.0 {
547                        continue;
548                    }
549                    for j_b in 0..p_b {
550                        let gj = off + j_a * p_b + j_b;
551                        out[[gi, gj]] += a_ij * self.factor_b[[i_b, j_b]];
552                    }
553                }
554            }
555        }
556        out
557    }
558
559    fn fingerprint(&self, hasher: &mut Fingerprinter) {
560        hasher.write_str("kronecker-penalty-op-v1");
561        hasher.write_usize(self.global_offset);
562        hasher.write_usize(self.k);
563        hasher.write_f64_array2(&self.factor_a);
564        hasher.write_f64_array2(&self.factor_b);
565    }
566}
567
568/// Kronecker-product penalty with an identity right factor:
569/// `P = A ⊗ I_p`.
570///
571/// This is the hot SAE smoothness case. Storing `I_p` as a dense matrix costs
572/// `O(p²)` memory per atom and makes every matvec pay an unnecessary right-factor
573/// loop. This operator stores only the identity dimension and keeps the same
574/// layout as [`KroneckerPenaltyOp`]: local index `i_a * p + i_b`.
575pub struct IdentityRightKroneckerPenaltyOp {
576    /// Left factor `A`, shape `(p_a, p_a)`.
577    pub factor_a: Array2<f64>,
578    /// Identity right-factor dimension `p`.
579    pub p: usize,
580    /// Global offset into the β vector where this block starts.
581    pub global_offset: usize,
582    /// Full β dimension `K`.
583    pub k: usize,
584}
585
586impl BetaPenaltyOp for IdentityRightKroneckerPenaltyOp {
587    fn dim(&self) -> usize {
588        self.k
589    }
590
591    fn matvec(&self, x: &[f64], y: &mut [f64]) {
592        let p_a = self.factor_a.nrows();
593        let p = self.p;
594        let off = self.global_offset;
595        for i_a in 0..p_a {
596            for i_b in 0..p {
597                let gi = off + i_a * p + i_b;
598                let mut acc = 0.0_f64;
599                for j_a in 0..p_a {
600                    let a_ij = self.factor_a[[i_a, j_a]];
601                    if a_ij == 0.0 {
602                        continue;
603                    }
604                    acc += a_ij * x[off + j_a * p + i_b];
605                }
606                y[gi] += acc;
607            }
608        }
609    }
610
611    fn output_range(&self) -> Option<Range<usize>> {
612        let off = self.global_offset;
613        Some(off..off + self.factor_a.nrows() * self.p)
614    }
615
616    fn matvec_local(&self, x: &[f64], y_local: &mut [f64]) {
617        // Byte-for-byte the `matvec` inner arithmetic, but the output writes to
618        // the LOCAL index `i_a·p + i_b` (== global `gi - off`) so the composite
619        // can hand this operator its own `y[off..off+p_a·p]` sub-slice. The
620        // per-index accumulation order over `j_a` is unchanged, so the result is
621        // bit-identical to `matvec`.
622        let p_a = self.factor_a.nrows();
623        let p = self.p;
624        let off = self.global_offset;
625        for i_a in 0..p_a {
626            for i_b in 0..p {
627                let li = i_a * p + i_b;
628                let mut acc = 0.0_f64;
629                for j_a in 0..p_a {
630                    let a_ij = self.factor_a[[i_a, j_a]];
631                    if a_ij == 0.0 {
632                        continue;
633                    }
634                    acc += a_ij * x[off + j_a * p + i_b];
635                }
636                y_local[li] += acc;
637            }
638        }
639    }
640
641    fn gradient(&self, beta: &[f64], out: &mut [f64]) {
642        self.matvec(beta, out);
643    }
644
645    fn diagonal(&self, diag: &mut [f64]) {
646        let p_a = self.factor_a.nrows();
647        let p = self.p;
648        let off = self.global_offset;
649        for i_a in 0..p_a {
650            let a_ii = self.factor_a[[i_a, i_a]];
651            for i_b in 0..p {
652                diag[off + i_a * p + i_b] += a_ii;
653            }
654        }
655    }
656
657    fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
658        let range = &offsets[id.0];
659        let b = range.end - range.start;
660        let p_a = self.factor_a.nrows();
661        let p = self.p;
662        let off = self.global_offset;
663        let block_end = off + p_a * p;
664        if block_end <= range.start || off >= range.end {
665            return;
666        }
667        for bi in 0..b {
668            let gi = range.start + bi;
669            if gi < off || gi >= block_end {
670                continue;
671            }
672            let li = gi - off;
673            let i_a = li / p;
674            let i_b = li % p;
675            for bj in 0..b {
676                let gj = range.start + bj;
677                if gj < off || gj >= block_end {
678                    continue;
679                }
680                let lj = gj - off;
681                let j_a = lj / p;
682                let j_b = lj % p;
683                if i_b == j_b {
684                    out[[bi, bj]] += self.factor_a[[i_a, j_a]];
685                }
686            }
687        }
688    }
689
690    fn to_dense(&self) -> Array2<f64> {
691        let p_a = self.factor_a.nrows();
692        let p = self.p;
693        let off = self.global_offset;
694        let mut out = Array2::<f64>::zeros((self.k, self.k));
695        for i_a in 0..p_a {
696            for j_a in 0..p_a {
697                let a_ij = self.factor_a[[i_a, j_a]];
698                if a_ij == 0.0 {
699                    continue;
700                }
701                for i_b in 0..p {
702                    let gi = off + i_a * p + i_b;
703                    let gj = off + j_a * p + i_b;
704                    out[[gi, gj]] += a_ij;
705                }
706            }
707        }
708        out
709    }
710
711    fn fingerprint(&self, hasher: &mut Fingerprinter) {
712        hasher.write_str("identity-right-kronecker-penalty-op-v1");
713        hasher.write_usize(self.global_offset);
714        hasher.write_usize(self.k);
715        hasher.write_usize(self.p);
716        hasher.write_f64_array2(&self.factor_a);
717    }
718}
719
720/// One co-occurring atom-pair block of a block-sparse left factor `A`.
721///
722/// `data` is the dense `(m_i × m_j)` coupling between the basis columns of
723/// atom `i` (rows, starting at left-factor offset `row_off`) and atom `j`
724/// (columns, starting at `col_off`). Both offsets are in *left-factor* (`A`)
725/// coordinates, i.e. `μ`-space, not β-space.
726#[derive(Debug, Clone)]
727pub struct SparseGBlock {
728    /// Left-factor (`μ`-space) row offset = `beta_offset[atom_i] / p`.
729    pub row_off: usize,
730    /// Left-factor (`μ`-space) column offset = `beta_offset[atom_j] / p`.
731    pub col_off: usize,
732    /// Dense `(m_i × m_j)` coupling block.
733    pub data: Array2<f64>,
734}
735
736/// Block-sparse Kronecker penalty `P = A ⊗ I_p` where the left factor `A`
737/// (dimension `dim_a × dim_a` in `μ`-space) is stored only on its non-empty
738/// co-occurring atom-pair blocks rather than as a dense `(dim_a × dim_a)`
739/// matrix.
740///
741/// This is the sparse-atom (`K = 100K`) replacement for wrapping the dense
742/// data-fit Gauss-Newton Gram `G` (`m_total × m_total`) in a
743/// [`KroneckerPenaltyOp`]: with per-row active sets of size `k_active ≪ K`,
744/// only the `(atom, atom')` pairs that co-occur in some row contribute a
745/// non-zero `(m_i × m_j)` block, so the storage and every matvec/diagonal
746/// pass cost `O(Σ_pairs m_i m_j · p)` instead of `O((m_total · p)²)`.
747///
748/// The β index of left-factor coordinate `μ` and output channel `oc` is
749/// `μ · p + oc` (the same `μ`-major / `oc`-minor layout the dense
750/// `KroneckerPenaltyOp { factor_b: I_p }` uses), so this op is a drop-in
751/// structured replacement: with the full dense pair set it reproduces the
752/// dense operator exactly.
753pub struct SparseBlockKroneckerPenaltyOp {
754    /// Right-factor identity dimension `p` (number of decoder output channels).
755    pub p: usize,
756    /// Left-factor dimension `dim_a` in `μ`-space (= `m_total`).
757    pub dim_a: usize,
758    /// Full β dimension `K = dim_a · p`.
759    pub k: usize,
760    /// Non-empty `(atom_i, atom_j)` coupling blocks of `A`.
761    pub blocks: Vec<SparseGBlock>,
762}
763
764#[derive(Debug, Clone)]
765pub struct DeviceSaeSmoothBlock {
766    pub global_offset: usize,
767    pub factor_a: Array2<f64>,
768}
769
770/// Frame-factored extension of [`DeviceSaePcgData`] (issue #1017/#1026,
771/// frames-engaged device PCG). Present only when at least one atom is genuinely
772/// frame-reduced (`ranks[k] < p`); absent (`None`) on the full-`B` path, where
773/// the legacy `G ⊗ I_p` channel-identical kernel applies byte-for-byte.
774///
775/// On the frames path the β border is the FACTORED coordinate space `C` of width
776/// `Σ_k M_k·r_k`, the data-fit β-Hessian is `G_{ij} ⊗ W_{ij}` (`W_{ij}=U_iᵀU_j`,
777/// carried on `frame_blocks`), the smooth penalty is `λ S_k ⊗ I_{r_k}`
778/// (`smooth_blocks`, reused — width `r_k` instead of `p`), and the per-row
779/// reduced-Schur cross-block `H_tβ^(i)` is the DENSE `(q_i × border_dim)` slab
780/// `row_htbeta[i]` (row-major) rather than the full-`B` factored `L_i · J_β`
781/// gather (so `a_phi`/`local_jac` are unused on this path).
782#[derive(Debug, Clone)]
783pub struct DeviceSaeFrameData {
784    /// Per-atom frame rank `r_k` (factored output width); `r_k == p` for an
785    /// un-framed atom riding the identity special case.
786    pub ranks: Vec<usize>,
787    /// Per-atom basis size `M_k`.
788    pub basis_sizes: Vec<usize>,
789    /// Per-atom factored-border offset `off_C[k]` (prefix sum of `M_k·r_k`),
790    /// length `n_atoms`. Atom `k`'s `C_k` block is `[off_C[k] .. +M_k·r_k)`.
791    pub border_offsets: Vec<usize>,
792    /// Co-occurring `(atom_i, atom_j)` data-fit blocks `g ⊗ w` (`w = U_iᵀU_j`).
793    pub frame_blocks: Vec<FactoredFrameGBlock>,
794    /// Right-factor width (`r_k`) of each entry of the top-level
795    /// `DeviceSaePcgData::smooth_blocks`, in the SAME order. On the frames path
796    /// the smooth penalty is `λ S_k ⊗ I_{r_k}` so the block at
797    /// `smooth_blocks[i].global_offset` has identity width `smooth_ranks[i]`
798    /// (which equals `ranks[atom]`), NOT the ambient `p`.
799    pub smooth_ranks: Vec<usize>,
800    /// Per-row dense cross-block `H_tβ^(i)` as a row-major `q_i × border_dim`
801    /// buffer (`q_i = row_dims[i]`). Empty inner `Vec` for a 0-dim row.
802    pub row_htbeta: Vec<Vec<f64>>,
803}
804
805#[derive(Debug, Clone)]
806pub struct DeviceSaePcgData {
807    pub p: usize,
808    pub beta_dim: usize,
809    // #1033 large-n: the per-row support `a_phi` and local Jacobians `local_jac`
810    // are ALSO held by the host matrix-free row operator (`SaeKroneckerRows`) for
811    // the lifetime of the inner solve. Storing them as `Arc<[…]>` lets the
812    // assembler hand BOTH consumers the SAME backing allocation instead of a
813    // second full `O(n·q·p)` clone (`device_rows = (a_phi.clone(), kron_jac.clone())`
814    // was the dominant always-resident duplication on the CPU non-frames path at
815    // the LLM shape p≈5120). Indexing/`.len()`/iteration are identical to `Vec`.
816    pub a_phi: Arc<[Vec<(usize, f64)>]>,
817    pub local_jac: Arc<[Vec<f64>]>,
818    pub smooth_blocks: Vec<DeviceSaeSmoothBlock>,
819    pub sparse_g_blocks: Vec<SparseGBlock>,
820    /// Frame-factored metadata. `None` ⇒ legacy full-`B` `G ⊗ I_p` path
821    /// (byte-identical to before this field existed). `Some` ⇒ frames-engaged
822    /// path: the kernel consumes `frame.frame_blocks`/`smooth_blocks` (now
823    /// rank-`r_k` wide) and `frame.row_htbeta` instead of the `⊗ I_p` gather.
824    pub frame: Option<DeviceSaeFrameData>,
825}
826
827impl DeviceSaePcgData {
828    /// Snapshot the per-row active-atom support `a_phi` into a shared `Arc<[…]>`
829    /// for the CPU residency operator ([`SaeResidentReducedSchur`]). Cloned once
830    /// per CG-solve build (cost `O(Σ_i m_i)`, dwarfed by the per-row factor solves
831    /// in the same build), so the resident matvec borrows the index lists without
832    /// re-cloning them on every CG iteration.
833    pub(crate) fn a_phi_shared(&self) -> Arc<[Vec<(usize, f64)>]> {
834        // #1033: `a_phi` is already an `Arc<[…]>`; hand back a refcount bump
835        // (`O(1)`) rather than re-cloning every `(idx, weight)` pair per CG build.
836        Arc::clone(&self.a_phi)
837    }
838
839    /// Share the per-row local Jacobians `local_jac` with the CPU residency
840    /// operator ([`SaeResidentReducedSchur`]) as an `O(1)` refcount bump. The
841    /// staged row factor used to hold a verbatim row-major copy of each
842    /// `local_jac[row]`; sharing the slab removes that second full `O(n·di·p)`
843    /// copy with byte-for-byte identical reads (#1033).
844    pub(crate) fn local_jac_shared(&self) -> Arc<[Vec<f64>]> {
845        Arc::clone(&self.local_jac)
846    }
847}
848
849impl BetaPenaltyOp for SparseBlockKroneckerPenaltyOp {
850    fn dim(&self) -> usize {
851        self.k
852    }
853
854    fn matvec(&self, x: &[f64], y: &mut [f64]) {
855        let p = self.p;
856        for blk in &self.blocks {
857            let (m_i, m_j) = blk.data.dim();
858            for li in 0..m_i {
859                let gi_base = (blk.row_off + li) * p;
860                for lj in 0..m_j {
861                    let a_ij = blk.data[[li, lj]];
862                    if a_ij == 0.0 {
863                        continue;
864                    }
865                    let gj_base = (blk.col_off + lj) * p;
866                    for oc in 0..p {
867                        y[gi_base + oc] += a_ij * x[gj_base + oc];
868                    }
869                }
870            }
871        }
872    }
873
874    fn gradient(&self, beta: &[f64], out: &mut [f64]) {
875        self.matvec(beta, out);
876    }
877
878    fn diagonal(&self, diag: &mut [f64]) {
879        let p = self.p;
880        for blk in &self.blocks {
881            // Only on-diagonal `A` blocks (row_off == col_off) carry diagonal
882            // mass; their `(li, li)` entries map to `(row_off+li)·p + oc`.
883            if blk.row_off != blk.col_off {
884                continue;
885            }
886            let (m_i, m_j) = blk.data.dim();
887            let m = m_i.min(m_j);
888            for li in 0..m {
889                let a_ii = blk.data[[li, li]];
890                let gi_base = (blk.row_off + li) * p;
891                for oc in 0..p {
892                    diag[gi_base + oc] += a_ii;
893                }
894            }
895        }
896    }
897
898    fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
899        let range = &offsets[id.0];
900        let b = range.end - range.start;
901        let p = self.p;
902        for blk in &self.blocks {
903            let (m_i, m_j) = blk.data.dim();
904            let row_start = blk.row_off * p;
905            let row_end = (blk.row_off + m_i) * p;
906            let col_start = blk.col_off * p;
907            let col_end = (blk.col_off + m_j) * p;
908            if row_end <= range.start
909                || row_start >= range.end
910                || col_end <= range.start
911                || col_start >= range.end
912            {
913                continue;
914            }
915            for bi in 0..b {
916                let gi = range.start + bi;
917                if gi < row_start || gi >= row_end {
918                    continue;
919                }
920                let li = (gi - row_start) / p;
921                let oc_i = (gi - row_start) % p;
922                for bj in 0..b {
923                    let gj = range.start + bj;
924                    if gj < col_start || gj >= col_end {
925                        continue;
926                    }
927                    let oc_j = (gj - col_start) % p;
928                    if oc_i != oc_j {
929                        continue;
930                    }
931                    let lj = (gj - col_start) / p;
932                    out[[bi, bj]] += blk.data[[li, lj]];
933                }
934            }
935        }
936    }
937
938    fn to_dense(&self) -> Array2<f64> {
939        let p = self.p;
940        let mut out = Array2::<f64>::zeros((self.k, self.k));
941        for blk in &self.blocks {
942            let (m_i, m_j) = blk.data.dim();
943            for li in 0..m_i {
944                let gi_base = (blk.row_off + li) * p;
945                for lj in 0..m_j {
946                    let a_ij = blk.data[[li, lj]];
947                    if a_ij == 0.0 {
948                        continue;
949                    }
950                    let gj_base = (blk.col_off + lj) * p;
951                    for oc in 0..p {
952                        out[[gi_base + oc, gj_base + oc]] += a_ij;
953                    }
954                }
955            }
956        }
957        out
958    }
959
960    fn row_abs_sums(&self) -> Array1<f64> {
961        // Mirror `to_dense`: entry `(gi_base+oc, gj_base+oc) += a_ij`. Each
962        // `(li, lj, oc)` lands in a DISTINCT column (`gj_base+oc` is injective in
963        // `(lj, oc)` for a fixed block, and blocks with the same `row_off` have
964        // disjoint `col_off`), so the row's `Σ_c|P[r,c]|` is just the sum of
965        // `|a_ij|` over the contributing `(block, lj)` pairs — no dense matrix.
966        let p = self.p;
967        let mut out = Array1::<f64>::zeros(self.k);
968        for blk in &self.blocks {
969            let (m_i, m_j) = blk.data.dim();
970            for li in 0..m_i {
971                let gi_base = (blk.row_off + li) * p;
972                let mut row_abs = 0.0_f64;
973                for lj in 0..m_j {
974                    row_abs += blk.data[[li, lj]].abs();
975                }
976                for oc in 0..p {
977                    out[gi_base + oc] += row_abs;
978                }
979            }
980        }
981        out
982    }
983
984    fn fingerprint(&self, hasher: &mut Fingerprinter) {
985        hasher.write_str("sparse-block-kronecker-penalty-op-v1");
986        hasher.write_usize(self.p);
987        hasher.write_usize(self.dim_a);
988        hasher.write_usize(self.k);
989        hasher.write_usize(self.blocks.len());
990        for blk in &self.blocks {
991            hasher.write_usize(blk.row_off);
992            hasher.write_usize(blk.col_off);
993            hasher.write_f64_array2(&blk.data);
994        }
995    }
996}
997
998/// One co-occurring `(atom_i, atom_j)` block of the **frame-factored** data-fit
999/// Gauss–Newton β-Hessian (issue #972 / #977 T1). Carries the basis-space Gram
1000/// `g` (`m_i × m_j`) AND the per-pair frame output factor `w = U_iᵀ U_j`
1001/// (`r_i × r_j`); the contributed Hessian sub-block is the Kronecker product
1002/// `g ⊗ w`.
1003#[derive(Debug, Clone)]
1004pub struct FactoredFrameGBlock {
1005    /// Atom index of the row factor (selects rank `r_i` and β offset).
1006    pub atom_i: usize,
1007    /// Atom index of the column factor (selects rank `r_j` and β offset).
1008    pub atom_j: usize,
1009    /// Basis-space coupling `G_{ij}` (`m_i × m_j`).
1010    pub g: Array2<f64>,
1011    /// Frame output factor `U_iᵀ U_j` (`r_i × r_j`). For `i == j` with an
1012    /// orthonormal frame this is `I_{r_i}` (the clean within-atom `g ⊗ I_r`
1013    /// collapse); across atoms it is the dense principal-angle cosine matrix
1014    /// between the two frames.
1015    pub w: Array2<f64>,
1016}
1017
1018/// Frame-factored data-fit Gauss–Newton β-Hessian operator (#972 / #977 T1):
1019/// the `Σ_k M_k·r_k` reduced-border analogue of [`SparseBlockKroneckerPenaltyOp`].
1020///
1021/// When every atom's decoder `B_k = C_k U_kᵀ` is profiled onto a Grassmann
1022/// frame `U_k ∈ St(p, r_k)`, the border carries only the shape coefficients
1023/// `C_k` (`M_k · r_k` entries) instead of the full `B_k` (`M_k · p`). The data
1024/// Gram in this reduced space is, for the isotropic likelihood,
1025/// `H[(i,li,a),(j,lj,b)] = G_{ij}[li,lj] · (U_iᵀ U_j)[a,b]` — within an atom the
1026/// orthonormal frame gives `U_iᵀU_i = I_{r_i}` and the block is the clean
1027/// `G ⊗ I_r` collapse; across co-active atoms the frames do not share a basis
1028/// so the output factor is the dense `U_iᵀU_j`.
1029///
1030/// The β layout is `μ`-major / frame-minor with a **variable** per-atom width
1031/// `r_k`: the index of (atom `k`, basis `li`, frame coord `a`) is
1032/// `offset[k] + li·r_k + a`, where `offset` is the prefix sum of `M_k · r_k`.
1033/// With every `r_k = p` and `U_k = I_p` this reproduces
1034/// [`SparseBlockKroneckerPenaltyOp`] exactly (a unit test pins the reduction),
1035/// so it is a strict generalization, not a separate code path.
1036pub struct FactoredFrameKroneckerOp {
1037    /// Per-atom frame rank `r_k` (the factored output width).
1038    pub ranks: Vec<usize>,
1039    /// Per-atom basis size `M_k`.
1040    pub basis_sizes: Vec<usize>,
1041    /// Per-atom β offset (prefix sum of `M_k · r_k`); `offsets[k]` is the start
1042    /// of atom `k`'s `C_k` block, `offsets[n_atoms]` the total dim.
1043    pub offsets: Vec<usize>,
1044    /// Total reduced β dimension `Σ_k M_k · r_k`.
1045    pub dim: usize,
1046    /// Non-empty co-occurring `(atom_i, atom_j)` blocks.
1047    pub blocks: Vec<FactoredFrameGBlock>,
1048}
1049
1050/// Frame output Gram `U_iᵀ U_j` (`r_i × r_j`) between two per-atom output
1051/// frames (each `p × r`). This is the dense principal-angle cosine matrix that
1052/// becomes the `w` factor of a [`FactoredFrameGBlock`]; for `i == j` with an
1053/// orthonormal frame it is `I_{r_i}`. Shared with
1054/// [`gam_terms::sae::manifold`], which builds the same factors when
1055/// profiling decoders onto Grassmann frames.
1056pub fn frame_output_gram(u_i: ArrayView2<f64>, u_j: ArrayView2<f64>) -> Array2<f64> {
1057    let (p_i, r_i) = u_i.dim();
1058    let (p_j, r_j) = u_j.dim();
1059    assert_eq!(
1060        p_i, p_j,
1061        "frame_output_gram: frames live in different ambient dims ({p_i} vs {p_j})"
1062    );
1063    let mut w = Array2::<f64>::zeros((r_i, r_j));
1064    for a in 0..r_i {
1065        for b in 0..r_j {
1066            let mut acc = 0.0;
1067            for c in 0..p_i {
1068                acc += u_i[[c, a]] * u_j[[c, b]];
1069            }
1070            w[[a, b]] = acc;
1071        }
1072    }
1073    w
1074}
1075
1076impl FactoredFrameKroneckerOp {
1077    /// Build from per-atom ranks + basis sizes and the co-occurring blocks.
1078    /// Computes the β offsets (prefix sum of `M_k·r_k`) and validates that each
1079    /// block's `g`/`w` shapes match the atoms' `(M, r)`.
1080    pub fn new(
1081        ranks: Vec<usize>,
1082        basis_sizes: Vec<usize>,
1083        blocks: Vec<FactoredFrameGBlock>,
1084    ) -> Result<Self, String> {
1085        if ranks.len() != basis_sizes.len() {
1086            return Err(format!(
1087                "FactoredFrameKroneckerOp: {} ranks but {} basis sizes",
1088                ranks.len(),
1089                basis_sizes.len()
1090            ));
1091        }
1092        let n_atoms = ranks.len();
1093        let mut offsets = Vec::with_capacity(n_atoms + 1);
1094        let mut acc = 0usize;
1095        for k in 0..n_atoms {
1096            offsets.push(acc);
1097            acc += basis_sizes[k] * ranks[k];
1098        }
1099        offsets.push(acc);
1100        let dim = acc;
1101        for blk in &blocks {
1102            if blk.atom_i >= n_atoms || blk.atom_j >= n_atoms {
1103                return Err(format!(
1104                    "FactoredFrameKroneckerOp: block atom indices ({}, {}) out of range (n_atoms = {n_atoms})",
1105                    blk.atom_i, blk.atom_j
1106                ));
1107            }
1108            if blk.g.dim() != (basis_sizes[blk.atom_i], basis_sizes[blk.atom_j]) {
1109                return Err(format!(
1110                    "FactoredFrameKroneckerOp: block ({}, {}) g has shape {:?} but expected ({}, {})",
1111                    blk.atom_i,
1112                    blk.atom_j,
1113                    blk.g.dim(),
1114                    basis_sizes[blk.atom_i],
1115                    basis_sizes[blk.atom_j]
1116                ));
1117            }
1118            if blk.w.dim() != (ranks[blk.atom_i], ranks[blk.atom_j]) {
1119                return Err(format!(
1120                    "FactoredFrameKroneckerOp: block ({}, {}) w has shape {:?} but expected ({}, {})",
1121                    blk.atom_i,
1122                    blk.atom_j,
1123                    blk.w.dim(),
1124                    ranks[blk.atom_i],
1125                    ranks[blk.atom_j]
1126                ));
1127            }
1128        }
1129        Ok(Self {
1130            ranks,
1131            basis_sizes,
1132            offsets,
1133            dim,
1134            blocks,
1135        })
1136    }
1137
1138    /// Convenience constructor that builds the operator directly from per-atom
1139    /// output frames + the basis-space Gram block map, computing the per-pair
1140    /// frame factors `W_ij = U_iᵀ U_j` itself.
1141    ///
1142    /// `frames[k]` is either `Some(U_k)` — a `p × r_k` (`r_k ≤ p`) output frame
1143    /// (a Grassmann representative `St(p, r_k)` need not be orthonormal here; the
1144    /// `W` factor carries whatever frame is supplied) — or `None`, meaning atom
1145    /// `k` keeps the full ambient output (`U_k = I_p`, so `r_k = p`). For each
1146    /// non-empty Gram block `(atom_i, atom_j)` the factor `W` is
1147    /// `U_iᵀ U_j` (`r_i × r_j`), with the `None` frame standing in for `I_p`:
1148    /// a framed×unframed cross gives `W = U_iᵀ` (`r_i × p`) and an unframed
1149    /// diagonal gives `W = I_p` — exactly reproducing the `g ⊗ I_p` full-`B`
1150    /// block. The resulting blocks are handed to [`Self::new`], which validates
1151    /// the `(M, r)` shapes and computes the β offsets.
1152    pub fn from_frames_and_blocks(
1153        frames: &[Option<Array2<f64>>],
1154        basis_sizes: &[usize],
1155        p: usize,
1156        g_blocks: &std::collections::BTreeMap<(usize, usize), Array2<f64>>,
1157    ) -> Result<Self, String> {
1158        if frames.len() != basis_sizes.len() {
1159            return Err(format!(
1160                "FactoredFrameKroneckerOp::from_frames_and_blocks: {} frames but {} basis sizes",
1161                frames.len(),
1162                basis_sizes.len()
1163            ));
1164        }
1165        let n_atoms = frames.len();
1166        // Per-atom rank: ncols of a supplied frame, else the ambient dim p.
1167        let mut ranks = Vec::with_capacity(n_atoms);
1168        for (k, frame) in frames.iter().enumerate() {
1169            match frame {
1170                Some(u) => {
1171                    let (pr, r) = u.dim();
1172                    if pr != p {
1173                        return Err(format!(
1174                            "FactoredFrameKroneckerOp::from_frames_and_blocks: frame {k} has {pr} rows but ambient dim is {p}"
1175                        ));
1176                    }
1177                    if r > p {
1178                        return Err(format!(
1179                            "FactoredFrameKroneckerOp::from_frames_and_blocks: frame {k} has rank {r} > ambient dim {p}"
1180                        ));
1181                    }
1182                    ranks.push(r);
1183                }
1184                None => ranks.push(p),
1185            }
1186        }
1187        // Materialize each atom's frame as a `p × r_k` view source: the supplied
1188        // `U_k`, or `I_p` for the unframed atoms.
1189        let identity = Array2::<f64>::eye(p);
1190        let frame_or_ident = |k: usize| -> ArrayView2<f64> {
1191            match &frames[k] {
1192                Some(u) => u.view(),
1193                None => identity.view(),
1194            }
1195        };
1196        let mut blocks = Vec::with_capacity(g_blocks.len());
1197        for (&(atom_i, atom_j), g) in g_blocks {
1198            if atom_i >= n_atoms || atom_j >= n_atoms {
1199                return Err(format!(
1200                    "FactoredFrameKroneckerOp::from_frames_and_blocks: block atom indices ({atom_i}, {atom_j}) out of range (n_atoms = {n_atoms})"
1201                ));
1202            }
1203            let w = frame_output_gram(frame_or_ident(atom_i), frame_or_ident(atom_j));
1204            blocks.push(FactoredFrameGBlock {
1205                atom_i,
1206                atom_j,
1207                g: g.clone(),
1208                w,
1209            });
1210        }
1211        Self::new(ranks, basis_sizes.to_vec(), blocks)
1212    }
1213}
1214
1215impl BetaPenaltyOp for FactoredFrameKroneckerOp {
1216    fn dim(&self) -> usize {
1217        self.dim
1218    }
1219
1220    fn matvec(&self, x: &[f64], y: &mut [f64]) {
1221        for blk in &self.blocks {
1222            let r_i = self.ranks[blk.atom_i];
1223            let r_j = self.ranks[blk.atom_j];
1224            let off_i = self.offsets[blk.atom_i];
1225            let off_j = self.offsets[blk.atom_j];
1226            let (m_i, m_j) = blk.g.dim();
1227            for li in 0..m_i {
1228                let yi_base = off_i + li * r_i;
1229                for lj in 0..m_j {
1230                    let g = blk.g[[li, lj]];
1231                    if g == 0.0 {
1232                        continue;
1233                    }
1234                    let xj_base = off_j + lj * r_j;
1235                    // y_block[li, a] += g · Σ_b w[a, b] · x_block[lj, b]
1236                    for a in 0..r_i {
1237                        let mut acc = 0.0;
1238                        for b in 0..r_j {
1239                            acc += blk.w[[a, b]] * x[xj_base + b];
1240                        }
1241                        y[yi_base + a] += g * acc;
1242                    }
1243                }
1244            }
1245        }
1246    }
1247
1248    fn gradient(&self, beta: &[f64], out: &mut [f64]) {
1249        self.matvec(beta, out);
1250    }
1251
1252    fn diagonal(&self, diag: &mut [f64]) {
1253        for blk in &self.blocks {
1254            // Only on-diagonal atom blocks carry diagonal mass; the entry at
1255            // (atom k, basis li, coord a) is g[li,li]·w[a,a].
1256            if blk.atom_i != blk.atom_j {
1257                continue;
1258            }
1259            let r = self.ranks[blk.atom_i];
1260            let off = self.offsets[blk.atom_i];
1261            let (m_i, m_j) = blk.g.dim();
1262            let m = m_i.min(m_j);
1263            for li in 0..m {
1264                let gii = blk.g[[li, li]];
1265                let base = off + li * r;
1266                for a in 0..r {
1267                    diag[base + a] += gii * blk.w[[a, a]];
1268                }
1269            }
1270        }
1271    }
1272
1273    fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
1274        // Dense sub-block over the β index range `offsets[id.0]`. Mirror the
1275        // global (i,a) ↔ (j,b) coupling, keeping only indices inside the range.
1276        let range = &offsets[id.0];
1277        let b_dim = range.end - range.start;
1278        for blk in &self.blocks {
1279            let r_i = self.ranks[blk.atom_i];
1280            let r_j = self.ranks[blk.atom_j];
1281            let off_i = self.offsets[blk.atom_i];
1282            let off_j = self.offsets[blk.atom_j];
1283            let (m_i, m_j) = blk.g.dim();
1284            for li in 0..m_i {
1285                for a in 0..r_i {
1286                    let gi = off_i + li * r_i + a;
1287                    if gi < range.start || gi >= range.end {
1288                        continue;
1289                    }
1290                    let bi = gi - range.start;
1291                    for lj in 0..m_j {
1292                        let g = blk.g[[li, lj]];
1293                        if g == 0.0 {
1294                            continue;
1295                        }
1296                        for b in 0..r_j {
1297                            let gj = off_j + lj * r_j + b;
1298                            if gj < range.start || gj >= range.end {
1299                                continue;
1300                            }
1301                            let bj = gj - range.start;
1302                            if bi < b_dim && bj < b_dim {
1303                                out[[bi, bj]] += g * blk.w[[a, b]];
1304                            }
1305                        }
1306                    }
1307                }
1308            }
1309        }
1310    }
1311
1312    fn to_dense(&self) -> Array2<f64> {
1313        let mut out = Array2::<f64>::zeros((self.dim, self.dim));
1314        for blk in &self.blocks {
1315            let r_i = self.ranks[blk.atom_i];
1316            let r_j = self.ranks[blk.atom_j];
1317            let off_i = self.offsets[blk.atom_i];
1318            let off_j = self.offsets[blk.atom_j];
1319            let (m_i, m_j) = blk.g.dim();
1320            for li in 0..m_i {
1321                for lj in 0..m_j {
1322                    let g = blk.g[[li, lj]];
1323                    if g == 0.0 {
1324                        continue;
1325                    }
1326                    for a in 0..r_i {
1327                        let gi = off_i + li * r_i + a;
1328                        for b in 0..r_j {
1329                            let gj = off_j + lj * r_j + b;
1330                            out[[gi, gj]] += g * blk.w[[a, b]];
1331                        }
1332                    }
1333                }
1334            }
1335        }
1336        out
1337    }
1338
1339    fn fingerprint(&self, hasher: &mut Fingerprinter) {
1340        hasher.write_str("factored-frame-kronecker-op-v1");
1341        hasher.write_usize(self.dim);
1342        for &r in &self.ranks {
1343            hasher.write_usize(r);
1344        }
1345        for &m in &self.basis_sizes {
1346            hasher.write_usize(m);
1347        }
1348        hasher.write_usize(self.blocks.len());
1349        for blk in &self.blocks {
1350            hasher.write_usize(blk.atom_i);
1351            hasher.write_usize(blk.atom_j);
1352            hasher.write_f64_array2(&blk.g);
1353            hasher.write_f64_array2(&blk.w);
1354        }
1355    }
1356}
1357
1358/// Composite penalty: sum of multiple `BetaPenaltyOp` operators.
1359pub struct CompositePenaltyOp {
1360    /// Full β dimension `K`.
1361    pub k: usize,
1362    /// Component operators, each contributing additively.
1363    pub ops: Vec<Arc<dyn BetaPenaltyOp>>,
1364}
1365
1366impl BetaPenaltyOp for CompositePenaltyOp {
1367    fn dim(&self) -> usize {
1368        self.k
1369    }
1370
1371    fn matvec(&self, x: &[f64], y: &mut [f64]) {
1372        // The reduced-Schur PCG matvec applies this composite ONCE PER CG
1373        // ITERATION as the penalty prologue `y += (H_ββ) x`. At the K=32k
1374        // manifold-SAE border the composite is a leading run of per-atom
1375        // Kronecker smooth penalties (`λ S_k ⊗ I_{r_k}`, one per atom, over
1376        // DISJOINT β blocks) followed by the cross-atom data-fit op and any
1377        // dense analytic tail — and this whole sum ran SERIALLY while the
1378        // point-elimination row term already fanned across all cores, so it was
1379        // the prologue's Amdahl ceiling on the wide border.
1380        //
1381        // Fan the leading run of mutually-disjoint, sorted, contiguous-or-gapped
1382        // output-range operators across rayon workers: each writes ONLY its own
1383        // `y[start..end]` sub-slice (no cross-thread aliasing), then the
1384        // remaining (`None`-range / overlapping) operators run SERIALLY in
1385        // original order. Because every prefix index is touched by exactly one
1386        // prefix operator and all prefix work happens-before the serial tail,
1387        // each output index accumulates in the SAME order as the fully-serial
1388        // loop — the result is BIT-IDENTICAL, not merely deterministic. Stay
1389        // serial when already inside a rayon worker (the topology race / nested
1390        // matvec) to avoid oversubscription — the same guard the row loop uses.
1391        let mut prefix_len = 0usize;
1392        let mut prev_end = 0usize;
1393        if rayon::current_thread_index().is_none() {
1394            for op in &self.ops {
1395                match op.output_range() {
1396                    Some(r) if r.start >= prev_end && r.end > r.start && r.end <= y.len() => {
1397                        prev_end = r.end;
1398                        prefix_len += 1;
1399                    }
1400                    _ => break,
1401                }
1402            }
1403        }
1404        // Only worth the fan-out when there is real disjoint work: at least two
1405        // blocks and a covered width past the same border threshold the dense
1406        // prologue uses. Otherwise fall through to the plain serial sum.
1407        if prefix_len >= 2 && prev_end >= SCHUR_PROLOGUE_PARALLEL_K_MIN {
1408            use rayon::prelude::*;
1409            // Carve `y` into one mutable sub-slice per prefix operator, skipping
1410            // any gaps between ranges. Sorted, non-overlapping ranges make this
1411            // a single left-to-right walk of `split_at_mut`.
1412            let mut subslices: Vec<&mut [f64]> = Vec::with_capacity(prefix_len);
1413            {
1414                let mut consumed = 0usize;
1415                let mut rest: &mut [f64] = y;
1416                for op in &self.ops[..prefix_len] {
1417                    let r = op.output_range().expect("prefix op has an output range");
1418                    let (_, after_gap) = rest.split_at_mut(r.start - consumed);
1419                    let (block, tail) = after_gap.split_at_mut(r.end - r.start);
1420                    subslices.push(block);
1421                    rest = tail;
1422                    consumed = r.end;
1423                }
1424            }
1425            self.ops[..prefix_len]
1426                .par_iter()
1427                .zip(subslices.par_iter_mut())
1428                .for_each(|(op, y_local)| op.matvec_local(x, y_local));
1429            for op in &self.ops[prefix_len..] {
1430                op.matvec(x, y);
1431            }
1432        } else {
1433            for op in &self.ops {
1434                op.matvec(x, y);
1435            }
1436        }
1437    }
1438
1439    fn gradient(&self, beta: &[f64], out: &mut [f64]) {
1440        for op in &self.ops {
1441            op.gradient(beta, out);
1442        }
1443    }
1444
1445    fn diagonal(&self, diag: &mut [f64]) {
1446        for op in &self.ops {
1447            op.diagonal(diag);
1448        }
1449    }
1450
1451    fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
1452        for op in &self.ops {
1453            op.block(id, offsets, out);
1454        }
1455    }
1456
1457    fn to_dense(&self) -> Array2<f64> {
1458        let mut out = Array2::<f64>::zeros((self.k, self.k));
1459        for op in &self.ops {
1460            let dense = op.to_dense();
1461            out += &dense;
1462        }
1463        out
1464    }
1465
1466    fn fingerprint(&self, hasher: &mut Fingerprinter) {
1467        hasher.write_str("composite-penalty-op-v1");
1468        hasher.write_usize(self.k);
1469        hasher.write_usize(self.ops.len());
1470        for op in &self.ops {
1471            op.fingerprint(hasher);
1472        }
1473    }
1474}
1475
1476/// Adapts a closure-based matrix-free `H_ββ` operator (from
1477/// [`ArrowSchurSystem::set_shared_beta_operator`]) to the `BetaPenaltyOp` trait.
1478///
1479/// `diagonal` holds the precomputed `diag(H_ββ)` supplied alongside the matvec;
1480/// `to_dense` falls back to probing all `K` canonical basis vectors.
1481pub struct MatvecDiagPenaltyOp {
1482    pub(crate) k: usize,
1483    pub(crate) matvec: SharedBetaMatvec,
1484    pub(crate) diagonal_vec: Array1<f64>,
1485}
1486
1487impl MatvecDiagPenaltyOp {
1488    pub fn new(k: usize, matvec: SharedBetaMatvec, diagonal_vec: Array1<f64>) -> Self {
1489        assert_eq!(diagonal_vec.len(), k);
1490        Self {
1491            k,
1492            matvec,
1493            diagonal_vec,
1494        }
1495    }
1496}
1497
1498impl BetaPenaltyOp for MatvecDiagPenaltyOp {
1499    fn dim(&self) -> usize {
1500        self.k
1501    }
1502
1503    fn matvec(&self, x: &[f64], y: &mut [f64]) {
1504        let x_arr = Array1::from_iter(x.iter().copied());
1505        let mut out = Array1::<f64>::zeros(self.k);
1506        (self.matvec)(x_arr.view(), &mut out);
1507        for a in 0..self.k {
1508            y[a] += out[a];
1509        }
1510    }
1511
1512    fn gradient(&self, beta: &[f64], out: &mut [f64]) {
1513        let beta_arr = Array1::from_iter(beta.iter().copied());
1514        let mut hb = Array1::<f64>::zeros(self.k);
1515        (self.matvec)(beta_arr.view(), &mut hb);
1516        for a in 0..self.k {
1517            out[a] += hb[a];
1518        }
1519    }
1520
1521    fn diagonal(&self, diag: &mut [f64]) {
1522        for j in 0..self.k.min(diag.len()) {
1523            diag[j] += self.diagonal_vec[j];
1524        }
1525    }
1526
1527    fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
1528        // Probe each basis vector in the block range to extract the sub-block.
1529        let range = &offsets[id.0];
1530        let b = range.end - range.start;
1531        let mut probe = Array1::<f64>::zeros(self.k);
1532        for bj in 0..b {
1533            probe.fill(0.0);
1534            probe[range.start + bj] = 1.0;
1535            let mut col = Array1::<f64>::zeros(self.k);
1536            (self.matvec)(probe.view(), &mut col);
1537            for bi in 0..b {
1538                out[[bi, bj]] += col[range.start + bi];
1539            }
1540        }
1541    }
1542
1543    fn to_dense(&self) -> Array2<f64> {
1544        let k = self.k;
1545        let mut out = Array2::<f64>::zeros((k, k));
1546        let mut probe = Array1::<f64>::zeros(k);
1547        for j in 0..k {
1548            probe.fill(0.0);
1549            probe[j] = 1.0;
1550            let mut col = Array1::<f64>::zeros(k);
1551            (self.matvec)(probe.view(), &mut col);
1552            for i in 0..k {
1553                out[[i, j]] = col[i];
1554            }
1555        }
1556        out
1557    }
1558
1559    fn fingerprint(&self, hasher: &mut Fingerprinter) {
1560        // The matvec closure cannot be hashed by content; the precomputed
1561        // diagonal is the operator's stable defining proxy (it is recomputed
1562        // alongside the matvec each time the operator is installed).
1563        hasher.write_str("matvec-diag-penalty-op-v1");
1564        hasher.write_usize(self.k);
1565        for &value in self.diagonal_vec.iter() {
1566            hasher.write_f64(value);
1567        }
1568    }
1569}