1use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
2
3use gam_linalg::faer_ndarray::{FaerSvd, fast_ab, fast_abt, fast_atb};
4use crate::manifold::SaeManifoldTerm;
5
6#[derive(Clone, Debug)]
7pub(crate) struct FrameProjection {
8 pub(crate) p: usize,
9 pub(crate) beta_offsets: Vec<usize>,
10 pub(crate) border_offsets: Vec<usize>,
11 pub(crate) basis_sizes: Vec<usize>,
12 pub(crate) ranks: Vec<usize>,
13 frames: Vec<Option<Array2<f64>>>,
14}
15
16impl FrameProjection {
17 pub(crate) fn new(term: &SaeManifoldTerm) -> Self {
18 Self {
19 p: term.output_dim(),
20 beta_offsets: term.beta_offsets(),
21 border_offsets: term.factored_border_offsets(),
22 basis_sizes: term.atoms.iter().map(|atom| atom.basis_size()).collect(),
23 ranks: term
24 .atoms
25 .iter()
26 .map(|atom| atom.border_frame_rank())
27 .collect(),
28 frames: term
29 .atoms
30 .iter()
31 .map(|atom| {
32 atom.decoder_frame
33 .as_ref()
34 .map(|frame| frame.frame().to_owned())
35 })
36 .collect(),
37 }
38 }
39
40 pub(crate) fn beta_dim(&self) -> usize {
41 self.basis_sizes.iter().sum::<usize>() * self.p
42 }
43
44 pub(crate) fn border_dim(&self) -> usize {
45 self.basis_sizes
46 .iter()
47 .zip(&self.ranks)
48 .map(|(m, r)| m * r)
49 .sum()
50 }
51
52 pub(crate) fn lift_border_vec(&self, border: ArrayView1<'_, f64>) -> Array1<f64> {
53 let mut out = Array1::<f64>::zeros(self.beta_dim());
54 for atom in 0..self.basis_sizes.len() {
55 self.lift_atom_vec_into(atom, border, out.view_mut());
56 }
57 out
58 }
59
60 pub(crate) fn project_border_vec(&self, beta: ArrayView1<'_, f64>) -> Array1<f64> {
61 let mut out = Array1::<f64>::zeros(self.border_dim());
62 for atom in 0..self.basis_sizes.len() {
63 self.project_atom_vec_into(atom, beta, out.view_mut(), 1.0);
64 }
65 out
66 }
67
68 pub(crate) fn lift_block(&self, atom: usize, block: ArrayView2<'_, f64>) -> Array2<f64> {
69 let m = self.basis_sizes[atom];
70 let r = self.ranks[atom];
71 if self.frames[atom].is_none() {
72 return block.to_owned();
73 }
74 let uk = self.frames[atom].as_ref().expect("framed atom has a frame");
75 let mut out = Array2::<f64>::zeros((m * self.p, m * self.p));
76 for b1 in 0..m {
77 for b2 in 0..m {
78 for c1 in 0..self.p {
79 for c2 in 0..self.p {
80 let mut acc = 0.0;
81 for j1 in 0..r {
82 for j2 in 0..r {
83 acc +=
84 uk[[c1, j1]] * block[[b1 * r + j1, b2 * r + j2]] * uk[[c2, j2]];
85 }
86 }
87 out[[b1 * self.p + c1, b2 * self.p + c2]] = acc;
88 }
89 }
90 }
91 }
92 out
93 }
94
95 pub(crate) fn project_block(&self, hbb: ArrayView2<'_, f64>) -> Array2<f64> {
96 let t = self.project_rows(hbb);
97 let mut out = Array2::<f64>::zeros((self.border_dim(), self.border_dim()));
98 for atom in 0..self.basis_sizes.len() {
99 self.project_block_left_atom(atom, t.view(), out.view_mut());
100 }
101 out
102 }
103
104 pub(crate) fn project_rows(&self, block: ArrayView2<'_, f64>) -> Array2<f64> {
105 let mut out = Array2::<f64>::zeros((block.nrows(), self.border_dim()));
106 for row in 0..block.nrows() {
107 let projected = self.project_border_vec(block.row(row));
108 out.row_mut(row).assign(&projected);
109 }
110 out
111 }
112
113 pub(crate) fn atom_border_range(&self, atom: usize) -> std::ops::Range<usize> {
114 let start = self.border_offsets[atom];
115 start..start + self.basis_sizes[atom] * self.ranks[atom]
116 }
117
118 pub(crate) fn lift_axis_into(
119 &self,
120 out: &mut Array1<f64>,
121 atom: usize,
122 basis_col: usize,
123 frame_col: usize,
124 ) {
125 let base = self.beta_offsets[atom] + basis_col * self.p;
126 match &self.frames[atom] {
127 None => out[base + frame_col] = 1.0,
128 Some(uk) => {
129 for out_col in 0..self.p {
130 out[base + out_col] = uk[[out_col, frame_col]];
131 }
132 }
133 }
134 }
135
136 pub(crate) fn lift_local_axis_into(
137 &self,
138 out: &mut Array1<f64>,
139 atom: usize,
140 basis_col: usize,
141 frame_col: usize,
142 ) {
143 let base = basis_col * self.p;
144 match &self.frames[atom] {
145 None => out[base + frame_col] = 1.0,
146 Some(uk) => {
147 for out_col in 0..self.p {
148 out[base + out_col] = uk[[out_col, frame_col]];
149 }
150 }
151 }
152 }
153
154 pub(crate) fn project_atom_vec_into(
155 &self,
156 atom: usize,
157 beta: ArrayView1<'_, f64>,
158 mut out: ndarray::ArrayViewMut1<'_, f64>,
159 scale: f64,
160 ) {
161 let m = self.basis_sizes[atom];
162 let r = self.ranks[atom];
163 let ob = self.beta_offsets[atom];
164 let oc = self.border_offsets[atom];
165 for basis_col in 0..m {
166 let base_b = ob + basis_col * self.p;
167 let base_c = oc + basis_col * r;
168 match &self.frames[atom] {
169 None => {
170 for j in 0..r {
171 out[base_c + j] += scale * beta[base_b + j];
172 }
173 }
174 Some(uk) => {
175 for j in 0..r {
176 let mut acc = 0.0;
177 for i in 0..self.p {
178 acc += uk[[i, j]] * beta[base_b + i];
179 }
180 out[base_c + j] += scale * acc;
181 }
182 }
183 }
184 }
185 }
186
187 pub(crate) fn project_local_atom_vec_into(
188 &self,
189 atom: usize,
190 beta: ArrayView1<'_, f64>,
191 out: ndarray::ArrayViewMut1<'_, f64>,
192 scale: f64,
193 ) {
194 self.project_atom_vec_into_with_base(atom, beta, out, scale, 0);
195 }
196
197 pub(crate) fn project_atom_vec_into_with_base(
198 &self,
199 atom: usize,
200 beta: ArrayView1<'_, f64>,
201 mut out: ndarray::ArrayViewMut1<'_, f64>,
202 scale: f64,
203 beta_base_offset: usize,
204 ) {
205 let m = self.basis_sizes[atom];
206 let r = self.ranks[atom];
207 let oc = self.border_offsets[atom];
208 for basis_col in 0..m {
209 let base_b = beta_base_offset + basis_col * self.p;
210 let base_c = oc + basis_col * r;
211 match &self.frames[atom] {
212 None => {
213 for j in 0..r {
214 out[base_c + j] += scale * beta[base_b + j];
215 }
216 }
217 Some(uk) => {
218 for j in 0..r {
219 let mut acc = 0.0;
220 for i in 0..self.p {
221 acc += uk[[i, j]] * beta[base_b + i];
222 }
223 out[base_c + j] += scale * acc;
224 }
225 }
226 }
227 }
228 }
229
230 pub(crate) fn lift_atom_vec_into(
231 &self,
232 atom: usize,
233 border: ArrayView1<'_, f64>,
234 mut out: ndarray::ArrayViewMut1<'_, f64>,
235 ) {
236 let m = self.basis_sizes[atom];
237 let r = self.ranks[atom];
238 let ob = self.beta_offsets[atom];
239 let oc = self.border_offsets[atom];
240 for basis_col in 0..m {
241 let base_b = ob + basis_col * self.p;
242 let base_c = oc + basis_col * r;
243 match &self.frames[atom] {
244 None => {
245 for i in 0..self.p {
246 out[base_b + i] = border[base_c + i];
247 }
248 }
249 Some(uk) => {
250 for i in 0..self.p {
251 let mut acc = 0.0;
252 for j in 0..r {
253 acc += uk[[i, j]] * border[base_c + j];
254 }
255 out[base_b + i] = acc;
256 }
257 }
258 }
259 }
260 }
261
262 pub(crate) fn accumulate_output_project(
263 &self,
264 atom: usize,
265 c_base: usize,
266 output: usize,
267 value: f64,
268 out: &mut [f64],
269 ) {
270 match &self.frames[atom] {
271 None => out[c_base + output] += value,
272 Some(uk) => {
273 let rank = self.ranks[atom];
274 let frame_row = uk.row(output);
275 let frame_slice = frame_row.as_slice().expect("frame rows are contiguous");
276 let out_slice = &mut out[c_base..c_base + rank];
277 for (slot, &u) in out_slice.iter_mut().zip(frame_slice.iter()) {
278 *slot += value * u;
279 }
280 }
281 }
282 }
283
284 pub(crate) fn output_variance(
285 &self,
286 atom: usize,
287 cov_c: ArrayView2<'_, f64>,
288 basis: ArrayView1<'_, f64>,
289 output: usize,
290 ) -> f64 {
291 let Some(uk) = &self.frames[atom] else {
292 return self.full_output_variance(atom, cov_c, basis, output);
293 };
294 let m = self.basis_sizes[atom];
295 let r = self.ranks[atom];
296 let mut var = 0.0;
297 for b1 in 0..m {
298 let phi1 = basis[b1];
299 if phi1 == 0.0 {
300 continue;
301 }
302 for b2 in 0..m {
303 let phi2 = basis[b2];
304 if phi2 == 0.0 {
305 continue;
306 }
307 for j1 in 0..r {
308 for j2 in 0..r {
309 var += phi1
310 * phi2
311 * uk[[output, j1]]
312 * cov_c[[b1 * r + j1, b2 * r + j2]]
313 * uk[[output, j2]];
314 }
315 }
316 }
317 }
318 var
319 }
320
321 pub(crate) fn full_output_variance(
322 &self,
323 atom: usize,
324 cov: ArrayView2<'_, f64>,
325 basis: ArrayView1<'_, f64>,
326 output: usize,
327 ) -> f64 {
328 let m = self.basis_sizes[atom];
329 let mut var = 0.0;
330 for b1 in 0..m {
331 let phi1 = basis[b1];
332 if phi1 == 0.0 {
333 continue;
334 }
335 for b2 in 0..m {
336 var += phi1 * basis[b2] * cov[[b1 * self.p + output, b2 * self.p + output]];
337 }
338 }
339 var
340 }
341
342 pub(crate) fn project_block_left_atom(
343 &self,
344 atom: usize,
345 t: ArrayView2<'_, f64>,
346 mut out: ndarray::ArrayViewMut2<'_, f64>,
347 ) {
348 let m = self.basis_sizes[atom];
349 let r = self.ranks[atom];
350 let ob = self.beta_offsets[atom];
351 let oc = self.border_offsets[atom];
352 for basis_col in 0..m {
353 let base_b = ob + basis_col * self.p;
354 let base_c = oc + basis_col * r;
355 match &self.frames[atom] {
356 None => {
357 for j in 0..r {
358 for c in 0..out.ncols() {
359 out[[base_c + j, c]] += t[[base_b + j, c]];
360 }
361 }
362 }
363 Some(uk) => {
364 for j in 0..r {
365 for c in 0..out.ncols() {
366 let mut acc = 0.0;
367 for i in 0..self.p {
368 acc += uk[[i, j]] * t[[base_b + i, c]];
369 }
370 out[[base_c + j, c]] += acc;
371 }
372 }
373 }
374 }
375 }
376 }
377}
378
379pub(crate) struct FramedDeviceArgs<'a> {
390 pub p: usize,
391 pub border_dim: usize,
392 pub border_offsets: &'a [usize],
393 pub ranks: &'a [usize],
394 pub basis_sizes: &'a [usize],
395 pub smooth_scaled_s: &'a [Array2<f64>],
396 pub frame_blocks: Vec<gam_solve::arrow_schur::FactoredFrameGBlock>,
397 pub rows: &'a [gam_solve::arrow_schur::ArrowRowBlock],
398}
399
400pub(crate) fn build_framed_device_sae_data(
401 args: FramedDeviceArgs<'_>,
402) -> gam_solve::arrow_schur::DeviceSaePcgData {
403 use gam_solve::arrow_schur::{DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock};
404 let FramedDeviceArgs {
405 p,
406 border_dim,
407 border_offsets,
408 ranks,
409 basis_sizes,
410 smooth_scaled_s,
411 frame_blocks,
412 rows,
413 } = args;
414 let n_atoms = ranks.len();
415 let mut smooth_blocks = Vec::with_capacity(n_atoms);
416 let mut smooth_ranks = Vec::with_capacity(n_atoms);
417 for k in 0..n_atoms {
418 smooth_blocks.push(DeviceSaeSmoothBlock {
419 global_offset: border_offsets[k],
420 factor_a: smooth_scaled_s[k].clone(),
421 });
422 smooth_ranks.push(ranks[k]);
423 }
424 let row_htbeta: Vec<Vec<f64>> = rows
425 .iter()
426 .map(|row| {
427 let (qi, w) = row.htbeta.dim();
428 if w != border_dim {
429 return Vec::new();
430 }
431 let mut flat = vec![0.0_f64; qi * w];
432 for c in 0..qi {
433 for a in 0..w {
434 flat[c * w + a] = row.htbeta[[c, a]];
435 }
436 }
437 flat
438 })
439 .collect();
440 DeviceSaePcgData {
441 p,
442 beta_dim: border_dim,
443 a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
446 local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
447 smooth_blocks,
448 sparse_g_blocks: Vec::new(),
449 frame: Some(DeviceSaeFrameData {
450 ranks: ranks.to_vec(),
451 basis_sizes: basis_sizes.to_vec(),
452 border_offsets: border_offsets.to_vec(),
453 frame_blocks,
454 smooth_ranks,
455 row_htbeta,
456 }),
457 }
458}
459
460pub(crate) const SAE_FRAME_RANK_CUTOFF: f64 = 1.0e-7;
465
466pub(crate) const SAE_FRAME_MIN_AUTO_OUTPUT_DIM: usize = 12;
470
471pub(crate) const SAE_FRAME_ACTIVATION_MARGIN: f64 = 0.25;
480
481#[derive(Debug, Clone)]
500pub struct GrassmannFrame {
501 frame: Array2<f64>,
503 gauge_singular_values: Array1<f64>,
506}
507
508impl GrassmannFrame {
509 pub fn output_dim(&self) -> usize {
511 self.frame.nrows()
512 }
513
514 pub fn rank(&self) -> usize {
516 self.frame.ncols()
517 }
518
519 pub fn gauge_singular_values(&self) -> &Array1<f64> {
524 &self.gauge_singular_values
525 }
526
527 pub fn frame(&self) -> ArrayView2<'_, f64> {
529 self.frame.view()
530 }
531
532 pub fn manifold_dimension(&self) -> usize {
537 let r = self.rank();
538 let p = self.output_dim();
539 r * (p - r)
540 }
541
542 pub(crate) fn from_oriented(
548 mut frame: Array2<f64>,
549 gauge_singular_values: Array1<f64>,
550 ) -> Self {
551 let (p, r) = frame.dim();
552 for col in 0..r {
553 let mut pivot_abs = 0.0_f64;
556 let mut pivot_val = 0.0_f64;
557 for row in 0..p {
558 let v = frame[[row, col]];
559 if v.abs() > pivot_abs {
560 pivot_abs = v.abs();
561 pivot_val = v;
562 }
563 }
564 if pivot_val < 0.0 {
565 for row in 0..p {
566 frame[[row, col]] = -frame[[row, col]];
567 }
568 }
569 }
570 Self {
571 frame,
572 gauge_singular_values,
573 }
574 }
575
576 pub fn polar_update(cross_moment: ArrayView2<'_, f64>) -> Result<Self, String> {
587 let (p, r) = cross_moment.dim();
588 if p == 0 || r == 0 {
589 return Err("GrassmannFrame::polar_update: cross-moment must be non-empty".into());
590 }
591 if r > p {
592 return Err(format!(
593 "GrassmannFrame::polar_update: frame rank r={r} cannot exceed output dim p={p}"
594 ));
595 }
596 let owned = cross_moment.to_owned();
597 let (u_opt, sv, vt_opt) = owned
598 .svd(true, true)
599 .map_err(|e| format!("GrassmannFrame::polar_update: SVD failed: {e}"))?;
600 let w = u_opt.ok_or_else(|| {
601 "GrassmannFrame::polar_update: thin SVD returned no left factor".to_string()
602 })?;
603 let vt = vt_opt.ok_or_else(|| {
604 "GrassmannFrame::polar_update: thin SVD returned no right factor".to_string()
605 })?;
606 let polar = fast_ab(&w, &vt);
609 Ok(Self::from_oriented(polar, sv))
610 }
611
612 pub fn reconstruct_decoder(&self, coords: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
617 if coords.ncols() != self.rank() {
618 return Err(format!(
619 "GrassmannFrame::reconstruct_decoder: coord cols {} must equal frame rank {}",
620 coords.ncols(),
621 self.rank()
622 ));
623 }
624 Ok(fast_abt(&coords.to_owned(), &self.frame))
625 }
626
627 pub fn project_decoder(&self, decoder: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
634 if decoder.ncols() != self.output_dim() {
635 return Err(format!(
636 "GrassmannFrame::project_decoder: decoder cols {} must equal output dim {}",
637 decoder.ncols(),
638 self.output_dim()
639 ));
640 }
641 Ok(fast_ab(&decoder.to_owned(), &self.frame))
642 }
643
644 pub fn max_principal_angle(&self, other: ArrayView2<'_, f64>) -> Result<f64, String> {
666 if other.nrows() != self.output_dim() {
667 return Err(format!(
668 "GrassmannFrame::max_principal_angle: other rows {} must equal output dim {}",
669 other.nrows(),
670 self.output_dim()
671 ));
672 }
673 let other_owned = other.to_owned();
674 let overlap = fast_atb(&self.frame, &other_owned);
675 let (_u, sv_cos, _vt) = overlap
676 .svd(false, false)
677 .map_err(|e| format!("GrassmannFrame::max_principal_angle: cos-SVD failed: {e}"))?;
678 let u_overlap = fast_ab(&self.frame, &overlap);
680 let v_perp = &other_owned - &u_overlap;
681 let (_u, sv_sin, _vt) = v_perp
682 .svd(false, false)
683 .map_err(|e| format!("GrassmannFrame::max_principal_angle: sin-SVD failed: {e}"))?;
684 let min_cos = sv_cos
690 .iter()
691 .copied()
692 .fold(1.0_f64, f64::min)
693 .clamp(0.0, 1.0);
694 let max_sin = sv_sin
695 .iter()
696 .copied()
697 .fold(0.0_f64, f64::max)
698 .clamp(0.0, 1.0);
699 Ok(max_sin.atan2(min_cos))
700 }
701}
702
703#[derive(Debug, Clone)]
709pub struct GrassmannCrossMoment {
710 moment: Array2<f64>,
711}
712
713impl GrassmannCrossMoment {
714 pub fn new(output_dim: usize, rank: usize) -> Self {
716 Self {
717 moment: Array2::<f64>::zeros((output_dim, rank)),
718 }
719 }
720
721 pub fn accumulate(
725 &mut self,
726 targets: ArrayView2<'_, f64>,
727 coords: ArrayView2<'_, f64>,
728 ) -> Result<(), String> {
729 if targets.ncols() != self.moment.nrows() || coords.ncols() != self.moment.ncols() {
730 return Err(format!(
731 "GrassmannCrossMoment::accumulate: expected targets (·,{}) and coords (·,{}); \
732 got (·,{}) and (·,{})",
733 self.moment.nrows(),
734 self.moment.ncols(),
735 targets.ncols(),
736 coords.ncols()
737 ));
738 }
739 if targets.nrows() != coords.nrows() {
740 return Err(format!(
741 "GrassmannCrossMoment::accumulate: targets rows {} must equal coords rows {}",
742 targets.nrows(),
743 coords.nrows()
744 ));
745 }
746 let block = fast_atb(&targets.to_owned(), &coords.to_owned());
747 self.moment += █
748 Ok(())
749 }
750
751 pub fn moment(&self) -> ArrayView2<'_, f64> {
753 self.moment.view()
754 }
755
756 pub fn polar_frame(&self) -> Result<GrassmannFrame, String> {
759 GrassmannFrame::polar_update(self.moment.view())
760 }
761}
762
763pub fn grassmann_recover_planted_span_angle(
775 targets: ArrayView2<'_, f64>,
776 coords: ArrayView2<'_, f64>,
777 planted: ArrayView2<'_, f64>,
778) -> Result<f64, String> {
779 let p = targets.ncols();
780 let r = coords.ncols();
781 if planted.dim() != (p, r) {
782 return Err(format!(
783 "grassmann_recover_planted_span_angle: planted frame must be ({p}, {r}); got {:?}",
784 planted.dim()
785 ));
786 }
787 let mut cross = GrassmannCrossMoment::new(p, r);
788 cross.accumulate(targets, coords)?;
789 let frame = cross.polar_frame()?;
790 frame.max_principal_angle(planted)
791}
792
793pub fn grassmann_assert_border_dim_invariant(term: &SaeManifoldTerm) -> Result<(), String> {
798 let expected: usize = term
799 .atoms
800 .iter()
801 .map(|a| a.basis_size() * a.border_frame_rank())
802 .sum();
803 let got = term.factored_border_dim();
804 if got != expected {
805 return Err(format!(
806 "grassmann border-dim invariant violated: factored_border_dim() = {got}, \
807 expected Σ M_k·r_k = {expected}"
808 ));
809 }
810 Ok(())
811}