use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use crate::linalg::faer_ndarray::{FaerSvd, fast_ab, fast_abt, fast_atb};
use crate::terms::sae::manifold::SaeManifoldTerm;
#[derive(Clone, Debug)]
pub(crate) struct FrameProjection {
pub(crate) p: usize,
pub(crate) beta_offsets: Vec<usize>,
pub(crate) border_offsets: Vec<usize>,
pub(crate) basis_sizes: Vec<usize>,
pub(crate) ranks: Vec<usize>,
frames: Vec<Option<Array2<f64>>>,
}
impl FrameProjection {
pub(crate) fn new(term: &SaeManifoldTerm) -> Self {
Self {
p: term.output_dim(),
beta_offsets: term.beta_offsets(),
border_offsets: term.factored_border_offsets(),
basis_sizes: term.atoms.iter().map(|atom| atom.basis_size()).collect(),
ranks: term
.atoms
.iter()
.map(|atom| atom.border_frame_rank())
.collect(),
frames: term
.atoms
.iter()
.map(|atom| {
atom.decoder_frame
.as_ref()
.map(|frame| frame.frame().to_owned())
})
.collect(),
}
}
pub(crate) fn beta_dim(&self) -> usize {
self.basis_sizes.iter().sum::<usize>() * self.p
}
pub(crate) fn border_dim(&self) -> usize {
self.basis_sizes
.iter()
.zip(&self.ranks)
.map(|(m, r)| m * r)
.sum()
}
pub(crate) fn lift_border_vec(&self, border: ArrayView1<'_, f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(self.beta_dim());
for atom in 0..self.basis_sizes.len() {
self.lift_atom_vec_into(atom, border, out.view_mut());
}
out
}
pub(crate) fn project_border_vec(&self, beta: ArrayView1<'_, f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(self.border_dim());
for atom in 0..self.basis_sizes.len() {
self.project_atom_vec_into(atom, beta, out.view_mut(), 1.0);
}
out
}
pub(crate) fn lift_block(&self, atom: usize, block: ArrayView2<'_, f64>) -> Array2<f64> {
let m = self.basis_sizes[atom];
let r = self.ranks[atom];
if self.frames[atom].is_none() {
return block.to_owned();
}
let uk = self.frames[atom].as_ref().expect("framed atom has a frame");
let mut out = Array2::<f64>::zeros((m * self.p, m * self.p));
for b1 in 0..m {
for b2 in 0..m {
for c1 in 0..self.p {
for c2 in 0..self.p {
let mut acc = 0.0;
for j1 in 0..r {
for j2 in 0..r {
acc +=
uk[[c1, j1]] * block[[b1 * r + j1, b2 * r + j2]] * uk[[c2, j2]];
}
}
out[[b1 * self.p + c1, b2 * self.p + c2]] = acc;
}
}
}
}
out
}
pub(crate) fn project_block(&self, hbb: ArrayView2<'_, f64>) -> Array2<f64> {
let t = self.project_rows(hbb);
let mut out = Array2::<f64>::zeros((self.border_dim(), self.border_dim()));
for atom in 0..self.basis_sizes.len() {
self.project_block_left_atom(atom, t.view(), out.view_mut());
}
out
}
pub(crate) fn project_rows(&self, block: ArrayView2<'_, f64>) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((block.nrows(), self.border_dim()));
for row in 0..block.nrows() {
let projected = self.project_border_vec(block.row(row));
out.row_mut(row).assign(&projected);
}
out
}
pub(crate) fn atom_border_range(&self, atom: usize) -> std::ops::Range<usize> {
let start = self.border_offsets[atom];
start..start + self.basis_sizes[atom] * self.ranks[atom]
}
pub(crate) fn lift_axis_into(
&self,
out: &mut Array1<f64>,
atom: usize,
basis_col: usize,
frame_col: usize,
) {
let base = self.beta_offsets[atom] + basis_col * self.p;
match &self.frames[atom] {
None => out[base + frame_col] = 1.0,
Some(uk) => {
for out_col in 0..self.p {
out[base + out_col] = uk[[out_col, frame_col]];
}
}
}
}
pub(crate) fn lift_local_axis_into(
&self,
out: &mut Array1<f64>,
atom: usize,
basis_col: usize,
frame_col: usize,
) {
let base = basis_col * self.p;
match &self.frames[atom] {
None => out[base + frame_col] = 1.0,
Some(uk) => {
for out_col in 0..self.p {
out[base + out_col] = uk[[out_col, frame_col]];
}
}
}
}
pub(crate) fn project_atom_vec_into(
&self,
atom: usize,
beta: ArrayView1<'_, f64>,
mut out: ndarray::ArrayViewMut1<'_, f64>,
scale: f64,
) {
let m = self.basis_sizes[atom];
let r = self.ranks[atom];
let ob = self.beta_offsets[atom];
let oc = self.border_offsets[atom];
for basis_col in 0..m {
let base_b = ob + basis_col * self.p;
let base_c = oc + basis_col * r;
match &self.frames[atom] {
None => {
for j in 0..r {
out[base_c + j] += scale * beta[base_b + j];
}
}
Some(uk) => {
for j in 0..r {
let mut acc = 0.0;
for i in 0..self.p {
acc += uk[[i, j]] * beta[base_b + i];
}
out[base_c + j] += scale * acc;
}
}
}
}
}
pub(crate) fn project_local_atom_vec_into(
&self,
atom: usize,
beta: ArrayView1<'_, f64>,
out: ndarray::ArrayViewMut1<'_, f64>,
scale: f64,
) {
self.project_atom_vec_into_with_base(atom, beta, out, scale, 0);
}
pub(crate) fn project_atom_vec_into_with_base(
&self,
atom: usize,
beta: ArrayView1<'_, f64>,
mut out: ndarray::ArrayViewMut1<'_, f64>,
scale: f64,
beta_base_offset: usize,
) {
let m = self.basis_sizes[atom];
let r = self.ranks[atom];
let oc = self.border_offsets[atom];
for basis_col in 0..m {
let base_b = beta_base_offset + basis_col * self.p;
let base_c = oc + basis_col * r;
match &self.frames[atom] {
None => {
for j in 0..r {
out[base_c + j] += scale * beta[base_b + j];
}
}
Some(uk) => {
for j in 0..r {
let mut acc = 0.0;
for i in 0..self.p {
acc += uk[[i, j]] * beta[base_b + i];
}
out[base_c + j] += scale * acc;
}
}
}
}
}
pub(crate) fn lift_atom_vec_into(
&self,
atom: usize,
border: ArrayView1<'_, f64>,
mut out: ndarray::ArrayViewMut1<'_, f64>,
) {
let m = self.basis_sizes[atom];
let r = self.ranks[atom];
let ob = self.beta_offsets[atom];
let oc = self.border_offsets[atom];
for basis_col in 0..m {
let base_b = ob + basis_col * self.p;
let base_c = oc + basis_col * r;
match &self.frames[atom] {
None => {
for i in 0..self.p {
out[base_b + i] = border[base_c + i];
}
}
Some(uk) => {
for i in 0..self.p {
let mut acc = 0.0;
for j in 0..r {
acc += uk[[i, j]] * border[base_c + j];
}
out[base_b + i] = acc;
}
}
}
}
}
pub(crate) fn accumulate_output_project(
&self,
atom: usize,
c_base: usize,
output: usize,
value: f64,
out: &mut [f64],
) {
match &self.frames[atom] {
None => out[c_base + output] += value,
Some(uk) => {
let rank = self.ranks[atom];
let frame_row = uk.row(output);
let frame_slice = frame_row.as_slice().expect("frame rows are contiguous");
let out_slice = &mut out[c_base..c_base + rank];
for (slot, &u) in out_slice.iter_mut().zip(frame_slice.iter()) {
*slot += value * u;
}
}
}
}
pub(crate) fn output_variance(
&self,
atom: usize,
cov_c: ArrayView2<'_, f64>,
basis: ArrayView1<'_, f64>,
output: usize,
) -> f64 {
let Some(uk) = &self.frames[atom] else {
return self.full_output_variance(atom, cov_c, basis, output);
};
let m = self.basis_sizes[atom];
let r = self.ranks[atom];
let mut var = 0.0;
for b1 in 0..m {
let phi1 = basis[b1];
if phi1 == 0.0 {
continue;
}
for b2 in 0..m {
let phi2 = basis[b2];
if phi2 == 0.0 {
continue;
}
for j1 in 0..r {
for j2 in 0..r {
var += phi1
* phi2
* uk[[output, j1]]
* cov_c[[b1 * r + j1, b2 * r + j2]]
* uk[[output, j2]];
}
}
}
}
var
}
pub(crate) fn full_output_variance(
&self,
atom: usize,
cov: ArrayView2<'_, f64>,
basis: ArrayView1<'_, f64>,
output: usize,
) -> f64 {
let m = self.basis_sizes[atom];
let mut var = 0.0;
for b1 in 0..m {
let phi1 = basis[b1];
if phi1 == 0.0 {
continue;
}
for b2 in 0..m {
var += phi1 * basis[b2] * cov[[b1 * self.p + output, b2 * self.p + output]];
}
}
var
}
pub(crate) fn project_block_left_atom(
&self,
atom: usize,
t: ArrayView2<'_, f64>,
mut out: ndarray::ArrayViewMut2<'_, f64>,
) {
let m = self.basis_sizes[atom];
let r = self.ranks[atom];
let ob = self.beta_offsets[atom];
let oc = self.border_offsets[atom];
for basis_col in 0..m {
let base_b = ob + basis_col * self.p;
let base_c = oc + basis_col * r;
match &self.frames[atom] {
None => {
for j in 0..r {
for c in 0..out.ncols() {
out[[base_c + j, c]] += t[[base_b + j, c]];
}
}
}
Some(uk) => {
for j in 0..r {
for c in 0..out.ncols() {
let mut acc = 0.0;
for i in 0..self.p {
acc += uk[[i, j]] * t[[base_b + i, c]];
}
out[[base_c + j, c]] += acc;
}
}
}
}
}
}
}
pub(crate) const SAE_FRAME_RANK_CUTOFF: f64 = 1.0e-7;
pub(crate) const SAE_FRAME_MIN_AUTO_OUTPUT_DIM: usize = 12;
pub(crate) const SAE_FRAME_ACTIVATION_MARGIN: f64 = 0.25;
#[derive(Debug, Clone)]
pub struct GrassmannFrame {
frame: Array2<f64>,
gauge_singular_values: Array1<f64>,
}
impl GrassmannFrame {
pub fn output_dim(&self) -> usize {
self.frame.nrows()
}
pub fn rank(&self) -> usize {
self.frame.ncols()
}
pub fn gauge_singular_values(&self) -> &Array1<f64> {
&self.gauge_singular_values
}
pub fn frame(&self) -> ArrayView2<'_, f64> {
self.frame.view()
}
pub fn manifold_dimension(&self) -> usize {
let r = self.rank();
let p = self.output_dim();
r * (p - r)
}
pub(crate) fn from_oriented(
mut frame: Array2<f64>,
gauge_singular_values: Array1<f64>,
) -> Self {
let (p, r) = frame.dim();
for col in 0..r {
let mut pivot_abs = 0.0_f64;
let mut pivot_val = 0.0_f64;
for row in 0..p {
let v = frame[[row, col]];
if v.abs() > pivot_abs {
pivot_abs = v.abs();
pivot_val = v;
}
}
if pivot_val < 0.0 {
for row in 0..p {
frame[[row, col]] = -frame[[row, col]];
}
}
}
Self {
frame,
gauge_singular_values,
}
}
pub fn polar_update(cross_moment: ArrayView2<'_, f64>) -> Result<Self, String> {
let (p, r) = cross_moment.dim();
if p == 0 || r == 0 {
return Err("GrassmannFrame::polar_update: cross-moment must be non-empty".into());
}
if r > p {
return Err(format!(
"GrassmannFrame::polar_update: frame rank r={r} cannot exceed output dim p={p}"
));
}
let owned = cross_moment.to_owned();
let (u_opt, sv, vt_opt) = owned
.svd(true, true)
.map_err(|e| format!("GrassmannFrame::polar_update: SVD failed: {e}"))?;
let w = u_opt.ok_or_else(|| {
"GrassmannFrame::polar_update: thin SVD returned no left factor".to_string()
})?;
let vt = vt_opt.ok_or_else(|| {
"GrassmannFrame::polar_update: thin SVD returned no right factor".to_string()
})?;
let polar = fast_ab(&w, &vt);
Ok(Self::from_oriented(polar, sv))
}
pub fn reconstruct_decoder(&self, coords: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
if coords.ncols() != self.rank() {
return Err(format!(
"GrassmannFrame::reconstruct_decoder: coord cols {} must equal frame rank {}",
coords.ncols(),
self.rank()
));
}
Ok(fast_abt(&coords.to_owned(), &self.frame))
}
pub fn project_decoder(&self, decoder: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
if decoder.ncols() != self.output_dim() {
return Err(format!(
"GrassmannFrame::project_decoder: decoder cols {} must equal output dim {}",
decoder.ncols(),
self.output_dim()
));
}
Ok(fast_ab(&decoder.to_owned(), &self.frame))
}
pub fn max_principal_angle(&self, other: ArrayView2<'_, f64>) -> Result<f64, String> {
if other.nrows() != self.output_dim() {
return Err(format!(
"GrassmannFrame::max_principal_angle: other rows {} must equal output dim {}",
other.nrows(),
self.output_dim()
));
}
let other_owned = other.to_owned();
let overlap = fast_atb(&self.frame, &other_owned);
let (_u, sv_cos, _vt) = overlap
.svd(false, false)
.map_err(|e| format!("GrassmannFrame::max_principal_angle: cos-SVD failed: {e}"))?;
let u_overlap = fast_ab(&self.frame, &overlap);
let v_perp = &other_owned - &u_overlap;
let (_u, sv_sin, _vt) = v_perp
.svd(false, false)
.map_err(|e| format!("GrassmannFrame::max_principal_angle: sin-SVD failed: {e}"))?;
let min_cos = sv_cos
.iter()
.copied()
.fold(1.0_f64, f64::min)
.clamp(0.0, 1.0);
let max_sin = sv_sin
.iter()
.copied()
.fold(0.0_f64, f64::max)
.clamp(0.0, 1.0);
Ok(max_sin.atan2(min_cos))
}
}
#[derive(Debug, Clone)]
pub struct GrassmannCrossMoment {
moment: Array2<f64>,
}
impl GrassmannCrossMoment {
pub fn new(output_dim: usize, rank: usize) -> Self {
Self {
moment: Array2::<f64>::zeros((output_dim, rank)),
}
}
pub fn accumulate(
&mut self,
targets: ArrayView2<'_, f64>,
coords: ArrayView2<'_, f64>,
) -> Result<(), String> {
if targets.ncols() != self.moment.nrows() || coords.ncols() != self.moment.ncols() {
return Err(format!(
"GrassmannCrossMoment::accumulate: expected targets (·,{}) and coords (·,{}); \
got (·,{}) and (·,{})",
self.moment.nrows(),
self.moment.ncols(),
targets.ncols(),
coords.ncols()
));
}
if targets.nrows() != coords.nrows() {
return Err(format!(
"GrassmannCrossMoment::accumulate: targets rows {} must equal coords rows {}",
targets.nrows(),
coords.nrows()
));
}
let block = fast_atb(&targets.to_owned(), &coords.to_owned());
self.moment += █
Ok(())
}
pub fn moment(&self) -> ArrayView2<'_, f64> {
self.moment.view()
}
pub fn polar_frame(&self) -> Result<GrassmannFrame, String> {
GrassmannFrame::polar_update(self.moment.view())
}
}
pub fn grassmann_recover_planted_span_angle(
targets: ArrayView2<'_, f64>,
coords: ArrayView2<'_, f64>,
planted: ArrayView2<'_, f64>,
) -> Result<f64, String> {
let p = targets.ncols();
let r = coords.ncols();
if planted.dim() != (p, r) {
return Err(format!(
"grassmann_recover_planted_span_angle: planted frame must be ({p}, {r}); got {:?}",
planted.dim()
));
}
let mut cross = GrassmannCrossMoment::new(p, r);
cross.accumulate(targets, coords)?;
let frame = cross.polar_frame()?;
frame.max_principal_angle(planted)
}
pub fn grassmann_assert_border_dim_invariant(term: &SaeManifoldTerm) -> Result<(), String> {
let expected: usize = term
.atoms
.iter()
.map(|a| a.basis_size() * a.border_frame_rank())
.sum();
let got = term.factored_border_dim();
if got != expected {
return Err(format!(
"grassmann border-dim invariant violated: factored_border_dim() = {got}, \
expected Σ M_k·r_k = {expected}"
));
}
Ok(())
}