use super::*;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SaeAtomBasisKind {
Duchon,
Periodic,
Sphere,
Torus,
Cylinder,
Linear,
EuclideanPatch,
Poincare,
Precomputed(String),
}
impl SaeAtomBasisKind {
pub(crate) fn latent_manifold(&self, latent_dim: usize) -> LatentManifold {
match self {
Self::Periodic => {
if latent_dim == 1 {
LatentManifold::Circle { period: 1.0 }
} else {
LatentManifold::Product(
(0..latent_dim)
.map(|_| LatentManifold::Circle { period: 1.0 })
.collect(),
)
}
}
Self::Sphere => LatentManifold::Product(vec![
LatentManifold::Interval {
lo: -std::f64::consts::FRAC_PI_2,
hi: std::f64::consts::FRAC_PI_2,
},
LatentManifold::Circle {
period: std::f64::consts::TAU,
},
]),
Self::Torus => {
if latent_dim == 1 {
LatentManifold::Circle { period: 1.0 }
} else {
LatentManifold::Product(
(0..latent_dim)
.map(|_| LatentManifold::Circle { period: 1.0 })
.collect(),
)
}
}
Self::Cylinder => LatentManifold::Product(vec![
LatentManifold::Circle { period: 1.0 },
LatentManifold::Euclidean,
]),
Self::Linear
| Self::Duchon
| Self::EuclideanPatch
| Self::Poincare
| Self::Precomputed(_) => LatentManifold::Euclidean,
}
}
pub(crate) fn projection_seed_grid(
&self,
latent_dim: usize,
resolution: usize,
) -> Option<Array2<f64>> {
match self {
Self::Periodic => torus_projection_seed_grid(latent_dim, resolution),
Self::Sphere if latent_dim == 2 => sphere_projection_seed_grid(resolution),
Self::Sphere => None,
Self::Torus => torus_projection_seed_grid(latent_dim, resolution),
Self::Cylinder if latent_dim == 2 => cylinder_projection_seed_grid(resolution),
Self::Cylinder => None,
Self::Linear
| Self::Duchon
| Self::EuclideanPatch
| Self::Poincare
| Self::Precomputed(_) => None,
}
}
}
pub(crate) fn sphere_projection_seed_grid(resolution: usize) -> Option<Array2<f64>> {
use std::f64::consts::PI;
let r = resolution.max(2);
let mut grid = Array2::<f64>::zeros((r * r, 2));
for i in 0..r {
let lat = -PI / 2.0 + PI * (i as f64 + 0.5) / r as f64;
for j in 0..r {
let lon = -PI + 2.0 * PI * (j as f64) / r as f64;
grid[[i * r + j, 0]] = lat;
grid[[i * r + j, 1]] = lon;
}
}
Some(grid)
}
pub(crate) fn cylinder_projection_seed_grid(resolution: usize) -> Option<Array2<f64>> {
let r = resolution.max(2);
let mut grid = Array2::<f64>::zeros((r, 2));
for i in 0..r {
grid[[i, 0]] = i as f64 / r as f64;
grid[[i, 1]] = 0.0;
}
Some(grid)
}
pub(crate) fn torus_projection_seed_grid(
latent_dim: usize,
resolution: usize,
) -> Option<Array2<f64>> {
if latent_dim == 0 || latent_dim >= usize::BITS as usize {
return None;
}
const MAX_GRID_POINTS: usize = 4096;
let min_points = 1usize << latent_dim;
if min_points > MAX_GRID_POINTS {
return None;
}
let requested = resolution.max(2);
let mut per_axis = requested;
while per_axis.saturating_pow(latent_dim as u32) > MAX_GRID_POINTS {
per_axis -= 1;
if per_axis < 2 {
return None;
}
}
let total: usize = (0..latent_dim).fold(1usize, |acc, _| acc.saturating_mul(per_axis));
let mut grid = Array2::<f64>::zeros((total, latent_dim));
let mut idx = vec![0usize; latent_dim];
for flat in 0..total {
for axis in 0..latent_dim {
grid[[flat, axis]] = idx[axis] as f64 / per_axis as f64;
}
for axis in (0..latent_dim).rev() {
idx[axis] += 1;
if idx[axis] < per_axis {
break;
}
idx[axis] = 0;
}
}
Some(grid)
}
#[derive(Clone, Copy, Debug)]
pub(crate) struct ArdAxisPrior {
pub(crate) value: f64,
pub(crate) grad: f64,
pub(crate) hess: f64,
pub(crate) sq_equiv: f64,
}
impl ArdAxisPrior {
pub(crate) fn eval(alpha: f64, t: f64, period: Option<f64>) -> Self {
match period {
None => Self {
value: 0.5 * alpha * t * t,
grad: alpha * t,
hess: alpha,
sq_equiv: t * t,
},
Some(p) => {
let kappa = std::f64::consts::TAU / p;
let (sin, cos) = (kappa * t).sin_cos();
let one_minus_cos = 1.0 - cos;
Self {
value: (alpha / (kappa * kappa)) * one_minus_cos,
grad: (alpha / kappa) * sin,
hess: alpha * cos,
sq_equiv: (2.0 / (kappa * kappa)) * one_minus_cos,
}
}
}
}
}
pub(crate) fn bessel_i0_scaled_poly(ax: f64) -> f64 {
let y = 3.75 / ax;
0.39894228
+ y * (0.01328592
+ y * (0.00225319
+ y * (-0.00157565
+ y * (0.00916281
+ y * (-0.02057706
+ y * (0.02635537 + y * (-0.01647633 + y * 0.00392377)))))))
}
pub(crate) fn bessel_i1_scaled_poly(ax: f64) -> f64 {
let y = 3.75 / ax;
0.39894228
+ y * (-0.03988024
+ y * (-0.00362018
+ y * (0.00163801
+ y * (-0.01031555
+ y * (0.02282967
+ y * (-0.02895312 + y * (0.01787654 - y * 0.00420059)))))))
}
pub(crate) fn bessel_i0(x: f64) -> f64 {
let ax = x.abs();
if ax < 3.75 {
let t = x / 3.75;
let t2 = t * t;
1.0 + t2
* (3.5156229
+ t2 * (3.0899424
+ t2 * (1.2067492 + t2 * (0.2659732 + t2 * (0.0360768 + t2 * 0.0045813)))))
} else {
(ax.exp() / ax.sqrt()) * bessel_i0_scaled_poly(ax)
}
}
pub(crate) fn bessel_i1(x: f64) -> f64 {
let ax = x.abs();
let value = if ax < 3.75 {
let t = x / 3.75;
let t2 = t * t;
ax * (0.5
+ t2 * (0.87890594
+ t2 * (0.51498869
+ t2 * (0.15084934 + t2 * (0.02658733 + t2 * (0.00301532 + t2 * 0.00032411))))))
} else {
(ax.exp() / ax.sqrt()) * bessel_i1_scaled_poly(ax)
};
if x < 0.0 { -value } else { value }
}
pub(crate) fn bessel_i0_log_and_ratio(eta: f64) -> (f64, f64) {
let ax = eta.abs();
if ax < 3.75 {
let i0 = bessel_i0(ax);
let i1 = bessel_i1(ax);
(i0.ln(), i1 / i0)
} else {
let poly0 = bessel_i0_scaled_poly(ax);
let poly1 = bessel_i1_scaled_poly(ax);
let log_i0 = ax - 0.5 * ax.ln() + poly0.ln();
let ratio = poly1 / poly0;
(log_i0, ratio)
}
}
#[derive(Debug, Clone)]
pub struct SaeManifoldAtom {
pub name: String,
pub basis_kind: SaeAtomBasisKind,
pub latent_dim: usize,
pub basis_values: Array2<f64>,
pub basis_jacobian: Array3<f64>,
pub decoder_coefficients: Array2<f64>,
pub smooth_penalty: Array2<f64>,
pub smooth_penalty_raw: Array2<f64>,
pub smooth_penalty_order: usize,
pub basis_evaluator: Option<Arc<dyn SaeBasisEvaluator>>,
pub basis_second_jet: Option<Arc<dyn SaeBasisSecondJet>>,
pub decoder_frame: Option<GrassmannFrame>,
pub homotopy_eta: f64,
pub chart_canonicalized: bool,
}
impl SaeManifoldAtom {
#[must_use = "build error must be handled"]
pub fn new(
name: impl Into<String>,
basis_kind: SaeAtomBasisKind,
latent_dim: usize,
basis_values: Array2<f64>,
basis_jacobian: Array3<f64>,
decoder_coefficients: Array2<f64>,
smooth_penalty: Array2<f64>,
) -> Result<Self, String> {
let n = basis_values.nrows();
let m = basis_values.ncols();
let p = decoder_coefficients.ncols();
if basis_jacobian.dim() != (n, m, latent_dim) {
return Err(format!(
"SaeManifoldAtom::new: basis_jacobian must be ({n}, {m}, {latent_dim}); got {:?}",
basis_jacobian.dim()
));
}
if decoder_coefficients.nrows() != m {
return Err(format!(
"SaeManifoldAtom::new: decoder rows {} must equal basis size {m}",
decoder_coefficients.nrows()
));
}
if smooth_penalty.dim() != (m, m) {
return Err(format!(
"SaeManifoldAtom::new: smooth penalty must be ({m}, {m}); got {:?}",
smooth_penalty.dim()
));
}
if p == 0 {
return Err("SaeManifoldAtom::new: decoder output dimension must be positive".into());
}
let smooth_penalty_order = smooth_penalty_nullity(&smooth_penalty)?;
let mut atom = Self {
name: name.into(),
basis_kind,
latent_dim,
basis_values,
decoder_coefficients,
smooth_penalty_raw: smooth_penalty.clone(),
smooth_penalty,
smooth_penalty_order,
basis_jacobian,
basis_evaluator: None,
basis_second_jet: None,
decoder_frame: None,
homotopy_eta: 1.0,
chart_canonicalized: false,
};
atom.refresh_intrinsic_smooth_penalty();
Ok(atom)
}
pub fn with_basis_evaluator(mut self, evaluator: Arc<dyn SaeBasisEvaluator>) -> Self {
self.basis_evaluator = Some(evaluator);
self.basis_second_jet = None;
self
}
pub fn with_basis_second_jet(mut self, evaluator: Arc<dyn SaeBasisSecondJet>) -> Self {
let base: Arc<dyn SaeBasisEvaluator> = evaluator.clone();
self.basis_evaluator = Some(base);
self.basis_second_jet = Some(evaluator);
self
}
pub fn reduce_basis_to_subspace(&mut self, q: &Array2<f64>) -> Result<(), String> {
let m = self.basis_size();
if q.nrows() != m {
return Err(format!(
"SaeManifoldAtom::reduce_basis_to_subspace: column map has {} rows, basis width {m}",
q.nrows()
));
}
let r = q.ncols();
if r == 0 || r > m {
return Err(format!(
"SaeManifoldAtom::reduce_basis_to_subspace: invalid retained rank {r} (basis width {m})"
));
}
let Some(inner) = self.basis_second_jet.clone() else {
return Err(
"SaeManifoldAtom::reduce_basis_to_subspace: requires an analytic second-jet \
evaluator to compose the reduced jets"
.to_string(),
);
};
let p = self.output_dim();
let d = self.latent_dim;
let phi_red = self.basis_values.dot(q);
let n = self.n_obs();
let mut jac_red = Array3::<f64>::zeros((n, r, d));
for axis in 0..d {
let slice = self.basis_jacobian.slice(s![.., .., axis]).to_owned();
let reduced = slice.dot(q);
for row in 0..n {
for col in 0..r {
jac_red[[row, col, axis]] = reduced[[row, col]];
}
}
}
let dec_red = q.t().dot(&self.decoder_coefficients);
if dec_red.dim() != (r, p) {
return Err(format!(
"SaeManifoldAtom::reduce_basis_to_subspace: reduced decoder dim {:?} != ({r}, {p})",
dec_red.dim()
));
}
let s_raw_red = q.t().dot(&self.smooth_penalty_raw).dot(q);
let order = smooth_penalty_nullity(&s_raw_red)?;
let reduced_eval = SubspaceReducedEvaluator::new(inner, q.clone())?;
let reduced_arc: Arc<dyn SaeBasisSecondJet> = Arc::new(reduced_eval);
let base: Arc<dyn SaeBasisEvaluator> = reduced_arc.clone();
self.basis_values = phi_red;
self.basis_jacobian = jac_red;
self.decoder_coefficients = dec_red;
self.smooth_penalty_raw = s_raw_red.clone();
self.smooth_penalty = s_raw_red;
self.smooth_penalty_order = order;
self.basis_evaluator = Some(base);
self.basis_second_jet = Some(reduced_arc);
self.decoder_frame = None;
self.refresh_intrinsic_smooth_penalty();
Ok(())
}
pub fn refresh_basis(&mut self, coords: ArrayView2<'_, f64>) -> Result<(), String> {
let Some(evaluator) = self.basis_evaluator.as_ref() else {
return Ok(());
};
let (phi, jet) = if self.homotopy_eta == 1.0 {
evaluator.evaluate(coords)?
} else {
let evaluated = evaluator.evaluate_phi_eta(coords, self.homotopy_eta)?;
(evaluated.phi, evaluated.jet)
};
if phi.dim() != self.basis_values.dim() {
return Err(format!(
"SaeManifoldAtom::refresh_basis: evaluator returned Phi {:?}, expected {:?}",
phi.dim(),
self.basis_values.dim()
));
}
if jet.dim() != self.basis_jacobian.dim() {
return Err(format!(
"SaeManifoldAtom::refresh_basis: evaluator returned jet {:?}, expected {:?}",
jet.dim(),
self.basis_jacobian.dim()
));
}
self.basis_values = phi;
self.basis_jacobian = jet;
Ok(())
}
pub fn n_obs(&self) -> usize {
self.basis_values.nrows()
}
pub fn basis_size(&self) -> usize {
self.basis_values.ncols()
}
pub fn output_dim(&self) -> usize {
self.decoder_coefficients.ncols()
}
pub fn border_frame_rank(&self) -> usize {
match &self.decoder_frame {
Some(frame) => frame.rank(),
None => self.output_dim(),
}
}
pub fn border_coeff_count(&self) -> usize {
self.basis_size() * self.border_frame_rank()
}
pub fn frame_manifold_dimension(&self) -> usize {
match &self.decoder_frame {
Some(frame) => frame.manifold_dimension(),
None => 0,
}
}
pub fn decoder_numerical_rank(&self) -> Result<usize, String> {
let p = self.output_dim();
if p == 0 || self.basis_size() == 0 {
return Ok(0);
}
let (_u, sv, _vt) = self
.decoder_coefficients
.svd(false, false)
.map_err(|e| format!("SaeManifoldAtom::decoder_numerical_rank: SVD failed: {e}"))?;
let max_sv = sv.iter().copied().fold(0.0_f64, f64::max);
if !(max_sv > 0.0) {
return Ok(0);
}
let tol = SAE_FRAME_RANK_CUTOFF * max_sv;
Ok(sv.iter().filter(|&&v| v > tol).count())
}
pub fn decoder_frame_activation_rank(&self) -> Result<Option<usize>, String> {
let p = self.output_dim();
if p == 0 || self.basis_size() == 0 {
return Ok(None);
}
if p < SAE_FRAME_MIN_AUTO_OUTPUT_DIM {
return Ok(None);
}
let numerical_rank = self.decoder_numerical_rank()?;
let r = numerical_rank.max(1).min(p);
let shrink_ok = (r as f64) <= (p as f64) * (1.0 - SAE_FRAME_ACTIVATION_MARGIN);
if !shrink_ok || p.saturating_sub(r) == 0 {
return Ok(None);
}
Ok(Some(r))
}
pub fn maybe_activate_decoder_frame(&mut self) -> Result<Option<usize>, String> {
let Some(r) = self.decoder_frame_activation_rank()? else {
self.decoder_frame = None;
return Ok(None);
};
let p = self.output_dim();
let (_w, sv, vt_opt) = self.decoder_coefficients.svd(false, true).map_err(|e| {
format!("SaeManifoldAtom::maybe_activate_decoder_frame: SVD failed: {e}")
})?;
let vt = vt_opt.ok_or_else(|| {
"SaeManifoldAtom::maybe_activate_decoder_frame: SVD returned no right factor"
.to_string()
})?;
let available = vt.nrows();
let r_eff = r.min(available);
if r_eff == 0 || p.saturating_sub(r_eff) == 0 {
self.decoder_frame = None;
return Ok(None);
}
let mut frame = Array2::<f64>::zeros((p, r_eff));
for col in 0..r_eff {
for row in 0..p {
frame[[row, col]] = vt[[col, row]];
}
}
let mut gauge = Array1::<f64>::zeros(r_eff);
for i in 0..r_eff {
gauge[i] = sv.get(i).copied().unwrap_or(0.0);
}
self.decoder_frame = Some(GrassmannFrame::from_oriented(frame, gauge));
let u_proj = self
.decoder_frame
.as_ref()
.expect("frame just set")
.frame()
.to_owned();
let c_proj = self.decoder_coefficients.dot(&u_proj);
self.decoder_coefficients = c_proj.dot(&u_proj.t());
Ok(Some(r_eff))
}
pub fn deactivate_decoder_frame(&mut self) {
self.decoder_frame = None;
}
pub fn factored_coordinates(&self) -> Result<Option<Array2<f64>>, String> {
match &self.decoder_frame {
Some(frame) => Ok(Some(
frame.project_decoder(self.decoder_coefficients.view())?,
)),
None => Ok(None),
}
}
pub fn reconstruct_decoder_coefficients(
&self,
coords: ArrayView2<'_, f64>,
) -> Result<Array2<f64>, String> {
let frame = self.decoder_frame.as_ref().ok_or_else(|| {
"SaeManifoldAtom::reconstruct_decoder_coefficients: no active frame".to_string()
})?;
frame.reconstruct_decoder(coords)
}
pub fn set_factored_coordinates(&mut self, coords: ArrayView2<'_, f64>) -> Result<(), String> {
let reconstructed = self.reconstruct_decoder_coefficients(coords)?;
if reconstructed.dim() != self.decoder_coefficients.dim() {
return Err(format!(
"SaeManifoldAtom::set_factored_coordinates: reconstructed decoder {:?} \
must match {:?}",
reconstructed.dim(),
self.decoder_coefficients.dim()
));
}
self.decoder_coefficients = reconstructed;
Ok(())
}
pub fn refresh_frame_from_cross_moment(
&mut self,
cross_moment: ArrayView2<'_, f64>,
) -> Result<(), String> {
if self.decoder_frame.is_none() {
return Err("SaeManifoldAtom::refresh_frame_from_cross_moment: no active frame".into());
}
let new_frame = GrassmannFrame::polar_update(cross_moment)?;
if new_frame.output_dim() != self.output_dim() {
return Err(format!(
"SaeManifoldAtom::refresh_frame_from_cross_moment: frame output dim {} \
must equal decoder output dim {}",
new_frame.output_dim(),
self.output_dim()
));
}
let coords = new_frame.project_decoder(self.decoder_coefficients.view())?;
self.decoder_coefficients = new_frame.reconstruct_decoder(coords.view())?;
self.decoder_frame = Some(new_frame);
Ok(())
}
pub fn decoded_row(&self, row: usize) -> Array1<f64> {
let p = self.output_dim();
let mut out = Array1::<f64>::zeros(p);
self.fill_decoded_row(row, out.as_slice_mut().expect("contiguous"));
out
}
pub fn fill_decoded_row(&self, row: usize, out: &mut [f64]) {
let p = self.output_dim();
let m = self.basis_size();
assert_eq!(out.len(), p);
for slot in out.iter_mut() {
*slot = 0.0;
}
for basis_col in 0..m {
let phi = self.basis_values[[row, basis_col]];
if phi == 0.0 {
continue;
}
for out_col in 0..p {
out[out_col] += phi * self.decoder_coefficients[[basis_col, out_col]];
}
}
}
pub fn decoded_derivative_row(&self, row: usize, latent_axis: usize) -> Array1<f64> {
let p = self.output_dim();
let mut out = Array1::<f64>::zeros(p);
self.fill_decoded_derivative_row(row, latent_axis, out.as_slice_mut().expect("contiguous"));
out
}
pub fn fill_decoded_derivative_row(&self, row: usize, latent_axis: usize, out: &mut [f64]) {
let p = self.output_dim();
let m = self.basis_size();
assert_eq!(out.len(), p);
for slot in out.iter_mut() {
*slot = 0.0;
}
for basis_col in 0..m {
let dphi = self.basis_jacobian[[row, basis_col, latent_axis]];
if dphi == 0.0 {
continue;
}
for out_col in 0..p {
out[out_col] += dphi * self.decoder_coefficients[[basis_col, out_col]];
}
}
}
pub fn refresh_intrinsic_smooth_penalty(&mut self) {
let m = self.basis_size();
if m == 0 || self.smooth_penalty_order == 0 || self.latent_dim != 1 {
self.smooth_penalty.assign(&self.smooth_penalty_raw);
return;
}
let n = self.n_obs();
let p = self.output_dim();
let beta = 0.5 - self.smooth_penalty_order as f64;
let mut act = vec![0.0_f64; m];
let mut num = vec![0.0_f64; m];
let mut deriv = vec![0.0_f64; p];
let hyperbolic = matches!(self.basis_kind, SaeAtomBasisKind::Poincare);
let linear_col = if hyperbolic && m >= 2 {
Some(1usize)
} else {
None
};
for row in 0..n {
self.fill_decoded_derivative_row(row, 0, &mut deriv);
let mut speed_sq = 0.0_f64;
for &d in deriv.iter() {
speed_sq += d * d;
}
if let Some(col) = linear_col {
let t = self.basis_values[[row, col]];
let lambda = 2.0 * t.cosh() * t.cosh();
if lambda.is_finite() && lambda > 0.0 {
speed_sq /= lambda * lambda;
}
}
for col in 0..m {
let phi = self.basis_values[[row, col]];
let w = phi * phi;
if w == 0.0 {
continue;
}
act[col] += w;
num[col] += w * speed_sq;
}
}
let mut speeds = vec![0.0_f64; m];
let mut log_acc = 0.0_f64;
let mut log_cnt = 0usize;
for col in 0..m {
let s = if act[col] > 0.0 {
num[col] / act[col]
} else {
0.0
};
speeds[col] = s;
if s > 0.0 && s.is_finite() {
log_acc += s.ln();
log_cnt += 1;
}
}
let center = if log_cnt > 0 {
(log_acc / log_cnt as f64).exp()
} else {
0.0
};
if !(center > 0.0 && center.is_finite()) {
self.smooth_penalty.assign(&self.smooth_penalty_raw);
return;
}
const RELATIVE_SPEED_FLOOR: f64 = 1.0e-6;
const RELATIVE_SPEED_CEIL: f64 = 1.0e6;
let mut root_w = vec![0.0_f64; m];
for col in 0..m {
let ratio = speeds[col] / center;
let ratio = if ratio.is_finite() {
ratio.clamp(RELATIVE_SPEED_FLOOR, RELATIVE_SPEED_CEIL)
} else {
RELATIVE_SPEED_CEIL
};
root_w[col] = ratio.powf(0.5 * beta);
}
for i in 0..m {
let ri = root_w[i];
for j in 0..m {
self.smooth_penalty[[i, j]] = ri * self.smooth_penalty_raw[[i, j]] * root_w[j];
}
}
}
}
pub(crate) fn smooth_penalty_nullity(s: &Array2<f64>) -> Result<usize, String> {
let m = s.ncols();
if m == 0 {
return Ok(0);
}
let mut sym = Array2::<f64>::zeros((m, m));
for i in 0..m {
for j in 0..m {
sym[[i, j]] = 0.5 * (s[[i, j]] + s[[j, i]]);
}
}
let (evals, _evecs) = sym
.eigh(Side::Lower)
.map_err(|e| format!("smooth_penalty_nullity: eigh failed: {e}"))?;
let max_eig = evals.iter().fold(0.0_f64, |acc, &v| acc.max(v));
if !(max_eig > 0.0) {
return Ok(0);
}
let tol = SAE_MANIFOLD_SPECTRAL_RANK_CUTOFF * max_eig;
Ok(evals.iter().filter(|&&v| v <= tol).count())
}