use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayView3, s};
use faer::Side;
use gam_linalg::faer_ndarray::FaerEigh;
use crate::inference::atom_lens::SAE_TRUST_ACTIVE_MASS_FLOOR;
use crate::manifold::{SaeAtomBasisKind, SaeManifoldTerm};
const SPECTRUM_REL_FLOOR: f64 = 1.0e-9;
const DEGENERACY_DIM_TOLERANCE: f64 = 0.5;
const EFFECTIVELY_LINEAR_NONLINEARITY_FLOOR: f64 = 0.1;
#[derive(Clone, Debug, PartialEq)]
pub struct AtomGeometryEntry {
pub name: String,
pub topology: String,
pub latent_dim: usize,
pub n_active: usize,
pub amplitude: f64,
pub effective_output_dim: Option<f64>,
pub ideal_curve_dim: Option<f64>,
pub degeneracy: Option<f64>,
pub nonlinearity: Option<f64>,
pub tangent_speed_mean: Option<f64>,
pub speed_cv: Option<f64>,
}
impl AtomGeometryEntry {
pub fn is_collapsed(&self) -> bool {
match (self.effective_output_dim, self.ideal_curve_dim) {
(Some(eff), Some(ideal)) => eff < ideal - DEGENERACY_DIM_TOLERANCE,
_ => false,
}
}
pub fn is_effectively_linear(&self) -> bool {
topology_is_curved(&self.topology)
&& self
.nonlinearity
.is_some_and(|nl| nl < EFFECTIVELY_LINEAR_NONLINEARITY_FLOOR)
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct AtomGeometryReport {
pub atoms: Vec<AtomGeometryEntry>,
}
impl AtomGeometryReport {
pub fn collapsed_atoms(&self) -> Vec<usize> {
self.atoms
.iter()
.enumerate()
.filter(|(_, e)| e.is_collapsed())
.map(|(i, _)| i)
.collect()
}
pub fn mean_effective_output_dim(&self) -> Option<f64> {
let (sum, count) = self
.atoms
.iter()
.filter_map(|e| e.effective_output_dim)
.fold((0.0_f64, 0usize), |(s, c), v| (s + v, c + 1));
(count > 0).then(|| sum / count as f64)
}
}
pub fn atom_geometry(
model: &SaeManifoldTerm,
assignments_override: Option<ArrayView2<'_, f64>>,
) -> AtomGeometryReport {
let assignments_owned;
let assignments = match assignments_override {
Some(view) => view,
None => {
assignments_owned = model.assignment.assignments();
assignments_owned.view()
}
};
let mut atoms = Vec::with_capacity(model.k_atoms());
for (atom_idx, atom) in model.atoms.iter().enumerate() {
let masses = assignments.slice(s![.., atom_idx]);
atoms.push(atom_geometry_entry_from_parts(
atom.name.clone(),
&atom.basis_kind,
atom.log_amplitude,
atom.basis_values.view(),
atom.basis_jacobian.view(),
atom.decoder_coefficients.view(),
masses,
));
}
AtomGeometryReport { atoms }
}
pub fn atom_geometry_entry_from_parts(
name: String,
basis_kind: &SaeAtomBasisKind,
log_amplitude: f64,
basis_values: ArrayView2<'_, f64>,
basis_jacobian: ArrayView3<'_, f64>,
decoder: ArrayView2<'_, f64>,
masses: ArrayView1<'_, f64>,
) -> AtomGeometryEntry {
let amplitude = log_amplitude.exp();
let m = basis_values.ncols();
let latent_dim = basis_jacobian.dim().2;
let topology = topology_name(basis_kind);
let gram = decoder.dot(&decoder.t());
let mut w_total = 0.0_f64;
let mut n_active = 0usize;
let mut m1 = Array1::<f64>::zeros(m); let mut m2 = Array2::<f64>::zeros((m, m)); let mut speed_w = 0.0_f64; let mut speed2_w = 0.0_f64;
let n = basis_values.nrows();
for row in 0..n {
let w = masses[row];
if !(w > SAE_TRUST_ACTIVE_MASS_FLOOR) {
continue;
}
n_active += 1;
w_total += w;
let phi = basis_values.slice(s![row, ..]);
for a in 0..m {
let pa = phi[a];
m1[a] += w * pa;
for b in 0..m {
m2[[a, b]] += w * pa * phi[b];
}
}
let mut speed_sq = 0.0_f64;
for c in 0..latent_dim {
let jc = basis_jacobian.slice(s![row, .., c]);
let gjc = gram.dot(&jc);
speed_sq += jc.dot(&gjc);
}
let speed = amplitude * speed_sq.max(0.0).sqrt();
speed_w += w * speed;
speed2_w += w * speed * speed;
}
let mut entry = AtomGeometryEntry {
name,
topology: topology.clone(),
latent_dim,
n_active,
amplitude,
effective_output_dim: None,
ideal_curve_dim: None,
degeneracy: None,
nonlinearity: None,
tangent_speed_mean: None,
speed_cv: None,
};
if w_total > 0.0 {
let mean = speed_w / w_total;
entry.tangent_speed_mean = Some(mean);
if mean > 0.0 {
let var = (speed2_w / w_total - mean * mean).max(0.0);
entry.speed_cv = Some(var.sqrt() / mean);
}
}
if n_active < 2 || w_total <= 0.0 {
return entry;
}
let mean_phi = &m1 / w_total;
let mut s_phi = Array2::<f64>::zeros((m, m));
for a in 0..m {
for b in 0..m {
s_phi[[a, b]] = m2[[a, b]] / w_total - mean_phi[a] * mean_phi[b];
}
}
if let Some(eigs) = decoded_cloud_spectrum(s_phi.view(), gram.view()) {
let total: f64 = eigs.iter().sum();
if total > 0.0 {
let lambda_max = eigs.iter().copied().fold(0.0_f64, f64::max);
let sum_sq: f64 = eigs.iter().map(|&l| l * l).sum();
if sum_sq > 0.0 {
let eff = total * total / sum_sq;
entry.effective_output_dim = Some(eff);
if let Some(ideal) = ideal_curve_dim(basis_kind, latent_dim, m, decoder.ncols()) {
entry.ideal_curve_dim = Some(ideal);
entry.degeneracy = Some((1.0 - eff / ideal).clamp(0.0, 1.0));
}
}
entry.nonlinearity = Some((1.0 - lambda_max / total).clamp(0.0, 1.0));
}
}
entry
}
fn decoded_cloud_spectrum(
s_phi: ArrayView2<'_, f64>,
gram: ArrayView2<'_, f64>,
) -> Option<Vec<f64>> {
let m = s_phi.nrows();
let (evals, evecs) = s_phi.eigh(Side::Lower).ok()?;
let sqrt_d: Vec<f64> = evals.iter().map(|&e| e.max(0.0).sqrt()).collect();
let mut s_half = Array2::<f64>::zeros((m, m));
for a in 0..m {
for b in 0..m {
let mut acc = 0.0_f64;
for r in 0..m {
acc += evecs[[a, r]] * sqrt_d[r] * evecs[[b, r]];
}
s_half[[a, b]] = acc;
}
}
let mc = s_half.dot(&gram).dot(&s_half);
let (cloud_evals, _) = mc.eigh(Side::Lower).ok()?;
let peak = cloud_evals.iter().copied().fold(0.0_f64, f64::max);
if !(peak > 0.0) {
return Some(Vec::new());
}
Some(
cloud_evals
.iter()
.copied()
.map(|l| l.max(0.0))
.filter(|&l| l > SPECTRUM_REL_FLOOR * peak)
.collect(),
)
}
fn ideal_curve_dim(kind: &SaeAtomBasisKind, latent_dim: usize, m: usize, p: usize) -> Option<f64> {
let d = latent_dim as f64;
let intrinsic = match kind {
SaeAtomBasisKind::Periodic => 2.0,
SaeAtomBasisKind::Torus => 2.0 * d,
SaeAtomBasisKind::Sphere => d + 1.0,
SaeAtomBasisKind::Cylinder => 2.0 + (d - 1.0).max(0.0),
SaeAtomBasisKind::Linear
| SaeAtomBasisKind::EuclideanPatch
| SaeAtomBasisKind::Poincare => d,
SaeAtomBasisKind::Duchon | SaeAtomBasisKind::Precomputed(_) => return None,
};
let basis_cap = (m as f64 - 1.0).max(1.0); Some(intrinsic.min(basis_cap).min(p as f64).max(1.0))
}
fn topology_name(kind: &SaeAtomBasisKind) -> String {
match kind {
SaeAtomBasisKind::Duchon => "duchon".to_string(),
SaeAtomBasisKind::Periodic => "periodic".to_string(),
SaeAtomBasisKind::Sphere => "sphere".to_string(),
SaeAtomBasisKind::Torus => "torus".to_string(),
SaeAtomBasisKind::Cylinder => "cylinder".to_string(),
SaeAtomBasisKind::Linear => "linear".to_string(),
SaeAtomBasisKind::EuclideanPatch => "euclidean_patch".to_string(),
SaeAtomBasisKind::Poincare => "poincare".to_string(),
SaeAtomBasisKind::Precomputed(tag) => format!("precomputed:{tag}"),
}
}
fn topology_is_curved(topology: &str) -> bool {
matches!(
topology,
"periodic" | "sphere" | "torus" | "cylinder" | "poincare" | "duchon"
)
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array3;
fn planted_circle(
n: usize,
radius: f64,
) -> (Array2<f64>, Array3<f64>, Array2<f64>, Array1<f64>) {
let m = 3usize;
let p = 4usize;
let mut phi = Array2::<f64>::zeros((n, m));
let mut jac = Array3::<f64>::zeros((n, m, 1));
for row in 0..n {
let theta = std::f64::consts::TAU * (row as f64) / (n as f64);
phi[[row, 0]] = 1.0;
phi[[row, 1]] = theta.cos();
phi[[row, 2]] = theta.sin();
jac[[row, 0, 0]] = 0.0;
jac[[row, 1, 0]] = -theta.sin();
jac[[row, 2, 0]] = theta.cos();
}
let mut dec = Array2::<f64>::zeros((m, p));
dec[[1, 0]] = radius;
dec[[2, 1]] = radius;
let masses = Array1::<f64>::ones(n);
(phi, jac, dec, masses)
}
#[test]
fn healthy_circle_reads_two_dimensional_and_curved() {
let (phi, jac, dec, masses) = planted_circle(64, 2.0);
let entry = atom_geometry_entry_from_parts(
"c".into(),
&SaeAtomBasisKind::Periodic,
0.0,
phi.view(),
jac.view(),
dec.view(),
masses.view(),
);
let eff = entry.effective_output_dim.expect("circle has a spectrum");
assert!(
(eff - 2.0).abs() < 1.0e-6,
"circle effective dim ≈ 2, got {eff}"
);
assert_eq!(entry.ideal_curve_dim, Some(2.0));
assert!(
entry.degeneracy.unwrap() < 1.0e-6,
"healthy circle is not degenerate"
);
assert!(!entry.is_collapsed());
let nl = entry.nonlinearity.unwrap();
assert!(
(nl - 0.5).abs() < 1.0e-6,
"circle nonlinearity ≈ ½, got {nl}"
);
assert!(!entry.is_effectively_linear());
assert!(
entry.speed_cv.unwrap() < 1.0e-6,
"circle arc speed is uniform"
);
assert!((entry.tangent_speed_mean.unwrap() - 2.0).abs() < 1.0e-6);
}
#[test]
fn collapsed_circle_reads_one_dimensional_and_degenerate() {
let (phi, jac, mut dec, masses) = planted_circle(64, 2.0);
dec.fill(0.0);
dec[[1, 0]] = 2.0;
dec[[2, 0]] = 2.0; let entry = atom_geometry_entry_from_parts(
"c".into(),
&SaeAtomBasisKind::Periodic,
0.0,
phi.view(),
jac.view(),
dec.view(),
masses.view(),
);
let eff = entry.effective_output_dim.unwrap();
assert!(eff < 1.5, "collapsed circle effective dim < 1.5, got {eff}");
assert!(
entry.degeneracy.unwrap() > 0.2,
"collapse must post material degeneracy"
);
assert!(
entry.is_collapsed(),
"a circle flattened to a line is collapsed"
);
}
#[test]
fn amplitude_scales_speed_but_not_shape() {
let (phi, jac, dec, masses) = planted_circle(64, 1.5);
let base = atom_geometry_entry_from_parts(
"a".into(),
&SaeAtomBasisKind::Periodic,
0.0,
phi.view(),
jac.view(),
dec.view(),
masses.view(),
);
let scaled = atom_geometry_entry_from_parts(
"a".into(),
&SaeAtomBasisKind::Periodic,
(3.0_f64).ln(),
phi.view(),
jac.view(),
dec.view(),
masses.view(),
);
assert!(
(scaled.tangent_speed_mean.unwrap() - 3.0 * base.tangent_speed_mean.unwrap()).abs()
< 1.0e-9
);
assert!(
(scaled.effective_output_dim.unwrap() - base.effective_output_dim.unwrap()).abs()
< 1.0e-9
);
assert!((scaled.nonlinearity.unwrap() - base.nonlinearity.unwrap()).abs() < 1.0e-9);
}
#[test]
fn inactive_atom_degrades_to_none_not_zero() {
let (phi, jac, dec, _) = planted_circle(64, 1.0);
let masses = Array1::<f64>::zeros(64); let entry = atom_geometry_entry_from_parts(
"dead".into(),
&SaeAtomBasisKind::Periodic,
0.0,
phi.view(),
jac.view(),
dec.view(),
masses.view(),
);
assert_eq!(entry.n_active, 0);
assert_eq!(entry.effective_output_dim, None);
assert_eq!(entry.degeneracy, None);
assert_eq!(entry.tangent_speed_mean, None);
assert!(!entry.is_collapsed() && !entry.is_effectively_linear());
}
}