use ndarray::{Array1, Array2, Array3, Array4, ArrayView1, ArrayView2, ArrayView3, ArrayView4, s};
use std::sync::Arc;
use crate::solver::arrow_schur::{
ArrowRowBlock, ArrowSchurError, ArrowSchurSystem, ArrowSolveOptions, BetaPenaltyOp,
CompositePenaltyOp, DensePenaltyOp, KroneckerPenaltyOp, SparseBlockKroneckerPenaltyOp,
SparseGBlock, StreamingArrowSchur, solve_streaming_reduced_beta,
};
use crate::terms::analytic_penalties::{
ARDPenalty, AnalyticPenalty, AnalyticPenaltyKind, AnalyticPenaltyRegistry,
IBPAssignmentPenalty, IsometryPenalty, MechanismSparsityPenalty, PenaltyTier, PsiSlice,
SoftmaxAssignmentSparsityPenalty,
};
use crate::terms::latent_coord::{LatentCoordValues, LatentIdMode, LatentManifold};
const SAE_MANIFOLD_ARMIJO_C1: f64 = 1.0e-4;
const SAE_MANIFOLD_MAX_LINESEARCH_HALVINGS: usize = 12;
#[derive(Debug, Clone)]
pub enum ScheduleKind {
Geometric { rate: f64 },
Linear { steps: usize },
ReciprocalIter,
}
#[derive(Debug, Clone)]
pub struct GumbelTemperatureSchedule {
pub tau_start: f64,
pub tau_min: f64,
pub decay: ScheduleKind,
pub iter_count: usize,
}
impl GumbelTemperatureSchedule {
#[must_use = "build error must be handled"]
pub fn new(tau_start: f64, tau_min: f64, decay: ScheduleKind) -> Result<Self, String> {
let sched = Self {
tau_start,
tau_min,
decay,
iter_count: 0,
};
sched.validate()?;
Ok(sched)
}
pub fn validate(&self) -> Result<(), String> {
if !(self.tau_start.is_finite() && self.tau_start > 0.0) {
return Err(format!(
"GumbelTemperatureSchedule: tau_start must be finite and positive; got {}",
self.tau_start
));
}
if !(self.tau_min.is_finite() && self.tau_min > 0.0) {
return Err(format!(
"GumbelTemperatureSchedule: tau_min must be finite and positive; got {}",
self.tau_min
));
}
if self.tau_min > self.tau_start {
return Err(format!(
"GumbelTemperatureSchedule: tau_min ({}) cannot exceed tau_start ({})",
self.tau_min, self.tau_start
));
}
match self.decay {
ScheduleKind::Geometric { rate } => {
if !(rate.is_finite() && rate > 0.0 && rate < 1.0) {
return Err(format!(
"GumbelTemperatureSchedule::Geometric: rate must be in (0, 1); got {rate}"
));
}
}
ScheduleKind::Linear { steps } => {
if steps == 0 {
return Err("GumbelTemperatureSchedule::Linear: steps must be positive".into());
}
}
ScheduleKind::ReciprocalIter => {}
}
Ok(())
}
pub fn current_tau(&self, iter: usize) -> f64 {
let raw = match self.decay {
ScheduleKind::Geometric { rate } => self.tau_start * rate.powf(iter as f64),
ScheduleKind::Linear { steps } => {
if iter >= steps {
self.tau_min
} else {
let frac = iter as f64 / steps as f64;
self.tau_start + frac * (self.tau_min - self.tau_start)
}
}
ScheduleKind::ReciprocalIter => self.tau_start / (1.0 + iter as f64),
};
raw.max(self.tau_min)
}
pub fn step(&mut self) -> f64 {
let tau = self.current_tau(self.iter_count);
self.iter_count += 1;
tau
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum SearchStrategy {
Fixed,
ExponentialSweep { values: Vec<f64> },
}
impl SearchStrategy {
#[must_use]
pub fn is_fixed(&self) -> bool {
matches!(self, Self::Fixed)
}
#[must_use]
pub fn sweep_values(&self) -> Option<&[f64]> {
match self {
Self::Fixed => None,
Self::ExponentialSweep { values } => Some(values),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SaeAtomBasisKind {
Duchon,
Periodic,
Sphere,
Torus,
EuclideanPatch,
Precomputed(String),
}
impl SaeAtomBasisKind {
fn latent_manifold(&self, latent_dim: usize) -> LatentManifold {
match self {
Self::Periodic => {
if latent_dim == 1 {
LatentManifold::Circle { period: 1.0 }
} else {
LatentManifold::Product(
(0..latent_dim)
.map(|_| LatentManifold::Circle { period: 1.0 })
.collect(),
)
}
}
Self::Sphere => LatentManifold::Product(vec![
LatentManifold::Interval {
lo: -std::f64::consts::FRAC_PI_2,
hi: std::f64::consts::FRAC_PI_2,
},
LatentManifold::Circle {
period: std::f64::consts::TAU,
},
]),
Self::Torus => {
if latent_dim == 1 {
LatentManifold::Circle { period: 1.0 }
} else {
LatentManifold::Product(
(0..latent_dim)
.map(|_| LatentManifold::Circle { period: 1.0 })
.collect(),
)
}
}
Self::Duchon | Self::EuclideanPatch | Self::Precomputed(_) => LatentManifold::Euclidean,
}
}
}
pub trait SaeBasisEvaluator: Send + Sync + std::fmt::Debug {
fn evaluate(&self, coords: ArrayView2<'_, f64>) -> Result<(Array2<f64>, Array3<f64>), String>;
fn second_jet_dyn(&self, _coords: ArrayView2<'_, f64>) -> Option<Result<Array4<f64>, String>> {
None
}
}
pub trait SaeBasisSecondJet: SaeBasisEvaluator {
fn second_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array4<f64>, String>;
}
#[derive(Debug, Clone)]
pub struct PeriodicHarmonicEvaluator {
pub num_basis: usize,
}
impl PeriodicHarmonicEvaluator {
pub fn new(num_basis: usize) -> Result<Self, String> {
if num_basis == 0 || num_basis % 2 == 0 {
return Err(format!(
"PeriodicHarmonicEvaluator requires odd num_basis >= 1; got {num_basis}"
));
}
Ok(Self { num_basis })
}
}
impl SaeBasisEvaluator for PeriodicHarmonicEvaluator {
fn second_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array4<f64>, String>> {
Some(<Self as SaeBasisSecondJet>::second_jet(self, coords))
}
fn evaluate(&self, coords: ArrayView2<'_, f64>) -> Result<(Array2<f64>, Array3<f64>), String> {
let n = coords.nrows();
let d = coords.ncols();
if d != 1 {
return Err(format!(
"PeriodicHarmonicEvaluator: expected latent_dim == 1, got {d}"
));
}
let m = self.num_basis;
let num_harmonics = (m - 1) / 2;
let two_pi = 2.0 * std::f64::consts::PI;
let mut phi = Array2::<f64>::zeros((n, m));
let mut jet = Array3::<f64>::zeros((n, m, 1));
for row in 0..n {
let t = coords[[row, 0]];
phi[[row, 0]] = 1.0;
for h in 1..=num_harmonics {
let angle = two_pi * (h as f64) * t;
let s = angle.sin();
let c = angle.cos();
let s_idx = 2 * h - 1;
let c_idx = 2 * h;
phi[[row, s_idx]] = s;
phi[[row, c_idx]] = c;
jet[[row, s_idx, 0]] = two_pi * (h as f64) * c;
jet[[row, c_idx, 0]] = -two_pi * (h as f64) * s;
}
}
Ok((phi, jet))
}
}
impl SaeBasisSecondJet for PeriodicHarmonicEvaluator {
fn second_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array4<f64>, String> {
let n = coords.nrows();
let d = coords.ncols();
if d != 1 {
return Err(format!(
"PeriodicHarmonicEvaluator::second_jet: expected latent_dim == 1, got {d}"
));
}
let m = self.num_basis;
let num_harmonics = (m - 1) / 2;
let two_pi = 2.0 * std::f64::consts::PI;
let mut h = Array4::<f64>::zeros((n, m, 1, 1));
for row in 0..n {
let t = coords[[row, 0]];
for k in 1..=num_harmonics {
let freq = two_pi * (k as f64);
let freq2 = freq * freq;
let angle = freq * t;
let s = angle.sin();
let c = angle.cos();
let s_idx = 2 * k - 1;
let c_idx = 2 * k;
h[[row, s_idx, 0, 0]] = -freq2 * s;
h[[row, c_idx, 0, 0]] = -freq2 * c;
}
}
Ok(h)
}
}
#[derive(Debug, Clone)]
pub struct RawPeriodicCircleEvaluator {
pub latent_dim: usize,
}
impl RawPeriodicCircleEvaluator {
pub fn new(latent_dim: usize) -> Result<Self, String> {
if latent_dim == 0 {
return Err("RawPeriodicCircleEvaluator requires latent_dim >= 1".to_string());
}
Ok(Self { latent_dim })
}
}
impl SaeBasisEvaluator for RawPeriodicCircleEvaluator {
fn evaluate(&self, coords: ArrayView2<'_, f64>) -> Result<(Array2<f64>, Array3<f64>), String> {
if coords.ncols() != self.latent_dim {
return Err(format!(
"RawPeriodicCircleEvaluator: expected latent_dim {}, got {}",
self.latent_dim,
coords.ncols()
));
}
let n = coords.nrows();
let mut phi = Array2::<f64>::zeros((n, 2));
let mut jet = Array3::<f64>::zeros((n, 2, self.latent_dim));
for row in 0..n {
let t = coords[[row, 0]];
phi[[row, 0]] = t.cos();
phi[[row, 1]] = t.sin();
jet[[row, 0, 0]] = -t.sin();
jet[[row, 1, 0]] = t.cos();
}
Ok((phi, jet))
}
}
#[derive(Debug, Clone)]
pub struct SphereChartEvaluator;
impl SaeBasisEvaluator for SphereChartEvaluator {
fn second_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array4<f64>, String>> {
Some(<Self as SaeBasisSecondJet>::second_jet(self, coords))
}
fn evaluate(&self, coords: ArrayView2<'_, f64>) -> Result<(Array2<f64>, Array3<f64>), String> {
if coords.ncols() != 2 {
return Err(format!(
"SphereChartEvaluator expects latent_dim == 2, got {}",
coords.ncols()
));
}
let n = coords.nrows();
let mut phi = Array2::<f64>::zeros((n, 7));
let mut jet = Array3::<f64>::zeros((n, 7, 2));
for row in 0..n {
let raw_lat = coords[[row, 0]];
let lat = raw_lat.clamp(-std::f64::consts::FRAC_PI_2, std::f64::consts::FRAC_PI_2);
let lat_active =
raw_lat > -std::f64::consts::FRAC_PI_2 && raw_lat < std::f64::consts::FRAC_PI_2;
let chain_lat = if lat_active { 1.0 } else { 0.0 };
let lon = coords[[row, 1]];
let clat = lat.cos();
let slat = lat.sin();
let clon = lon.cos();
let slon = lon.sin();
let x = clat * clon;
let y = clat * slon;
let z = slat;
phi[[row, 0]] = 1.0;
phi[[row, 1]] = x;
phi[[row, 2]] = y;
phi[[row, 3]] = z;
phi[[row, 4]] = x * y;
phi[[row, 5]] = y * z;
phi[[row, 6]] = x * z;
let dx_dlat = -slat * clon * chain_lat;
let dx_dlon = -clat * slon;
let dy_dlat = -slat * slon * chain_lat;
let dy_dlon = clat * clon;
let dz_dlat = clat * chain_lat;
jet[[row, 1, 0]] = dx_dlat;
jet[[row, 1, 1]] = dx_dlon;
jet[[row, 2, 0]] = dy_dlat;
jet[[row, 2, 1]] = dy_dlon;
jet[[row, 3, 0]] = dz_dlat;
jet[[row, 4, 0]] = dx_dlat * y + x * dy_dlat;
jet[[row, 4, 1]] = dx_dlon * y + x * dy_dlon;
jet[[row, 5, 0]] = dy_dlat * z + y * dz_dlat;
jet[[row, 5, 1]] = dy_dlon * z;
jet[[row, 6, 0]] = dx_dlat * z + x * dz_dlat;
jet[[row, 6, 1]] = dx_dlon * z;
}
Ok((phi, jet))
}
}
impl SaeBasisSecondJet for SphereChartEvaluator {
fn second_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array4<f64>, String> {
if coords.ncols() != 2 {
return Err(format!(
"SphereChartEvaluator::second_jet expects latent_dim == 2, got {}",
coords.ncols()
));
}
let n = coords.nrows();
let mut h = Array4::<f64>::zeros((n, 7, 2, 2));
for row in 0..n {
let raw_lat = coords[[row, 0]];
let lat = raw_lat.clamp(-std::f64::consts::FRAC_PI_2, std::f64::consts::FRAC_PI_2);
let lat_active =
raw_lat > -std::f64::consts::FRAC_PI_2 && raw_lat < std::f64::consts::FRAC_PI_2;
let a = if lat_active { 1.0 } else { 0.0 };
let lon = coords[[row, 1]];
let clat = lat.cos();
let slat = lat.sin();
let clon = lon.cos();
let slon = lon.sin();
let x = clat * clon;
let y = clat * slon;
let z = slat;
let dx = [-slat * clon * a, -clat * slon];
let dy = [-slat * slon * a, clat * clon];
let dz = [clat * a, 0.0];
let hx = [[-x * a, slat * slon * a], [slat * slon * a, -x]];
let hy = [[-y * a, -slat * clon * a], [-slat * clon * a, -y]];
let hz = [[-z * a, 0.0], [0.0, 0.0]];
for axis_a in 0..2 {
for axis_b in 0..2 {
h[[row, 1, axis_a, axis_b]] = hx[axis_a][axis_b];
h[[row, 2, axis_a, axis_b]] = hy[axis_a][axis_b];
h[[row, 3, axis_a, axis_b]] = hz[axis_a][axis_b];
}
}
let pair = |hf: [[f64; 2]; 2],
df: [f64; 2],
f: f64,
hg: [[f64; 2]; 2],
dg: [f64; 2],
g: f64|
-> [[f64; 2]; 2] {
let mut out = [[0.0; 2]; 2];
for axis_a in 0..2 {
for axis_b in 0..2 {
out[axis_a][axis_b] = hf[axis_a][axis_b] * g
+ df[axis_a] * dg[axis_b]
+ df[axis_b] * dg[axis_a]
+ f * hg[axis_a][axis_b];
}
}
out
};
let hxy = pair(hx, dx, x, hy, dy, y);
let hyz = pair(hy, dy, y, hz, dz, z);
let hxz = pair(hx, dx, x, hz, dz, z);
for axis_a in 0..2 {
for axis_b in 0..2 {
h[[row, 4, axis_a, axis_b]] = hxy[axis_a][axis_b];
h[[row, 5, axis_a, axis_b]] = hyz[axis_a][axis_b];
h[[row, 6, axis_a, axis_b]] = hxz[axis_a][axis_b];
}
}
}
Ok(h)
}
}
#[derive(Debug, Clone)]
pub struct TorusHarmonicEvaluator {
pub latent_dim: usize,
pub num_harmonics: usize,
}
impl TorusHarmonicEvaluator {
pub fn new(latent_dim: usize, num_harmonics: usize) -> Result<Self, String> {
if latent_dim == 0 {
return Err("TorusHarmonicEvaluator requires latent_dim >= 1".to_string());
}
if num_harmonics == 0 {
return Err("TorusHarmonicEvaluator requires num_harmonics >= 1".to_string());
}
Ok(Self {
latent_dim,
num_harmonics,
})
}
pub fn axis_basis_size(&self) -> usize {
2 * self.num_harmonics + 1
}
pub fn basis_size(&self) -> usize {
let axis_m = self.axis_basis_size();
let mut total: usize = 1;
for _ in 0..self.latent_dim {
total = total
.checked_mul(axis_m)
.expect("TorusHarmonicEvaluator: basis size overflowed usize");
}
total
}
}
impl SaeBasisEvaluator for TorusHarmonicEvaluator {
fn second_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array4<f64>, String>> {
Some(<Self as SaeBasisSecondJet>::second_jet(self, coords))
}
fn evaluate(&self, coords: ArrayView2<'_, f64>) -> Result<(Array2<f64>, Array3<f64>), String> {
let d = self.latent_dim;
if coords.ncols() != d {
return Err(format!(
"TorusHarmonicEvaluator: expected latent_dim {d}, got {}",
coords.ncols()
));
}
let n = coords.nrows();
let axis_m = self.axis_basis_size();
let m = self.basis_size();
let h_max = self.num_harmonics;
let two_pi = 2.0 * std::f64::consts::PI;
let mut phi = Array2::<f64>::zeros((n, m));
let mut jet = Array3::<f64>::zeros((n, m, d));
let mut phi_axis = vec![vec![0.0_f64; axis_m]; d];
let mut dphi_axis = vec![vec![0.0_f64; axis_m]; d];
for row in 0..n {
for axis in 0..d {
let t = coords[[row, axis]];
phi_axis[axis][0] = 1.0;
dphi_axis[axis][0] = 0.0;
for h in 1..=h_max {
let freq = two_pi * (h as f64);
let angle = freq * t;
let s = angle.sin();
let c = angle.cos();
let s_idx = 2 * h - 1;
let c_idx = 2 * h;
phi_axis[axis][s_idx] = s;
phi_axis[axis][c_idx] = c;
dphi_axis[axis][s_idx] = freq * c;
dphi_axis[axis][c_idx] = -freq * s;
}
}
let mut idx = vec![0usize; d];
for flat in 0..m {
let mut val = 1.0_f64;
for axis in 0..d {
val *= phi_axis[axis][idx[axis]];
}
phi[[row, flat]] = val;
for axis_target in 0..d {
let mut deriv = 1.0_f64;
for axis in 0..d {
deriv *= if axis == axis_target {
dphi_axis[axis][idx[axis]]
} else {
phi_axis[axis][idx[axis]]
};
}
jet[[row, flat, axis_target]] = deriv;
}
for axis in (0..d).rev() {
idx[axis] += 1;
if idx[axis] < axis_m {
break;
}
idx[axis] = 0;
}
}
}
Ok((phi, jet))
}
}
impl SaeBasisSecondJet for TorusHarmonicEvaluator {
fn second_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array4<f64>, String> {
let d = self.latent_dim;
if coords.ncols() != d {
return Err(format!(
"TorusHarmonicEvaluator::second_jet expects latent_dim == {d}, got {}",
coords.ncols()
));
}
let n = coords.nrows();
let axis_m = self.axis_basis_size();
let m = self.basis_size();
let h_max = self.num_harmonics;
let two_pi = 2.0 * std::f64::consts::PI;
let mut hess = Array4::<f64>::zeros((n, m, d, d));
let mut phi_axis = vec![vec![0.0_f64; axis_m]; d];
let mut dphi_axis = vec![vec![0.0_f64; axis_m]; d];
let mut d2phi_axis = vec![vec![0.0_f64; axis_m]; d];
for row in 0..n {
for axis in 0..d {
let t = coords[[row, axis]];
phi_axis[axis][0] = 1.0;
dphi_axis[axis][0] = 0.0;
d2phi_axis[axis][0] = 0.0;
for k in 1..=h_max {
let freq = two_pi * (k as f64);
let freq2 = freq * freq;
let angle = freq * t;
let s = angle.sin();
let c = angle.cos();
let s_idx = 2 * k - 1;
let c_idx = 2 * k;
phi_axis[axis][s_idx] = s;
phi_axis[axis][c_idx] = c;
dphi_axis[axis][s_idx] = freq * c;
dphi_axis[axis][c_idx] = -freq * s;
d2phi_axis[axis][s_idx] = -freq2 * s;
d2phi_axis[axis][c_idx] = -freq2 * c;
}
}
let mut idx = vec![0usize; d];
for flat in 0..m {
for axis_a in 0..d {
for axis_b in 0..d {
let mut prod = 1.0_f64;
for axis in 0..d {
let factor = if axis == axis_a && axis == axis_b {
d2phi_axis[axis][idx[axis]]
} else if axis == axis_a || axis == axis_b {
dphi_axis[axis][idx[axis]]
} else {
phi_axis[axis][idx[axis]]
};
prod *= factor;
}
hess[[row, flat, axis_a, axis_b]] = prod;
}
}
for axis in (0..d).rev() {
idx[axis] += 1;
if idx[axis] < axis_m {
break;
}
idx[axis] = 0;
}
}
}
Ok(hess)
}
}
#[derive(Debug, Clone)]
pub struct AffineCoordinateEvaluator {
pub latent_dim: usize,
}
impl AffineCoordinateEvaluator {
pub fn new(latent_dim: usize) -> Self {
Self { latent_dim }
}
}
impl SaeBasisEvaluator for AffineCoordinateEvaluator {
fn second_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array4<f64>, String>> {
Some(<Self as SaeBasisSecondJet>::second_jet(self, coords))
}
fn evaluate(&self, coords: ArrayView2<'_, f64>) -> Result<(Array2<f64>, Array3<f64>), String> {
if coords.ncols() != self.latent_dim {
return Err(format!(
"AffineCoordinateEvaluator: expected latent_dim {}, got {}",
self.latent_dim,
coords.ncols()
));
}
let n = coords.nrows();
let m = self.latent_dim + 1;
let mut phi = Array2::<f64>::zeros((n, m));
let mut jet = Array3::<f64>::zeros((n, m, self.latent_dim));
phi.column_mut(0).fill(1.0);
for row in 0..n {
for axis in 0..self.latent_dim {
phi[[row, axis + 1]] = coords[[row, axis]];
jet[[row, axis + 1, axis]] = 1.0;
}
}
Ok((phi, jet))
}
}
impl SaeBasisSecondJet for AffineCoordinateEvaluator {
fn second_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array4<f64>, String> {
if coords.ncols() != self.latent_dim {
return Err(format!(
"AffineCoordinateEvaluator::second_jet: expected latent_dim {}, got {}",
self.latent_dim,
coords.ncols()
));
}
let n = coords.nrows();
let m = self.latent_dim + 1;
let d = self.latent_dim;
Ok(Array4::<f64>::zeros((n, m, d, d)))
}
}
#[derive(Debug, Clone)]
pub struct DuchonCoordinateEvaluator {
pub centers: Array2<f64>,
pub order: crate::basis::DuchonNullspaceOrder,
}
impl DuchonCoordinateEvaluator {
pub fn new(centers: Array2<f64>, m: usize) -> Result<Self, String> {
if centers.ncols() == 0 {
return Err("DuchonCoordinateEvaluator: centers must have at least one column".into());
}
if m == 0 {
return Err("DuchonCoordinateEvaluator: Duchon m must be at least 1".into());
}
let order = match m {
1 => crate::basis::DuchonNullspaceOrder::Zero,
2 => crate::basis::DuchonNullspaceOrder::Linear,
other => crate::basis::DuchonNullspaceOrder::Degree(other - 1),
};
Ok(Self { centers, order })
}
}
impl SaeBasisEvaluator for DuchonCoordinateEvaluator {
fn second_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array4<f64>, String>> {
Some(<Self as SaeBasisSecondJet>::second_jet(self, coords))
}
fn evaluate(&self, coords: ArrayView2<'_, f64>) -> Result<(Array2<f64>, Array3<f64>), String> {
if coords.ncols() != self.centers.ncols() {
return Err(format!(
"DuchonCoordinateEvaluator: expected latent_dim {}, got {}",
self.centers.ncols(),
coords.ncols()
));
}
crate::basis::duchon_sae_atom_basis_with_jet(coords, self.centers.view(), self.order)
.map_err(|err| err.to_string())
}
}
impl SaeBasisSecondJet for DuchonCoordinateEvaluator {
fn second_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array4<f64>, String> {
if coords.ncols() != self.centers.ncols() {
return Err(format!(
"DuchonCoordinateEvaluator::second_jet: expected latent_dim {}, got {}",
self.centers.ncols(),
coords.ncols()
));
}
crate::basis::duchon_sae_atom_second_jet(coords, self.centers.view(), self.order)
.map_err(|err| err.to_string())
}
}
#[derive(Debug, Clone)]
pub struct EuclideanPatchEvaluator {
pub latent_dim: usize,
pub max_degree: usize,
}
impl EuclideanPatchEvaluator {
pub fn new(latent_dim: usize, max_degree: usize) -> Result<Self, String> {
if latent_dim == 0 {
return Err("EuclideanPatchEvaluator: latent_dim must be positive".into());
}
Ok(Self {
latent_dim,
max_degree,
})
}
fn order(&self) -> crate::basis::DuchonNullspaceOrder {
match self.max_degree {
0 => crate::basis::DuchonNullspaceOrder::Zero,
1 => crate::basis::DuchonNullspaceOrder::Linear,
k => crate::basis::DuchonNullspaceOrder::Degree(k),
}
}
}
impl SaeBasisEvaluator for EuclideanPatchEvaluator {
fn second_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array4<f64>, String>> {
Some(<Self as SaeBasisSecondJet>::second_jet(self, coords))
}
fn evaluate(&self, coords: ArrayView2<'_, f64>) -> Result<(Array2<f64>, Array3<f64>), String> {
if coords.ncols() != self.latent_dim {
return Err(format!(
"EuclideanPatchEvaluator: expected latent_dim {}, got {}",
self.latent_dim,
coords.ncols()
));
}
let exponents = crate::basis::monomial_exponents(self.latent_dim, self.max_degree);
let n = coords.nrows();
let m = exponents.len();
let mut phi = Array2::<f64>::zeros((n, m));
for (col, alpha) in exponents.iter().enumerate() {
for row in 0..n {
let mut value = 1.0_f64;
for (axis, &exp) in alpha.iter().enumerate() {
if exp != 0 {
value *= coords[[row, axis]].powi(exp as i32);
}
}
phi[[row, col]] = value;
}
}
let jet = crate::basis::duchon_polynomial_first_derivative_nd(coords, self.order());
if jet.shape() != [n, m, self.latent_dim] {
return Err(format!(
"EuclideanPatchEvaluator: monomial jet shape {:?} disagrees with ({n}, {m}, {})",
jet.shape(),
self.latent_dim
));
}
Ok((phi, jet))
}
}
impl SaeBasisSecondJet for EuclideanPatchEvaluator {
fn second_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array4<f64>, String> {
if coords.ncols() != self.latent_dim {
return Err(format!(
"EuclideanPatchEvaluator::second_jet: expected latent_dim {}, got {}",
self.latent_dim,
coords.ncols()
));
}
let exponents = crate::basis::monomial_exponents(self.latent_dim, self.max_degree);
let n = coords.nrows();
let m = exponents.len();
let d = self.latent_dim;
let mut hess = Array4::<f64>::zeros((n, m, d, d));
for (col, alpha) in exponents.iter().enumerate() {
for a in 0..d {
if alpha[a] == 0 {
continue;
}
for c in 0..d {
if a != c && alpha[c] == 0 {
continue;
}
let lead = if a == c {
(alpha[a] as f64) * (alpha[a].saturating_sub(1) as f64)
} else {
(alpha[a] as f64) * (alpha[c] as f64)
};
if lead == 0.0 {
continue;
}
for row in 0..n {
let mut value = lead;
for axis in 0..d {
let mut exp = alpha[axis];
if axis == a {
exp = exp.saturating_sub(1);
}
if axis == c {
exp = exp.saturating_sub(1);
}
if exp != 0 {
value *= coords[[row, axis]].powi(exp as i32);
}
}
hess[[row, col, a, c]] = value;
}
}
}
}
Ok(hess)
}
}
#[derive(Debug, Clone)]
pub struct SaeManifoldAtom {
pub name: String,
pub basis_kind: SaeAtomBasisKind,
pub latent_dim: usize,
pub basis_values: Array2<f64>,
pub basis_jacobian: Array3<f64>,
pub decoder_coefficients: Array2<f64>,
pub smooth_penalty: Array2<f64>,
pub basis_evaluator: Option<Arc<dyn SaeBasisEvaluator>>,
pub basis_second_jet: Option<Arc<dyn SaeBasisSecondJet>>,
}
impl SaeManifoldAtom {
#[must_use = "build error must be handled"]
pub fn new(
name: impl Into<String>,
basis_kind: SaeAtomBasisKind,
latent_dim: usize,
basis_values: Array2<f64>,
basis_jacobian: Array3<f64>,
decoder_coefficients: Array2<f64>,
smooth_penalty: Array2<f64>,
) -> Result<Self, String> {
let n = basis_values.nrows();
let m = basis_values.ncols();
let p = decoder_coefficients.ncols();
if basis_jacobian.dim() != (n, m, latent_dim) {
return Err(format!(
"SaeManifoldAtom::new: basis_jacobian must be ({n}, {m}, {latent_dim}); got {:?}",
basis_jacobian.dim()
));
}
if decoder_coefficients.nrows() != m {
return Err(format!(
"SaeManifoldAtom::new: decoder rows {} must equal basis size {m}",
decoder_coefficients.nrows()
));
}
if smooth_penalty.dim() != (m, m) {
return Err(format!(
"SaeManifoldAtom::new: smooth penalty must be ({m}, {m}); got {:?}",
smooth_penalty.dim()
));
}
if p == 0 {
return Err("SaeManifoldAtom::new: decoder output dimension must be positive".into());
}
Ok(Self {
name: name.into(),
basis_kind,
latent_dim,
basis_values,
basis_jacobian,
decoder_coefficients,
smooth_penalty,
basis_evaluator: None,
basis_second_jet: None,
})
}
pub fn with_basis_evaluator(mut self, evaluator: Arc<dyn SaeBasisEvaluator>) -> Self {
self.basis_evaluator = Some(evaluator);
self.basis_second_jet = None;
self
}
pub fn with_basis_second_jet(mut self, evaluator: Arc<dyn SaeBasisSecondJet>) -> Self {
let base: Arc<dyn SaeBasisEvaluator> = evaluator.clone();
self.basis_evaluator = Some(base);
self.basis_second_jet = Some(evaluator);
self
}
pub fn refresh_basis(&mut self, coords: ArrayView2<'_, f64>) -> Result<(), String> {
let evaluator = self.basis_evaluator.as_ref().ok_or_else(|| {
format!(
"SaeManifoldAtom {} has no basis evaluator; caller must rebuild the term after each coordinate step",
self.name
)
})?;
let (phi, jet) = evaluator.evaluate(coords)?;
if phi.dim() != self.basis_values.dim() {
return Err(format!(
"SaeManifoldAtom::refresh_basis: evaluator returned Phi {:?}, expected {:?}",
phi.dim(),
self.basis_values.dim()
));
}
if jet.dim() != self.basis_jacobian.dim() {
return Err(format!(
"SaeManifoldAtom::refresh_basis: evaluator returned jet {:?}, expected {:?}",
jet.dim(),
self.basis_jacobian.dim()
));
}
self.basis_values = phi;
self.basis_jacobian = jet;
Ok(())
}
pub fn n_obs(&self) -> usize {
self.basis_values.nrows()
}
pub fn basis_size(&self) -> usize {
self.basis_values.ncols()
}
pub fn output_dim(&self) -> usize {
self.decoder_coefficients.ncols()
}
pub fn decoded_row(&self, row: usize) -> Array1<f64> {
let p = self.output_dim();
let mut out = Array1::<f64>::zeros(p);
self.fill_decoded_row(row, out.as_slice_mut().expect("contiguous"));
out
}
pub fn fill_decoded_row(&self, row: usize, out: &mut [f64]) {
let p = self.output_dim();
let m = self.basis_size();
assert_eq!(out.len(), p);
for slot in out.iter_mut() {
*slot = 0.0;
}
for basis_col in 0..m {
let phi = self.basis_values[[row, basis_col]];
if phi == 0.0 {
continue;
}
for out_col in 0..p {
out[out_col] += phi * self.decoder_coefficients[[basis_col, out_col]];
}
}
}
pub fn decoded_derivative_row(&self, row: usize, latent_axis: usize) -> Array1<f64> {
let p = self.output_dim();
let mut out = Array1::<f64>::zeros(p);
self.fill_decoded_derivative_row(row, latent_axis, out.as_slice_mut().expect("contiguous"));
out
}
pub fn fill_decoded_derivative_row(&self, row: usize, latent_axis: usize, out: &mut [f64]) {
let p = self.output_dim();
let m = self.basis_size();
assert_eq!(out.len(), p);
for slot in out.iter_mut() {
*slot = 0.0;
}
for basis_col in 0..m {
let dphi = self.basis_jacobian[[row, basis_col, latent_axis]];
if dphi == 0.0 {
continue;
}
for out_col in 0..p {
out[out_col] += dphi * self.decoder_coefficients[[basis_col, out_col]];
}
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum AssignmentMode {
Softmax { temperature: f64, sparsity: f64 },
IBPMap {
temperature: f64,
alpha: f64,
learnable_alpha: bool,
},
JumpReLU { temperature: f64, threshold: f64 },
}
impl AssignmentMode {
#[must_use]
pub fn softmax(temperature: f64) -> Self {
Self::Softmax {
temperature,
sparsity: 1.0,
}
}
#[must_use]
pub fn ibp_map(temperature: f64, alpha: f64, learnable_alpha: bool) -> Self {
Self::IBPMap {
temperature,
alpha,
learnable_alpha,
}
}
#[must_use]
pub fn jumprelu(temperature: f64, threshold: f64) -> Self {
Self::JumpReLU {
temperature,
threshold,
}
}
pub fn temperature(&self) -> f64 {
match *self {
AssignmentMode::Softmax { temperature, .. }
| AssignmentMode::IBPMap { temperature, .. }
| AssignmentMode::JumpReLU { temperature, .. } => temperature,
}
}
fn set_temperature(&mut self, new_temperature: f64) -> Result<(), String> {
if !(new_temperature.is_finite() && new_temperature > 0.0) {
return Err(format!(
"AssignmentMode: temperature must be finite and positive; got {new_temperature}"
));
}
match self {
AssignmentMode::Softmax { temperature, .. }
| AssignmentMode::IBPMap { temperature, .. }
| AssignmentMode::JumpReLU { temperature, .. } => {
*temperature = new_temperature;
}
}
Ok(())
}
fn validate(&self) -> Result<(), String> {
let temperature = self.temperature();
if !(temperature.is_finite() && temperature > 0.0) {
return Err(format!(
"AssignmentMode: temperature must be finite and positive; got {temperature}"
));
}
match *self {
AssignmentMode::Softmax { sparsity, .. } => {
if !(sparsity.is_finite() && sparsity > 0.0) {
return Err(format!(
"AssignmentMode::Softmax: sparsity must be finite and positive; got {sparsity}"
));
}
}
AssignmentMode::IBPMap { alpha, .. } => {
if !(alpha.is_finite() && alpha > 0.0) {
return Err(format!(
"AssignmentMode::IBPMap: alpha must be finite and positive; got {alpha}"
));
}
}
AssignmentMode::JumpReLU { threshold, .. } => {
if !threshold.is_finite() {
return Err(format!(
"AssignmentMode::JumpReLU: threshold must be finite; got {threshold}"
));
}
}
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct SaeAssignment {
pub logits: Array2<f64>,
pub coords: Vec<LatentCoordValues>,
pub mode: AssignmentMode,
}
impl SaeAssignment {
#[must_use = "build error must be handled"]
pub fn new(
logits: Array2<f64>,
coords: Vec<LatentCoordValues>,
temperature: f64,
) -> Result<Self, String> {
Self::with_mode(logits, coords, AssignmentMode::softmax(temperature))
}
#[must_use = "build error must be handled"]
pub fn with_mode(
logits: Array2<f64>,
coords: Vec<LatentCoordValues>,
mode: AssignmentMode,
) -> Result<Self, String> {
mode.validate()?;
let n = logits.nrows();
let k = logits.ncols();
if coords.len() != k {
return Err(format!(
"SaeAssignment::new: coords length {} must equal K={k}",
coords.len()
));
}
for (atom, coord) in coords.iter().enumerate() {
if coord.n_obs() != n {
return Err(format!(
"SaeAssignment::new: coord atom {atom} has n_obs={} but logits has {n}",
coord.n_obs()
));
}
}
Ok(Self {
logits,
coords,
mode,
})
}
pub fn n_obs(&self) -> usize {
self.logits.nrows()
}
pub fn k_atoms(&self) -> usize {
self.logits.ncols()
}
pub fn total_coord_dim(&self) -> usize {
self.coords.iter().map(|c| c.latent_dim()).sum()
}
pub fn row_block_dim(&self) -> usize {
self.k_atoms() + self.total_coord_dim()
}
pub fn coord_offsets(&self) -> Vec<usize> {
let mut out = Vec::with_capacity(self.k_atoms());
let mut cursor = self.k_atoms();
for coord in &self.coords {
out.push(cursor);
cursor += coord.latent_dim();
}
out
}
pub fn assignments(&self) -> Array2<f64> {
let n = self.n_obs();
let k = self.k_atoms();
let mut out = Array2::<f64>::zeros((n, k));
for row in 0..n {
let a = self.assignments_row(row);
for atom in 0..k {
out[[row, atom]] = a[atom];
}
}
out
}
pub fn try_assignments(&self) -> Result<Array2<f64>, String> {
let n = self.n_obs();
let k = self.k_atoms();
let mut out = Array2::<f64>::zeros((n, k));
for row in 0..n {
let a = self.try_assignments_row(row)?;
for atom in 0..k {
out[[row, atom]] = a[atom];
}
}
Ok(out)
}
pub fn assignments_row(&self, row: usize) -> Array1<f64> {
self.try_assignments_row(row)
.expect("assignment logits must be finite")
}
pub fn try_assignments_row(&self, row: usize) -> Result<Array1<f64>, String> {
validate_finite_logits(self.logits.row(row), row)?;
if self.k_atoms() == 1 {
return Ok(Array1::from_vec(vec![1.0]));
}
match self.mode {
AssignmentMode::Softmax { temperature, .. } => {
Ok(softmax_row(self.logits.row(row), temperature))
}
AssignmentMode::IBPMap {
temperature, alpha, ..
} => Ok(ibp_map_row(self.logits.row(row), temperature, alpha)),
AssignmentMode::JumpReLU {
temperature,
threshold,
} => Ok(jumprelu_row(self.logits.row(row), temperature, threshold)),
}
}
pub fn flatten_ext_coords(&self) -> Array1<f64> {
let n = self.n_obs();
let q = self.row_block_dim();
let k = self.k_atoms();
let offsets = self.coord_offsets();
let mut out = Array1::<f64>::zeros(n * q);
for row in 0..n {
let base = row * q;
for atom in 0..k {
out[base + atom] = self.logits[[row, atom]];
}
for atom in 0..k {
let d = self.coords[atom].latent_dim();
let t_row = self.coords[atom].row(row);
for axis in 0..d {
out[base + offsets[atom] + axis] = t_row[axis];
}
}
}
out
}
#[must_use = "build error must be handled"]
pub fn from_blocks_with_no_gauge(
logits: Array2<f64>,
coord_blocks: Vec<Array2<f64>>,
temperature: f64,
) -> Result<Self, String> {
let coords = coord_blocks
.iter()
.map(|c| LatentCoordValues::from_matrix(c.view(), LatentIdMode::None))
.collect();
Self::new(logits, coords, temperature)
}
#[must_use = "build error must be handled"]
pub fn from_blocks_with_mode(
logits: Array2<f64>,
coord_blocks: Vec<Array2<f64>>,
mode: AssignmentMode,
) -> Result<Self, String> {
let coords = coord_blocks
.iter()
.map(|c| LatentCoordValues::from_matrix(c.view(), LatentIdMode::None))
.collect();
Self::with_mode(logits, coords, mode)
}
#[must_use = "build error must be handled"]
pub fn from_blocks_with_mode_and_manifolds(
logits: Array2<f64>,
coord_blocks: Vec<Array2<f64>>,
manifolds: Vec<LatentManifold>,
mode: AssignmentMode,
) -> Result<Self, String> {
if coord_blocks.len() != manifolds.len() {
return Err(format!(
"SaeAssignment::from_blocks_with_mode_and_manifolds: coord block length {} != manifold length {}",
coord_blocks.len(),
manifolds.len()
));
}
let coords = coord_blocks
.iter()
.zip(manifolds)
.map(|(c, manifold)| {
LatentCoordValues::from_matrix_with_manifold(c.view(), LatentIdMode::None, manifold)
})
.collect();
Self::with_mode(logits, coords, mode)
}
}
#[derive(Debug, Clone)]
pub struct SaeManifoldRho {
pub log_lambda_sparse: f64,
pub log_lambda_smooth: f64,
pub log_ard: Vec<Array1<f64>>,
}
impl SaeManifoldRho {
#[must_use]
pub fn new(log_lambda_sparse: f64, log_lambda_smooth: f64, log_ard: Vec<Array1<f64>>) -> Self {
Self {
log_lambda_sparse,
log_lambda_smooth,
log_ard,
}
}
pub fn lambda_sparse(&self) -> f64 {
self.log_lambda_sparse.exp()
}
pub fn lambda_smooth(&self) -> f64 {
self.log_lambda_smooth.exp()
}
}
pub trait SaeKroneckerRow {
fn apply_jbeta(&self, row: usize, x_beta: &[f64], u_out: &mut [f64]);
fn scatter_jbeta_t(&self, row: usize, u: &[f64], y_beta: &mut [f64]);
fn apply_l(&self, row: usize, u: &[f64], w_out: &mut [f64]);
fn apply_l_t(&self, row: usize, v: &[f64], u_out: &mut [f64]);
}
#[derive(Debug, Clone)]
pub struct SaeKroneckerRows {
p: usize,
a_phi: Vec<Vec<(usize, f64)>>,
local_jac: Vec<Vec<f64>>,
}
impl SaeKroneckerRows {
pub fn new(p: usize, a_phi: Vec<Vec<(usize, f64)>>, local_jac: Vec<Vec<f64>>) -> Self {
assert_eq!(
a_phi.len(),
local_jac.len(),
"SaeKroneckerRows: a_phi rows ({}) != local_jac rows ({})",
a_phi.len(),
local_jac.len(),
);
Self {
p,
a_phi,
local_jac,
}
}
}
impl SaeKroneckerRow for SaeKroneckerRows {
fn apply_jbeta(&self, row: usize, x_beta: &[f64], u_out: &mut [f64]) {
for val in u_out.iter_mut() {
*val = 0.0;
}
for &(beta_base, phi) in &self.a_phi[row] {
if phi == 0.0 {
continue;
}
for j in 0..self.p {
u_out[j] += phi * x_beta[beta_base + j];
}
}
}
fn scatter_jbeta_t(&self, row: usize, u: &[f64], y_beta: &mut [f64]) {
for &(beta_base, phi) in &self.a_phi[row] {
if phi == 0.0 {
continue;
}
for j in 0..self.p {
y_beta[beta_base + j] += phi * u[j];
}
}
}
fn apply_l(&self, row: usize, u: &[f64], w_out: &mut [f64]) {
let jac = &self.local_jac[row];
let q_i = jac.len() / self.p;
for c in 0..q_i {
let mut acc = 0.0_f64;
for j in 0..self.p {
acc += jac[c * self.p + j] * u[j];
}
w_out[c] = acc;
}
}
fn apply_l_t(&self, row: usize, v: &[f64], u_out: &mut [f64]) {
let jac = &self.local_jac[row];
let q_i = jac.len() / self.p;
for c in 0..q_i {
let vc = v[c];
if vc == 0.0 {
continue;
}
for j in 0..self.p {
u_out[j] += jac[c * self.p + j] * vc;
}
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct SaeManifoldLoss {
pub data_fit: f64,
pub assignment_sparsity: f64,
pub smoothness: f64,
pub ard: f64,
}
impl SaeManifoldLoss {
pub const fn total(&self) -> f64 {
self.data_fit + self.assignment_sparsity + self.smoothness + self.ard
}
pub const fn evidence_proxy(&self) -> f64 {
-self.total()
}
}
#[derive(Debug, Clone)]
pub struct SaeRowLayout {
pub active_atoms: Vec<Vec<usize>>,
pub coord_starts: Vec<Vec<usize>>,
pub coord_offsets_full: Vec<usize>,
pub coord_dims: Vec<usize>,
}
impl SaeRowLayout {
fn from_jumprelu(
n: usize,
k_atoms: usize,
threshold: f64,
logits: &Array2<f64>,
coord_dims: Vec<usize>,
coord_offsets_full: Vec<usize>,
) -> Self {
let mut per_row = Vec::with_capacity(n);
for row in 0..n {
let row_logits = logits.row(row);
let active: Vec<usize> = (0..k_atoms)
.filter(|&k| row_logits[k] > threshold)
.collect();
per_row.push(active);
}
Self::from_active_atoms(per_row, coord_dims, coord_offsets_full)
}
fn from_dense_weights(
assignments: &[Array1<f64>],
k_active_cap: usize,
cutoff: f64,
coord_dims: Vec<usize>,
coord_offsets_full: Vec<usize>,
) -> Self {
let cap = k_active_cap.max(1);
let mut per_row = Vec::with_capacity(assignments.len());
for a in assignments {
let k = a.len();
let mut idx: Vec<usize> = (0..k).collect();
idx.sort_by(|&i, &j| {
a[j].abs()
.partial_cmp(&a[i].abs())
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut active: Vec<usize> = idx
.iter()
.copied()
.take(cap)
.filter(|&k_idx| a[k_idx].abs() > cutoff)
.collect();
if active.is_empty() {
if let Some(&top) = idx.first() {
active.push(top);
}
}
active.sort_unstable();
per_row.push(active);
}
Self::from_active_atoms(per_row, coord_dims, coord_offsets_full)
}
fn from_active_atoms(
active_atoms: Vec<Vec<usize>>,
coord_dims: Vec<usize>,
coord_offsets_full: Vec<usize>,
) -> Self {
let mut coord_starts_all = Vec::with_capacity(active_atoms.len());
for active in &active_atoms {
let mut starts = Vec::with_capacity(active.len());
let mut cursor = active.len();
for &k in active {
starts.push(cursor);
cursor += coord_dims[k];
}
coord_starts_all.push(starts);
}
Self {
active_atoms,
coord_starts: coord_starts_all,
coord_offsets_full,
coord_dims,
}
}
pub fn row_q_active(&self, row: usize) -> usize {
let active = &self.active_atoms[row];
let coord_sum: usize = active.iter().map(|&k| self.coord_dims[k]).sum();
active.len() + coord_sum
}
pub fn expand_row(&self, row: usize, delta_t_row: &[f64], out: &mut [f64]) {
for v in out.iter_mut() {
*v = 0.0;
}
let active = &self.active_atoms[row];
let starts = &self.coord_starts[row];
for (j, &k) in active.iter().enumerate() {
out[k] = delta_t_row[j];
let d = self.coord_dims[k];
let full_off = self.coord_offsets_full[k];
for axis in 0..d {
out[full_off + axis] = delta_t_row[starts[j] + axis];
}
}
}
}
#[derive(Debug, Clone)]
pub struct SaeManifoldTerm {
pub atoms: Vec<SaeManifoldAtom>,
pub assignment: SaeAssignment,
temperature_schedule: Option<GumbelTemperatureSchedule>,
last_row_layout: Option<SaeRowLayout>,
}
#[derive(Debug)]
struct SaeManifoldMutableState {
atoms: Vec<(Array2<f64>, Array3<f64>, Array2<f64>)>,
logits: Array2<f64>,
coords: Vec<LatentCoordValues>,
}
impl SaeManifoldTerm {
#[must_use = "build error must be handled"]
pub fn new(atoms: Vec<SaeManifoldAtom>, assignment: SaeAssignment) -> Result<Self, String> {
if atoms.is_empty() {
return Err("SaeManifoldTerm::new: at least one atom required".into());
}
let n = atoms[0].n_obs();
let p = atoms[0].output_dim();
if assignment.n_obs() != n || assignment.k_atoms() != atoms.len() {
return Err(format!(
"SaeManifoldTerm::new: assignment shape ({}, {}) does not match atoms ({n}, {})",
assignment.n_obs(),
assignment.k_atoms(),
atoms.len()
));
}
for (k, atom) in atoms.iter().enumerate() {
if atom.n_obs() != n {
return Err(format!(
"SaeManifoldTerm::new: atom {k} has n_obs={} but atom 0 has {n}",
atom.n_obs()
));
}
if atom.output_dim() != p {
return Err(format!(
"SaeManifoldTerm::new: atom {k} output_dim={} but atom 0 has {p}",
atom.output_dim()
));
}
if atom.latent_dim != assignment.coords[k].latent_dim() {
return Err(format!(
"SaeManifoldTerm::new: atom {k} latent_dim={} but assignment coord has {}",
atom.latent_dim,
assignment.coords[k].latent_dim()
));
}
}
Ok(Self {
atoms,
assignment,
temperature_schedule: None,
last_row_layout: None,
})
}
pub fn set_temperature_schedule(
&mut self,
sched: GumbelTemperatureSchedule,
) -> Result<(), String> {
sched.validate()?;
self.assignment
.mode
.set_temperature(sched.current_tau(sched.iter_count))?;
self.temperature_schedule = Some(sched);
Ok(())
}
fn advance_temperature_schedule(&mut self) -> Result<Option<f64>, String> {
let Some(schedule) = self.temperature_schedule.as_mut() else {
return Ok(None);
};
schedule.validate()?;
let tau = schedule.step();
self.assignment.mode.set_temperature(tau)?;
Ok(Some(tau))
}
pub fn n_obs(&self) -> usize {
self.assignment.n_obs()
}
pub fn k_atoms(&self) -> usize {
self.atoms.len()
}
fn validate_analytic_penalty_registry(
&self,
registry: &AnalyticPenaltyRegistry,
) -> Result<(), String> {
let mut row_block_penalty_present = false;
for penalty in ®istry.penalties {
if penalty.tier() != PenaltyTier::Psi {
continue;
}
let is_logit = matches!(
penalty,
AnalyticPenaltyKind::IBPAssignment(_)
| AnalyticPenaltyKind::SoftmaxAssignmentSparsity(_)
);
if is_logit {
continue;
}
if !sae_penalty_is_row_block_supported(penalty) {
return Err(format!(
"SAE-manifold term refuses analytic penalty {:?}: this kind \
has cross-row structure and cannot be expressed in the \
arrow-Schur row layout. Use only row-block-supported \
coord penalties (ARD, BlockOrthogonality, \
Sparsity/TopK/JumpReLU, RowPrecisionPrior, \
ParametricRowPrecisionPrior, ScadMcp, Isometry) on the \
coord latent block, or move the penalty to a non-SAE \
term",
penalty.name()
));
}
row_block_penalty_present = true;
}
if row_block_penalty_present {
let mut dims = self.assignment.coords.iter().map(|c| c.latent_dim());
if let Some(first) = dims.next() {
if let Some(mismatch) = dims.find(|d| *d != first) {
return Err(format!(
"SAE-manifold term refuses row-block analytic penalty: \
atoms have heterogeneous coord latent dims (saw {first} \
and {mismatch}). Row-block penalties (ARD, \
BlockOrthogonality, ...) target the unified \"t\" \
latent block whose declared `d` matches one shape; \
per-atom dispatch with mixed `d_k` would silently \
truncate or expand axes. Configure all atoms with the \
same `atom_dim`, or split the row-block penalty into \
per-atom descriptors keyed to per-atom latent blocks"
));
}
}
}
Ok(())
}
pub fn output_dim(&self) -> usize {
self.atoms[0].output_dim()
}
pub fn beta_dim(&self) -> usize {
let p = self.output_dim();
self.atoms.iter().map(|a| a.basis_size() * p).sum()
}
pub fn beta_offsets(&self) -> Vec<usize> {
let p = self.output_dim();
let mut out = Vec::with_capacity(self.k_atoms());
let mut cursor = 0usize;
for atom in &self.atoms {
out.push(cursor);
cursor += atom.basis_size() * p;
}
out
}
pub fn beta_block_offsets(&self) -> Arc<[std::ops::Range<usize>]> {
let p = self.output_dim();
let mut ranges: Vec<std::ops::Range<usize>> = Vec::with_capacity(self.k_atoms());
let mut cursor = 0usize;
for atom in &self.atoms {
let width = atom.basis_size() * p;
ranges.push(cursor..cursor + width);
cursor += width;
}
Arc::from(ranges.into_boxed_slice())
}
fn sparse_active_plan(&self) -> Option<(usize, f64)> {
const BYTES_PER_F64: usize = 8;
const HOST_GRAM_BYTES: usize = 2 * 1024 * 1024 * 1024; const RELATIVE_CUTOFF: f64 = 1.0e-3;
let k_atoms = self.k_atoms();
if k_atoms <= 1 {
return None;
}
if !self.ext_coord_manifold().is_euclidean() {
return None;
}
let p = self.output_dim();
let m_total: usize = self.atoms.iter().map(|a| a.basis_size()).sum();
let dense_gram_bytes = m_total
.saturating_mul(m_total)
.saturating_mul(BYTES_PER_F64);
let budget = match crate::gpu::runtime::GpuRuntime::global() {
Some(rt) => rt.memory_budget_bytes / 4,
None => HOST_GRAM_BYTES,
};
if dense_gram_bytes <= budget {
return None;
}
let m_atom = (m_total as f64 / k_atoms as f64).max(1.0);
let max_active_basis = ((budget as f64 / BYTES_PER_F64 as f64).sqrt() / m_atom).floor();
let k_active_cap = (max_active_basis as usize).clamp(1, k_atoms);
if p == 0 {
return None;
}
Some((k_active_cap, RELATIVE_CUTOFF))
}
pub fn flatten_beta(&self) -> Array1<f64> {
let p = self.output_dim();
let offsets = self.beta_offsets();
let mut out = Array1::<f64>::zeros(self.beta_dim());
for (atom_idx, atom) in self.atoms.iter().enumerate() {
let m = atom.basis_size();
let off = offsets[atom_idx];
for basis_col in 0..m {
for out_col in 0..p {
out[off + basis_col * p + out_col] =
atom.decoder_coefficients[[basis_col, out_col]];
}
}
}
out
}
pub fn set_flat_beta(&mut self, beta: ArrayView1<'_, f64>) -> Result<(), String> {
if beta.len() != self.beta_dim() {
return Err(format!(
"set_flat_beta: beta length {} != expected {}",
beta.len(),
self.beta_dim()
));
}
let p = self.output_dim();
let offsets = self.beta_offsets();
for (atom_idx, atom) in self.atoms.iter_mut().enumerate() {
let m = atom.basis_size();
let off = offsets[atom_idx];
for basis_col in 0..m {
for out_col in 0..p {
atom.decoder_coefficients[[basis_col, out_col]] =
beta[off + basis_col * p + out_col];
}
}
}
Ok(())
}
pub fn fitted(&self) -> Array2<f64> {
self.try_fitted().expect("assignment logits must be finite")
}
pub fn try_fitted(&self) -> Result<Array2<f64>, String> {
let n = self.n_obs();
let p = self.output_dim();
let k_atoms = self.k_atoms();
let mut out = Array2::<f64>::zeros((n, p));
let mut g_buf = vec![0.0_f64; p];
for row in 0..n {
let a = self.assignment.try_assignments_row(row)?;
for atom_idx in 0..k_atoms {
self.atoms[atom_idx].fill_decoded_row(row, &mut g_buf);
let a_k = a[atom_idx];
let mut out_row = out.row_mut(row);
for out_col in 0..p {
out_row[out_col] += a_k * g_buf[out_col];
}
}
}
Ok(out)
}
pub fn loss(
&self,
target: ArrayView2<'_, f64>,
rho: &SaeManifoldRho,
) -> Result<SaeManifoldLoss, String> {
self.loss_scaled(target, rho, 1.0)
}
pub fn loss_scaled(
&self,
target: ArrayView2<'_, f64>,
rho: &SaeManifoldRho,
penalty_scale: f64,
) -> Result<SaeManifoldLoss, String> {
if !(penalty_scale.is_finite() && penalty_scale > 0.0) {
return Err(format!(
"SaeManifoldTerm::loss_scaled: penalty_scale must be finite and positive; got {penalty_scale}"
));
}
if target.dim() != (self.n_obs(), self.output_dim()) {
return Err(format!(
"SaeManifoldTerm::loss: Z must be ({}, {}); got {:?}",
self.n_obs(),
self.output_dim(),
target.dim()
));
}
let fitted = self.try_fitted()?;
let mut data_fit = 0.0_f64;
for row in 0..target.nrows() {
for out_col in 0..target.ncols() {
let r = target[[row, out_col]] - fitted[[row, out_col]];
data_fit += 0.5 * r * r;
}
}
let assignment_sparsity = assignment_prior_value(&self.assignment, rho);
let smoothness = penalty_scale * self.decoder_smoothness_value(rho.lambda_smooth());
let ard = self.ard_value(rho)?;
Ok(SaeManifoldLoss {
data_fit,
assignment_sparsity,
smoothness,
ard,
})
}
fn decoder_smoothness_value(&self, lambda_smooth: f64) -> f64 {
let mut acc = 0.0;
for atom in &self.atoms {
let sb = atom.smooth_penalty.dot(&atom.decoder_coefficients);
acc += 0.5 * lambda_smooth * (&atom.decoder_coefficients * &sb).sum();
}
acc
}
fn ard_value(&self, rho: &SaeManifoldRho) -> Result<f64, String> {
if rho.log_ard.len() != self.k_atoms() {
return Err(format!(
"ARD rho has {} atoms but term has {}",
rho.log_ard.len(),
self.k_atoms()
));
}
let n = self.n_obs();
let mut acc = 0.0;
for (atom_idx, coord) in self.assignment.coords.iter().enumerate() {
let d = coord.latent_dim();
if rho.log_ard[atom_idx].len() != d {
return Err(format!(
"ARD rho atom {atom_idx} has len {} but atom dim is {d}",
rho.log_ard[atom_idx].len()
));
}
for axis in 0..d {
let log_alpha = rho.log_ard[atom_idx][axis];
let alpha = log_alpha.exp();
let mut sq = 0.0;
for row in 0..n {
let v = coord.row(row)[axis];
sq += v * v;
}
acc += 0.5 * alpha * sq - 0.5 * (n as f64) * log_alpha;
}
}
Ok(acc)
}
pub fn assemble_arrow_schur(
&mut self,
target: ArrayView2<'_, f64>,
rho: &SaeManifoldRho,
analytic_penalties: Option<&AnalyticPenaltyRegistry>,
) -> Result<ArrowSchurSystem, String> {
self.assemble_arrow_schur_scaled(target, rho, analytic_penalties, 1.0)
}
pub fn assemble_arrow_schur_scaled(
&mut self,
target: ArrayView2<'_, f64>,
rho: &SaeManifoldRho,
analytic_penalties: Option<&AnalyticPenaltyRegistry>,
penalty_scale: f64,
) -> Result<ArrowSchurSystem, String> {
if !(penalty_scale.is_finite() && penalty_scale > 0.0) {
return Err(format!(
"SaeManifoldTerm::assemble_arrow_schur_scaled: penalty_scale must be finite and positive; got {penalty_scale}"
));
}
if target.dim() != (self.n_obs(), self.output_dim()) {
return Err(format!(
"SaeManifoldTerm::assemble_arrow_schur: Z must be ({}, {}); got {:?}",
self.n_obs(),
self.output_dim(),
target.dim()
));
}
if rho.log_ard.len() != self.k_atoms() {
return Err(format!(
"SaeManifoldTerm::assemble_arrow_schur: log_ard length {} != K {}",
rho.log_ard.len(),
self.k_atoms()
));
}
let n = self.n_obs();
let p = self.output_dim();
let k_atoms = self.k_atoms();
let q = self.assignment.row_block_dim();
let beta_dim = self.beta_dim();
let beta_offsets = self.beta_offsets();
let coord_offsets = self.assignment.coord_offsets();
let lambda_smooth = rho.lambda_smooth() * penalty_scale;
let (assignment_grad, assignment_hdiag) =
assignment_prior_grad_hdiag(&self.assignment, rho)?;
let mut smooth_ops: Vec<Arc<dyn BetaPenaltyOp>> = Vec::with_capacity(self.atoms.len());
let mut smooth_grad_gb = vec![0.0_f64; beta_dim];
for (atom_idx, atom) in self.atoms.iter().enumerate() {
let m = atom.basis_size();
let off = beta_offsets[atom_idx];
let mut scaled_s = Array2::<f64>::zeros((m, m));
for i in 0..m {
for j in 0..m {
let s_ij = 0.5 * (atom.smooth_penalty[[i, j]] + atom.smooth_penalty[[j, i]]);
scaled_s[[i, j]] = lambda_smooth * s_ij;
}
}
let sb = scaled_s.dot(&atom.decoder_coefficients);
for out_col in 0..p {
for i in 0..m {
let beta_i = off + i * p + out_col;
smooth_grad_gb[beta_i] += sb[[i, out_col]];
}
}
let identity_p = Array2::<f64>::eye(p);
smooth_ops.push(Arc::new(KroneckerPenaltyOp {
factor_a: scaled_s,
factor_b: identity_p,
global_offset: off,
k: beta_dim,
}));
}
let coord_dims: Vec<usize> = self
.assignment
.coords
.iter()
.map(|c| c.latent_dim())
.collect();
let row_layout: Option<SaeRowLayout> = match self.assignment.mode {
AssignmentMode::JumpReLU { threshold, .. } => Some(SaeRowLayout::from_jumprelu(
n,
k_atoms,
threshold,
&self.assignment.logits,
coord_dims.clone(),
self.assignment.coord_offsets(),
)),
AssignmentMode::Softmax { .. } | AssignmentMode::IBPMap { .. } => {
match self.sparse_active_plan() {
Some((k_active_cap, relative_cutoff)) => {
let mut assignments_all = Vec::with_capacity(n);
for row in 0..n {
assignments_all.push(self.assignment.try_assignments_row(row)?);
}
let peak = assignments_all
.iter()
.flat_map(|a| a.iter())
.fold(0.0_f64, |m, &v| m.max(v.abs()));
let cutoff = relative_cutoff * peak;
Some(SaeRowLayout::from_dense_weights(
&assignments_all,
k_active_cap,
cutoff,
coord_dims.clone(),
self.assignment.coord_offsets(),
))
}
None => None,
}
}
};
let mut sys = if let Some(ref layout) = row_layout {
let per_row_dims: Vec<usize> = (0..n).map(|row| layout.row_q_active(row)).collect();
ArrowSchurSystem::new_with_per_row_dims(per_row_dims, beta_dim)
} else {
ArrowSchurSystem::new(n, q, beta_dim)
};
for (i, g) in smooth_grad_gb.iter().enumerate() {
sys.gb[i] += g;
}
let mut decoded = Array2::<f64>::zeros((k_atoms, p));
let mut dg_buf = vec![0.0_f64; p];
let mut fitted = Array1::<f64>::zeros(p);
let mut error = Array1::<f64>::zeros(p);
let m_total: usize = self.atoms.iter().map(|a| a.basis_size()).sum();
let mu_offsets: Vec<usize> = beta_offsets.iter().map(|&off| off / p).collect();
let mut g_blocks: std::collections::BTreeMap<(usize, usize), Array2<f64>> =
std::collections::BTreeMap::new();
let ibp_prior_vec = match self.assignment.mode {
AssignmentMode::IBPMap { alpha, .. } => {
Some(ibp_stick_breaking_prior(k_atoms, alpha).to_vec())
}
_ => None,
};
let ibp_prior_slice = ibp_prior_vec.as_deref();
let mut decoded_scratch = vec![0.0_f64; p];
let mut kron_a_phi: Vec<Vec<(usize, f64)>> = Vec::with_capacity(n);
let mut kron_jac: Vec<Vec<f64>> = Vec::with_capacity(n);
let all_atoms_index: Vec<usize> = (0..k_atoms).collect();
for row in 0..n {
let assignments = self.assignment.try_assignments_row(row)?;
fitted.fill(0.0);
let row_active_owned: Option<&[usize]> =
row_layout.as_ref().map(|l| l.active_atoms[row].as_slice());
match row_active_owned {
Some(active) => {
for &atom_idx in active {
let a_k = assignments[atom_idx];
self.atoms[atom_idx].fill_decoded_row(row, &mut decoded_scratch);
for out_col in 0..p {
decoded[[atom_idx, out_col]] = decoded_scratch[out_col];
fitted[out_col] += a_k * decoded_scratch[out_col];
}
}
}
None => {
for atom_idx in 0..k_atoms {
let a_k = assignments[atom_idx];
self.atoms[atom_idx].fill_decoded_row(row, &mut decoded_scratch);
for out_col in 0..p {
decoded[[atom_idx, out_col]] = decoded_scratch[out_col];
fitted[out_col] += a_k * decoded_scratch[out_col];
}
}
}
}
for out_col in 0..p {
error[out_col] = fitted[out_col] - target[[row, out_col]];
}
let (q_row, local_jac_row) = if let Some(ref layout) = row_layout {
let active = &layout.active_atoms[row];
let starts = &layout.coord_starts[row];
let q_active = layout.row_q_active(row);
let mut jac_compact = Array2::<f64>::zeros((q_active, p));
let logits_row = self.assignment.logits.row(row);
for (j, &k) in active.iter().enumerate() {
fill_active_atom_logit_jvp(
self.assignment.mode,
k,
logits_row[k],
assignments[k],
decoded.row(k),
fitted.view(),
ibp_prior_slice,
&mut jac_compact,
j,
);
}
for (j, &k) in active.iter().enumerate() {
let d = self.atoms[k].latent_dim;
let a_k = assignments[k];
let coord_start = starts[j];
for axis in 0..d {
self.atoms[k].fill_decoded_derivative_row(row, axis, &mut dg_buf);
for out_col in 0..p {
jac_compact[[coord_start + axis, out_col]] = a_k * dg_buf[out_col];
}
}
}
(q_active, jac_compact)
} else {
let mut jac_row = Array2::<f64>::zeros((q, p));
fill_assignment_logit_jvp_rows(
self.assignment.mode,
self.assignment.logits.row(row),
assignments.view(),
decoded.view(),
fitted.view(),
ibp_prior_slice,
&mut jac_row,
);
for atom_idx in 0..k_atoms {
let d = self.atoms[atom_idx].latent_dim;
let off = coord_offsets[atom_idx];
let a_k = assignments[atom_idx];
for axis in 0..d {
self.atoms[atom_idx].fill_decoded_derivative_row(row, axis, &mut dg_buf);
for out_col in 0..p {
jac_row[[off + axis, out_col]] = a_k * dg_buf[out_col];
}
}
}
(q, jac_row)
};
let mut block = ArrowRowBlock::new(q_row, beta_dim);
for a in 0..q_row {
let mut g = 0.0;
for out_col in 0..p {
g += local_jac_row[[a, out_col]] * error[out_col];
}
block.gt[a] += g;
for b in 0..q_row {
let mut h = 0.0;
for out_col in 0..p {
h += local_jac_row[[a, out_col]] * local_jac_row[[b, out_col]];
}
block.htt[[a, b]] += h;
}
}
let assignment_base = row * k_atoms;
if let Some(ref layout) = row_layout {
let active = &layout.active_atoms[row];
for (j, &k) in active.iter().enumerate() {
block.gt[j] += assignment_grad[assignment_base + k];
block.htt[[j, j]] += assignment_hdiag[assignment_base + k];
}
} else {
for atom_idx in 0..k_atoms {
block.gt[atom_idx] += assignment_grad[assignment_base + atom_idx];
block.htt[[atom_idx, atom_idx]] += assignment_hdiag[assignment_base + atom_idx];
}
}
if let Some(ref layout) = row_layout {
let active = &layout.active_atoms[row];
let starts = &layout.coord_starts[row];
for (j, &k) in active.iter().enumerate() {
let coord = &self.assignment.coords[k];
let d = coord.latent_dim();
if rho.log_ard[k].len() != d {
return Err(format!(
"ARD rho atom {k} has len {} but atom dim is {d}",
rho.log_ard[k].len()
));
}
let row_t = coord.row(row);
for axis in 0..d {
let alpha = rho.log_ard[k][axis].exp();
block.gt[starts[j] + axis] += alpha * row_t[axis];
block.htt[[starts[j] + axis, starts[j] + axis]] += alpha;
}
}
} else {
for atom_idx in 0..k_atoms {
let coord = &self.assignment.coords[atom_idx];
let d = coord.latent_dim();
if rho.log_ard[atom_idx].len() != d {
return Err(format!(
"ARD rho atom {atom_idx} has len {} but atom dim is {d}",
rho.log_ard[atom_idx].len()
));
}
let off = coord_offsets[atom_idx];
let row_t = coord.row(row);
for axis in 0..d {
let alpha = rho.log_ard[atom_idx][axis].exp();
block.gt[off + axis] += alpha * row_t[axis];
block.htt[[off + axis, off + axis]] += alpha;
}
}
}
let row_active: &[usize] = match row_layout {
Some(ref layout) => layout.active_atoms[row].as_slice(),
None => &all_atoms_index,
};
let mut a_phi: Vec<(usize, f64)> = Vec::with_capacity(row_active.len() * 4);
let mut weighted_phi: Vec<(usize, Vec<f64>)> = Vec::with_capacity(row_active.len());
for &atom_idx in row_active {
let atom = &self.atoms[atom_idx];
let atom_beta_off = beta_offsets[atom_idx];
let m = atom.basis_size();
let a_k = assignments[atom_idx];
let mut wphi = Vec::with_capacity(m);
for basis_col in 0..m {
let phi = atom.basis_values[[row, basis_col]];
let w = a_k * phi;
a_phi.push((atom_beta_off + basis_col * p, w));
wphi.push(w);
}
weighted_phi.push((atom_idx, wphi));
}
for &(beta_base_i, j_beta_i) in a_phi.iter() {
if j_beta_i == 0.0 {
continue;
}
for out_col in 0..p {
sys.gb[beta_base_i + out_col] += j_beta_i * error[out_col];
}
}
for ai in 0..weighted_phi.len() {
let (atom_i, ref wphi_i) = weighted_phi[ai];
let m_i = wphi_i.len();
for aj in 0..weighted_phi.len() {
let (atom_j, ref wphi_j) = weighted_phi[aj];
let m_j = wphi_j.len();
let blk = g_blocks
.entry((atom_i, atom_j))
.or_insert_with(|| Array2::<f64>::zeros((m_i, m_j)));
for li in 0..m_i {
let wi = wphi_i[li];
if wi == 0.0 {
continue;
}
for lj in 0..m_j {
blk[[li, lj]] += wi * wphi_j[lj];
}
}
}
}
kron_a_phi.push(a_phi);
let mut jac_flat = vec![0.0_f64; q_row * p];
for c in 0..q_row {
for j in 0..p {
jac_flat[c * p + j] = local_jac_row[[c, j]];
}
}
kron_jac.push(jac_flat);
sys.rows[row] = block;
}
if row_layout.is_none() {
self.apply_sae_riemannian_geometry(&mut sys);
let manifold = self.ext_coord_manifold();
if !manifold.is_euclidean() {
let ext = self.ext_coord_matrix();
let mut t_buf = vec![0.0_f64; q];
let mut col_buf = Array1::<f64>::zeros(q);
for row_idx in 0..n {
let ext_row = ext.row(row_idx);
for (slot, &v) in t_buf.iter_mut().zip(ext_row.iter()) {
*slot = v;
}
let t_i = ArrayView1::from(t_buf.as_slice());
let jac_flat = &mut kron_jac[row_idx];
let q_row = jac_flat.len() / p;
for j in 0..p {
for c in 0..q_row {
col_buf[c] = jac_flat[c * p + j];
}
let projected_col =
manifold.project_to_tangent(t_i, col_buf.slice(ndarray::s![..q_row]));
for c in 0..q_row {
jac_flat[c * p + j] = projected_col[c];
}
}
}
}
}
{
let kron = Arc::new(SaeKroneckerRows::new(p, kron_a_phi, kron_jac));
let kron_t = Arc::clone(&kron);
sys.set_row_htbeta_operator(
move |row_idx, x, out| {
let out_slice = out.as_slice_mut().expect("out is always standard-layout");
if let Some(xs) = x.as_slice() {
kron.apply_jbeta(row_idx, xs, out_slice);
} else {
let x_vec: Vec<f64> = x.iter().copied().collect();
kron.apply_jbeta(row_idx, &x_vec, out_slice);
}
},
move |row_idx, v, out| {
let out_slice = out.as_slice_mut().expect("out is always standard-layout");
if let Some(vs) = v.as_slice() {
kron_t.scatter_jbeta_t(row_idx, vs, out_slice);
} else {
let v_vec: Vec<f64> = v.iter().copied().collect();
kron_t.scatter_jbeta_t(row_idx, &v_vec, out_slice);
}
},
);
}
let mut beta_penalty_written = false;
if let Some(registry) = analytic_penalties {
self.validate_analytic_penalty_registry(registry)
.map_err(|err| format!("SaeManifoldTerm::assemble_arrow_schur: {err}"))?;
beta_penalty_written = self
.add_sae_analytic_penalty_contributions(&mut sys, registry, penalty_scale)
.map_err(|err| format!("SaeManifoldTerm::assemble_arrow_schur: {err}"))?;
}
sys.set_block_offsets(self.beta_block_offsets());
{
let g_sparse_blocks: Vec<SparseGBlock> = g_blocks
.into_iter()
.filter_map(|((atom_i, atom_j), data)| {
if data.iter().all(|&v| v == 0.0) {
None
} else {
Some(SparseGBlock {
row_off: mu_offsets[atom_i],
col_off: mu_offsets[atom_j],
data,
})
}
})
.collect();
let mut ops: Vec<Arc<dyn BetaPenaltyOp>> = smooth_ops;
ops.push(Arc::new(SparseBlockKroneckerPenaltyOp {
p,
dim_a: m_total,
k: beta_dim,
blocks: g_sparse_blocks,
}));
if beta_penalty_written {
ops.push(Arc::new(DensePenaltyOp(sys.hbb.clone())));
}
sys.set_penalty_op(Arc::new(CompositePenaltyOp { k: beta_dim, ops }));
}
self.last_row_layout = row_layout;
Ok(sys)
}
fn ext_coord_matrix(&self) -> Array2<f64> {
let n = self.n_obs();
let q = self.assignment.row_block_dim();
let flat = self.assignment.flatten_ext_coords();
let mut out = Array2::<f64>::zeros((n, q));
for row in 0..n {
for col in 0..q {
out[[row, col]] = flat[row * q + col];
}
}
out
}
fn ext_coord_manifold(&self) -> LatentManifold {
let mut parts = Vec::with_capacity(self.assignment.row_block_dim());
for _ in 0..self.k_atoms() {
parts.push(LatentManifold::Euclidean);
}
let mut any_constrained = false;
for coord in &self.assignment.coords {
if coord.manifold().is_euclidean() {
for _ in 0..coord.latent_dim() {
parts.push(LatentManifold::Euclidean);
}
} else {
any_constrained = true;
parts.push(coord.manifold().clone());
}
}
if any_constrained {
LatentManifold::Product(parts)
} else {
LatentManifold::Euclidean
}
}
fn apply_sae_riemannian_geometry(&self, sys: &mut ArrowSchurSystem) {
let manifold = self.ext_coord_manifold();
if manifold.is_euclidean() {
return;
}
let ext = self.ext_coord_matrix();
let latent =
LatentCoordValues::from_matrix_with_manifold(ext.view(), LatentIdMode::None, manifold);
sys.apply_riemannian_latent_geometry(&latent);
}
pub fn update_ard_reml(&self, rho: &mut SaeManifoldRho) -> Result<(), String> {
if rho.log_ard.len() != self.k_atoms() {
return Err(format!(
"SaeManifoldTerm::update_ard_reml: log_ard length {} != K {}",
rho.log_ard.len(),
self.k_atoms()
));
}
let n = self.n_obs() as f64;
for (atom_idx, coord) in self.assignment.coords.iter().enumerate() {
let d = coord.latent_dim();
if rho.log_ard[atom_idx].len() != d {
return Err(format!(
"SaeManifoldTerm::update_ard_reml: atom {atom_idx} log_ard length {} != dim {d}",
rho.log_ard[atom_idx].len()
));
}
for axis in 0..d {
let mut sq = 0.0;
for row in 0..coord.n_obs() {
let v = coord.row(row)[axis];
sq += v * v;
}
if sq < 1.0e-10 {
log::warn!(
"[SAE-ARD] update_ard_reml: atom {atom_idx} axis {axis} coordinate \
variance ‖t‖²={sq:.3e} below 1e-10; preserving prior log_ard={} rather \
than letting α=n/‖t‖² saturate the clamp ceiling",
rho.log_ard[atom_idx][axis],
);
continue;
}
let alpha = n / sq;
rho.log_ard[atom_idx][axis] = alpha.ln().clamp(-8.0, 16.0);
}
}
Ok(())
}
fn add_sae_analytic_penalty_contributions(
&self,
sys: &mut ArrowSchurSystem,
registry: &AnalyticPenaltyRegistry,
penalty_scale: f64,
) -> Result<bool, ArrowSchurError> {
let rho_global = Array1::<f64>::zeros(registry.total_rho_count());
let layout = registry.rho_layout();
let logits_flat = flat_logits(self.assignment.logits.view());
let beta = self.flatten_beta();
let mut beta_penalty_written = false;
for (penalty, (rho_slice, tier, name)) in registry.penalties.iter().zip(layout.iter()) {
let rho_local = rho_global.slice(s![rho_slice.clone()]);
match tier {
PenaltyTier::Psi => {
if matches!(
penalty,
AnalyticPenaltyKind::IBPAssignment(_)
| AnalyticPenaltyKind::SoftmaxAssignmentSparsity(_)
) {
self.add_sae_logit_penalty(sys, penalty, logits_flat.view(), rho_local);
} else {
assert!(
sae_penalty_is_row_block_supported(penalty),
"validate_analytic_penalty_registry should have \
refused non-row-block Psi-tier penalty {:?} \
(registry layout name {name:?})",
penalty.name()
);
let offsets = self.assignment.coord_offsets();
for atom_idx in 0..self.k_atoms() {
let off = offsets[atom_idx];
let coord = &self.assignment.coords[atom_idx];
if let AnalyticPenaltyKind::Isometry(iso) = penalty {
let atom = &self.atoms[atom_idx];
let p = atom.decoder_coefficients.ncols();
let mut corrected: IsometryPenalty = (**iso).clone();
corrected.p_out = p;
let coords_mat = coord.as_matrix();
let second_jet_installed = refresh_isometry_caches_from_atom(
&corrected,
atom,
coords_mat.view(),
)
.map_err(|reason| ArrowSchurError::SchurFactorFailed { reason })?;
if !second_jet_installed {
match atom
.basis_evaluator
.as_ref()
.and_then(|e| e.second_jet_dyn(coords_mat.view()))
{
Some(Ok(hess)) => {
let n_obs = coords_mat.nrows();
let d = atom.latent_dim;
let m = atom.basis_size();
if hess.dim() != (n_obs, m, d, d) {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"SAE Isometry atom '{}': second_jet_dyn \
returned shape {:?}, expected \
({n_obs}, {m}, {d}, {d})",
atom.name,
hess.dim()
),
});
}
let b = &atom.decoder_coefficients;
let mut jac2 = Array2::<f64>::zeros((n_obs, p * d * d));
for n in 0..n_obs {
for i in 0..p {
for a in 0..d {
for c in 0..d {
let mut acc = 0.0;
for mm in 0..m {
acc += hess[[n, mm, a, c]]
* b[[mm, i]];
}
jac2[[n, (i * d + a) * d + c]] = acc;
}
}
}
}
corrected
.set_jacobian_second_cache(Some(Arc::new(jac2)));
}
Some(Err(reason)) => {
return Err(ArrowSchurError::SchurFactorFailed {
reason,
});
}
None => {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"IsometryPenalty requested for SAE atom \
'{}' (basis kind {:?}) but this evaluator \
does not expose an analytic second jet; \
use AffineCoordinateEvaluator, \
SphereChartEvaluator, \
PeriodicHarmonicEvaluator, or \
TorusHarmonicEvaluator for SAE-Isometry",
atom.name, atom.basis_kind
),
});
}
}
}
let corrected_kind =
AnalyticPenaltyKind::Isometry(Arc::new(corrected));
self.add_sae_coord_penalty(
sys,
off,
coord,
&corrected_kind,
rho_local,
);
} else {
self.add_sae_coord_penalty(sys, off, coord, penalty, rho_local);
}
}
}
}
PenaltyTier::Beta => {
self.add_sae_beta_penalty(sys, penalty, beta.view(), rho_local, penalty_scale);
beta_penalty_written = true;
}
PenaltyTier::Rho => {}
}
}
Ok(beta_penalty_written)
}
fn add_sae_logit_penalty(
&self,
sys: &mut ArrowSchurSystem,
penalty: &AnalyticPenaltyKind,
target: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
) {
let n = self.n_obs();
let k = self.k_atoms();
let grad = penalty.grad_target(target, rho_local);
for row in 0..n {
for atom in 0..k {
sys.rows[row].gt[atom] += grad[row * k + atom];
}
}
if let Some(diag) = penalty.hessian_diag(target, rho_local) {
for row in 0..n {
for atom in 0..k {
sys.rows[row].htt[[atom, atom]] += diag[row * k + atom];
}
}
}
}
fn add_sae_coord_penalty(
&self,
sys: &mut ArrowSchurSystem,
off: usize,
coord: &LatentCoordValues,
penalty: &AnalyticPenaltyKind,
rho_local: ArrayView1<'_, f64>,
) {
let n = coord.n_obs();
let d = coord.latent_dim();
let target = coord.as_flat().view();
let grad = penalty.grad_target(target, rho_local);
for row in 0..n {
for axis in 0..d {
sys.rows[row].gt[off + axis] += grad[row * d + axis];
}
}
if let Some(diag) = penalty.hessian_diag(target, rho_local) {
for row in 0..n {
for axis in 0..d {
sys.rows[row].htt[[off + axis, off + axis]] += diag[row * d + axis];
}
}
return;
}
let mut probe = Array1::<f64>::zeros(n * d);
for axis in 0..d {
probe.fill(0.0);
for row in 0..n {
probe[row * d + axis] = 1.0;
}
let hv = penalty.hvp(target, rho_local, probe.view());
for row in 0..n {
for b in 0..d {
sys.rows[row].htt[[off + b, off + axis]] += hv[row * d + b];
}
}
}
}
fn add_sae_beta_penalty(
&self,
sys: &mut ArrowSchurSystem,
penalty: &AnalyticPenaltyKind,
target_beta: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
penalty_scale: f64,
) {
if let AnalyticPenaltyKind::MechanismSparsity(base) = penalty {
let beta_offsets = self.beta_offsets();
let p = self.output_dim();
for (atom_idx, atom) in self.atoms.iter().enumerate() {
let m = atom.basis_size();
let start = beta_offsets[atom_idx];
let end = start + m * p;
let mut per_atom: MechanismSparsityPenalty = (**base).clone();
per_atom.target = PsiSlice {
range: start..end,
latent_dim: Some(m),
};
self.add_sae_mech_sparsity_atom(
sys,
&per_atom,
target_beta,
rho_local,
start,
end,
penalty_scale,
);
}
return;
}
let k = self.beta_dim();
let grad = penalty.grad_target(target_beta, rho_local);
for j in 0..k {
sys.gb[j] += penalty_scale * grad[j];
}
if let Some(diag) = penalty.hessian_diag(target_beta, rho_local) {
for j in 0..k {
sys.hbb[[j, j]] += penalty_scale * diag[j];
}
return;
}
let mut probe = Array1::<f64>::zeros(k);
for j in 0..k {
probe.fill(0.0);
probe[j] = 1.0;
let hv = penalty.hvp(target_beta, rho_local, probe.view());
for i in 0..k {
sys.hbb[[i, j]] += penalty_scale * hv[i];
}
}
}
fn add_sae_mech_sparsity_atom(
&self,
sys: &mut ArrowSchurSystem,
per_atom: &MechanismSparsityPenalty,
target_beta: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
start: usize,
end: usize,
penalty_scale: f64,
) {
let grad = per_atom.grad_target(target_beta, rho_local);
for j in start..end {
sys.gb[j] += penalty_scale * grad[j];
}
let k = self.beta_dim();
let mut probe = Array1::<f64>::zeros(k);
for j in start..end {
probe.fill(0.0);
probe[j] = 1.0;
let hv = per_atom.hvp(target_beta, rho_local, probe.view());
for i in start..end {
sys.hbb[[i, j]] += penalty_scale * hv[i];
}
}
}
pub fn solve_newton_step(
&mut self,
target: ArrayView2<'_, f64>,
rho: &SaeManifoldRho,
analytic_penalties: Option<&AnalyticPenaltyRegistry>,
ridge_ext_coord: f64,
ridge_beta: f64,
) -> Result<(Array1<f64>, Array1<f64>), ArrowSchurError> {
let sys = self
.assemble_arrow_schur(target, rho, analytic_penalties)
.map_err(|reason| ArrowSchurError::SchurFactorFailed { reason })?;
sys.solve_with_lm_escalation(ridge_ext_coord, ridge_beta)
.map(|(delta_t, delta_beta, _diag)| (delta_t, delta_beta))
}
pub fn apply_newton_step(
&mut self,
delta_ext_coord: ArrayView1<'_, f64>,
delta_beta: ArrayView1<'_, f64>,
step_size: f64,
) -> Result<(), String> {
self.apply_newton_step_impl(delta_ext_coord, delta_beta, step_size, true)
}
pub fn apply_newton_step_external_basis_refresh(
&mut self,
delta_ext_coord: ArrayView1<'_, f64>,
delta_beta: ArrayView1<'_, f64>,
step_size: f64,
) -> Result<(), String> {
self.apply_newton_step_impl(delta_ext_coord, delta_beta, step_size, false)
}
fn snapshot_mutable_state(&self) -> SaeManifoldMutableState {
let atoms = self
.atoms
.iter()
.map(|atom| {
(
atom.basis_values.clone(),
atom.basis_jacobian.clone(),
atom.decoder_coefficients.clone(),
)
})
.collect();
SaeManifoldMutableState {
atoms,
logits: self.assignment.logits.clone(),
coords: self.assignment.coords.clone(),
}
}
fn restore_mutable_state(&mut self, snapshot: &SaeManifoldMutableState) {
for (atom, (basis_values, basis_jacobian, decoder)) in
self.atoms.iter_mut().zip(snapshot.atoms.iter())
{
atom.basis_values.assign(basis_values);
atom.basis_jacobian.assign(basis_jacobian);
atom.decoder_coefficients.assign(decoder);
}
self.assignment.logits.assign(&snapshot.logits);
self.assignment.coords.clone_from(&snapshot.coords);
}
fn apply_newton_step_impl(
&mut self,
delta_ext_coord: ArrayView1<'_, f64>,
delta_beta: ArrayView1<'_, f64>,
step_size: f64,
refresh_basis: bool,
) -> Result<(), String> {
if !(step_size.is_finite() && step_size > 0.0) {
return Err(format!(
"SaeManifoldTerm::apply_newton_step: step_size must be finite and positive; got {step_size}"
));
}
let n = self.n_obs();
let q = self.assignment.row_block_dim();
let k_atoms = self.k_atoms();
if delta_beta.len() != self.beta_dim() {
return Err(format!(
"SaeManifoldTerm::apply_newton_step: delta_beta length {} != expected {}",
delta_beta.len(),
self.beta_dim()
));
}
if let Some(ref layout) = self.last_row_layout.clone() {
let total_len: usize = (0..n).map(|row| layout.row_q_active(row)).sum();
if delta_ext_coord.len() != total_len {
return Err(format!(
"SaeManifoldTerm::apply_newton_step: compact delta_ext_coord length {} != expected {}",
delta_ext_coord.len(),
total_len
));
}
let mut full_delta = vec![0.0_f64; n * q];
let mut compact_off = 0usize;
for row in 0..n {
let q_active = layout.row_q_active(row);
let compact_row: Vec<f64> = delta_ext_coord
.slice(ndarray::s![compact_off..compact_off + q_active])
.iter()
.copied()
.collect();
layout.expand_row(row, &compact_row, &mut full_delta[row * q..(row + 1) * q]);
compact_off += q_active;
}
for row in 0..n {
let row_base = row * q;
for atom_idx in 0..k_atoms {
self.assignment.logits[[row, atom_idx]] +=
step_size * full_delta[row_base + atom_idx];
}
}
let coord_offsets = self.assignment.coord_offsets();
for atom_idx in 0..k_atoms {
let d = self.assignment.coords[atom_idx].latent_dim();
let mut delta_coord = Array1::<f64>::zeros(n * d);
for row in 0..n {
let row_base = row * q + coord_offsets[atom_idx];
for axis in 0..d {
delta_coord[row * d + axis] = full_delta[row_base + axis];
}
}
self.assignment.coords[atom_idx].retract_flat_delta(delta_coord.view());
if refresh_basis {
let coords = self.assignment.coords[atom_idx].as_matrix();
self.atoms[atom_idx].refresh_basis(coords.view())?;
}
}
} else {
if delta_ext_coord.len() != n * q {
return Err(format!(
"SaeManifoldTerm::apply_newton_step: delta_ext_coord length {} != expected {}",
delta_ext_coord.len(),
n * q
));
}
let coord_offsets = self.assignment.coord_offsets();
for row in 0..n {
let row_base = row * q;
for atom_idx in 0..k_atoms {
self.assignment.logits[[row, atom_idx]] +=
step_size * delta_ext_coord[row_base + atom_idx];
}
}
for atom_idx in 0..k_atoms {
let d = self.assignment.coords[atom_idx].latent_dim();
let mut delta_coord = Array1::<f64>::zeros(n * d);
for row in 0..n {
let row_base = row * q + coord_offsets[atom_idx];
for axis in 0..d {
delta_coord[row * d + axis] = step_size * delta_ext_coord[row_base + axis];
}
}
self.assignment.coords[atom_idx].retract_flat_delta(delta_coord.view());
if refresh_basis {
let coords = self.assignment.coords[atom_idx].as_matrix();
self.atoms[atom_idx].refresh_basis(coords.view())?;
}
}
}
let mut beta = self.flatten_beta();
for idx in 0..beta.len() {
beta[idx] += step_size * delta_beta[idx];
}
self.set_flat_beta(beta.view())
}
pub fn run_joint_fit_arrow_schur(
&mut self,
target: ArrayView2<'_, f64>,
rho: &mut SaeManifoldRho,
analytic_penalties: Option<&AnalyticPenaltyRegistry>,
max_iter: usize,
step_size: f64,
ridge_ext_coord: f64,
ridge_beta: f64,
) -> Result<SaeManifoldLoss, String> {
if !(step_size.is_finite() && step_size > 0.0) {
return Err(format!(
"SaeManifoldTerm::run_joint_fit_arrow_schur: step_size must be finite and positive; got {step_size}"
));
}
{
let mut grams = self.empty_decoder_gram_accumulator();
self.accumulate_decoder_gram(&mut grams);
self.finalize_decoder_identifiability_audit(&grams, self.n_obs())?;
}
for _ in 0..max_iter {
self.advance_temperature_schedule()?;
self.update_ard_reml(rho)?;
let pre_step_loss = self.loss(target, rho)?;
let pre_step_total = pre_step_loss.total();
let sys = self
.assemble_arrow_schur(target, rho, analytic_penalties)
.map_err(|err| format!("SaeManifoldTerm::run_joint_fit_arrow_schur: {err}"))?;
let (delta_ext_coord, delta_beta, _diag) = sys
.solve_with_lm_escalation(ridge_ext_coord, ridge_beta)
.map_err(|err| format!("SaeManifoldTerm::run_joint_fit_arrow_schur: {err}"))?;
let directional_decrease = sae_manifold_newton_directional_decrease(
&sys,
delta_ext_coord.view(),
delta_beta.view(),
);
let mut grad_norm_sq = 0.0;
for (row_idx, row) in sys.rows.iter().enumerate() {
let di = sys.row_dims[row_idx];
for axis in 0..di {
grad_norm_sq += row.gt[axis] * row.gt[axis];
}
}
for idx in 0..sys.k {
grad_norm_sq += sys.gb[idx] * sys.gb[idx];
}
let mut step_norm_sq = 0.0;
for &v in delta_ext_coord.iter() {
step_norm_sq += v * v;
}
for &v in delta_beta.iter() {
step_norm_sq += v * v;
}
let directional_decrease_floor = 1.0e-14 * grad_norm_sq.sqrt() * step_norm_sq.sqrt();
let snapshot = self.snapshot_mutable_state();
if !(pre_step_total.is_finite()
&& directional_decrease.is_finite()
&& directional_decrease > 0.0
&& directional_decrease > directional_decrease_floor)
{
self.restore_mutable_state(&snapshot);
break;
}
let mut trial_step_size = step_size;
let mut accepted = false;
for halving in 0..=SAE_MANIFOLD_MAX_LINESEARCH_HALVINGS {
if halving > 0 {
self.restore_mutable_state(&snapshot);
}
let trial_result = self
.apply_newton_step(delta_ext_coord.view(), delta_beta.view(), trial_step_size)
.and_then(|()| self.loss(target, rho));
if let Ok(post_step_loss) = trial_result {
let post_step_total = post_step_loss.total();
let armijo_bound = pre_step_total
- SAE_MANIFOLD_ARMIJO_C1 * trial_step_size * directional_decrease;
if post_step_total.is_finite() && post_step_total <= armijo_bound {
accepted = true;
break;
}
}
trial_step_size *= 0.5;
}
if !accepted {
self.restore_mutable_state(&snapshot);
break;
}
}
self.update_ard_reml(rho)?;
self.loss(target, rho)
}
fn empty_decoder_gram_accumulator(&self) -> Vec<Array2<f64>> {
self.atoms
.iter()
.map(|atom| {
let m = atom.basis_size();
Array2::<f64>::zeros((m, m))
})
.collect()
}
fn accumulate_decoder_gram(&self, grams: &mut [Array2<f64>]) {
let n = self.n_obs();
let assignments = self.assignment.assignments();
for (atom_idx, atom) in self.atoms.iter().enumerate() {
let m = atom.basis_size();
if m == 0 {
continue;
}
let assign_col = assignments.column(atom_idx);
let gram = &mut grams[atom_idx];
let mut weighted = vec![0.0_f64; m];
for row in 0..n {
let a_k = assign_col[row];
if a_k == 0.0 {
continue;
}
for col in 0..m {
weighted[col] = a_k * atom.basis_values[[row, col]];
}
for i in 0..m {
let wi = weighted[i];
if wi == 0.0 {
continue;
}
for j in 0..m {
gram[[i, j]] += wi * weighted[j];
}
}
}
}
}
fn finalize_decoder_identifiability_audit(
&self,
grams: &[Array2<f64>],
n_total: usize,
) -> Result<(), String> {
for (atom_idx, atom) in self.atoms.iter().enumerate() {
let m = atom.basis_size();
if m == 0 {
continue;
}
let rank =
crate::solver::identifiability_audit::rank_of_gram(&grams[atom_idx], n_total)
.map_err(|e| {
format!(
"SaeManifoldTerm: pre-fit decoder audit (atom '{}'): \
Gram eigendecomposition failed: {e}",
atom.name,
)
})?;
if rank < m {
let dropped = m - rank;
if rank == 0 {
return Err(format!(
"SaeManifoldTerm: pre-fit identifiability audit: decoder atom '{}' has \
rank-0 weighted design (n={n_total}, M_k={m}); all assignment weights \
vanish or the basis is degenerate, so the Arrow-Schur Newton system for \
this block is singular",
atom.name,
));
}
log::info!(
"[SAE-AUDIT] decoder atom '{}' weighted design is rank-deficient \
(rank={rank}/{m}, {dropped} weakly-identified column(s), n={n_total}); the \
Arrow-Schur ridge will regularise the deficient directions",
atom.name,
);
}
}
Ok(())
}
pub fn materialize_chunk(
&self,
chunk_logits: Array2<f64>,
chunk_coords: Vec<Array2<f64>>,
) -> Result<SaeManifoldTerm, String> {
let k_atoms = self.k_atoms();
if chunk_logits.ncols() != k_atoms {
return Err(format!(
"SaeManifoldTerm::materialize_chunk: chunk_logits has {} cols but K={k_atoms}",
chunk_logits.ncols()
));
}
if chunk_coords.len() != k_atoms {
return Err(format!(
"SaeManifoldTerm::materialize_chunk: chunk_coords has {} atoms but K={k_atoms}",
chunk_coords.len()
));
}
let n_chunk = chunk_logits.nrows();
let mut atoms = Vec::with_capacity(k_atoms);
for (atom_idx, atom) in self.atoms.iter().enumerate() {
let coords = &chunk_coords[atom_idx];
if coords.nrows() != n_chunk || coords.ncols() != atom.latent_dim {
return Err(format!(
"SaeManifoldTerm::materialize_chunk: atom {atom_idx} coords shape {:?} != ({n_chunk}, {})",
coords.dim(),
atom.latent_dim
));
}
let evaluator = atom.basis_evaluator.as_ref().ok_or_else(|| {
format!(
"SaeManifoldTerm::materialize_chunk: atom '{}' has no basis evaluator; a \
streaming fit must re-evaluate Φ(t) at each chunk's coordinates",
atom.name
)
})?;
let (phi, jet) = evaluator.evaluate(coords.view())?;
let m = atom.basis_size();
if phi.dim() != (n_chunk, m) {
return Err(format!(
"SaeManifoldTerm::materialize_chunk: atom '{}' evaluator returned Φ {:?}, expected ({n_chunk}, {m})",
atom.name,
phi.dim()
));
}
if jet.dim() != (n_chunk, m, atom.latent_dim) {
return Err(format!(
"SaeManifoldTerm::materialize_chunk: atom '{}' evaluator returned jet {:?}, expected ({n_chunk}, {m}, {})",
atom.name,
jet.dim(),
atom.latent_dim
));
}
let mut chunk_atom = SaeManifoldAtom::new(
atom.name.clone(),
atom.basis_kind.clone(),
atom.latent_dim,
phi,
jet,
atom.decoder_coefficients.clone(),
atom.smooth_penalty.clone(),
)?;
chunk_atom.basis_evaluator = atom.basis_evaluator.clone();
chunk_atom.basis_second_jet = atom.basis_second_jet.clone();
atoms.push(chunk_atom);
}
let coord_values: Vec<LatentCoordValues> = chunk_coords
.iter()
.zip(self.assignment.coords.iter())
.map(|(c, src)| {
LatentCoordValues::from_matrix_with_manifold(
c.view(),
LatentIdMode::None,
src.manifold().clone(),
)
})
.collect();
let assignment =
SaeAssignment::with_mode(chunk_logits, coord_values, self.assignment.mode)?;
let mut term = SaeManifoldTerm::new(atoms, assignment)?;
term.temperature_schedule = self.temperature_schedule.clone();
Ok(term)
}
pub fn run_joint_fit_arrow_schur_streaming<F>(
&mut self,
n_total: usize,
chunk_size: usize,
rho: &mut SaeManifoldRho,
analytic_penalties: Option<&AnalyticPenaltyRegistry>,
max_iter: usize,
step_size: f64,
ridge_ext_coord: f64,
ridge_beta: f64,
mut chunk_init: F,
) -> Result<SaeManifoldLoss, String>
where
F: FnMut(usize, usize) -> Result<(Array2<f64>, Vec<Array2<f64>>, Array2<f64>), String>,
{
if !(step_size.is_finite() && step_size > 0.0) {
return Err(format!(
"SaeManifoldTerm::run_joint_fit_arrow_schur_streaming: step_size must be finite and positive; got {step_size}"
));
}
if chunk_size == 0 {
return Err(
"SaeManifoldTerm::run_joint_fit_arrow_schur_streaming: chunk_size must be positive"
.to_string(),
);
}
if n_total == 0 {
return Err(
"SaeManifoldTerm::run_joint_fit_arrow_schur_streaming: n_total must be positive"
.to_string(),
);
}
let beta_dim = self.beta_dim();
{
let mut grams = self.empty_decoder_gram_accumulator();
let mut start = 0usize;
while start < n_total {
let end = (start + chunk_size).min(n_total);
let (logits, coords, _z_chunk) = chunk_init(start, end)?;
let chunk = self.materialize_chunk(logits, coords)?;
chunk.accumulate_decoder_gram(&mut grams);
start = end;
}
self.finalize_decoder_identifiability_audit(&grams, n_total)?;
}
let mut last_loss = SaeManifoldLoss {
data_fit: 0.0,
assignment_sparsity: 0.0,
smoothness: 0.0,
ard: 0.0,
};
for _ in 0..max_iter {
self.advance_temperature_schedule()?;
let options = ArrowSolveOptions::automatic(beta_dim);
let mut s_acc = Array2::<f64>::zeros((beta_dim, beta_dim));
let mut rhs_acc = Array1::<f64>::zeros(beta_dim);
let mut gb_acc = Array1::<f64>::zeros(beta_dim);
let mut ard_sumsq: Vec<Array1<f64>> = self
.assignment
.coords
.iter()
.map(|c| Array1::<f64>::zeros(c.latent_dim()))
.collect();
let mut pre_step_total = 0.0_f64;
let mut chunk_ranges: Vec<(usize, usize)> = Vec::new();
let mut start = 0usize;
while start < n_total {
let end = (start + chunk_size).min(n_total);
let n_chunk = end - start;
let penalty_scale = n_chunk as f64 / n_total as f64;
let (logits, coords, z_chunk) = chunk_init(start, end)?;
if z_chunk.dim() != (n_chunk, self.output_dim()) {
return Err(format!(
"SaeManifoldTerm::run_joint_fit_arrow_schur_streaming: chunk [{start}, {end}) \
Z slice shape {:?} != ({n_chunk}, {})",
z_chunk.dim(),
self.output_dim()
));
}
let mut chunk = self.materialize_chunk(logits, coords)?;
chunk_ranges.push((start, end));
for (atom_idx, coord) in chunk.assignment.coords.iter().enumerate() {
let d = coord.latent_dim();
for row in 0..coord.n_obs() {
let row_t = coord.row(row);
for axis in 0..d {
ard_sumsq[atom_idx][axis] += row_t[axis] * row_t[axis];
}
}
}
pre_step_total += chunk
.loss_scaled(z_chunk.view(), rho, penalty_scale)?
.total();
let sys = chunk
.assemble_arrow_schur_scaled(
z_chunk.view(),
rho,
analytic_penalties,
penalty_scale,
)
.map_err(|err| {
format!("SaeManifoldTerm::run_joint_fit_arrow_schur_streaming: {err}")
})?;
for j in 0..beta_dim {
gb_acc[j] += sys.gb[j];
}
Self::accumulate_chunk_reduced_schur(
&sys,
ridge_ext_coord,
&options,
&mut s_acc,
&mut rhs_acc,
)?;
start = end;
}
for j in 0..beta_dim {
s_acc[[j, j]] += ridge_beta;
rhs_acc[j] -= gb_acc[j];
}
let delta_beta =
solve_streaming_reduced_beta(&s_acc, &rhs_acc, &options).map_err(|err| {
format!("SaeManifoldTerm::run_joint_fit_arrow_schur_streaming: {err}")
})?;
let beta0 = self.flatten_beta();
let mut directional_decrease = 0.0_f64;
for j in 0..beta_dim {
directional_decrease += rhs_acc[j] * delta_beta[j];
}
if !(pre_step_total.is_finite()
&& directional_decrease.is_finite()
&& directional_decrease > 0.0)
{
self.update_ard_reml_from_sumsq(rho, &ard_sumsq, n_total);
last_loss = self.streaming_loss(&chunk_ranges, rho, n_total, &mut chunk_init)?;
break;
}
let mut trial_step = step_size;
let mut accepted_loss: Option<SaeManifoldLoss> = None;
for _ in 0..=SAE_MANIFOLD_MAX_LINESEARCH_HALVINGS {
let mut trial_beta = beta0.clone();
for j in 0..beta_dim {
trial_beta[j] += trial_step * delta_beta[j];
}
self.set_flat_beta(trial_beta.view())?;
let trial_loss =
self.streaming_loss(&chunk_ranges, rho, n_total, &mut chunk_init)?;
let trial_total = trial_loss.total();
let armijo_bound =
pre_step_total - SAE_MANIFOLD_ARMIJO_C1 * trial_step * directional_decrease;
if trial_total.is_finite() && trial_total <= armijo_bound {
accepted_loss = Some(trial_loss);
break;
}
trial_step *= 0.5;
}
match accepted_loss {
Some(loss) => {
self.update_ard_reml_from_sumsq(rho, &ard_sumsq, n_total);
last_loss = loss;
}
None => {
self.set_flat_beta(beta0.view())?;
self.update_ard_reml_from_sumsq(rho, &ard_sumsq, n_total);
last_loss =
self.streaming_loss(&chunk_ranges, rho, n_total, &mut chunk_init)?;
break;
}
}
}
Ok(last_loss)
}
fn accumulate_chunk_reduced_schur(
sys: &ArrowSchurSystem,
ridge_ext_coord: f64,
options: &ArrowSolveOptions,
s_acc: &mut Array2<f64>,
rhs_acc: &mut Array1<f64>,
) -> Result<(), String> {
let k = sys.k;
let chunk_n = sys.rows.len();
let mut streaming = StreamingArrowSchur::from_system(sys, chunk_n.max(1));
streaming
.reset_accumulator(0.0)
.map_err(|e| e.to_string())?;
streaming
.accumulate_chunk(0, chunk_n, ridge_ext_coord, options.mode)
.map_err(|e| e.to_string())?;
let (contrib_s, contrib_rhs) = streaming.take_accumulators();
for i in 0..k {
rhs_acc[i] += contrib_rhs[i];
for j in 0..k {
s_acc[[i, j]] += contrib_s[[i, j]];
}
}
Ok(())
}
fn streaming_loss<F>(
&self,
chunk_ranges: &[(usize, usize)],
rho: &SaeManifoldRho,
n_total: usize,
chunk_init: &mut F,
) -> Result<SaeManifoldLoss, String>
where
F: FnMut(usize, usize) -> Result<(Array2<f64>, Vec<Array2<f64>>, Array2<f64>), String>,
{
let mut data_fit = 0.0_f64;
let mut assignment_sparsity = 0.0_f64;
let mut smoothness = 0.0_f64;
let mut ard = 0.0_f64;
for &(start, end) in chunk_ranges {
let n_chunk = end - start;
let penalty_scale = n_chunk as f64 / n_total as f64;
let (logits, coords, z_chunk) = chunk_init(start, end)?;
let chunk = self.materialize_chunk(logits, coords)?;
let loss = chunk.loss_scaled(z_chunk.view(), rho, penalty_scale)?;
data_fit += loss.data_fit;
assignment_sparsity += loss.assignment_sparsity;
smoothness += loss.smoothness;
ard += loss.ard;
}
Ok(SaeManifoldLoss {
data_fit,
assignment_sparsity,
smoothness,
ard,
})
}
fn update_ard_reml_from_sumsq(
&self,
rho: &mut SaeManifoldRho,
sumsq: &[Array1<f64>],
n_total: usize,
) {
let n = n_total as f64;
for (atom_idx, coord) in self.assignment.coords.iter().enumerate() {
let d = coord.latent_dim();
if atom_idx >= sumsq.len() || rho.log_ard[atom_idx].len() != d {
continue;
}
for axis in 0..d {
let sq = sumsq[atom_idx][axis];
if sq < 1.0e-10 {
continue;
}
let alpha = n / sq;
rho.log_ard[atom_idx][axis] = alpha.ln().clamp(-8.0, 16.0);
}
}
}
pub fn run_single_external_basis_refresh_step_arrow_schur(
&mut self,
target: ArrayView2<'_, f64>,
rho: &mut SaeManifoldRho,
analytic_penalties: Option<&AnalyticPenaltyRegistry>,
step_size: f64,
ridge_ext_coord: f64,
ridge_beta: f64,
) -> Result<SaeManifoldLoss, String> {
self.advance_temperature_schedule()?;
self.update_ard_reml(rho)?;
let pre_step_loss = self.loss(target, rho)?;
let (delta_ext_coord, delta_beta) = self
.solve_newton_step(target, rho, analytic_penalties, ridge_ext_coord, ridge_beta)
.map_err(|err| {
format!(
"SaeManifoldTerm::run_single_external_basis_refresh_step_arrow_schur: {err}"
)
})?;
self.apply_newton_step_external_basis_refresh(
delta_ext_coord.view(),
delta_beta.view(),
step_size,
)?;
self.update_ard_reml(rho)?;
Ok(pre_step_loss)
}
pub fn analytic_penalty_descriptors(&self) -> (AnalyticPenaltyKind, Vec<ARDPenalty>) {
let assignment = match self.assignment.mode {
AssignmentMode::Softmax { temperature, .. } => {
AnalyticPenaltyKind::SoftmaxAssignmentSparsity(Arc::new(
SoftmaxAssignmentSparsityPenalty::new(self.k_atoms(), temperature),
))
}
AssignmentMode::IBPMap {
temperature,
alpha,
learnable_alpha,
} => {
let penalty =
IBPAssignmentPenalty::new(self.k_atoms(), alpha, temperature, learnable_alpha);
let penalty = match self.temperature_schedule.clone() {
Some(schedule) => penalty.with_temperature_schedule(schedule),
None => penalty,
};
AnalyticPenaltyKind::IBPAssignment(Arc::new(penalty))
}
AssignmentMode::JumpReLU { .. } => {
panic!(
"JumpReLU assignment mode uses the built-in gated L1 assignment prior and has no AnalyticPenaltyKind descriptor"
)
}
};
let mut ard = Vec::with_capacity(self.k_atoms());
for coord in &self.assignment.coords {
ard.push(ARDPenalty::new(
PsiSlice::full(coord.len(), Some(coord.latent_dim())),
coord.latent_dim(),
));
}
(assignment, ard)
}
}
fn sae_manifold_newton_directional_decrease(
sys: &ArrowSchurSystem,
delta_ext_coord: ArrayView1<'_, f64>,
delta_beta: ArrayView1<'_, f64>,
) -> f64 {
assert_eq!(delta_ext_coord.len(), sys.row_offsets[sys.rows.len()]);
assert_eq!(delta_beta.len(), sys.k);
let mut gradient_dot_step = 0.0;
for (row_idx, row) in sys.rows.iter().enumerate() {
let row_base = sys.row_offsets[row_idx];
let di = sys.row_dims[row_idx];
for axis in 0..di {
gradient_dot_step += row.gt[axis] * delta_ext_coord[row_base + axis];
}
}
for idx in 0..sys.k {
gradient_dot_step += sys.gb[idx] * delta_beta[idx];
}
-gradient_dot_step
}
fn softmax_row(logits: ArrayView1<'_, f64>, temperature: f64) -> Array1<f64> {
let k = logits.len();
let inv_tau = 1.0 / temperature;
let mut max_logit = f64::NEG_INFINITY;
for &v in logits.iter() {
max_logit = max_logit.max(v);
}
let mut out = Array1::<f64>::zeros(k);
let mut sum = 0.0;
for i in 0..k {
let v = ((logits[i] - max_logit) * inv_tau).exp();
out[i] = v;
sum += v;
}
assert!(sum.is_finite() && sum > 0.0);
for v in out.iter_mut() {
*v /= sum;
}
out
}
fn validate_finite_logits(logits: ArrayView1<'_, f64>, row: usize) -> Result<(), String> {
for (col, &v) in logits.iter().enumerate() {
if !v.is_finite() {
return Err(format!(
"SaeAssignment: non-finite assignment logit at row {row}, atom {col}: {v}"
));
}
}
Ok(())
}
fn ibp_stick_breaking_prior(k_atoms: usize, alpha: f64) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(k_atoms);
let ratio = alpha / (alpha + 1.0);
let mut acc = 1.0;
for k in 0..k_atoms {
out[k] = acc;
acc *= ratio;
}
out
}
fn ibp_map_row(logits: ArrayView1<'_, f64>, temperature: f64, alpha: f64) -> Array1<f64> {
let prior = ibp_stick_breaking_prior(logits.len(), alpha);
let mut out = Array1::<f64>::zeros(logits.len());
for i in 0..logits.len() {
out[i] = crate::linalg::utils::stable_logistic(logits[i] / temperature) * prior[i];
}
out
}
#[must_use]
pub fn ibp_map_row_value_grad(
logits: ArrayView1<'_, f64>,
temperature: f64,
alpha: f64,
) -> (Array1<f64>, Array1<f64>) {
let prior = ibp_stick_breaking_prior(logits.len(), alpha);
let inv_tau = 1.0 / temperature;
let mut value = Array1::<f64>::zeros(logits.len());
let mut grad = Array1::<f64>::zeros(logits.len());
for i in 0..logits.len() {
let sig = crate::linalg::utils::stable_logistic(logits[i] * inv_tau);
value[i] = sig * prior[i];
grad[i] = sig * (1.0 - sig) * inv_tau * prior[i];
}
(value, grad)
}
fn jumprelu_row(logits: ArrayView1<'_, f64>, temperature: f64, threshold: f64) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(logits.len());
for i in 0..logits.len() {
if logits[i] > threshold {
out[i] = crate::linalg::utils::stable_logistic(logits[i] / temperature);
}
}
out
}
#[allow(clippy::too_many_arguments)]
fn fill_active_atom_logit_jvp(
mode: AssignmentMode,
k: usize,
logit_k: f64,
a_k: f64,
decoded_k: ArrayView1<'_, f64>,
fitted: ArrayView1<'_, f64>,
ibp_prior: Option<&[f64]>,
jac_compact: &mut Array2<f64>,
j: usize,
) {
let p = fitted.len();
match mode {
AssignmentMode::Softmax { temperature, .. } => {
let inv_tau = 1.0 / temperature;
for out_col in 0..p {
jac_compact[[j, out_col]] = a_k * (decoded_k[out_col] - fitted[out_col]) * inv_tau;
}
}
AssignmentMode::IBPMap { temperature, .. } => {
let inv_tau = 1.0 / temperature;
let prior =
ibp_prior.expect("fill_active_atom_logit_jvp: IBPMap requires precomputed prior");
let pi_k = prior[k];
let sig = if pi_k > 0.0 { a_k / pi_k } else { 0.0 };
let dz = sig * (1.0 - sig) * inv_tau * pi_k;
for out_col in 0..p {
jac_compact[[j, out_col]] = dz * decoded_k[out_col];
}
}
AssignmentMode::JumpReLU {
temperature,
threshold,
} => {
if logit_k <= threshold {
return;
}
let inv_tau = 1.0 / temperature;
let activation = crate::linalg::utils::stable_logistic(logit_k * inv_tau);
let da = activation * (1.0 - activation) * inv_tau;
for out_col in 0..p {
jac_compact[[j, out_col]] = da * decoded_k[out_col];
}
}
}
}
fn fill_assignment_logit_jvp_rows(
mode: AssignmentMode,
logits: ArrayView1<'_, f64>,
assignments: ArrayView1<'_, f64>,
decoded: ArrayView2<'_, f64>,
fitted: ArrayView1<'_, f64>,
ibp_prior: Option<&[f64]>,
local_jac: &mut Array2<f64>,
) {
if assignments.len() == 1 {
for logit_col in 0..assignments.len() {
for out_col in 0..fitted.len() {
local_jac[[logit_col, out_col]] = 0.0;
}
}
return;
}
match mode {
AssignmentMode::Softmax { temperature, .. } => {
let inv_tau = 1.0 / temperature;
for logit_col in 0..assignments.len() {
for out_col in 0..fitted.len() {
local_jac[[logit_col, out_col]] = assignments[logit_col]
* (decoded[[logit_col, out_col]] - fitted[out_col])
* inv_tau;
}
}
}
AssignmentMode::IBPMap { temperature, .. } => {
let inv_tau = 1.0 / temperature;
let prior = ibp_prior
.expect("fill_assignment_logit_jvp_rows: IBPMap requires precomputed prior");
for logit_col in 0..assignments.len() {
let pi_k = prior[logit_col];
let a_k = assignments[logit_col];
let sig = if pi_k > 0.0 { a_k / pi_k } else { 0.0 };
let dz = sig * (1.0 - sig) * inv_tau * pi_k;
for out_col in 0..fitted.len() {
local_jac[[logit_col, out_col]] = dz * decoded[[logit_col, out_col]];
}
}
}
AssignmentMode::JumpReLU {
temperature,
threshold,
} => {
let inv_tau = 1.0 / temperature;
for logit_col in 0..assignments.len() {
if logits[logit_col] <= threshold {
continue;
}
let activation = crate::linalg::utils::stable_logistic(logits[logit_col] * inv_tau);
let da = activation * (1.0 - activation) * inv_tau;
for out_col in 0..fitted.len() {
local_jac[[logit_col, out_col]] = da * decoded[[logit_col, out_col]];
}
}
}
}
}
fn flat_logits(logits: ArrayView2<'_, f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(logits.len());
for row in 0..logits.nrows() {
let start = row * logits.ncols();
for col in 0..logits.ncols() {
out[start + col] = logits[[row, col]];
}
}
out
}
fn assignment_prior_value(assignment: &SaeAssignment, rho: &SaeManifoldRho) -> f64 {
if assignment.k_atoms() == 1 {
return 0.0;
}
for row in 0..assignment.n_obs() {
validate_finite_logits(assignment.logits.row(row), row)
.expect("assignment logits must be finite");
}
let target = flat_logits(assignment.logits.view());
match assignment.mode {
AssignmentMode::Softmax {
temperature,
sparsity,
} => {
let penalty = SoftmaxAssignmentSparsityPenalty::new(assignment.k_atoms(), temperature);
let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse + sparsity.ln()]);
penalty.value(target.view(), rho_view.view())
}
AssignmentMode::IBPMap {
temperature,
alpha,
learnable_alpha,
} => {
let penalty = IBPAssignmentPenalty::new(
assignment.k_atoms(),
alpha,
temperature,
learnable_alpha,
);
let rho_view = if learnable_alpha {
Array1::from_vec(vec![rho.log_lambda_sparse])
} else {
Array1::zeros(0)
};
penalty.value(target.view(), rho_view.view())
}
AssignmentMode::JumpReLU {
temperature,
threshold,
} => {
let sparsity_strength = rho.log_lambda_sparse.exp();
let mut acc = 0.0;
for &logit in target.iter() {
if logit > threshold {
acc += crate::linalg::utils::stable_logistic(logit / temperature);
}
}
sparsity_strength * acc
}
}
}
fn assignment_prior_grad_hdiag(
assignment: &SaeAssignment,
rho: &SaeManifoldRho,
) -> Result<(Array1<f64>, Array1<f64>), String> {
if assignment.k_atoms() == 1 {
let n_obs = assignment.n_obs();
return Ok((Array1::zeros(n_obs), Array1::zeros(n_obs)));
}
for row in 0..assignment.n_obs() {
validate_finite_logits(assignment.logits.row(row), row)?;
}
let target = flat_logits(assignment.logits.view());
match assignment.mode {
AssignmentMode::Softmax {
temperature,
sparsity,
} => {
let penalty = SoftmaxAssignmentSparsityPenalty::new(assignment.k_atoms(), temperature);
let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse + sparsity.ln()]);
let grad = penalty.grad_target(target.view(), rho_view.view());
let diag = penalty
.hessian_diag(target.view(), rho_view.view())
.ok_or_else(|| "softmax assignment hessian diag unavailable".to_string())?;
Ok((grad, diag))
}
AssignmentMode::IBPMap {
temperature,
alpha,
learnable_alpha,
} => {
let penalty = IBPAssignmentPenalty::new(
assignment.k_atoms(),
alpha,
temperature,
learnable_alpha,
);
let rho_view = if learnable_alpha {
Array1::from_vec(vec![rho.log_lambda_sparse])
} else {
Array1::zeros(0)
};
let grad = penalty.grad_target(target.view(), rho_view.view());
let diag = penalty
.hessian_diag(target.view(), rho_view.view())
.ok_or_else(|| "IBP assignment hessian diag unavailable".to_string())?;
Ok((grad, diag))
}
AssignmentMode::JumpReLU {
temperature,
threshold,
} => {
let sparsity_strength = rho.log_lambda_sparse.exp();
let inv_tau = 1.0 / temperature;
let inv_tau2 = inv_tau * inv_tau;
let mut grad = Array1::<f64>::zeros(target.len());
let mut diag = Array1::<f64>::zeros(target.len());
for idx in 0..target.len() {
let logit = target[idx];
if logit <= threshold {
continue;
}
let activation = crate::linalg::utils::stable_logistic(logit * inv_tau);
let slope = activation * (1.0 - activation);
grad[idx] = sparsity_strength * slope * inv_tau;
diag[idx] = sparsity_strength * slope * slope * inv_tau2;
}
Ok((grad, diag))
}
}
}
fn sae_penalty_is_row_block_supported(penalty: &AnalyticPenaltyKind) -> bool {
matches!(
penalty,
AnalyticPenaltyKind::Ard(_)
| AnalyticPenaltyKind::TopKActivation(_)
| AnalyticPenaltyKind::JumpReLU(_)
| AnalyticPenaltyKind::Sparsity(_)
| AnalyticPenaltyKind::SoftmaxAssignmentSparsity(_)
| AnalyticPenaltyKind::IBPAssignment(_)
| AnalyticPenaltyKind::RowPrecisionPrior(_)
| AnalyticPenaltyKind::ParametricRowPrecisionPrior(_)
| AnalyticPenaltyKind::ScadMcp(_)
| AnalyticPenaltyKind::BlockOrthogonality(_)
| AnalyticPenaltyKind::Isometry(_)
)
}
pub fn sae_row_block_penalty_kinds() -> &'static [&'static str] {
&[
"ard",
"top_k_activation",
"jumprelu",
"sparsity",
"softmax_assignment_sparsity",
"ibp_assignment",
"row_precision_prior",
"parametric_row_precision_prior",
"scad_mcp",
"block_orthogonality",
"isometry",
]
}
#[must_use = "build error must be handled"]
pub fn term_from_padded_blocks_with_mode(
n_obs: usize,
p_out: usize,
basis_kinds: &[SaeAtomBasisKind],
basis_values: ArrayView3<'_, f64>,
basis_jacobian: ArrayView4<'_, f64>,
basis_sizes: &[usize],
latent_dims: &[usize],
decoder_coefficients: ArrayView3<'_, f64>,
smooth_penalties: ArrayView3<'_, f64>,
logits: ArrayView2<'_, f64>,
coords: &[Array2<f64>],
mode: AssignmentMode,
evaluators: &[Option<Arc<dyn SaeBasisEvaluator>>],
) -> Result<SaeManifoldTerm, String> {
let k_atoms = basis_sizes.len();
if latent_dims.len() != k_atoms || basis_kinds.len() != k_atoms || coords.len() != k_atoms {
return Err("term_from_padded_blocks: K-length metadata mismatch".into());
}
if !evaluators.is_empty() && evaluators.len() != k_atoms {
return Err(format!(
"term_from_padded_blocks: evaluators length {} must equal K={k_atoms} or be empty",
evaluators.len()
));
}
if logits.dim() != (n_obs, k_atoms) {
return Err(format!(
"term_from_padded_blocks: logits must be ({n_obs}, {k_atoms}); got {:?}",
logits.dim()
));
}
let mut atoms = Vec::with_capacity(k_atoms);
for k in 0..k_atoms {
let m = basis_sizes[k];
let d = latent_dims[k];
let phi = basis_values.slice(s![k, 0..n_obs, 0..m]).to_owned();
let jet = basis_jacobian.slice(s![k, 0..n_obs, 0..m, 0..d]).to_owned();
let b = decoder_coefficients.slice(s![k, 0..m, 0..p_out]).to_owned();
let s = smooth_penalties.slice(s![k, 0..m, 0..m]).to_owned();
let atom = SaeManifoldAtom::new(
format!("atom_{k}"),
basis_kinds[k].clone(),
d,
phi,
jet,
b,
s,
)?;
let atom = match evaluators.get(k).and_then(|slot| slot.clone()) {
Some(evaluator) => atom.with_basis_evaluator(evaluator),
None => atom,
};
atoms.push(atom);
}
let manifolds = basis_kinds
.iter()
.zip(latent_dims.iter().copied())
.map(|(kind, d)| kind.latent_manifold(d))
.collect();
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits.to_owned(),
coords.to_vec(),
manifolds,
mode,
)?;
SaeManifoldTerm::new(atoms, assignment)
}
pub fn refresh_isometry_caches_from_atom(
penalty: &IsometryPenalty,
atom: &SaeManifoldAtom,
coords: ArrayView2<'_, f64>,
) -> Result<bool, String> {
let evaluator = atom.basis_evaluator.as_ref().ok_or_else(|| {
format!(
"refresh_isometry_caches_from_atom: atom {} has no basis evaluator",
atom.name
)
})?;
let (_phi, jet) = evaluator.evaluate(coords)?;
let n_obs = coords.nrows();
let d = atom.latent_dim;
let m = atom.basis_size();
let p = atom.decoder_coefficients.ncols();
if penalty.p_out != p {
return Err(format!(
"refresh_isometry_caches_from_atom: penalty.p_out={} but atom.decoder.cols={p}",
penalty.p_out
));
}
if jet.dim() != (n_obs, m, d) {
return Err(format!(
"refresh_isometry_caches_from_atom: evaluator first jet has shape {:?}, expected ({n_obs}, {m}, {d})",
jet.dim()
));
}
let b = &atom.decoder_coefficients;
let mut jac = Array2::<f64>::zeros((n_obs, p * d));
for n in 0..n_obs {
for i in 0..p {
for a in 0..d {
let mut acc = 0.0;
for mm in 0..m {
acc += jet[[n, mm, a]] * b[[mm, i]];
}
jac[[n, i * d + a]] = acc;
}
}
}
let jac2_opt = if let Some(second_eval) = atom.basis_second_jet.as_ref() {
let hess = second_eval.second_jet(coords)?;
if hess.dim() != (n_obs, m, d, d) {
return Err(format!(
"refresh_isometry_caches_from_atom: evaluator second jet has shape {:?}, expected ({n_obs}, {m}, {d}, {d})",
hess.dim()
));
}
let mut jac2 = Array2::<f64>::zeros((n_obs, p * d * d));
for n in 0..n_obs {
for i in 0..p {
for a in 0..d {
for c in 0..d {
let mut acc = 0.0;
for mm in 0..m {
acc += hess[[n, mm, a, c]] * b[[mm, i]];
}
jac2[[n, (i * d + a) * d + c]] = acc;
}
}
}
}
Some(Arc::new(jac2))
} else {
None
};
let installed = jac2_opt.is_some();
penalty.refresh_caches(Some(Arc::new(jac)), jac2_opt);
Ok(installed)
}
pub fn refresh_isometry_caches_from_term(
registry: &AnalyticPenaltyRegistry,
term: &SaeManifoldTerm,
coords_per_atom: &[Array2<f64>],
) -> Result<usize, String> {
if coords_per_atom.len() != term.atoms.len() {
return Err(format!(
"refresh_isometry_caches_from_term: coords_per_atom length {} != number of atoms {}",
coords_per_atom.len(),
term.atoms.len()
));
}
let mut refreshed_with_second = 0usize;
let mut consumed_per_signature: std::collections::HashMap<(usize, usize), usize> =
std::collections::HashMap::new();
for entry in registry.penalties.iter() {
let AnalyticPenaltyKind::Isometry(p) = entry else {
continue;
};
let Some(p_latent_dim) = p.target.latent_dim else {
continue;
};
let signature = (p_latent_dim, p.p_out);
let already_consumed = consumed_per_signature.entry(signature).or_insert(0);
let mut seen = 0usize;
let mut paired: Option<usize> = None;
for (atom_idx, atom) in term.atoms.iter().enumerate() {
let matches = atom.latent_dim == p_latent_dim
&& atom.decoder_coefficients.ncols() == p.p_out
&& atom.basis_evaluator.is_some();
if !matches {
continue;
}
if seen == *already_consumed {
paired = Some(atom_idx);
break;
}
seen += 1;
}
let Some(atom_idx) = paired else {
continue;
};
*already_consumed += 1;
let atom = &term.atoms[atom_idx];
let coords = coords_per_atom[atom_idx].view();
if refresh_isometry_caches_from_atom(p, atom, coords)? {
refreshed_with_second += 1;
}
}
Ok(refreshed_with_second)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use ndarray::array;
#[test]
fn search_strategy_exposes_fixed_and_sweep_values() {
assert!(SearchStrategy::Fixed.is_fixed());
let strategy = SearchStrategy::ExponentialSweep {
values: vec![0.1, 1.0, 10.0],
};
assert!(!strategy.is_fixed());
assert_eq!(strategy.sweep_values(), Some([0.1, 1.0, 10.0].as_slice()));
}
fn periodic_basis(coords: &Array2<f64>) -> (Array2<f64>, Array3<f64>) {
let n = coords.nrows();
let mut phi = Array2::<f64>::zeros((n, 3));
let mut jet = Array3::<f64>::zeros((n, 3, 1));
for row in 0..n {
let x = coords[[row, 0]].rem_euclid(1.0);
let angle = 2.0 * std::f64::consts::PI * x;
phi[[row, 0]] = 1.0;
phi[[row, 1]] = angle.sin();
phi[[row, 2]] = angle.cos();
jet[[row, 1, 0]] = 2.0 * std::f64::consts::PI * angle.cos();
jet[[row, 2, 0]] = -2.0 * std::f64::consts::PI * angle.sin();
}
(phi, jet)
}
#[test]
fn snapshot_restore_round_trips_mutated_state() {
let coords0 = array![[0.05], [0.20], [0.55], [0.80]];
let (phi0, jet0) = periodic_basis(&coords0);
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi0,
jet0,
array![[0.2], [-0.3], [0.4]],
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_evaluator(Arc::new(TestPeriodicEvaluator));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((4, 1)),
vec![coords0],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::ibp_map(0.7, 1.0, true),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let snapshot = term.snapshot_mutable_state();
let pre_basis = term.atoms[0].basis_values.clone();
let pre_jet = term.atoms[0].basis_jacobian.clone();
let pre_decoder = term.atoms[0].decoder_coefficients.clone();
let pre_logits = term.assignment.logits.clone();
let pre_coords = term.assignment.coords[0].as_matrix();
let q = term.assignment.row_block_dim();
let beta_dim = term.beta_dim();
let delta_ext = Array1::<f64>::from_elem(4 * q, 0.3);
let delta_beta = Array1::<f64>::from_elem(beta_dim, -0.4);
term.apply_newton_step(delta_ext.view(), delta_beta.view(), 1.0)
.unwrap();
assert!(
(&term.atoms[0].basis_values - &pre_basis)
.mapv(f64::abs)
.sum()
> 1e-9
|| (&term.atoms[0].decoder_coefficients - &pre_decoder)
.mapv(f64::abs)
.sum()
> 1e-9,
"apply_newton_step did not perturb the snapshotted state"
);
term.restore_mutable_state(&snapshot);
assert_eq!(term.atoms[0].basis_values, pre_basis);
assert_eq!(term.atoms[0].basis_jacobian, pre_jet);
assert_eq!(term.atoms[0].decoder_coefficients, pre_decoder);
assert_eq!(term.assignment.logits, pre_logits);
assert_eq!(term.assignment.coords[0].as_matrix(), pre_coords);
}
#[test]
fn ibp_path_refreshes_periodic_basis_for_two_newton_iterations() {
let coords0 = array![[0.05], [0.20], [0.55], [0.80]];
let (phi0, jet0) = periodic_basis(&coords0);
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi0,
jet0,
array![[0.2], [-0.3], [0.4]],
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_evaluator(Arc::new(TestPeriodicEvaluator));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((4, 1)),
vec![coords0],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::ibp_map(0.7, 1.0, true),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let target = array![[0.10], [0.05], [-0.15], [0.20]];
let mut rho = SaeManifoldRho::new(0.0, -6.0, vec![Array1::<f64>::zeros(1)]);
let loss0 = term.loss(target.view(), &rho).unwrap().total();
let basis0 = term.atoms[0].basis_values.clone();
let loss = term
.run_joint_fit_arrow_schur(target.view(), &mut rho, None, 2, 0.05, 1.0e-3, 1.0e-3)
.unwrap();
assert!(loss.total().is_finite());
assert!(loss.total() <= loss0 + 1.0e-8);
assert!(
term.assignment.coords[0]
.as_flat()
.iter()
.all(|v| v.is_finite())
);
assert!(term.assignment.assignments().iter().all(|v| v.is_finite()));
let basis_delta = (&term.atoms[0].basis_values - &basis0).mapv(f64::abs).sum();
assert!(basis_delta > 1.0e-10);
}
#[test]
fn run_joint_fit_arrow_schur_escalates_ridge_on_non_pd_row_block() {
let coords = array![[0.1], [0.4], [0.7]];
let (phi, jet) = periodic_basis(&coords);
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
array![[0.05], [-0.05], [0.05]],
Array2::<f64>::zeros((3, 3)),
)
.unwrap();
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((3, 1)),
vec![coords],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(0.7),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let target = array![[0.20], [-0.10], [0.45]];
let mut rho = SaeManifoldRho::new(0.0, -20.0, vec![Array1::<f64>::zeros(1)]);
let result =
term.run_joint_fit_arrow_schur(target.view(), &mut rho, None, 1, 1.0, 1.0e-6, 1.0e-6);
assert!(
result.is_ok(),
"run_joint_fit_arrow_schur should recover from degenerate H_tt via LM ridge escalation; got: {result:?}",
);
}
#[test]
fn solve_newton_step_escalates_ridge_on_non_pd_row_block() {
let coords = array![[0.1], [0.4], [0.7]];
let (phi, jet) = periodic_basis(&coords);
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
array![[0.05], [-0.05], [0.05]],
Array2::<f64>::zeros((3, 3)),
)
.unwrap();
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((3, 1)),
vec![coords],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(0.7),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let target = array![[0.20], [-0.10], [0.45]];
let rho = SaeManifoldRho::new(0.0, -20.0, vec![Array1::<f64>::zeros(1)]);
let result = term.solve_newton_step(target.view(), &rho, None, 1.0e-6, 1.0e-6);
assert!(
result.is_ok(),
"solve_newton_step should recover from degenerate H_tt via LM ridge escalation; got: {result:?}",
);
}
#[test]
fn sae_arrow_schur_beta_quadratic_model_matches_penalized_loss_change() {
let coords = array![[0.10], [0.35], [0.80]];
let (phi, jet) = periodic_basis(&coords);
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
array![[0.65], [-0.45], [0.25]],
array![[3.0, 0.4, -0.2], [0.1, 2.5, 0.3], [-0.5, 0.2, 1.8]],
)
.unwrap();
let assignment = SaeAssignment::from_blocks_with_mode(
Array2::<f64>::zeros((3, 1)),
vec![coords],
AssignmentMode::softmax(0.7),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let target = array![[0.20], [-0.10], [0.45]];
let rho = SaeManifoldRho::new(0.0, 1.3_f64.ln(), vec![array![0.9_f64.ln()]]);
let sys = term
.assemble_arrow_schur(target.view(), &rho, None)
.unwrap();
let beta0 = term.flatten_beta();
let loss0 = term.loss(target.view(), &rho).unwrap().total();
let mut direction = sys.gb.mapv(|v| -v);
let direction_norm = direction.iter().map(|v| v * v).sum::<f64>().sqrt();
assert!(direction_norm > 1.0e-12);
for value in direction.iter_mut() {
*value /= direction_norm;
}
let epsilon = 1.0e-3;
let delta = direction.mapv(|v| epsilon * v);
let beta_trial = beta0 + δ
term.set_flat_beta(beta_trial.view()).unwrap();
let actual = term.loss(target.view(), &rho).unwrap().total() - loss0;
let linear = sys.gb.dot(&delta);
let mut hbb_delta = Array1::<f64>::zeros(delta.len());
{
let op = sys.effective_penalty_op();
let d_slice = delta.as_slice().expect("delta is contiguous");
let hd_slice = hbb_delta.as_slice_mut().expect("hbb_delta is contiguous");
op.matvec(d_slice, hd_slice);
}
let quadratic = 0.5 * delta.dot(&hbb_delta);
let predicted = linear + quadratic;
let error = (actual - predicted).abs();
assert!(
error <= 1.0e-4,
"actual={actual:.12e}, predicted={predicted:.12e}, error={error:.12e}"
);
}
#[test]
fn sae_row_layout_from_dense_weights_top_k_and_cutoff() {
let coord_dims = vec![2usize, 1, 2];
let coord_offsets_full = vec![3usize, 5, 6];
let assignments = vec![
Array1::from_vec(vec![0.7, 0.01, 0.29]),
Array1::from_vec(vec![0.001, 0.002, 0.0005]),
];
let layout =
SaeRowLayout::from_dense_weights(&assignments, 2, 0.05, coord_dims, coord_offsets_full);
assert_eq!(layout.active_atoms[0], vec![0, 2]);
assert_eq!(layout.active_atoms[1], vec![1]);
assert_eq!(layout.row_q_active(0), 6);
assert_eq!(layout.row_q_active(1), 2);
let compact = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0];
let mut full = vec![0.0_f64; 8];
layout.expand_row(0, &compact, &mut full);
assert_eq!(full[0], 1.0);
assert_eq!(full[1], 0.0);
assert_eq!(full[2], 2.0);
assert_eq!(full[3], 3.0);
assert_eq!(full[4], 4.0);
assert_eq!(full[5], 0.0);
assert_eq!(full[6], 5.0);
assert_eq!(full[7], 6.0);
}
#[test]
fn sae_mechsparsity_beta_block_routes_through_arrow_schur_gb() {
let coords = array![[0.10], [0.35], [0.80]];
let (phi, jet) = periodic_basis(&coords);
let decoder = array![
[0.7, -0.2, 0.05, 0.4],
[-0.5, 0.6, -0.1, 0.3],
[0.2, 0.0, -0.4, -0.6],
];
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
decoder.clone(),
Array2::<f64>::eye(3),
)
.unwrap();
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((3, 1)),
vec![coords],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(0.7),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let m = 3usize;
let p = 4usize;
let slice = PsiSlice::full(m * p, Some(m));
let penalty = MechanismSparsityPenalty::new(
slice,
vec![vec![0, 1], vec![2, 3]],
1.0,
1.0e-6,
(term.n_obs()) as f64,
false,
)
.unwrap();
let mut registry = AnalyticPenaltyRegistry::new();
registry.push(AnalyticPenaltyKind::MechanismSparsity(Arc::new(penalty)));
let target = array![
[0.20, 0.10, -0.05, 0.25],
[-0.10, 0.30, 0.15, -0.20],
[0.45, -0.05, 0.10, 0.30],
];
let rho = SaeManifoldRho::new(0.0, -6.0, vec![Array1::<f64>::zeros(1)]);
let sys = term
.assemble_arrow_schur(target.view(), &rho, Some(®istry))
.unwrap();
assert_eq!(sys.gb.len(), m * p, "gb should match flatten_beta length");
let mut absmax = 0.0_f64;
for v in sys.gb.iter().copied() {
assert!(v.is_finite());
if v.abs() > absmax {
absmax = v.abs();
}
}
assert!(
absmax > 1.0e-6,
"MechSparsity must inject a non-trivial gradient into the SAE arrow-Schur gb; absmax={absmax:.3e}"
);
let beta = term.flatten_beta();
let expected = {
let s = (0.5_f64.powi(2) + 0.6_f64.powi(2) + 1.0e-12).sqrt();
(2.0_f64).sqrt() * (-0.5_f64) / s
};
let observed = sys.gb[1 * p + 0];
assert!(
(observed - expected).abs() <= 1.0e-6,
"expected MechSparsity gb entry at (basis=1, feat=0) ≈ {expected:.6e}, got {observed:.6e} (beta entry = {})",
beta[1 * p + 0]
);
}
#[derive(Debug)]
struct TestPeriodicEvaluator;
impl SaeBasisEvaluator for TestPeriodicEvaluator {
fn evaluate(
&self,
coords: ArrayView2<'_, f64>,
) -> Result<(Array2<f64>, Array3<f64>), String> {
Ok(periodic_basis(&coords.to_owned()))
}
}
fn assert_jacobian_matches_central_difference<E: SaeBasisEvaluator>(
evaluator: &E,
coords: Array2<f64>,
tolerance: f64,
) {
let epsilon = 1.0e-6;
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let (n_rows, n_basis) = phi.dim();
let latent_dim = coords.ncols();
assert_eq!(jet.dim(), (n_rows, n_basis, latent_dim));
for row in 0..n_rows {
for axis in 0..latent_dim {
let mut plus = coords.clone();
let mut minus = coords.clone();
plus[[row, axis]] += epsilon;
minus[[row, axis]] -= epsilon;
let (phi_plus, plus_jet) = evaluator.evaluate(plus.view()).unwrap();
let (phi_minus, minus_jet) = evaluator.evaluate(minus.view()).unwrap();
assert_eq!(plus_jet.dim(), jet.dim());
assert_eq!(minus_jet.dim(), jet.dim());
for basis in 0..n_basis {
let finite_difference =
(phi_plus[[row, basis]] - phi_minus[[row, basis]]) / (2.0 * epsilon);
let analytic = jet[[row, basis, axis]];
let error = (analytic - finite_difference).abs();
assert!(
error <= tolerance,
"row={row} basis={basis} axis={axis}: analytic={analytic:.12e}, \
finite_difference={finite_difference:.12e}, error={error:.12e}, \
tolerance={tolerance:.12e}"
);
}
}
}
}
#[test]
fn sae_basis_evaluator_jacobians_match_central_differences() {
assert_jacobian_matches_central_difference(
&PeriodicHarmonicEvaluator::new(7).unwrap(),
array![[-0.37], [0.0], [0.125], [0.41]],
1.0e-6,
);
assert_jacobian_matches_central_difference(
&RawPeriodicCircleEvaluator::new(3).unwrap(),
array![[-1.2, 0.3, 2.0], [0.0, -0.4, 0.8], [2.4, 1.1, -0.7]],
1.0e-6,
);
let sphere_coords = array![[-0.7, -1.2], [-0.25, 0.0], [0.35, 0.9], [0.8, 2.1]];
assert_jacobian_matches_central_difference(
&SphereChartEvaluator,
sphere_coords.clone(),
1.0e-6,
);
let (sphere_phi, sphere_jet) = SphereChartEvaluator.evaluate(sphere_coords.view()).unwrap();
assert_eq!(sphere_phi.dim(), (sphere_coords.nrows(), 7));
assert_eq!(sphere_jet.dim(), (sphere_coords.nrows(), 7, 2));
for row in 0..sphere_coords.nrows() {
let lat = sphere_coords[[row, 0]];
let lon = sphere_coords[[row, 1]];
let clat = lat.cos();
let slat = lat.sin();
let clon = lon.cos();
let slon = lon.sin();
let z = slat;
let dx_dlon = -clat * slon;
let dy_dlon = clat * clon;
assert_eq!(sphere_jet[[row, 3, 1]], 0.0);
assert!((sphere_jet[[row, 5, 1]] - dy_dlon * z).abs() <= 1.0e-12);
assert!((sphere_jet[[row, 6, 1]] - dx_dlon * z).abs() <= 1.0e-12);
}
assert_jacobian_matches_central_difference(
&AffineCoordinateEvaluator::new(3),
array![[0.0, -1.0, 2.0], [3.5, 0.25, -0.75]],
1.0e-6,
);
let torus_coords = array![[0.1, 0.7], [0.42, 0.0], [0.95, 0.33], [0.5, 0.5]];
assert_jacobian_matches_central_difference(
&TorusHarmonicEvaluator::new(2, 3).unwrap(),
torus_coords.clone(),
1.0e-6,
);
let (torus_phi, torus_jet) = TorusHarmonicEvaluator::new(2, 3)
.unwrap()
.evaluate(torus_coords.view())
.unwrap();
assert_eq!(torus_phi.dim(), (torus_coords.nrows(), 49));
assert_eq!(torus_jet.dim(), (torus_coords.nrows(), 49, 2));
for row in 0..torus_coords.nrows() {
assert!((torus_phi[[row, 0]] - 1.0).abs() <= 1.0e-12);
assert!(torus_jet[[row, 0, 0]].abs() <= 1.0e-12);
assert!(torus_jet[[row, 0, 1]].abs() <= 1.0e-12);
}
}
fn assert_second_jet_matches_central_difference<E: SaeBasisSecondJet>(
evaluator: &E,
coords: Array2<f64>,
tolerance: f64,
) -> Result<(), String> {
let epsilon = 1.0e-4;
let second = evaluator.second_jet(coords.view())?;
let (_phi, jet) = evaluator.evaluate(coords.view())?;
let (n_rows, n_basis, latent_dim, latent_dim_b) = second.dim();
assert_eq!(latent_dim, latent_dim_b);
assert_eq!((n_rows, n_basis, latent_dim), jet.dim());
for row in 0..n_rows {
for axis_c in 0..latent_dim {
let mut plus = coords.clone();
let mut minus = coords.clone();
plus[[row, axis_c]] += epsilon;
minus[[row, axis_c]] -= epsilon;
let (_, jet_plus) = evaluator.evaluate(plus.view()).unwrap();
let (_, jet_minus) = evaluator.evaluate(minus.view()).unwrap();
for basis in 0..n_basis {
for axis_a in 0..latent_dim {
let fd = (jet_plus[[row, basis, axis_a]] - jet_minus[[row, basis, axis_a]])
/ (2.0 * epsilon);
let analytic = second[[row, basis, axis_a, axis_c]];
let error = (analytic - fd).abs();
assert!(
error <= tolerance,
"row={row} basis={basis} axis_a={axis_a} axis_c={axis_c}: \
analytic={analytic:.12e}, fd={fd:.12e}, error={error:.12e}, \
tol={tolerance:.12e}"
);
}
}
}
}
for row in 0..n_rows {
for basis in 0..n_basis {
for axis_a in 0..latent_dim {
for axis_b in 0..latent_dim {
let h_ab = second[[row, basis, axis_a, axis_b]];
let h_ba = second[[row, basis, axis_b, axis_a]];
assert!(
(h_ab - h_ba).abs() <= 1.0e-12,
"second_jet not symmetric: row={row} basis={basis} \
({axis_a},{axis_b})={h_ab:.6e} vs ({axis_b},{axis_a})={h_ba:.6e}"
);
}
}
}
}
Ok(())
}
#[test]
fn isometry_periodic_second_jet_matches_fd() -> Result<(), String> {
assert_second_jet_matches_central_difference(
&PeriodicHarmonicEvaluator::new(7).unwrap(),
array![[-0.37], [0.0], [0.125], [0.41]],
1.0e-5,
)?;
Ok(())
}
#[test]
fn isometry_sphere_second_jet_matches_fd() -> Result<(), String> {
let sphere_coords = array![[-0.7, -1.2], [-0.25, 0.0], [0.35, 0.9], [0.8, 2.1]];
assert_second_jet_matches_central_difference(&SphereChartEvaluator, sphere_coords, 1.0e-5)?;
Ok(())
}
#[test]
fn isometry_torus_second_jet_matches_fd() -> Result<(), String> {
let torus_coords = array![[0.1, 0.7], [0.42, 0.0], [0.95, 0.33], [0.5, 0.5]];
let evaluator = TorusHarmonicEvaluator::new(2, 3).unwrap();
assert!(evaluator.basis_size() > 0);
assert_second_jet_matches_central_difference(&evaluator, torus_coords, 1.0e-5)?;
Ok(())
}
#[test]
fn duchon_coordinate_evaluator_phi_and_jet_share_column_count() {
for (d, centers) in [
(1usize, array![[-1.0], [-0.4], [0.1], [0.6], [1.2], [1.9]]),
(
2usize,
array![
[-1.0, -0.8],
[-0.3, 0.4],
[0.2, -0.5],
[0.7, 0.9],
[1.1, -0.2],
[1.6, 0.6],
],
),
] {
let evaluator = DuchonCoordinateEvaluator::new(centers, 2).unwrap();
let coords = match d {
1 => array![[-0.5], [0.0], [0.3], [0.8]],
_ => array![[-0.5, 0.2], [0.0, -0.3], [0.3, 0.7], [0.8, -0.1]],
};
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
assert_eq!(
phi.ncols(),
jet.shape()[1],
"Duchon d={d}: Phi has {} columns but jet has {}",
phi.ncols(),
jet.shape()[1]
);
assert_eq!(jet.shape()[0], coords.nrows());
assert_eq!(jet.shape()[2], d);
}
}
#[test]
fn duchon_coordinate_evaluator_jacobian_matches_fd() {
let centers = array![
[-1.0, -0.8],
[-0.3, 0.4],
[0.2, -0.5],
[0.7, 0.9],
[1.1, -0.2],
[1.6, 0.6],
];
let evaluator = DuchonCoordinateEvaluator::new(centers, 2).unwrap();
let coords = array![[-0.5, 0.2], [0.05, -0.35], [0.45, 0.75], [1.3, 0.1]];
assert_jacobian_matches_central_difference(&evaluator, coords, 1.0e-4);
}
#[test]
fn duchon_coordinate_evaluator_second_jet_matches_fd() -> Result<(), String> {
let centers = array![
[-1.0, -0.8],
[-0.3, 0.4],
[0.2, -0.5],
[0.7, 0.9],
[1.1, -0.2],
[1.6, 0.6],
];
let evaluator = DuchonCoordinateEvaluator::new(centers, 2).unwrap();
let coords = array![[-0.5, 0.2], [0.05, -0.35], [0.45, 0.75], [1.3, 0.1]];
assert_second_jet_matches_central_difference(&evaluator, coords, 1.0e-4)?;
Ok(())
}
#[test]
fn euclidean_patch_evaluator_jets_match_fd() -> Result<(), String> {
let evaluator = EuclideanPatchEvaluator::new(2, 2).unwrap();
let coords = array![[0.0, -1.0], [3.5, 0.25], [-0.75, 1.2], [0.4, 0.9]];
assert_jacobian_matches_central_difference(&evaluator, coords.clone(), 1.0e-6);
assert_second_jet_matches_central_difference(&evaluator, coords, 1.0e-5)?;
let (phi, _jet) = evaluator.evaluate(array![[0.0, 0.0]].view())?;
assert_eq!(phi.ncols(), 6);
Ok(())
}
#[test]
fn sae_torus_atom_recovers_two_frequency_synthetic() {
let n = 96usize;
let p = 4usize;
let h = 3usize;
let d = 2usize;
let evaluator = TorusHarmonicEvaluator::new(d, h).unwrap();
let m = evaluator.basis_size();
let mut true_coords = Array2::<f64>::zeros((n, d));
for i in 0..n {
true_coords[[i, 0]] = ((i as f64) * 0.137).rem_euclid(1.0);
true_coords[[i, 1]] = ((i as f64) * 0.241 + 0.13).rem_euclid(1.0);
}
let mut z = Array2::<f64>::zeros((n, p));
for i in 0..n {
let t1 = 2.0 * std::f64::consts::PI * true_coords[[i, 0]];
let t2 = 2.0 * std::f64::consts::PI * true_coords[[i, 1]];
z[[i, 0]] = t1.sin() + 0.3 * t2.cos();
z[[i, 1]] = t1.cos() + 0.2 * (t1 + t2).sin();
z[[i, 2]] = t2.sin();
z[[i, 3]] = 0.5 * (t1 - t2).cos();
}
let sst: f64 = z.iter().map(|v| v * v).sum::<f64>();
let (phi0, jet0) = evaluator.evaluate(true_coords.view()).unwrap();
let mut penalty = Array2::<f64>::eye(m);
penalty *= 1.0e-4;
let atom = SaeManifoldAtom::new(
"torus_atom",
SaeAtomBasisKind::Torus,
d,
phi0,
jet0,
Array2::<f64>::zeros((m, p)),
penalty,
)
.unwrap()
.with_basis_evaluator(Arc::new(TorusHarmonicEvaluator::new(d, h).unwrap()));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n, 1)),
vec![true_coords],
vec![LatentManifold::Product(vec![
LatentManifold::Circle { period: 1.0 },
LatentManifold::Circle { period: 1.0 },
])],
AssignmentMode::softmax(0.5),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let mut rho = SaeManifoldRho::new(0.0, -4.0, vec![Array1::<f64>::zeros(1)]);
let ridge = 1.0e-6;
for _ in 0..10 {
let loss = term
.run_joint_fit_arrow_schur(z.view(), &mut rho, None, 1, 1.0, ridge, ridge)
.unwrap();
if !loss.total().is_finite() {
break;
}
}
let fitted = term.fitted();
assert_eq!(fitted.dim(), (n, p));
let mut sse = 0.0_f64;
for ((row, col), v) in fitted.indexed_iter() {
let r = v - z[[row, col]];
sse += r * r;
}
let r2 = 1.0 - sse / sst.max(1.0e-12);
assert!(
r2 >= 0.5,
"torus atom R² too low: {r2:.4} (sst={sst:.4}, sse={sse:.4})"
);
}
#[test]
fn sae_sphere_atom_recovers_synthetic_signal() {
let n = 96usize;
let p = 3usize;
let d = 2usize;
let mut true_coords = Array2::<f64>::zeros((n, d));
for i in 0..n {
let t = (i as f64) / (n as f64);
true_coords[[i, 0]] = -0.5 + 1.0 * t; true_coords[[i, 1]] = -std::f64::consts::PI + 2.0 * std::f64::consts::PI * t;
}
let mut z = Array2::<f64>::zeros((n, p));
for i in 0..n {
let lat = true_coords[[i, 0]];
let lon = true_coords[[i, 1]];
let x = lat.cos() * lon.cos();
let y = lat.cos() * lon.sin();
let zc = lat.sin();
z[[i, 0]] = x;
z[[i, 1]] = y;
z[[i, 2]] = zc;
}
let sst: f64 = z.iter().map(|v| v * v).sum::<f64>();
let (phi0, jet0) = SphereChartEvaluator.evaluate(true_coords.view()).unwrap();
let m = phi0.ncols();
let mut penalty = Array2::<f64>::eye(m);
penalty *= 1.0e-4;
let atom = SaeManifoldAtom::new(
"sphere_atom",
SaeAtomBasisKind::Sphere,
d,
phi0,
jet0,
Array2::<f64>::zeros((m, p)),
penalty,
)
.unwrap()
.with_basis_evaluator(Arc::new(SphereChartEvaluator));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n, 1)),
vec![true_coords],
vec![LatentManifold::Product(vec![
LatentManifold::Interval {
lo: -std::f64::consts::FRAC_PI_2,
hi: std::f64::consts::FRAC_PI_2,
},
LatentManifold::Circle {
period: std::f64::consts::TAU,
},
])],
AssignmentMode::softmax(0.5),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let mut rho = SaeManifoldRho::new(0.0, -4.0, vec![Array1::<f64>::zeros(1)]);
let ridge = 1.0e-6;
for _ in 0..10 {
let loss = term
.run_joint_fit_arrow_schur(z.view(), &mut rho, None, 1, 1.0, ridge, ridge)
.unwrap();
if !loss.total().is_finite() {
break;
}
}
let fitted = term.fitted();
assert_eq!(fitted.dim(), (n, p));
let mut sse = 0.0_f64;
for ((row, col), v) in fitted.indexed_iter() {
let r = v - z[[row, col]];
sse += r * r;
}
let r2 = 1.0 - sse / sst.max(1.0e-12);
assert!(
r2 >= 0.5,
"sphere atom R² too low: {r2:.4} (sst={sst:.4}, sse={sse:.4})"
);
}
#[test]
fn sae_manifold_fit_10_steps_one_harmonic_reaches_high_r2() {
let n = 64usize;
let m = 3usize;
let p = 1usize;
let true_t: Vec<f64> = (0..n).map(|i| (i as f64) / (n as f64)).collect();
let mut z = Array2::<f64>::zeros((n, p));
for i in 0..n {
let angle = 2.0 * std::f64::consts::PI * true_t[i];
z[[i, 0]] = 0.7 * angle.sin() + 0.3 * angle.cos();
}
let sst: f64 = z.iter().map(|v| v * v).sum::<f64>();
let evaluator = PeriodicHarmonicEvaluator::new(m).unwrap();
let mut coords0_data = Array2::<f64>::zeros((n, 1));
for i in 0..n {
coords0_data[[i, 0]] = (true_t[i] + 0.25).rem_euclid(1.0);
}
let (phi0, jet0) = evaluator.evaluate(coords0_data.view()).unwrap();
let atom = SaeManifoldAtom::new(
"periodic_atom",
SaeAtomBasisKind::Periodic,
1,
phi0,
jet0,
Array2::<f64>::zeros((m, p)),
Array2::<f64>::eye(m),
)
.unwrap()
.with_basis_evaluator(Arc::new(PeriodicHarmonicEvaluator::new(m).unwrap()));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n, 1)),
vec![coords0_data],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(0.5),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let mut rho = SaeManifoldRho::new(0.0, -6.0, vec![Array1::<f64>::zeros(1)]);
let max_iter = 10usize;
let learning_rate = 1.0;
let ridge = 1.0e-6;
let mut prev_total = f64::INFINITY;
for _ in 0..max_iter {
let loss = term
.run_joint_fit_arrow_schur(z.view(), &mut rho, None, 1, learning_rate, ridge, ridge)
.unwrap();
let total = loss.total();
if !total.is_finite() {
break;
}
let denom = prev_total.abs().max(1.0e-12);
let rel = (prev_total - total).abs() / denom;
prev_total = total;
if rel < 1.0e-6 {
break;
}
}
let fitted = term.fitted();
assert_eq!(fitted.dim(), (n, p));
let mut ssr = 0.0;
for i in 0..n {
let r = z[[i, 0]] - fitted[[i, 0]];
ssr += r * r;
}
let r2 = 1.0 - ssr / sst.max(1.0e-12);
assert!(
r2 >= 0.95,
"10-step in-sample R² = {r2:.4} (ssr={ssr:.6}, sst={sst:.6}) should be >= 0.95"
);
}
#[test]
fn softmax_assignment_hessian_diag_is_available_for_k2() {
let n = 4usize;
let k = 2usize;
let logits =
Array2::<f64>::from_shape_fn((n, k), |(i, j)| 0.1 * (i as f64) - 0.2 * (j as f64));
let coords: Vec<Array2<f64>> = (0..k).map(|_| Array2::<f64>::zeros((n, 1))).collect();
let manifolds = vec![LatentManifold::Circle { period: 1.0 }; k];
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits,
coords,
manifolds,
AssignmentMode::softmax(0.7),
)
.unwrap();
let rho = SaeManifoldRho::new(0.0, -6.0, vec![Array1::<f64>::zeros(1); k]);
let (grad, diag) = assignment_prior_grad_hdiag(&assignment, &rho)
.expect("softmax assignment Hessian diagonal must be available");
assert_eq!(grad.len(), n * k);
assert_eq!(diag.len(), n * k);
assert!(grad.iter().all(|v| v.is_finite()));
assert!(diag.iter().all(|v| v.is_finite()));
}
#[test]
fn jumprelu_assignment_prior_hessian_diag_is_psd_over_logit_sweep() {
let n = 6usize;
let k = 2usize;
let temperature = 0.35_f64;
let threshold = 0.1_f64;
let logits = Array2::<f64>::from_shape_vec(
(n, k),
vec![
-2.0, -0.2, 0.0, 0.05, 0.1, 0.15, 0.4, 0.9, 1.5, 2.5, 4.0, 6.0,
],
)
.expect("valid logit grid");
let coords: Vec<Array2<f64>> = (0..k).map(|_| Array2::<f64>::zeros((n, 1))).collect();
let manifolds = vec![LatentManifold::Circle { period: 1.0 }; k];
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits.clone(),
coords,
manifolds,
AssignmentMode::jumprelu(temperature, threshold),
)
.expect("valid JumpReLU assignment");
let rho = SaeManifoldRho::new(0.7_f64.ln(), -6.0, vec![Array1::<f64>::zeros(1); k]);
let (grad, diag) = assignment_prior_grad_hdiag(&assignment, &rho)
.expect("JumpReLU assignment prior hessian diag");
let inv_tau = 1.0 / temperature;
let inv_tau2 = inv_tau * inv_tau;
let sparsity_strength = rho.log_lambda_sparse.exp();
assert_eq!(grad.len(), n * k);
assert_eq!(diag.len(), n * k);
for (idx, &entry) in diag.iter().enumerate() {
let logit = logits[[idx / k, idx % k]];
let expected = if logit > threshold {
let activation = crate::linalg::utils::stable_logistic(logit * inv_tau);
let slope = activation * (1.0 - activation);
sparsity_strength * slope * slope * inv_tau2
} else {
0.0
};
assert!(
entry.is_finite() && entry >= 0.0,
"JumpReLU gated hessian_diag majorizer must be finite and PSD at index {idx}; entry={entry}"
);
assert_abs_diff_eq!(entry, expected, epsilon = 1e-12);
}
}
#[test]
fn ibp_map_k2_periodic_torus_recovers_signal_with_lsq_init() {
use crate::linalg::faer_ndarray::{FaerCholesky, fast_ata, fast_atb};
use faer::Side as FaerSide;
let n = 200usize;
let p = 8usize;
let k = 2usize;
let m = 5usize;
let mut theta = Array2::<f64>::zeros((n, 2));
for i in 0..n {
theta[[i, 0]] = ((i as f64) * 0.07) % 1.0;
theta[[i, 1]] = ((i as f64) * 0.13 + 0.31) % 1.0;
}
let mut raw = Array2::<f64>::zeros((n, 4));
for i in 0..n {
let a1 = 2.0 * std::f64::consts::PI * theta[[i, 0]];
let a2 = 2.0 * std::f64::consts::PI * theta[[i, 1]];
raw[[i, 0]] = a1.cos();
raw[[i, 1]] = a1.sin();
raw[[i, 2]] = a2.cos();
raw[[i, 3]] = a2.sin();
}
let mix = Array2::<f64>::from_shape_fn((4, p), |(i, j)| {
((i as f64 + 1.0) * 0.37 + (j as f64) * 0.21).sin()
});
let mut z = Array2::<f64>::zeros((n, p));
for i in 0..n {
for j in 0..p {
let mut acc = 0.0;
for r in 0..4 {
acc += raw[[i, r]] * mix[[r, j]];
}
z[[i, j]] = acc;
}
}
let mut col_mean = Array1::<f64>::zeros(p);
for j in 0..p {
let mut acc = 0.0;
for i in 0..n {
acc += z[[i, j]];
}
col_mean[j] = acc / n as f64;
}
for i in 0..n {
for j in 0..p {
z[[i, j]] -= col_mean[j];
}
}
let mut coords_k = vec![Array2::<f64>::zeros((n, 1)); k];
for i in 0..n {
coords_k[0][[i, 0]] = (theta[[i, 0]] + 0.05).rem_euclid(1.0);
coords_k[1][[i, 0]] = (theta[[i, 1]] + 0.07).rem_euclid(1.0);
}
let evaluator = PeriodicHarmonicEvaluator::new(m).unwrap();
let mut phi_k = Vec::with_capacity(k);
let mut jet_k = Vec::with_capacity(k);
for atom_idx in 0..k {
let (phi, jet) = evaluator.evaluate(coords_k[atom_idx].view()).unwrap();
phi_k.push(phi);
jet_k.push(jet);
}
let m_total = k * m;
let mut x = Array2::<f64>::zeros((n, m_total));
for atom_idx in 0..k {
for i in 0..n {
for col in 0..m {
x[[i, atom_idx * m + col]] = 0.5 * phi_k[atom_idx][[i, col]];
}
}
}
let mut xtx = fast_ata(&x);
let mut trace = 0.0_f64;
for i in 0..m_total {
trace += xtx[[i, i]];
}
let jitter = (trace / m_total as f64).max(1.0) * 1.0e-8;
for i in 0..m_total {
xtx[[i, i]] += jitter;
}
let xtz = fast_atb(&x, &z);
let b_joint = xtx
.cholesky(FaerSide::Lower)
.expect("LSQ Cholesky")
.solve_mat(&xtz);
let mut atoms = Vec::with_capacity(k);
for atom_idx in 0..k {
let mut b = Array2::<f64>::zeros((m, p));
for col in 0..m {
for j in 0..p {
b[[col, j]] = b_joint[[atom_idx * m + col, j]];
}
}
let atom = SaeManifoldAtom::new(
format!("torus_atom_{atom_idx}"),
SaeAtomBasisKind::Periodic,
1,
phi_k[atom_idx].clone(),
jet_k[atom_idx].clone(),
b,
Array2::<f64>::eye(m),
)
.unwrap()
.with_basis_evaluator(Arc::new(PeriodicHarmonicEvaluator::new(m).unwrap()));
atoms.push(atom);
}
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n, k)),
coords_k,
vec![LatentManifold::Circle { period: 1.0 }; k],
AssignmentMode::ibp_map(0.7, 1.0, false),
)
.unwrap();
let mut term = SaeManifoldTerm::new(atoms, assignment).unwrap();
let mut rho = SaeManifoldRho::new(0.0, -6.0, vec![Array1::<f64>::zeros(1); k]);
let mut prev_total = f64::INFINITY;
for _ in 0..30 {
let loss = term
.run_joint_fit_arrow_schur(z.view(), &mut rho, None, 1, 1.0, 1.0e-6, 1.0e-6)
.unwrap();
let total = loss.total();
if !total.is_finite() {
break;
}
let denom = prev_total.abs().max(1.0e-12);
let rel = (prev_total - total).abs() / denom;
prev_total = total;
if rel < 1.0e-6 {
break;
}
}
let fitted = term.fitted();
let mut ssr = 0.0;
let mut sst = 0.0;
for i in 0..n {
for j in 0..p {
let r = z[[i, j]] - fitted[[i, j]];
ssr += r * r;
sst += z[[i, j]] * z[[i, j]];
}
}
let r2 = 1.0 - ssr / sst.max(1.0e-12);
assert!(
r2 > 0.5,
"K=2 periodic torus IBP-MAP R² = {r2:.4} (ssr={ssr:.4}, sst={sst:.4}) should be > 0.5 with LSQ-seeded decoder"
);
let assignments = term.assignment.assignments();
let mean_active: f64 = assignments.iter().copied().sum::<f64>() / (n as f64);
assert!(
mean_active > 0.2,
"mean active mass across rows = {mean_active:.4} should exceed 0.2; assignment did not collapse"
);
}
#[test]
fn softmax_k2_periodic_completes_joint_fit_step() {
let n = 64usize;
let p = 4usize;
let k = 2usize;
let m = 3usize;
let mut z = Array2::<f64>::zeros((n, p));
for i in 0..n {
let a = 2.0 * std::f64::consts::PI * (i as f64) / (n as f64);
z[[i, 0]] = a.sin();
z[[i, 1]] = a.cos();
z[[i, 2]] = (2.0 * a).sin();
z[[i, 3]] = (2.0 * a).cos();
}
let evaluator = PeriodicHarmonicEvaluator::new(m).unwrap();
let mut coords_k = vec![Array2::<f64>::zeros((n, 1)); k];
for i in 0..n {
coords_k[0][[i, 0]] = (i as f64) / (n as f64);
coords_k[1][[i, 0]] = ((i as f64) * 2.0 / (n as f64)).rem_euclid(1.0);
}
let mut atoms = Vec::new();
for atom_idx in 0..k {
let (phi, jet) = evaluator.evaluate(coords_k[atom_idx].view()).unwrap();
let b = Array2::<f64>::from_shape_fn((m, p), |(i, j)| {
0.1 * ((i as f64 + 1.0) * (j as f64 + 1.0)).sin()
});
let atom = SaeManifoldAtom::new(
format!("a_{atom_idx}"),
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
b,
Array2::<f64>::eye(m),
)
.unwrap()
.with_basis_evaluator(Arc::new(PeriodicHarmonicEvaluator::new(m).unwrap()));
atoms.push(atom);
}
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n, k)),
coords_k,
vec![LatentManifold::Circle { period: 1.0 }; k],
AssignmentMode::softmax(0.7),
)
.unwrap();
let mut term = SaeManifoldTerm::new(atoms, assignment).unwrap();
let mut rho = SaeManifoldRho::new(0.0, -6.0, vec![Array1::<f64>::zeros(1); k]);
let loss0 = term
.run_joint_fit_arrow_schur(z.view(), &mut rho, None, 1, 1.0, 1.0e-6, 1.0e-6)
.expect("softmax K=2 must complete first joint-fit step");
assert!(loss0.total().is_finite());
let loss1 = term
.run_joint_fit_arrow_schur(z.view(), &mut rho, None, 1, 1.0, 1.0e-6, 1.0e-6)
.expect("softmax K=2 must complete second joint-fit step");
assert!(loss1.total().is_finite());
}
fn assert_isometry_wiring_matches_fd(
evaluator: Arc<dyn SaeBasisSecondJet>,
coords: Array2<f64>,
) {
let n_obs = coords.nrows();
let latent_dim = coords.ncols();
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let m = phi.ncols();
let p: usize = 3;
let mut decoder = Array2::<f64>::zeros((m, p));
for i in 0..m {
for j in 0..p {
let x = (i as f64) * 0.371 + (j as f64) * 0.193 + 0.5;
decoder[[i, j]] = (x.sin() * 0.9) + 0.1 * ((i + j) as f64).cos();
}
}
let smooth = Array2::<f64>::eye(m);
let atom = SaeManifoldAtom::new(
"iso_wire_test",
SaeAtomBasisKind::Periodic,
latent_dim,
phi.clone(),
jet.clone(),
decoder.clone(),
smooth,
)
.unwrap()
.with_basis_second_jet(evaluator);
let target_slice = PsiSlice::full(n_obs * latent_dim, Some(latent_dim));
let penalty = IsometryPenalty::new_euclidean(target_slice, p);
let rho = Array1::<f64>::zeros(1);
let target_flat: Array1<f64> = coords.iter().copied().collect();
let v0 = penalty.value(target_flat.view(), rho.view());
assert_eq!(v0, IsometryPenalty::DEFAULT_VALUE_ON_MISSING_CACHE);
let g0 = penalty.grad_target(target_flat.view(), rho.view());
assert!(
g0.iter().all(|x| *x == 0.0),
"grad_target without cache must be all zeros, got {g0:?}"
);
let installed_second =
refresh_isometry_caches_from_atom(&penalty, &atom, coords.view()).unwrap();
assert!(
installed_second,
"evaluator must implement second_jet for this oracle to run"
);
let value = penalty.value(target_flat.view(), rho.view());
assert!(
value > 1.0e-6,
"expected non-trivial isometry loss after cache refresh, got {value}"
);
let grad = penalty.grad_target(target_flat.view(), rho.view());
assert_eq!(grad.len(), target_flat.len());
let max_abs = grad.iter().fold(0.0_f64, |acc, x| acc.max(x.abs()));
assert!(
max_abs > 1.0e-6,
"expected non-zero isometry gradient on at least one component, max |grad|={max_abs}"
);
let h_fd = 1.0e-5;
let probe_idx = 0usize; let mut coords_plus = coords.clone();
coords_plus[[0, 0]] += h_fd;
let mut coords_minus = coords.clone();
coords_minus[[0, 0]] -= h_fd;
refresh_isometry_caches_from_atom(&penalty, &atom, coords_plus.view()).unwrap();
let target_plus: Array1<f64> = coords_plus.iter().copied().collect();
let v_plus = penalty.value(target_plus.view(), rho.view());
refresh_isometry_caches_from_atom(&penalty, &atom, coords_minus.view()).unwrap();
let target_minus: Array1<f64> = coords_minus.iter().copied().collect();
let v_minus = penalty.value(target_minus.view(), rho.view());
refresh_isometry_caches_from_atom(&penalty, &atom, coords.view()).unwrap();
let grad_base = penalty.grad_target(target_flat.view(), rho.view());
let fd = (v_plus - v_minus) / (2.0 * h_fd);
let analytic = grad_base[probe_idx];
assert!(
(analytic - fd).abs() <= 1.0e-3 + 1.0e-4 * analytic.abs().max(fd.abs()),
"isometry grad/FD mismatch at coord 0: analytic={analytic:.6e}, fd={fd:.6e}"
);
}
#[test]
fn isometry_wiring_periodic_matches_fd() {
assert_isometry_wiring_matches_fd(
Arc::new(PeriodicHarmonicEvaluator::new(5).unwrap()),
array![[0.12], [0.37], [0.58], [0.81]],
);
}
#[test]
fn isometry_wiring_sphere_matches_fd() {
assert_isometry_wiring_matches_fd(
Arc::new(SphereChartEvaluator),
array![[-0.5, 0.3], [0.2, -1.1], [0.7, 0.9]],
);
}
#[test]
fn isometry_wiring_torus_matches_fd() {
assert_isometry_wiring_matches_fd(
Arc::new(TorusHarmonicEvaluator::new(2, 2).unwrap()),
array![[0.13, 0.42], [0.66, 0.19], [0.88, 0.55]],
);
}
#[test]
fn refresh_isometry_caches_pairs_each_penalty_to_its_own_atom() {
let latent_dim = 1usize;
let p_out = 3usize;
let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(5).unwrap());
let coords0 = array![[0.05], [0.20], [0.55], [0.80]];
let coords1 = array![[0.13], [0.41], [0.62], [0.91]];
let build_atom = |name: &str, coords: &Array2<f64>, seed: f64| {
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let m = phi.ncols();
let mut decoder = Array2::<f64>::zeros((m, p_out));
for i in 0..m {
for j in 0..p_out {
let x = (i as f64) * 0.371 + (j as f64) * 0.193 + seed;
decoder[[i, j]] = (x.sin() * 0.9) + 0.1 * ((i + j) as f64).cos();
}
}
let smooth = Array2::<f64>::eye(m);
SaeManifoldAtom::new(
name,
SaeAtomBasisKind::Periodic,
latent_dim,
phi,
jet,
decoder,
smooth,
)
.unwrap()
.with_basis_second_jet(evaluator.clone() as Arc<dyn SaeBasisSecondJet>)
};
let atom0 = build_atom("atom0", &coords0, 0.5);
let atom1 = build_atom("atom1", &coords1, 1.7);
let slice0 = PsiSlice::full(coords0.nrows() * latent_dim, Some(latent_dim));
let control0 = IsometryPenalty::new_euclidean(slice0, p_out);
refresh_isometry_caches_from_atom(&control0, &atom0, coords0.view()).unwrap();
let expected0 = control0
.jacobian_cache()
.expect("control penalty 0 must have a Jacobian cache");
let slice1 = PsiSlice::full(coords1.nrows() * latent_dim, Some(latent_dim));
let control1 = IsometryPenalty::new_euclidean(slice1, p_out);
refresh_isometry_caches_from_atom(&control1, &atom1, coords1.view()).unwrap();
let expected1 = control1
.jacobian_cache()
.expect("control penalty 1 must have a Jacobian cache");
assert_ne!(
*expected0, *expected1,
"atom 0 and atom 1 must produce distinct Jacobian caches"
);
let logits = Array2::<f64>::zeros((coords0.nrows(), 2));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits,
vec![coords0.clone(), coords1.clone()],
vec![
LatentManifold::Circle { period: 1.0 },
LatentManifold::Circle { period: 1.0 },
],
AssignmentMode::ibp_map(0.7, 1.0, true),
)
.unwrap();
let term = SaeManifoldTerm::new(vec![atom0, atom1], assignment).unwrap();
let mut registry = AnalyticPenaltyRegistry::new();
let pslice0 = PsiSlice::full(coords0.nrows() * latent_dim, Some(latent_dim));
let pslice1 = PsiSlice::full(coords1.nrows() * latent_dim, Some(latent_dim));
registry.push(AnalyticPenaltyKind::Isometry(Arc::new(
IsometryPenalty::new_euclidean(pslice0, p_out),
)));
registry.push(AnalyticPenaltyKind::Isometry(Arc::new(
IsometryPenalty::new_euclidean(pslice1, p_out),
)));
let coords_per_atom = vec![coords0.clone(), coords1.clone()];
let refreshed =
refresh_isometry_caches_from_term(®istry, &term, &coords_per_atom).unwrap();
assert_eq!(refreshed, 2, "both penalties should install second caches");
let cache0 = match ®istry.penalties[0] {
AnalyticPenaltyKind::Isometry(p) => p
.jacobian_cache()
.expect("penalty 0 cache must be populated"),
_ => panic!("expected isometry penalty at index 0"),
};
let cache1 = match ®istry.penalties[1] {
AnalyticPenaltyKind::Isometry(p) => p
.jacobian_cache()
.expect("penalty 1 cache must be populated"),
_ => panic!("expected isometry penalty at index 1"),
};
assert_eq!(
*cache0, *expected0,
"penalty 0 must be refreshed against atom 0"
);
assert_eq!(
*cache1, *expected1,
"penalty 1 must be refreshed against atom 1 (regression: old find() paired it to atom 0)"
);
assert_ne!(
*cache0, *cache1,
"the two penalties must not collapse onto the same atom"
);
}
}