Skip to main content

gam_sae/
frames.rs

1use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
2
3use gam_linalg::faer_ndarray::{FaerSvd, fast_ab, fast_abt, fast_atb};
4use crate::manifold::SaeManifoldTerm;
5
6#[derive(Clone, Debug)]
7pub(crate) struct FrameProjection {
8    pub(crate) p: usize,
9    pub(crate) beta_offsets: Vec<usize>,
10    pub(crate) border_offsets: Vec<usize>,
11    pub(crate) basis_sizes: Vec<usize>,
12    pub(crate) ranks: Vec<usize>,
13    frames: Vec<Option<Array2<f64>>>,
14}
15
16impl FrameProjection {
17    pub(crate) fn new(term: &SaeManifoldTerm) -> Self {
18        Self {
19            p: term.output_dim(),
20            beta_offsets: term.beta_offsets(),
21            border_offsets: term.factored_border_offsets(),
22            basis_sizes: term.atoms.iter().map(|atom| atom.basis_size()).collect(),
23            ranks: term
24                .atoms
25                .iter()
26                .map(|atom| atom.border_frame_rank())
27                .collect(),
28            frames: term
29                .atoms
30                .iter()
31                .map(|atom| {
32                    atom.decoder_frame
33                        .as_ref()
34                        .map(|frame| frame.frame().to_owned())
35                })
36                .collect(),
37        }
38    }
39
40    pub(crate) fn beta_dim(&self) -> usize {
41        self.basis_sizes.iter().sum::<usize>() * self.p
42    }
43
44    pub(crate) fn border_dim(&self) -> usize {
45        self.basis_sizes
46            .iter()
47            .zip(&self.ranks)
48            .map(|(m, r)| m * r)
49            .sum()
50    }
51
52    pub(crate) fn lift_border_vec(&self, border: ArrayView1<'_, f64>) -> Array1<f64> {
53        let mut out = Array1::<f64>::zeros(self.beta_dim());
54        for atom in 0..self.basis_sizes.len() {
55            self.lift_atom_vec_into(atom, border, out.view_mut());
56        }
57        out
58    }
59
60    pub(crate) fn project_border_vec(&self, beta: ArrayView1<'_, f64>) -> Array1<f64> {
61        let mut out = Array1::<f64>::zeros(self.border_dim());
62        for atom in 0..self.basis_sizes.len() {
63            self.project_atom_vec_into(atom, beta, out.view_mut(), 1.0);
64        }
65        out
66    }
67
68    pub(crate) fn lift_block(&self, atom: usize, block: ArrayView2<'_, f64>) -> Array2<f64> {
69        let m = self.basis_sizes[atom];
70        let r = self.ranks[atom];
71        if self.frames[atom].is_none() {
72            return block.to_owned();
73        }
74        let uk = self.frames[atom].as_ref().expect("framed atom has a frame");
75        let mut out = Array2::<f64>::zeros((m * self.p, m * self.p));
76        for b1 in 0..m {
77            for b2 in 0..m {
78                for c1 in 0..self.p {
79                    for c2 in 0..self.p {
80                        let mut acc = 0.0;
81                        for j1 in 0..r {
82                            for j2 in 0..r {
83                                acc +=
84                                    uk[[c1, j1]] * block[[b1 * r + j1, b2 * r + j2]] * uk[[c2, j2]];
85                            }
86                        }
87                        out[[b1 * self.p + c1, b2 * self.p + c2]] = acc;
88                    }
89                }
90            }
91        }
92        out
93    }
94
95    pub(crate) fn project_block(&self, hbb: ArrayView2<'_, f64>) -> Array2<f64> {
96        let t = self.project_rows(hbb);
97        let mut out = Array2::<f64>::zeros((self.border_dim(), self.border_dim()));
98        for atom in 0..self.basis_sizes.len() {
99            self.project_block_left_atom(atom, t.view(), out.view_mut());
100        }
101        out
102    }
103
104    pub(crate) fn project_rows(&self, block: ArrayView2<'_, f64>) -> Array2<f64> {
105        let mut out = Array2::<f64>::zeros((block.nrows(), self.border_dim()));
106        for row in 0..block.nrows() {
107            let projected = self.project_border_vec(block.row(row));
108            out.row_mut(row).assign(&projected);
109        }
110        out
111    }
112
113    pub(crate) fn atom_border_range(&self, atom: usize) -> std::ops::Range<usize> {
114        let start = self.border_offsets[atom];
115        start..start + self.basis_sizes[atom] * self.ranks[atom]
116    }
117
118    pub(crate) fn lift_axis_into(
119        &self,
120        out: &mut Array1<f64>,
121        atom: usize,
122        basis_col: usize,
123        frame_col: usize,
124    ) {
125        let base = self.beta_offsets[atom] + basis_col * self.p;
126        match &self.frames[atom] {
127            None => out[base + frame_col] = 1.0,
128            Some(uk) => {
129                for out_col in 0..self.p {
130                    out[base + out_col] = uk[[out_col, frame_col]];
131                }
132            }
133        }
134    }
135
136    pub(crate) fn lift_local_axis_into(
137        &self,
138        out: &mut Array1<f64>,
139        atom: usize,
140        basis_col: usize,
141        frame_col: usize,
142    ) {
143        let base = basis_col * self.p;
144        match &self.frames[atom] {
145            None => out[base + frame_col] = 1.0,
146            Some(uk) => {
147                for out_col in 0..self.p {
148                    out[base + out_col] = uk[[out_col, frame_col]];
149                }
150            }
151        }
152    }
153
154    pub(crate) fn project_atom_vec_into(
155        &self,
156        atom: usize,
157        beta: ArrayView1<'_, f64>,
158        mut out: ndarray::ArrayViewMut1<'_, f64>,
159        scale: f64,
160    ) {
161        let m = self.basis_sizes[atom];
162        let r = self.ranks[atom];
163        let ob = self.beta_offsets[atom];
164        let oc = self.border_offsets[atom];
165        for basis_col in 0..m {
166            let base_b = ob + basis_col * self.p;
167            let base_c = oc + basis_col * r;
168            match &self.frames[atom] {
169                None => {
170                    for j in 0..r {
171                        out[base_c + j] += scale * beta[base_b + j];
172                    }
173                }
174                Some(uk) => {
175                    for j in 0..r {
176                        let mut acc = 0.0;
177                        for i in 0..self.p {
178                            acc += uk[[i, j]] * beta[base_b + i];
179                        }
180                        out[base_c + j] += scale * acc;
181                    }
182                }
183            }
184        }
185    }
186
187    pub(crate) fn project_local_atom_vec_into(
188        &self,
189        atom: usize,
190        beta: ArrayView1<'_, f64>,
191        out: ndarray::ArrayViewMut1<'_, f64>,
192        scale: f64,
193    ) {
194        self.project_atom_vec_into_with_base(atom, beta, out, scale, 0);
195    }
196
197    pub(crate) fn project_atom_vec_into_with_base(
198        &self,
199        atom: usize,
200        beta: ArrayView1<'_, f64>,
201        mut out: ndarray::ArrayViewMut1<'_, f64>,
202        scale: f64,
203        beta_base_offset: usize,
204    ) {
205        let m = self.basis_sizes[atom];
206        let r = self.ranks[atom];
207        let oc = self.border_offsets[atom];
208        for basis_col in 0..m {
209            let base_b = beta_base_offset + basis_col * self.p;
210            let base_c = oc + basis_col * r;
211            match &self.frames[atom] {
212                None => {
213                    for j in 0..r {
214                        out[base_c + j] += scale * beta[base_b + j];
215                    }
216                }
217                Some(uk) => {
218                    for j in 0..r {
219                        let mut acc = 0.0;
220                        for i in 0..self.p {
221                            acc += uk[[i, j]] * beta[base_b + i];
222                        }
223                        out[base_c + j] += scale * acc;
224                    }
225                }
226            }
227        }
228    }
229
230    pub(crate) fn lift_atom_vec_into(
231        &self,
232        atom: usize,
233        border: ArrayView1<'_, f64>,
234        mut out: ndarray::ArrayViewMut1<'_, f64>,
235    ) {
236        let m = self.basis_sizes[atom];
237        let r = self.ranks[atom];
238        let ob = self.beta_offsets[atom];
239        let oc = self.border_offsets[atom];
240        for basis_col in 0..m {
241            let base_b = ob + basis_col * self.p;
242            let base_c = oc + basis_col * r;
243            match &self.frames[atom] {
244                None => {
245                    for i in 0..self.p {
246                        out[base_b + i] = border[base_c + i];
247                    }
248                }
249                Some(uk) => {
250                    for i in 0..self.p {
251                        let mut acc = 0.0;
252                        for j in 0..r {
253                            acc += uk[[i, j]] * border[base_c + j];
254                        }
255                        out[base_b + i] = acc;
256                    }
257                }
258            }
259        }
260    }
261
262    pub(crate) fn accumulate_output_project(
263        &self,
264        atom: usize,
265        c_base: usize,
266        output: usize,
267        value: f64,
268        out: &mut [f64],
269    ) {
270        match &self.frames[atom] {
271            None => out[c_base + output] += value,
272            Some(uk) => {
273                let rank = self.ranks[atom];
274                let frame_row = uk.row(output);
275                let frame_slice = frame_row.as_slice().expect("frame rows are contiguous");
276                let out_slice = &mut out[c_base..c_base + rank];
277                for (slot, &u) in out_slice.iter_mut().zip(frame_slice.iter()) {
278                    *slot += value * u;
279                }
280            }
281        }
282    }
283
284    pub(crate) fn output_variance(
285        &self,
286        atom: usize,
287        cov_c: ArrayView2<'_, f64>,
288        basis: ArrayView1<'_, f64>,
289        output: usize,
290    ) -> f64 {
291        let Some(uk) = &self.frames[atom] else {
292            return self.full_output_variance(atom, cov_c, basis, output);
293        };
294        let m = self.basis_sizes[atom];
295        let r = self.ranks[atom];
296        let mut var = 0.0;
297        for b1 in 0..m {
298            let phi1 = basis[b1];
299            if phi1 == 0.0 {
300                continue;
301            }
302            for b2 in 0..m {
303                let phi2 = basis[b2];
304                if phi2 == 0.0 {
305                    continue;
306                }
307                for j1 in 0..r {
308                    for j2 in 0..r {
309                        var += phi1
310                            * phi2
311                            * uk[[output, j1]]
312                            * cov_c[[b1 * r + j1, b2 * r + j2]]
313                            * uk[[output, j2]];
314                    }
315                }
316            }
317        }
318        var
319    }
320
321    pub(crate) fn full_output_variance(
322        &self,
323        atom: usize,
324        cov: ArrayView2<'_, f64>,
325        basis: ArrayView1<'_, f64>,
326        output: usize,
327    ) -> f64 {
328        let m = self.basis_sizes[atom];
329        let mut var = 0.0;
330        for b1 in 0..m {
331            let phi1 = basis[b1];
332            if phi1 == 0.0 {
333                continue;
334            }
335            for b2 in 0..m {
336                var += phi1 * basis[b2] * cov[[b1 * self.p + output, b2 * self.p + output]];
337            }
338        }
339        var
340    }
341
342    pub(crate) fn project_block_left_atom(
343        &self,
344        atom: usize,
345        t: ArrayView2<'_, f64>,
346        mut out: ndarray::ArrayViewMut2<'_, f64>,
347    ) {
348        let m = self.basis_sizes[atom];
349        let r = self.ranks[atom];
350        let ob = self.beta_offsets[atom];
351        let oc = self.border_offsets[atom];
352        for basis_col in 0..m {
353            let base_b = ob + basis_col * self.p;
354            let base_c = oc + basis_col * r;
355            match &self.frames[atom] {
356                None => {
357                    for j in 0..r {
358                        for c in 0..out.ncols() {
359                            out[[base_c + j, c]] += t[[base_b + j, c]];
360                        }
361                    }
362                }
363                Some(uk) => {
364                    for j in 0..r {
365                        for c in 0..out.ncols() {
366                            let mut acc = 0.0;
367                            for i in 0..self.p {
368                                acc += uk[[i, j]] * t[[base_b + i, c]];
369                            }
370                            out[[base_c + j, c]] += acc;
371                        }
372                    }
373                }
374            }
375        }
376    }
377}
378
379/// Build the frames-engaged device SAE PCG data (issue #1017/#1026): the
380/// factored-border analogue of the full-`B` `DeviceSaePcgData`. The penalty side
381/// carries the smooth `λ S_k ⊗ I_{r_k}` blocks (right-width `r_k`, at `off_c[k]`)
382/// and the data-fit `G_{ij} ⊗ W_{ij}` blocks; the reduced-Schur side carries the
383/// per-row DENSE cross-block `H_tβ^(i)` as a row-major `q_i × border_dim` slab.
384///
385/// `args.frame_blocks` are the same `(g, w)` blocks fed to
386/// `FactoredFrameKroneckerOp`, snapshotted before that op consumed them.
387/// `args.smooth_scaled_s[k]` is `λ S_k` (`M_k × M_k`). A row whose `htbeta` is
388/// not at the factored width contributes an empty slab (reduced-Schur term zero).
389pub(crate) struct FramedDeviceArgs<'a> {
390    pub p: usize,
391    pub border_dim: usize,
392    pub border_offsets: &'a [usize],
393    pub ranks: &'a [usize],
394    pub basis_sizes: &'a [usize],
395    pub smooth_scaled_s: &'a [Array2<f64>],
396    pub frame_blocks: Vec<gam_solve::arrow_schur::FactoredFrameGBlock>,
397    pub rows: &'a [gam_solve::arrow_schur::ArrowRowBlock],
398}
399
400pub(crate) fn build_framed_device_sae_data(
401    args: FramedDeviceArgs<'_>,
402) -> gam_solve::arrow_schur::DeviceSaePcgData {
403    use gam_solve::arrow_schur::{DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock};
404    let FramedDeviceArgs {
405        p,
406        border_dim,
407        border_offsets,
408        ranks,
409        basis_sizes,
410        smooth_scaled_s,
411        frame_blocks,
412        rows,
413    } = args;
414    let n_atoms = ranks.len();
415    let mut smooth_blocks = Vec::with_capacity(n_atoms);
416    let mut smooth_ranks = Vec::with_capacity(n_atoms);
417    for k in 0..n_atoms {
418        smooth_blocks.push(DeviceSaeSmoothBlock {
419            global_offset: border_offsets[k],
420            factor_a: smooth_scaled_s[k].clone(),
421        });
422        smooth_ranks.push(ranks[k]);
423    }
424    let row_htbeta: Vec<Vec<f64>> = rows
425        .iter()
426        .map(|row| {
427            let (qi, w) = row.htbeta.dim();
428            if w != border_dim {
429                return Vec::new();
430            }
431            let mut flat = vec![0.0_f64; qi * w];
432            for c in 0..qi {
433                for a in 0..w {
434                    flat[c * w + a] = row.htbeta[[c, a]];
435                }
436            }
437            flat
438        })
439        .collect();
440    DeviceSaePcgData {
441        p,
442        beta_dim: border_dim,
443        // #1033: empty shared slices — the frames path carries its cross-block
444        // through `frame.frame_blocks`, not the full-`B` `a_phi`/`local_jac`.
445        a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
446        local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
447        smooth_blocks,
448        sparse_g_blocks: Vec::new(),
449        frame: Some(DeviceSaeFrameData {
450            ranks: ranks.to_vec(),
451            basis_sizes: basis_sizes.to_vec(),
452            border_offsets: border_offsets.to_vec(),
453            frame_blocks,
454            smooth_ranks,
455            row_htbeta,
456        }),
457    }
458}
459
460/// Relative spectral cutoff used when the Grassmann-frame factorization decides
461/// the effective column rank `r` of an atom's decoder `B_k` (issue #972). A
462/// singular value of `B_k` below `cutoff · σ_max` carries `< (σ/σ_max)²` of the
463/// decoder energy and is dropped from the profiled frame.
464pub(crate) const SAE_FRAME_RANK_CUTOFF: f64 = 1.0e-7;
465
466/// Small ambient decoders stay on the full-`B` path. Below this width the dense
467/// decoder border is cheap, while auto-profiling a cold low-rank decoder changes
468/// the β coordinates before the inner solve has learned the output span.
469pub(crate) const SAE_FRAME_MIN_AUTO_OUTPUT_DIM: usize = 12;
470
471/// Border-saving threshold for auto-activating the low-rank Grassmann
472/// factorization (issue #972). The factored border holds `Σ_k M_k · r` instead
473/// of `Σ_k M_k · p`, so factorization is beneficial only when the chosen frame
474/// rank `r` is materially smaller than the ambient output dimension `p`. We
475/// require `r ≤ p · (1 − margin)` (frame must shrink the per-atom border by at
476/// least this fraction) AND a positive absolute gap `p − r ≥ 1`, so a full-rank
477/// atom (`r == p`) never pays the polar-step / frame-storage cost for zero
478/// border saving and stays bit-for-bit on the historical full-`B` path.
479pub(crate) const SAE_FRAME_ACTIVATION_MARGIN: f64 = 0.25;
480
481/// A Grassmann point: a `p × r` column-orthonormal FRAME `U` spanning an atom's
482/// decoder column space (issue #972).
483///
484/// The decoder coefficient matrix `B_k` (`M_k × p`) factors as `B_k = C_k · Uᵀ`
485/// where `C_k` (`M_k × r`) is the coordinate matrix that lives IN the
486/// arrow-Schur border and `U` (`p × r`) is this frame, profiled OUT of the
487/// border by closed-form streaming polar steps. The border then carries only
488/// `Σ_k M_k · r` coefficients rather than `Σ_k M_k · p` — the reduction that
489/// keeps the border Cholesky / evidence log-det tractable at frontier `p`.
490///
491/// **Canonical inner gauge.** `U` is only defined up to a right `r × r`
492/// orthogonal rotation `U → U R` (with the matching `C_k → C_k R`); the column
493/// span (the Grassmann point) is invariant. For deterministic serialization we
494/// pin a canonical representative: the frame is the left-singular subspace of
495/// the cross-moment, ordered by descending singular value, with each column's
496/// sign fixed so its largest-magnitude entry is non-negative. The ordering is
497/// recorded by the `gauge_singular_values` field so the same span always
498/// serializes to the same bytes (no run-to-run rotation drift).
499#[derive(Debug, Clone)]
500pub struct GrassmannFrame {
501    /// Column-orthonormal frame `U`, shape `(p, r)` with `Uᵀ U = I_r`.
502    frame: Array2<f64>,
503    /// Singular values of the most recent cross-moment used to build `U`,
504    /// descending, length `r`. The canonical ordering gauge (issue #972).
505    gauge_singular_values: Array1<f64>,
506}
507
508impl GrassmannFrame {
509    /// Ambient output dimension `p`.
510    pub fn output_dim(&self) -> usize {
511        self.frame.nrows()
512    }
513
514    /// Frame rank `r` (number of profiled column directions).
515    pub fn rank(&self) -> usize {
516        self.frame.ncols()
517    }
518
519    /// Canonical descending singular values of the cross-moment that fixed this
520    /// frame's column ordering (issue #972). Exposed so the serialization /
521    /// canonicalization path can read the recorded gauge and reproduce the same
522    /// span byte-for-byte (no run-to-run rotation drift).
523    pub fn gauge_singular_values(&self) -> &Array1<f64> {
524        &self.gauge_singular_values
525    }
526
527    /// Read-only view of the orthonormal frame `U` (`p × r`).
528    pub fn frame(&self) -> ArrayView2<'_, f64> {
529        self.frame.view()
530    }
531
532    /// Grassmann manifold dimension `r·(p − r)` of this frame — the count of
533    /// profiled-out degrees of freedom that must enter the Laplace evidence
534    /// dimension accounting (issue #972, evidence honesty). A point on the
535    /// Grassmannian `Gr(r, p)` has exactly this many intrinsic coordinates.
536    pub fn manifold_dimension(&self) -> usize {
537        let r = self.rank();
538        let p = self.output_dim();
539        r * (p - r)
540    }
541
542    /// Build the canonical-gauge frame for a `p × r` orthonormal `U` paired with
543    /// its `gauge_singular_values`. Enforces the column-sign convention
544    /// (largest-magnitude entry per column non-negative) so the span serializes
545    /// deterministically. The caller guarantees `U` is already column-orthonormal
546    /// and its columns are ordered by descending singular value.
547    pub(crate) fn from_oriented(
548        mut frame: Array2<f64>,
549        gauge_singular_values: Array1<f64>,
550    ) -> Self {
551        let (p, r) = frame.dim();
552        for col in 0..r {
553            // Sign-fix: make the largest-magnitude entry of each column
554            // non-negative so `U` and `−U` (same span) serialize identically.
555            let mut pivot_abs = 0.0_f64;
556            let mut pivot_val = 0.0_f64;
557            for row in 0..p {
558                let v = frame[[row, col]];
559                if v.abs() > pivot_abs {
560                    pivot_abs = v.abs();
561                    pivot_val = v;
562                }
563            }
564            if pivot_val < 0.0 {
565                for row in 0..p {
566                    frame[[row, col]] = -frame[[row, col]];
567                }
568            }
569        }
570        Self {
571            frame,
572            gauge_singular_values,
573        }
574    }
575
576    /// Closed-form streaming POLAR step (issue #972): given an accumulated
577    /// `p × r` cross-moment `Mcm` (a sum of decoder-target outer products that
578    /// pulls the frame toward the current column-span evidence), return the
579    /// orthogonal polar factor `U_new = polar(Mcm)`.
580    ///
581    /// `polar(M) = W Vᵀ` from the thin SVD `M = W Σ Vᵀ`: the nearest
582    /// column-orthonormal matrix to `M` in Frobenius norm, and the closed-form
583    /// MAP frame update on the Grassmannian. Runs OUTSIDE the border (an
584    /// `O(p r² )` thin SVD), so the border never carries the `p` factor.
585    /// `gauge_singular_values = Σ` records the canonical descending-σ ordering.
586    pub fn polar_update(cross_moment: ArrayView2<'_, f64>) -> Result<Self, String> {
587        let (p, r) = cross_moment.dim();
588        if p == 0 || r == 0 {
589            return Err("GrassmannFrame::polar_update: cross-moment must be non-empty".into());
590        }
591        if r > p {
592            return Err(format!(
593                "GrassmannFrame::polar_update: frame rank r={r} cannot exceed output dim p={p}"
594            ));
595        }
596        let owned = cross_moment.to_owned();
597        let (u_opt, sv, vt_opt) = owned
598            .svd(true, true)
599            .map_err(|e| format!("GrassmannFrame::polar_update: SVD failed: {e}"))?;
600        let w = u_opt.ok_or_else(|| {
601            "GrassmannFrame::polar_update: thin SVD returned no left factor".to_string()
602        })?;
603        let vt = vt_opt.ok_or_else(|| {
604            "GrassmannFrame::polar_update: thin SVD returned no right factor".to_string()
605        })?;
606        // `W` is `p × r`, `Vᵀ` is `r × r`. polar(M) = W·Vᵀ is `p × r`,
607        // column-orthonormal because both factors have orthonormal columns/rows.
608        let polar = fast_ab(&w, &vt);
609        Ok(Self::from_oriented(polar, sv))
610    }
611
612    /// Project a coordinate matrix `C_k` (`M_k × r`) back to the full decoder
613    /// `B_k = C_k · Uᵀ` (`M_k × p`) — the reconstruction used wherever the
614    /// full-`B` consumers (assembly, decode, smoothness pullback) read the
615    /// decoder. `fast_abt` computes `C_k · Uᵀ` without materializing `Uᵀ`.
616    pub fn reconstruct_decoder(&self, coords: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
617        if coords.ncols() != self.rank() {
618            return Err(format!(
619                "GrassmannFrame::reconstruct_decoder: coord cols {} must equal frame rank {}",
620                coords.ncols(),
621                self.rank()
622            ));
623        }
624        Ok(fast_abt(&coords.to_owned(), &self.frame))
625    }
626
627    /// Project a full decoder `B_k` (`M_k × p`) onto this frame, returning the
628    /// coordinate matrix `C_k = B_k · U` (`M_k × r`) that the border stores.
629    /// The frame is orthonormal so `U` is its own pseudo-inverse-from-the-right:
630    /// `C_k = B_k U` recovers the in-span coordinates exactly and discards the
631    /// component of `B_k` orthogonal to the frame (zero when `B_k`'s span lies in
632    /// `range(U)`, i.e. when the frame rank matched the decoder rank).
633    pub fn project_decoder(&self, decoder: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
634        if decoder.ncols() != self.output_dim() {
635            return Err(format!(
636                "GrassmannFrame::project_decoder: decoder cols {} must equal output dim {}",
637                decoder.ncols(),
638                self.output_dim()
639            ));
640        }
641        Ok(fast_ab(&decoder.to_owned(), &self.frame))
642    }
643
644    /// Largest principal angle (radians) between this frame's column span and
645    /// another `p × r'` orthonormal frame's span — the Grassmann geodesic
646    /// distance component used by the planted-atom recovery verifier (issue
647    /// #972).
648    ///
649    /// The naive formula `arccos(min σ_i(UᵀV))` loses half the available
650    /// precision for near-parallel spans: when `cos θ = 1 − ε` (the
651    /// `ε ~ fp64.eps` regime hit by a polar update of an already-orthonormal
652    /// frame), `arccos(1 − ε) ≈ √(2ε)` ≈ `1.49e-8`, so a planted span the
653    /// solver actually recovered to machine precision was being reported as
654    /// `O(√fp64.eps)` off. The stable form uses BOTH the cosines from
655    /// `M = UᵀV` (small-angle limit: `cos θ ≈ 1 − θ²/2`, sensitive to noise)
656    /// AND the sines from the orthogonal complement
657    /// `V_⊥ = (I − UUᵀ) V` (small-angle limit: `sin θ ≈ θ`, sensitive to the
658    /// quantity we actually want), then combines them with `atan2(sin, cos)`.
659    /// `atan2` returns a precise angle across the whole `[0, π/2]` interval
660    /// regardless of which leg is small — so an exactly-equal-frame test now
661    /// reports the genuine ~fp64.eps residual instead of an inflated
662    /// `√fp64.eps`. The pairing is exact because the singular values of
663    /// `M` and `V_⊥` are matched component-wise to the same principal
664    /// angle: `σ_r(M) = cos θ_max` and `σ_1(V_⊥) = sin θ_max`.
665    pub fn max_principal_angle(&self, other: ArrayView2<'_, f64>) -> Result<f64, String> {
666        if other.nrows() != self.output_dim() {
667            return Err(format!(
668                "GrassmannFrame::max_principal_angle: other rows {} must equal output dim {}",
669                other.nrows(),
670                self.output_dim()
671            ));
672        }
673        let other_owned = other.to_owned();
674        let overlap = fast_atb(&self.frame, &other_owned);
675        let (_u, sv_cos, _vt) = overlap
676            .svd(false, false)
677            .map_err(|e| format!("GrassmannFrame::max_principal_angle: cos-SVD failed: {e}"))?;
678        // V_⊥ = V − U·(UᵀV); its largest singular value is sin(θ_max).
679        let u_overlap = fast_ab(&self.frame, &overlap);
680        let v_perp = &other_owned - &u_overlap;
681        let (_u, sv_sin, _vt) = v_perp
682            .svd(false, false)
683            .map_err(|e| format!("GrassmannFrame::max_principal_angle: sin-SVD failed: {e}"))?;
684        // Smallest cosine and largest sine both correspond to θ_max; combine
685        // via atan2 for full precision across [0, π/2]. Clamp the SVD outputs
686        // into [0, 1] before pairing — both arise from singular values of
687        // matrices whose true norms are ≤ 1, so any drift above 1 or below
688        // 0 is pure floating-point noise.
689        let min_cos = sv_cos
690            .iter()
691            .copied()
692            .fold(1.0_f64, f64::min)
693            .clamp(0.0, 1.0);
694        let max_sin = sv_sin
695            .iter()
696            .copied()
697            .fold(0.0_f64, f64::max)
698            .clamp(0.0, 1.0);
699        Ok(max_sin.atan2(min_cos))
700    }
701}
702
703/// Streaming `p × r` cross-moment accumulator for the closed-form polar frame
704/// update (issue #972). Sums decoder-target outer products `Σ_i t_i c_iᵀ`
705/// (ambient target `t_i ∈ ℝ^p` against in-span coordinate `c_i ∈ ℝ^r`) so the
706/// frame can be re-polared from accumulated evidence WITHOUT re-touching the
707/// border. Accumulation is `O(p r)` per update and never forms a `p × p` matrix.
708#[derive(Debug, Clone)]
709pub struct GrassmannCrossMoment {
710    moment: Array2<f64>,
711}
712
713impl GrassmannCrossMoment {
714    /// Empty `p × r` accumulator.
715    pub fn new(output_dim: usize, rank: usize) -> Self {
716        Self {
717            moment: Array2::<f64>::zeros((output_dim, rank)),
718        }
719    }
720
721    /// Accumulate the full-batch cross-moment `Targetᵀ · Coords` where
722    /// `targets` is `(N × p)` ambient decoder targets and `coords` is `(N × r)`
723    /// in-span coordinates. `fast_atb` forms `Targetᵀ Coords` (`p × r`) directly.
724    pub fn accumulate(
725        &mut self,
726        targets: ArrayView2<'_, f64>,
727        coords: ArrayView2<'_, f64>,
728    ) -> Result<(), String> {
729        if targets.ncols() != self.moment.nrows() || coords.ncols() != self.moment.ncols() {
730            return Err(format!(
731                "GrassmannCrossMoment::accumulate: expected targets (·,{}) and coords (·,{}); \
732                 got (·,{}) and (·,{})",
733                self.moment.nrows(),
734                self.moment.ncols(),
735                targets.ncols(),
736                coords.ncols()
737            ));
738        }
739        if targets.nrows() != coords.nrows() {
740            return Err(format!(
741                "GrassmannCrossMoment::accumulate: targets rows {} must equal coords rows {}",
742                targets.nrows(),
743                coords.nrows()
744            ));
745        }
746        let block = fast_atb(&targets.to_owned(), &coords.to_owned());
747        self.moment += &block;
748        Ok(())
749    }
750
751    /// Read the accumulated `p × r` cross-moment.
752    pub fn moment(&self) -> ArrayView2<'_, f64> {
753        self.moment.view()
754    }
755
756    /// Re-polar the frame from the accumulated cross-moment (the streaming
757    /// closed-form step): `U_new = polar(Mcm)`.
758    pub fn polar_frame(&self) -> Result<GrassmannFrame, String> {
759        GrassmannFrame::polar_update(self.moment.view())
760    }
761}
762
763/// Verification helper (issue #972): recover the planted low-rank column span of
764/// an atom by polaring the decoder-target cross-moment and report the largest
765/// principal angle (radians) between the recovered frame and a planted
766/// orthonormal frame `planted` (`p × r`).
767///
768/// `targets` (`N × p`) are the ambient decoder targets and `coords` (`N × r`)
769/// the latent coordinates that generated them (`targets ≈ coords · plantedᵀ`).
770/// The closed-form polar of `Σ targetsᵀ coords` recovers `range(planted)`; a
771/// successful low-rank fit drives the returned angle to `0`. Used by the
772/// `planted_low_rank_frame_recovered_by_polar` test, and available to callers
773/// that want a runtime span-recovery diagnostic.
774pub fn grassmann_recover_planted_span_angle(
775    targets: ArrayView2<'_, f64>,
776    coords: ArrayView2<'_, f64>,
777    planted: ArrayView2<'_, f64>,
778) -> Result<f64, String> {
779    let p = targets.ncols();
780    let r = coords.ncols();
781    if planted.dim() != (p, r) {
782        return Err(format!(
783            "grassmann_recover_planted_span_angle: planted frame must be ({p}, {r}); got {:?}",
784            planted.dim()
785        ));
786    }
787    let mut cross = GrassmannCrossMoment::new(p, r);
788    cross.accumulate(targets, coords)?;
789    let frame = cross.polar_frame()?;
790    frame.max_principal_angle(planted)
791}
792
793/// Verification helper (issue #972): the factored arrow-Schur border dimension
794/// equals `Σ_k M_k · r_k` exactly. Returns `Ok(())` iff the invariant holds for
795/// `term`, else an explanatory error. Compiled-in so the border-size contract is
796/// checkable at runtime, not only in tests.
797pub fn grassmann_assert_border_dim_invariant(term: &SaeManifoldTerm) -> Result<(), String> {
798    let expected: usize = term
799        .atoms
800        .iter()
801        .map(|a| a.basis_size() * a.border_frame_rank())
802        .sum();
803    let got = term.factored_border_dim();
804    if got != expected {
805        return Err(format!(
806            "grassmann border-dim invariant violated: factored_border_dim() = {got}, \
807             expected Σ M_k·r_k = {expected}"
808        ));
809    }
810    Ok(())
811}